import {
  useQuery,
  type QueryKey,
  type UseQueryResult,
} from "@tanstack/react-query";
import { downloadZip } from "client-zip";
import { useCallback, useEffect, useMemo, useState } from "react";
import { useForm } from "react-hook-form";
import toast from "react-hot-toast";
import { P, match } from "ts-pattern";
import type {
  CogInputSchema,
  ErrorWithJSON,
  Model,
  Prediction,
  Version,
} from "../../types";
import { route } from "../../urls";
import { fetchAndRenderMarkdown, getPrediction } from "./api";
import {
  PREDICTION_REFETCH_BACKOFF_FACTOR,
  PREDICTION_REFETCH_BASE_SETUP_INTERVAL,
  PREDICTION_REFETCH_MAX_INTERVAL,
  PREDICTION_REFETCH_PROCESSING_INTERVAL,
  TERMINAL_PREDICTION_STATUSES,
} from "./constants";

export const queryKeys = {
  predictions: {
    uuid: (uuid: string | null) => ["predictions", uuid],
  },
  versions: {
    uuid: (uuid: string) => ["versions", uuid],
  },
};

export function calculatePredictionRefetchSetupInterval(
  runCount: number
): number {
  if (runCount < 2) {
    return 1;
  }

  return Math.min(
    PREDICTION_REFETCH_BASE_SETUP_INTERVAL *
      PREDICTION_REFETCH_BACKOFF_FACTOR ** (runCount - 2),
    PREDICTION_REFETCH_MAX_INTERVAL
  );
}

export function usePrediction({
  uuid,
}: {
  uuid: string | null;
}): UseQueryResult<Prediction, ErrorWithJSON> {
  return useQuery<Prediction, ErrorWithJSON, Prediction>({
    enabled: Boolean(uuid),
    refetchOnWindowFocus(query) {
      const lastKnownStatus = query.state.data?.status;
      if (!lastKnownStatus) return false;
      return !TERMINAL_PREDICTION_STATUSES.includes(lastKnownStatus);
    },
    refetchIntervalInBackground: true,
    refetchInterval: (query) =>
      match(query.state.data?.status)
        .with(
          P.nullish,
          P.union("canceled", "failed", "succeeded"),
          () => false as const
        )
        .with("processing", () => PREDICTION_REFETCH_PROCESSING_INTERVAL)
        .with(P.union("canceling", "starting"), () =>
          calculatePredictionRefetchSetupInterval(query.state.dataUpdateCount)
        )
        .exhaustive(),
    queryKey: queryKeys.predictions.uuid(uuid),
    queryFn: ({ signal }) => getPrediction({ uuid, signal }),
    // By default, initial query data is treated as totally fresh, as if it were just fetched.
    // By setting a staleTime that matches the refetchInterval, we can effectively delay a subsequent
    // refetch until the staleTime has passed, at which point our refetchInterval will kick in
    // and decide whether or not to refetch based on the prediction status.
    staleTime: PREDICTION_REFETCH_PROCESSING_INTERVAL,
  });
}

export function useInputsForm(
  properties: CogInputSchema["properties"],
  initialInputs?: Record<string, any>
) {
  return useForm<any>({
    mode: "onChange",
    defaultValues: Object.keys(properties).reduce((acc, name) => {
      const property = properties[name];
      const hasType = "type" in property;

      const defaultInputs = Object.fromEntries(
        Object.entries(properties)
          .filter(([_, value]) => "default" in value)
          .map(([key, value]) => [key, value.default])
      );

      const values = initialInputs || defaultInputs;

      if (values && name in values) {
        // "anyOf" fields won't have a type, so we'll naïvely just use the initial value.
        if (!hasType) {
          const val = values[name];
          acc[name] = val;
          return acc;
        }

        const propertyType = property.type;

        let val = values[name];

        // We want to make sure we any integers or number initial inputs are cast to numbers.
        if (propertyType === "integer") {
          val = Number.parseInt(val);
        } else if (propertyType === "number") {
          val = Number.parseFloat(val);
        } else if (
          propertyType === "array" &&
          property.items.type === "string" &&
          Array.isArray(val)
        ) {
          val = val.map((v) => ({ value: v }));
        }

        // Initialize secret fields to be blank.
        if (propertyType === "string" && property.format === "password") {
          val = "";
        }

        acc[name] = val;
      } else if ("default" in property) {
        acc[name] = property.default;
      } else if ("enum" in property && property.enum?.length) {
        // To match native select behaviour, if there are any choices, pick the
        // first by default.
        acc[name] = property.enum[0];
      } else {
        acc[name] = null;
      }
      return acc;
    }, {}),
  });
}

function parseExtension(url: string) {
  const parts = url.split(".");
  return parts.length > 1 ? `.${parts[parts.length - 1]}` : "";
}

