import { useQuery } from '@tanstack/react-query';
import ProjectAPI from '@apis/ProjectAPI';
import ProjectModel from '@models/ProjectModel';
import {
  AnalysisType,
  IProjectAnalysisModel,
} from '@models/ProjectAnalysisModel/ProjectAnalysisModelBase';
import { InternalBenchmarkAnalysisModel } from '@models/ProjectAnalysisModel/InternalBenchmarkAnalysisModel';
import { DriversOfOutcomeAnalysisModel } from '@models/ProjectAnalysisModel/DriversOfOutcomeAnalysisModel';
import { BenchmarkDataPoint } from '@components/charts/datamodels/BenchmarkDataPoint';
import { RawVoiceAnalysisModel } from '@models/ProjectAnalysisModel/RawVoiceAnalysisModel';
import { OvertimeComparisonAnalysisModel } from '@models/ProjectAnalysisModel/OvertimeComparisonAnalysisModel';

// This function exists because if the data on the server is too large to send back (due to Lambda Restrictions)
// Then we recieve these chunks and re-compute the z-scores, benchmark, etc on the front-end...
export const recomputeScores = (data: BenchmarkDataPoint[]) => {
  const flatValuesWithNewColumn: any[] = data.map((dataPoint) => ({
    ...dataPoint,
    pos_inc_minus_neg_inc:
      dataPoint.positiveIncidence - dataPoint.negativeIncidence,
  }));

  interface ExtendedBenchmarkDataPoint extends BenchmarkDataPoint {
    pos_inc_minus_neg_incZ: number;
  }

  // Grouping by topic
  const groupedByTopic: Record<string, ExtendedBenchmarkDataPoint[]> =
    flatValuesWithNewColumn.reduce((acc, dataPoint) => {
      (acc[dataPoint.topic] = acc[dataPoint.topic] || []).push(dataPoint);
      return acc;
    }, {});

  // Helper function to calculate mean
  const mean = (values: number[]): number =>
    values.reduce((acc, val) => acc + val, 0) / values.length;

  // Helper function to calculate standard deviation
  const stdDev = (values: number[]): number => {
    const mu = mean(values);
    return Math.sqrt(
      values.reduce((acc, val) => acc + Math.pow(val - mu, 2), 0) /
        values.length
    );
  };

  // Function to calculate and update z-scores for each group
  Object.entries(groupedByTopic).forEach(([topic, group]) => {
    [
      'incidence',
      'sentiment',
      'positiveIncidence',
      'negativeIncidence',
      'pos_inc_minus_neg_inc',
    ].forEach((property) => {
      const values = group.map((dataPoint) => dataPoint[property]);
      const mu = mean(values);
      const sigma = stdDev(values);

      group.forEach((dataPoint) => {
        dataPoint[`${property}Z`] = (dataPoint[property] - mu) / sigma;
      });
    });
  });

  // Rename pos_inc_minus_neg_incZ to benchmark... and delete the original column
  Object.values(groupedByTopic).forEach((group) => {
    group.forEach((dataPoint) => {
      dataPoint.benchmarkScore = dataPoint.pos_inc_minus_neg_incZ;
      delete dataPoint.pos_inc_minus_neg_incZ;
    });
  });

  // Flatten...
  const modifiedData = Object.values(groupedByTopic).reduce(
    (acc, group) => acc.concat(group),
    []
  );

  const formattedData = modifiedData.map((dataPoint) => {
    return new BenchmarkDataPoint(
      dataPoint.benchmarkName,
      dataPoint.populationName,
      dataPoint.topic,
      dataPoint.isTheme,
      dataPoint.topicUserFriendlyName,
      dataPoint.incidence,
      dataPoint.sentiment,
      dataPoint.rowCount,
      dataPoint.vopicCount,
      dataPoint.positiveIncidence,
      dataPoint.negativeIncidence,
      dataPoint.incidenceZ,
      dataPoint.sentimentZ,
      dataPoint.positiveIncidenceZ,
      dataPoint.negativeIncidenceZ,
      dataPoint.benchmarkScore
    );
  });

  return formattedData;
};

export const useBenchmarkData = (
  project?: ProjectModel,
  currentAnalysis?: IProjectAnalysisModel
) => {
  return useQuery({
    queryKey: [
      'benchmark-data',
      project?.projectId,
      project?.organizationId,
      currentAnalysis?.id,
    ],
    queryFn: async () => {
      // TODO - Handle drivers analysis data...
      // Load all of the component data we'll use...
      let analysis:
        | InternalBenchmarkAnalysisModel
        | DriversOfOutcomeAnalysisModel
        | RawVoiceAnalysisModel
        | OvertimeComparisonAnalysisModel;
      if (currentAnalysis.type === AnalysisType.DRIVERS_OF_OUTCOME) {
        analysis = currentAnalysis as DriversOfOutcomeAnalysisModel;
        return ProjectAPI.getBenchmarkData(
          project.projectId,
          null,
          analysis.focalPopulation,
          '*'
        );
      } else if (currentAnalysis.type === AnalysisType.INTERNAL_BENCHMARK) {
        analysis = currentAnalysis as InternalBenchmarkAnalysisModel;
        const data = await ProjectAPI.getBenchmarkData(
          project.projectId,
          analysis.internalBenchmark,
          '*',
          '*'
        );

        return recomputeScores(data);
      } else if (currentAnalysis.type === AnalysisType.RAW_VOICE) {
        analysis = currentAnalysis as RawVoiceAnalysisModel;
        const data = await ProjectAPI.getBenchmarkData(
          project.projectId,
          analysis.populationBenchmark,
          '*',
          '*'
        );

        return recomputeScores(data);
      } else if (currentAnalysis.type === AnalysisType.PROGRESS_OVER_TIME) {
        analysis = currentAnalysis as OvertimeComparisonAnalysisModel;
        const data = await ProjectAPI.getBenchmarkData(
          project.projectId,
          analysis.overtimeBenchmark,
          '*',
          '*'
        );

        return recomputeScores(data);
      }
    },
    enabled: !!project && !!currentAnalysis,
    staleTime: 1000 * 60 * 90, // 90 minutes...
  });
};
