import {Node} from "@reactflow/core/dist/esm/types";
import {XYPosition} from "@reactflow/core/dist/esm/types/utils";
import {Position, Rect} from "reactflow";
import React from "react";
import {NodeInfo} from "./types";

function getNodeIntersection<T extends Node>(intersectionNode: T, targetNode: T): XYPosition {
    const {
        width: intersectionNodeWidth,
        height: intersectionNodeHeight,
        positionAbsolute: intersectionNodePosition,
    } = intersectionNode;
    const targetPosition = targetNode.positionAbsolute as XYPosition;

    const w = (intersectionNodeWidth ?? 0) / 2;
    const h = (intersectionNodeHeight ?? 0) / 2;

    const x2 = (intersectionNodePosition?.x ?? 0) + w;
    const y2 = (intersectionNodePosition?.y ?? 0) + h;
    const x1 = targetPosition.x + w;
    const y1 = targetPosition.y + h;

    const xx1 = (x1 - x2) / (2 * w) - (y1 - y2) / (2 * h);
    const yy1 = (x1 - x2) / (2 * w) + (y1 - y2) / (2 * h);
    const a = 1 / (Math.abs(xx1) + Math.abs(yy1));
    const xx3 = a * xx1;
    const yy3 = a * yy1;
    const x = w * (xx3 + yy3) + x2;
    const y = h * (-xx3 + yy3) + y2;

    return {x, y};
}

function getEdgePosition<T extends Node>(node: T, intersectionPoint: XYPosition) {
    const n = {...node.positionAbsolute, ...node} as Rect;
    const nx = Math.round(n.x);
    const ny = Math.round(n.y);
    const px = Math.round(intersectionPoint.x);
    const py = Math.round(intersectionPoint.y);

    if (px <= nx + 1) {
        return Position.Left;
    }
    if (px >= nx + n.width - 1) {
        return Position.Right;
    }
    if (py <= ny + 1) {
        return Position.Top;
    }
    if (py >= n.y + n.height - 1) {
        return Position.Bottom;
    }

    return Position.Top;
}

export function getEdgeParams<T extends Node>(source: T, target: T) {
    const sourceIntersectionPoint = getNodeIntersection(source, target);
    const targetIntersectionPoint = getNodeIntersection(target, source);

    const sourcePos = getEdgePosition(source, sourceIntersectionPoint);
    const targetPos = getEdgePosition(target, targetIntersectionPoint);

    return {
        sx: sourceIntersectionPoint.x,
        sy: sourceIntersectionPoint.y,
        tx: targetIntersectionPoint.x,
        ty: targetIntersectionPoint.y,
        sourcePos,
        targetPos,
    };
}

export function getStylePropsForConnection(targetNode: Node<NodeInfo>, sourceNode: Node<NodeInfo>): React.CSSProperties {
    switch (targetNode.type) {
        case "selection":
            return {
                strokeWidth: 2,
                stroke: "var(--bs-primary)"
            };
        case "order":
            return {
                strokeWidth: 2,
                stroke: "var(--bs-info)"
            };
        default:
            switch (sourceNode.type) {
                case "column":
                    return {
                        strokeWidth: 2,
                        stroke: "var(--bs-warning)"
                    };
                default:
                    return {
                        strokeWidth: 2,
                        stroke: "var(--bs-danger)"
                    };
            }
    }
}