export function useDownloadOutput(prediction: Prediction) {
  const outputFiles = prediction._extras.output_files;
  const id = prediction.id;

  const stringifiedUrls = useMemo(
    () =>
      JSON.stringify(
        outputFiles.filter((url) => url && typeof url === "string")
      ),
    [outputFiles]
  );

  const [isEmpty, setIsEmpty] = useState(true);
  const [isPreparing, setIsPreparing] = useState(false);

  const download = useCallback(async () => {
    if (isEmpty || isPreparing) {
      return;
    }

    const urls = JSON.parse(stringifiedUrls);

    if (!urls || urls.length === 0) {
      return;
    }

    setIsPreparing(true);

    const prefix = `replicate-prediction-${id}`;
    const anchor = document.createElement("a");

    let blob: Blob;
    let filename: string;

    try {
      if (urls.length === 1) {
        const url = urls[0];
        const res = await fetch(url);
        blob = await res.blob();
        filename = `${prefix}${parseExtension(url)}`;
      } else {
        const filesToZip = await Promise.all(
          urls.map(async (url, i) => ({
            name: `${prefix}-${i}${parseExtension(url)}`,
            input: await fetch(url),
          }))
        );
        const res = downloadZip(filesToZip);
        blob = await res.blob();
        filename = `${prefix}.zip`;
      }

      anchor.href = URL.createObjectURL(blob);
      anchor.download = filename;
      anchor.click();
      anchor.remove();
      setIsPreparing(false);
    } catch (e) {
      setIsPreparing(false);
      toast.error(`Download failed: ${(e as Error).message}`);
    }
  }, [isEmpty, isPreparing, id, stringifiedUrls]);

  useEffect(() => {
    if (!["canceled", "failed", "succeeded"].includes(prediction.status)) {
      setIsEmpty(true);
      return;
    }

    const urls = JSON.parse(stringifiedUrls);

    if (!urls || urls.length === 0) {
      setIsEmpty(true);
      return;
    }

    setIsEmpty(false);
  }, [prediction.status, stringifiedUrls]);

  return {
    isEmpty,
    isPreparing,
    download,
  };
}

export function useRemoteMarkdown({ url }: { url: string }) {
  return useQuery({
    queryKey: ["markdown", url],
    queryFn: () => fetchAndRenderMarkdown({ url }),
    refetchOnWindowFocus: false,
  });
}

export function usePredictionOutputShouldStream(prediction: Prediction | null):
  | {
      predictionOutputShouldStream: true;
      predictionOutputStreamUrl: string;
    }
  | {
      predictionOutputShouldStream: false;
      predictionOutputStreamUrl: null;
    } {
  if (!prediction?.urls.stream) {
    return {
      predictionOutputShouldStream: false,
      predictionOutputStreamUrl: null,
    };
  }

  return {
    predictionOutputShouldStream: true,
    predictionOutputStreamUrl: prediction.urls.stream,
  };
}

interface DockerImage {
  docker_image_name: string;
}

export function useDockerImage(
  version: Version
): UseQueryResult<DockerImage, ErrorWithJSON> {
  return useQueryOnce<DockerImage>({
    purpose: "fetch Docker image",
    queryKey: ["api_version_docker_image", version.id],
    url: route("api_version_docker_image", {
      username: version._extras.model.owner,
      name: version._extras.model.name,
      docker_image_id: version.id,
    }),
  });
}

interface PredictionPrice {
  average_prediction_price: string | null;
}

export function usePredictionPrice(
  version: Version
): UseQueryResult<PredictionPrice, ErrorWithJSON> {
  return useQueryOnce<PredictionPrice>({
    purpose: "fetch version prediction price",
    queryKey: ["api_version_prediction_price", version.id],
    url: route("api_version_prediction_price", {
      username: version._extras.model.owner,
      name: version._extras.model.name,
      docker_image_id: version.id,
    }),
  });
}

export interface Capabilities {
  hotswap: boolean;
  run: boolean;
  stream: boolean;
  train: boolean;
}

export function useVersionCapabilities(
  version: Version
): UseQueryResult<Capabilities, ErrorWithJSON> {
  return useQueryOnce<Capabilities>({
    purpose: "fetch version capabilities",
    queryKey: ["api_version_capabilities", version.id],
    url: route("api_version_capabilities", {
      username: version._extras.model.owner,
      name: version._extras.model.name,
      docker_image_id: version.id,
    }),
  });
}

export function useModelCapabilities(
  model: Model
): UseQueryResult<Capabilities, ErrorWithJSON> {
  return useQueryOnce<Capabilities>({
    purpose: "fetch model capabilities",
    queryKey: ["api_model_capabilities", model.owner, model.name],
    url: route("api_model_capabilities", {
      username: model.owner,
      name: model.name,
    }),
  });
}

interface InputHints {
  trigger_word: string | null;
}

export function useInputHints(
  version: Version
): UseQueryResult<InputHints, ErrorWithJSON> {
  return useQueryOnce<InputHints>({
    purpose: "fetch version input hints",
    queryKey: ["api_version_input_hints", version.id],
    url: route("api_version_input_hints", {
      username: version._extras.model.owner,
      name: version._extras.model.name,
      docker_image_id: version.id,
    }),
  });
}

export function useInputHintsOfficialModel(
  model: Model
): UseQueryResult<InputHints, ErrorWithJSON> {
  return useQueryOnce<InputHints>({
    purpose: "fetch model input hints",
    queryKey: ["api_model_input_hints", model.owner, model.name],
    url: route("api_model_input_hints", {
      username: model.owner,
      name: model.name,
    }),
  });
}

function useQueryOnce<T>({
  purpose,
  queryKey,
  url,
}: {
  purpose: string;
  queryKey: QueryKey;
  url: string;
}): UseQueryResult<T, ErrorWithJSON> {
  return useQuery({
    async queryFn() {
      const res = await fetch(url);

      if (res.ok) {
        return res.json<T>();
      }

      let error = {
        detail: `Failed to ${purpose}`,
        status: res.status,
      };
      try {
        error = {
          ...error,
          ...(await res.json<ErrorWithJSON>()),
        };
      } catch (e) {}

      return Promise.reject(error);
    },
    queryKey,
    refetchOnWindowFocus: false,
  });
}
