import { useMemo, type ComponentPropsWithoutRef } from "react";
import {
  Area,
  ComposedChart,
  Line,
  ResponsiveContainer,
  Tooltip,
  XAxis,
  YAxis,
} from "recharts";
import type { AxisDomain, BaseAxisProps } from "recharts/types/util/types";
import { match } from "ts-pattern";
import IconTooltip from "../icon-tooltip";
import { CHART_HEIGHT, CHART_MARGIN } from "./constants";
import { EmptyState } from "./empty-state";
import type { TimeDisplayPreference } from "./metrics";
import { formatXAxisTick } from "./util";

export type GpuMemoryChartDatum = {
  x: Date;
  p50: number;
  p100: number;
  max_total: number;
  min_total: number;
  count?: number;
};

export interface GpuMemoryAggregateStats {
  p50: number;
  max_total: number;
  min_total: number;
}

function GpuMemoryAggregateStatsItem({
  aggregateStats,
}: { aggregateStats: GpuMemoryAggregateStats }) {
  const total = [
    ...new Set(
      [aggregateStats.min_total, aggregateStats.max_total].map((p) =>
        (p / 1024).toFixed(2)
      )
    ),
  ].join("-");
  const median = (aggregateStats.p50 / 1024).toFixed(2);
  const medianPercentage = (
    (aggregateStats.p50 / aggregateStats.max_total) *
    100
  ).toLocaleString("en-US", {
    minimumFractionDigits: 0,
    maximumFractionDigits: 2,
  });

  return (
    <dl className="flex items-center divide-x divide-r8-gray-6 *:px-4 first:*:pl-0 last:*:pr-0 mt-3 mb-2">
      <div className="space-y-0.5">
        <dt className="text-r8-xs uppercase tracking-wide text-r8-gray-11 space-x-1 flex items-center">
          <span>Median used</span>
          <IconTooltip
            tooltipText="This is the median GPU memory for all instances used over the visible time period"
            icon="Info"
            weight="regular"
          />
        </dt>
        <dd className="font-semibold flex gap-2 items-center">
          <span className="block size-3 rounded-full bg-r8-green-10" />
          <span className="text-r8-gray-12">
            {median}GB
            {!Number.isNaN(Number(medianPercentage)) && (
              <>
                {" "}
                <span className="text-r8-sm text-r8-gray-11">
                  ({medianPercentage}%)
                </span>
              </>
            )}
          </span>
        </dd>
      </div>
      <div className="space-y-0.5">
        <dt className="text-r8-xs uppercase tracking-wide text-r8-gray-11">
          Total
        </dt>
        <dd className="font-semibold flex gap-2 items-center">
          <span className="block size-3 rounded-full border border-dashed border-r8-gray-12" />
          <span>{total}GB</span>
        </dd>
      </div>
    </dl>
  );
}

export function GpuMemoryChart({
  data,
  aggregateStats,
  displayEmptyState,
  error,
  loading,
  noDataReason,
  syncId,
  yTickFormatter,
  timeDisplay = "utc",
}: {
  data: GpuMemoryChartDatum[];
  aggregateStats: GpuMemoryAggregateStats | undefined;
  displayEmptyState: boolean;
  error: boolean;
  loading: boolean;
  noDataReason: string;
  syncId: string;
  yDomain?: AxisDomain;
  yTickFormatter?: BaseAxisProps["tickFormatter"];
  timeDisplay?: TimeDisplayPreference;
}) {
  if (error) {
    return <EmptyState error noDataReason={noDataReason} />;
  }

  if (displayEmptyState) {
    return <EmptyState loading={loading} noDataReason={noDataReason} />;
  }

  const hasCount = useMemo(() => data.some((d) => d.count != null), [data]);

  return (
    <div className="relative">
      {aggregateStats ? (
        <GpuMemoryAggregateStatsItem aggregateStats={aggregateStats} />
      ) : null}

      <div style={{ height: CHART_HEIGHT * 1.5 }}>
        <ResponsiveContainer width="100%" height="100%">
          <ComposedChart data={data} syncId={syncId} margin={CHART_MARGIN}>
            <defs>
              <linearGradient id="primary-gradient" x1="0" y1="0" x2="0" y2="1">
                <stop offset="0%" stopColor="var(--gray-8)" stopOpacity={1} />
                <stop offset="100%" stopColor="var(--gray-1)" stopOpacity={1} />
              </linearGradient>
              <pattern
                id="striped-gradient"
                patternUnits="userSpaceOnUse"
                width="8"
                height="8"
              >
                <path
                  d="M-2,2 l4,-4
                         M0,8 l8,-8
                         M6,10 l4,-4"
                  style={{
                    stroke: "var(--green-10)",
                    strokeWidth: 3,
                  }}
                />
              </pattern>
            </defs>
            <XAxis
              dataKey="x"
              interval="equidistantPreserveStart"
              className="text-xs font-sans"
              tickFormatter={(tick: Date) => {
                return formatXAxisTick(tick, timeDisplay) ?? "";
              }}
              tickMargin={8}
            />
            <YAxis
              allowDecimals={false}
              dataKey="max_total"
              tickFormatter={yTickFormatter}
              tickLine={false}
              tickMargin={0}
              unit="GB"
              style={{
                fontSize: "0.7rem",
                fontFamily: "jetbrains-mono",
              }}
              width={48}
            />
            <Tooltip
              content={(props) => (
                <GpuMemoryMetricTooltip {...props} timeDisplay={timeDisplay} />
              )}
              isAnimationActive={false}
              cursor={{
                stroke: "var(--gray-12)",
              }}
            />
            <Area
              unit="GB"
              isAnimationActive={false}
              dot={false}
              type="step"
              dataKey="p100"
              activeDot={{
                fill: "var(--green-8)",
                stroke: "var(--green-9)",
                fillOpacity: 1,
                strokeWidth: 1,
              }}
              fill="url(#striped-gradient)"
              stroke="var(--green-8)"
            />
            <Area
              unit="GB"
              isAnimationActive={false}
              dot={false}
              type="step"
              dataKey="p50"
              activeDot={{
                fill: "var(--green-10)",
                stroke: "var(--green-5)",
                fillOpacity: 1,
                strokeWidth: 1,
              }}
              fill="var(--green-10)"
              fillOpacity="1"
              stroke="var(--green-2)"
            />
            <Line
              unit="GB"
              isAnimationActive={false}
              dot={false}
              type="step"
              activeDot={{
                fill: "var(--gray-12)",
                stroke: "0",
              }}
              dataKey="max_total"
              stroke="var(--gray-12)"
              fill="var(--gray-12)"
              strokeDasharray={4}
              strokeWidth={1.5}
            />
            {hasCount && (
              <Line
                isAnimationActive={false}
                dot={false}
                type="step"
                fill="var(--gray-12)"
                activeDot={{
                  stroke: "0",
                }}
                dataKey="count"
                stroke="transparent"
              />
            )}
          </ComposedChart>
        </ResponsiveContainer>
      </div>
    </div>
  );
}

