import * as d3 from "d3";
import React, {useState} from "react";
import { AxisLeft } from "./AxisLeft";
import { AxisBottom } from "./AxisBottom";
import { Tooltip } from "./Tooltip";
import ZoomInDepth from "./ZoomInDepth";
import {InteractionData, ScatterplotProps} from "./types";

const MARGIN = { top: 60, right: 60, bottom: 60, left: 60 };

export const Scatterplot = ({ width, height, data, filteredData, xAxisLabel = 'X', yAxisLabel = 'Y', zoomInDepthEnabled = false, zoomWithAxisEnabled = false, filterEnabled = false }: ScatterplotProps) => {
    const boundsWidth = width - MARGIN.right - MARGIN.left;
    const boundsHeight = height - MARGIN.top - MARGIN.bottom;

    const [hovered, setHovered] = useState<InteractionData | null>(null);
    const [hoveredIndex, setHoveredIndex] = useState<number | null>(null);

    const currentData = (filterEnabled ? filteredData : data) || [];

    const xMin = d3.min(currentData, d => d.x) ?? 0;
    const xMax = d3.max(currentData, d => d.x) ?? 0;
    const yMin = d3.min(currentData, d => d.y) ?? 0;
    const yMax = d3.max(currentData, d => d.y) ?? 0;

    // const xScale = d3.scaleLinear().domain([xMin, xMax]).range([0, boundsWidth]);
    // const yScale = d3.scaleLinear().domain([yMin, yMax]).range([boundsHeight, 0]);

    const xScale = d3.scaleLinear()
        .domain([xMin, xMax]).nice()
        .range([0, boundsWidth]);
    const yScale = d3.scaleLinear()
        .domain([yMin, yMax]).nice()
        .range([boundsHeight, 0]);

    // Define the xAxis and yAxis functions
    const xAxis = d3.axisBottom(xScale);
    const yAxis = d3.axisLeft(yScale);

    const allGroups = currentData.map((d) => String(d.group));
    const colorScale = d3.scaleOrdinal<string>()
        .domain(allGroups)
        .range(["#e0ac2b", "#e85252", "#6689c6", "#9a6fb0", "#a53253"]);

    const allShapes = currentData.map((d, i) => (
        <circle
            key={i}
            r={hoveredIndex === i ? 10 : 6}
            cx={xScale(d.x)}
            cy={yScale(d.y)}
            stroke={colorScale(d.group)}
            fill={colorScale(d.group)}
            fillOpacity={0.7}
            onMouseEnter={() => {
                setHovered({
                    xPos: xScale(d.x),
                    yPos: yScale(d.y),
                    name: d.subGroup,
                    x: d.x,
                    y: d.y,
                });
                setHoveredIndex(i);
            }}
            onMouseLeave={() => {
                setHovered(null);
                setHoveredIndex(null);
            }}
            style={{ transition: 'all 0.3s ease-in-out' }}
        />
    ));

    function formatSubscriptForSVG(label: string): JSX.Element[] {
        return label.split(/(\d)/).map((char, index) =>
            /\d/.test(char) ? (
                <tspan key={index} dy="5" fontSize="12">
                    {char}
                </tspan>
            ) : (
                <tspan key={index}>{char}</tspan>
            )
        );
    }

    const plotContents = (
        <g transform={`translate(${MARGIN.left},${MARGIN.top})`}>
            <g transform={`translate(0, ${boundsHeight})`} className="stroke-black">
                <AxisBottom xScale={xScale} pixelsPerTick={40} height={boundsHeight} />
                <line
                    x1={0}
                    x2={boundsWidth}
                    y1={0}
                    y2={0}
                    className="stroke-black"
                    strokeWidth={2}
                />
                <text
                    x={boundsWidth / 2}
                    y={MARGIN.bottom / 1.2}
                    textAnchor="middle"
                    style={{ fontSize: "16px", fill: "#666" }}
                >
                    {xAxisLabel}
                </text>
            </g>
            <g className="stroke-black">
                <AxisLeft yScale={yScale} pixelsPerTick={40} width={boundsWidth} />
                <line
                    x1={0}
                    x2={0}
                    y1={0}
                    y2={boundsHeight}
                    className="stroke-gray-800"
                    strokeWidth={2}
                />
                <text
                    transform={`translate(${-MARGIN.left * 0.75},${boundsHeight / 2}) rotate(-90)`}
                    textAnchor="middle"
                    style={{ fontSize: "16px", fill: "#666" }}
                >
                    {yAxisLabel}
                </text>
            </g>
            {allShapes}
        </g>
    );

    return (
        <div style={{ position: "relative", margin: "2rem" }}>
            {zoomInDepthEnabled ? (
                <ZoomInDepth width={width} height={height}  margin={MARGIN}>
                    {plotContents}
                </ZoomInDepth>
            ) : (
                <svg width={width} height={height}>
                    {plotContents}
                </svg>
            )}
            <div
                style={{
                    width: boundsWidth,
                    height: boundsHeight,
                    position: "absolute",
                    top: 0,
                    left: 0,
                    pointerEvents: "none",
                    marginLeft: MARGIN.left,
                    marginTop: MARGIN.top,
                }}
            >
                <Tooltip interactionData={hovered} />
            </div>
        </div>
    );
};
