import { TimeMetrics } from "../images/CrossValidationImageKeypointsStats";
import * as d3 from "d3";
import { getCircleStatsColor } from "../../helpers/CrossValidation";
import React from "react";

interface GausianProps {
    timeMetrics: TimeMetrics[];
    width: number;
    height: number;
    margin: { top: number; right: number; bottom: number; left: number };
    key_name: string;
}

export class CrossValidationGaussianPlot extends React.Component<GausianProps> {
    private svgGaussian = React.createRef<SVGSVGElement>();

    async componentDidMount(): Promise<void> {
        this.renderGaussian(this.props.key_name, this.svgGaussian);
    }

    async componentDidUpdate(prevProps: Readonly<GausianProps>): Promise<void> {
        this.renderGaussian(this.props.key_name, this.svgGaussian);
    }

    renderGaussian = (key: string, svgRef: any) => {
        const { timeMetrics, width, height, margin } = this.props;

        // Filter by data depending on required plot information.
        let dataPoints = [];
        if (key !== "fatigue" && key !== "completionTime") return;
        dataPoints = timeMetrics.map((d) => d[key]);

        const mean = d3.mean(dataPoints) || 0;
        const stdDev = d3.deviation(dataPoints) || 1;
        // Remove all plot before rendering the new one
        d3.select(svgRef.current).selectAll("*").remove();
        const svg = d3.select(svgRef.current).attr("width", width).attr("height", height);

        // Estimate scale ranges.
        const xScale = d3
            .scaleLinear()
            .domain([mean - 4 * stdDev, mean + 4 * stdDev])
            .range([margin.left, width - margin.right]);

        const maxYDomain = 1 / (stdDev * Math.sqrt(2 * Math.PI));
        const yScale = d3
            .scaleLinear()
            .domain([0, maxYDomain * 1.5])
            .range([height - margin.bottom, margin.top]);

        // Add axis with its corresponding titles.
        const xAxis = svg
            .append("g")
            .attr("class", "x-axis")
            .attr("transform", `translate(0,${height - margin.bottom})`)
            .call(d3.axisBottom(xScale));

        const yAxis = svg
            .append("g")
            .attr("class", "y-axis")
            .attr("transform", `translate(${margin.left},0)`)
            .call(d3.axisLeft(yScale));

        svg.append("text")
            .attr("class", "x-axis-label")
            .attr("x", width / 2)
            .attr("y", height - margin.bottom + 40)
            .attr("text-anchor", "middle")
            .text("Time (min)");

        svg.append("text")
            .attr("class", "y-axis-label")
            .attr("x", -height / 2)
            .attr("y", margin.left - 40)
            .attr("text-anchor", "middle")
            .attr("transform", "rotate(-90)")
            .text("Probability Density");

        //Add vertical dash line when x=0 for visualisation purposes
        const yAxisZeroLine = svg
            .append("line")
            .attr("class", "vertical-dash-line")
            .attr("x1", xScale(0))
            .attr("x2", xScale(0))
            .attr("y1", margin.top)
            .attr("y2", height - margin.bottom)
            .style("stroke", "black")
            .style("stroke-width", 1)
            .style("stroke-dasharray", "5,5"); // Dash pattern

        //Add zoom
        const zoom = d3
            .zoom()
            .scaleExtent([0.1, 5])
            .on("zoom", function (event) {
                const newX = event.transform.rescaleX(xScale);
                const newY = event.transform.rescaleY(yScale);

                xAxis.call(d3.axisBottom(newX));
                yAxis.call(d3.axisLeft(newY));

                const [xMin, xMax] = newX.domain();
                const [yMin, yMax] = newY.domain();

                // Update the circles according to the new scale
                circles
                    .attr("cx", (d: { completionTime: number; fatigue: number }) => {
                        return newX(d[key]);
                    })
                    .attr("cy", height - margin.bottom);

                const filteredData = data.filter((d) => {
                    return d.x >= xMin && d.x <= xMax && d.y >= yMin && d.y <= yMax;
                });
                path.attr(
                    "d",
                    line(
                        filteredData.map((d) => ({
                            x: d.x,
                            y: d.y,
                        }))
                    )
                ).attr("transform", event.transform);

                // Remove vertical y-axis if its outside the x-y axis boundary area.
                if (newX(0) > xMax) {
                    yAxisZeroLine.attr("x1", newX(0)).attr("x2", newX(0));
                } else {
                    yAxisZeroLine.remove();
                }
            });

        svg.call(zoom as any);

        svg.append("rect")
            .attr("width", width)
            .attr("height", height)
            .style("fill", "none")
            .style("pointer-events", "all")
            .attr("transform", `translate(${margin.left},${margin.top})`)
            .call(zoom as any);

        // Draw gaussian line
        const line = d3
            .line<{ x: number; y: number }>()
            .x((d: any) => xScale(d.x))
            .y((d: any) => yScale(d.y));

        const data = d3.range(mean - 4 * stdDev, mean + 4 * stdDev, 0.1).map((x) => {
            const y = maxYDomain * Math.exp(-0.5 * Math.pow((x - mean) / stdDev, 2));
            return { x, y };
        });

        const path = svg
            .append("path")
            .data([data])
            .attr("d", line)
            .attr("fill", "none")
            .attr("stroke", "red")
            .attr("stroke-width", 2);

        // Add circles with data info.
        const circles = svg
            .selectAll(".circle")
            .data(timeMetrics)
            .enter()
            .append("circle")
            .attr("cx", (d: { completionTime: number; fatigue: number }) => {
                return xScale(d[key]);
            })
            .attr("cy", height - margin.bottom)
            .attr("r", 4)
            .attr("fill", (d) => {
                return getCircleStatsColor(d.annotationType);
            })
            .attr("stroke", "black")
            .attr("stroke-width", 1);

        // Add tooltips
        const tooltipText = svg
            .selectAll("tooltip")
            .data(timeMetrics)
            .enter()
            .append("g")
            .style("opacity", 0)
            .attr("class", "tooltip")
            .attr("id", (d, i) => `tooltipText-${i}`);

        tooltipText
            .append("rect")
            .attr("width", 140)
            .attr("height", 30)
            .attr("x", width - 150)
            .attr("y", 20)
            .attr("rx", 5)
            .attr("ry", 5)
            .attr("fill", "white")
            .attr("stroke", "purple")
            .attr("stroke-width", 1);

        tooltipText
            .append("text")
            .attr("x", width - 80)
            .attr("y", 40)
            .attr("text-anchor", "middle")
            .attr("font-size", "12px")
            .attr("fill", "black")
            .text((d) => {
                const scale = d[key];
                return `${d.user}: ${scale}`;
            });

        // Add circle events
        circles
            .on("mouseover", function (event, d) {
                // On hover, make the circle orange and show tooltip
                d3.select(this).style("fill", "purple");
                d3.select(tooltipText.nodes()[timeMetrics.indexOf(d)]).style("opacity", 1);
            })
            .on("mouseout", function (event, d) {
                const annotationType = d.annotationType;
                const color = getCircleStatsColor(annotationType);
                d3.select(this).style("fill", color); // Reset circle color

                // Hide the text tooltip
                d3.select(tooltipText.nodes()[timeMetrics.indexOf(d)]).style("opacity", 0);
            });
    };

    render() {
        return (
            <div>
                <svg ref={this.svgGaussian} style={{ marginLeft: "50px" }}></svg>
            </div>
        );
    }
}