function GpuMemoryTooltipLineItem({
  name,
  payload,
  type,
  total,
}: {
  name: string;
  payload: any;
  type: "line" | "area" | "invisible";
  total: number | number[];
}) {
  const color = payload.fill?.startsWith("url") ? payload.color : payload.fill;
  const hatching = payload.fill?.startsWith("url");

  return (
    <div className="text-xs flex gap-6 items-center justify-between">
      <div className="flex gap-2 items-center">
        {match(type)
          .with("line", () => (
            <div
              className="size-3 rounded-full flex-shrink-0 border border-dashed"
              style={{
                borderColor: color,
              }}
            />
          ))
          .with("area", () => (
            <div
              className="size-3 rounded-full flex-shrink-0 relative overflow-hidden"
              style={{
                backgroundColor: color,
              }}
            >
              <div
                className="absolute inset-0"
                style={{
                  backgroundColor: color,
                }}
              />
              {hatching ? (
                <div
                  className="absolute inset-0"
                  style={{
                    backgroundImage:
                      "repeating-linear-gradient(-45deg, var(--gray-1), var(--gray-1) 2px, transparent 2px, transparent 4px)",
                  }}
                />
              ) : null}
            </div>
          ))
          .with("invisible", () => (
            <div className="size-3 rounded-full flex-shrink-0" />
          ))
          .exhaustive()}
        <div className="first-letter:uppercase w-full" style={{ color }}>
          {name}
        </div>
      </div>
      <div className="text-right flex space-x-0.5">
        <span className="whitespace-nowrap">
          {Array.isArray(total) &&
          total.length > 1 &&
          !total.every((t) => t === total[0])
            ? total.map((t) => t.toFixed(0)).join("-")
            : Array.isArray(total)
              ? total[0].toFixed(2)
              : total.toFixed(2)}
        </span>
        {payload.unit ? (
          <span className="text-r8-gray-11">{payload.unit}</span>
        ) : null}
      </div>
    </div>
  );
}

export function GpuMemoryMetricTooltip({
  active,
  payload,
  label,
  timeDisplay = "utc",
}: ComponentPropsWithoutRef<typeof Tooltip> & {
  timeDisplay?: TimeDisplayPreference;
}) {
  if (active && payload && payload.length) {
    const formattedDateLabel = useMemo(() => {
      if (!label) return null;
      return formatXAxisTick(label as Date, timeDisplay) ?? "";
    }, [label, timeDisplay]);

    if (!formattedDateLabel) return null;

    const p50 = payload.find((p) => p.dataKey === "p50");
    const p100 = payload.find((p) => p.dataKey === "p100");
    const maxTotal = payload.find((p) => p.dataKey === "max_total");
    const minTotal = payload.find((p) => p.dataKey === "min_total");
    const count = payload.find((p) => p.dataKey === "count");

    return (
      <div className="max-w-[12rem] border border-r8-gray-12 bg-white dark:bg-r8-gray-1 px-2 pb-2 pt-1 space-y-1">
        <span className="text-r8-gray-11 text-xs">{formattedDateLabel}</span>
        {maxTotal ? (
          <GpuMemoryTooltipLineItem
            name="Total"
            type="line"
            payload={maxTotal}
            total={
              minTotal
                ? [Number(minTotal.value), Number(maxTotal.value)]
                : Number(maxTotal.value)
            }
          />
        ) : null}
        {p100 ? (
          <GpuMemoryTooltipLineItem
            name="Max used"
            type="area"
            payload={p100}
            total={Number(p100.value)}
          />
        ) : null}
        {p50 ? (
          <GpuMemoryTooltipLineItem
            name="Median used"
            type="area"
            payload={p50}
            total={Number(p50.value)}
          />
        ) : null}
        {count ? (
          <GpuMemoryTooltipLineItem
            name="Data count"
            type="invisible"
            payload={count}
            total={Number(count.value)}
          />
        ) : null}
      </div>
    );
  }

  return null;
}
