import {
  AgCartesianAxisOptions,
  AgCartesianChartOptions,
  AgChartLegendPosition,
  AgWaterfallSeriesOptions,
  AgWaterfallSeriesTooltipRendererParams,
  WaterfallSeriesTotalMeta,
} from 'ag-charts-community';
import { DateTime } from 'luxon';
import React, { useMemo } from 'react';

import AgChart from 'components/AgGridComponents/AgChart/AgChart';
import { renderTooltip } from 'components/AgGridComponents/AgChart/agChartTooltips';
import { useGroupColors } from 'components/CustomizeDriverChartsBlock/hooks';
import theme from 'config/theme';
import {
  ChartAxisType,
  ChartDisplay,
  ChartElementPosition,
  ChartSize,
  DriverFormat,
} from 'generated/graphql';
import { isNotNull, safeObjGet } from 'helpers/typescript';
import useAppSelector from 'hooks/useAppSelector';
import useBlockContext from 'hooks/useBlockContext';
import { Attribute } from 'reduxStore/models/dimensions';
import { DriverId } from 'reduxStore/models/drivers';
import { DEFAULT_DISPLAY_CONFIGURATION, DisplayConfiguration } from 'reduxStore/models/value';
import { layersSelector } from 'selectors/layerSelector';

import {
  CHART_TOOLTIP_CLASSNAME,
  CURRENCY_FORMAT,
  GREEN_1,
  NUMBER_FORMAT,
  PERCENT_FORMAT,
  RED_1,
} from './agCharts';
import { useChartDriverAggregatedData, useChartDriverConfig } from './chartDriverHooks';

const { fonts } = theme;
const EMPTY_DATA: WaterfallData[] = [];

interface WaterfallData {
  id: string;
  label: string;
  value: number;
}

interface WaterfallDatum extends WaterfallData {
  axisLabel: string;
  index: number;
  totalType: 'subtotal' | 'total';
}

type WaterfallSeriesTotalMetaWithDriverIds = WaterfallSeriesTotalMeta & { driverIds: DriverId[] };

interface AgWaterfallChartProps {
  driverIds: DriverId[];
  chartDisplay: ChartDisplay;
  size?: ChartSize;
  stacked?: boolean;
  chartIndex?: number;
}

function useChartDriverWaterfallTimeSeriesData(
  driverIds: DriverId[],
  [start, end]: [DateTime, DateTime],
  chartDisplay: ChartDisplay,
  attributesBySubDriverId: Record<string, Attribute[]>,
  driverNamesById: Record<string, string>,
) {
  const {
    isAnyDriverLoading,
    currentLayerId,
    data: ungroupedData,
  } = useChartDriverAggregatedData(driverIds, [start, end], chartDisplay);

  const data = useMemo((): WaterfallData[] => {
    if (isAnyDriverLoading || !ungroupedData) {
      return EMPTY_DATA;
    }

    return chartDisplay.groups.flatMap((group) => {
      const seriesForGroup = chartDisplay.series.filter((series) => {
        return group.seriesIds.includes(series.id);
      });

      const groupedData = ungroupedData.filter((driver) => {
        return seriesForGroup.some((series) => series.driverId === driver.id);
      });

      return groupedData.map(({ id: driverId, value }) => {
        const attributes = safeObjGet(attributesBySubDriverId[driverId]) ?? [];
        const label = attributes.reduce(
          (prev, attr) => `${prev} (${attr.value})`,
          driverNamesById[driverId],
        );
        return {
          id: driverId,
          label,
          value: Math.abs(value ?? 0) * (group.isPositive === true ? 1 : -1),
        };
      });
    });
  }, [
    attributesBySubDriverId,
    chartDisplay.groups,
    chartDisplay.series,
    driverNamesById,
    isAnyDriverLoading,
    ungroupedData,
  ]);

  return {
    currentLayerId,
    data,
    isAnyDriverLoading,
  };
}

const AgWaterfallChart: React.FC<AgWaterfallChartProps> = ({ driverIds, chartDisplay }) => {
  const { blockId } = useBlockContext();
  const layersById = useAppSelector(layersSelector);
  const colorsByGroupId = useGroupColors(chartDisplay);
  const {
    attributesBySubDriverId,
    dateRange,
    driverDisplayConfigurationsById,
    driverNamesById,
    height,
    totalDisplayConfiguration,
    width,
  } = useChartDriverConfig(blockId, driverIds, chartDisplay);
  const { isAnyDriverLoading, data } = useChartDriverWaterfallTimeSeriesData(
    driverIds,
    dateRange,
    chartDisplay,
    attributesBySubDriverId,
    driverNamesById,
  );

  /**
   * AG Chart totals are used to display the subtotal and total values for the waterfall chart.
   *
   * The index indicates the subtotal/total position within the data using a 0-based index.
   * As such, for groups that have the showSubtotal feature enabled, this index needs to be
   * cumulative for all groups and associated series. You can read more about this functionality
   * in the AG Charts documentation at:
   * https://www.ag-grid.com/charts/react/waterfall-series/#total--subtotal-values
   */
  const totals = useMemo<WaterfallSeriesTotalMetaWithDriverIds[]>(() => {
    const totalMetas = chartDisplay.groups
      .filter((group) => group.showSubtotal)
      .map(
        (group, index, groups): WaterfallSeriesTotalMetaWithDriverIds => ({
          totalType: 'subtotal',
          axisLabel: group.name ?? `Group ${index + 1} Subtotal`,
          index:
            groups.slice(0, index).reduce((prev, g) => prev + g.seriesIds.length, 0) +
            group.seriesIds.length -
            1,
          // TODO: we can't reliably match on driver id.
          driverIds: group.seriesIds
            .map((seriesId) => chartDisplay.series.find((s) => s.id === seriesId)?.driverId)
            .filter(isNotNull),
        }),
      );
    const driverAxis = chartDisplay.axes.find((axis) => axis.type === ChartAxisType.Driver);
    if (driverAxis?.driver?.showTotals) {
      totalMetas.push({
        totalType: 'total',
        axisLabel: 'Total',
        index: data.length - 1,
        // TODO: we can't reliably match on driver id.
        driverIds: totalMetas.flatMap((t) => t.driverIds),
      });
    }
    return totalMetas;
  }, [chartDisplay.axes, chartDisplay.groups, chartDisplay.series, data.length]);

  const totalsByIndex = useMemo<
    Record<number, { value: number; displayConfiguration: DisplayConfiguration }>
  >(() => {
    const out: Record<number, { value: number; displayConfiguration: DisplayConfiguration }> = {};

    for (const total of totals) {
      // TODO: we can't reliably match on driver id.
      const items = total.driverIds
        .map((driverId) => data.find((d) => d.id === driverId))
        .filter(isNotNull);
      const sum = items.reduce((s, item) => s + item.value, 0);

      const driverId = total.driverIds.find((id) => driverDisplayConfigurationsById[id] != null);
      const displayConfiguration =
        driverId != null
          ? driverDisplayConfigurationsById[driverId]
          : DEFAULT_DISPLAY_CONFIGURATION;

      out[total.index] = {
        value: sum,
        displayConfiguration,
      };
    }

    return out;
  }, [data, driverDisplayConfigurationsById, totals]);

  const series: AgWaterfallSeriesOptions[] = useMemo(() => {
    const tooltipRenderer = (params: AgWaterfallSeriesTooltipRendererParams<WaterfallDatum>) => {
      const displayConfiguration =
        driverDisplayConfigurationsById[params.datum.id] ?? DEFAULT_DISPLAY_CONFIGURATION;

      // There is currently not an efficient method to render a tooltip for the subtotal or total in AG Charts.
      // Support ticket has been created and a new item in the backlog/roadmap has been created.
      // See: AG-12764 - [Charts] Add computed value to waterfall total/subtotal tooltip renderer params
      if (params.datum.totalType != null) {
        return renderTooltip(
          params.datum.label,
          {
            id: params.datum.id,
            key: params.itemId,
            value: totalsByIndex[params.datum.index].value,
          },
          displayConfiguration,
        );
      }

      return renderTooltip(
        params.datum.label,
        { id: params.datum.id, key: params.itemId, value: params.datum.value },
        displayConfiguration,
      );
    };

    const positiveGroup = chartDisplay.groups.find((group) => group.isPositive);
    const negativeGroup = chartDisplay.groups.find((group) => group.isPositive === false);

    return [
      {
        type: 'waterfall',
        xKey: 'label',
        yKey: 'value',
        totals,
        item: {
          positive: {
            fill: positiveGroup ? colorsByGroupId[positiveGroup.id].value : GREEN_1.value,
          },
          negative: {
            fill: negativeGroup ? colorsByGroupId[negativeGroup.id].value : RED_1.value,
          },
          total: {
            fill: '#d9d9d9',
          },
        },
        highlightStyle: {
          series: {
            enabled: true,
            dimOpacity: 0.5,
          },
        },
        tooltip: {
          enabled: true,
          showArrow: false,
          renderer: tooltipRenderer,
        },
      },
    ];
  }, [
    chartDisplay.groups,
    colorsByGroupId,
    driverDisplayConfigurationsById,
    totals,
    totalsByIndex,
  ]);

  const axes = useMemo(() => {
    const driverAxis = chartDisplay.axes.find((axis) => axis.type === ChartAxisType.Driver);
    const yAxis: AgCartesianAxisOptions = {
      gridLine: {
        enabled: false,
      },
      label: {
        color: 'gray.500',
        fontSize: 10,
        fontFamily: fonts.body,
        format:
          totalDisplayConfiguration.format === DriverFormat.Currency
            ? CURRENCY_FORMAT
            : totalDisplayConfiguration.format === DriverFormat.Percentage
              ? PERCENT_FORMAT
              : NUMBER_FORMAT,
      },
      line: { enabled: false },
      min: driverAxis?.driver?.min ?? undefined,
      max: driverAxis?.driver?.max ?? undefined,
      nice: driverAxis?.driver?.round !== false,
      position: driverAxis?.position === ChartElementPosition.Left ? 'left' : 'right',
      title: {
        enabled: Boolean(driverAxis?.showLabel),
        text: driverAxis?.name ?? '',
        fontFamily: fonts.body,
        fontSize: 10,
        color: 'gray.500',
        spacing: 8,
      },
      type: 'number',
    };

    const categoryAxis = chartDisplay.axes.find((axis) => axis.type === ChartAxisType.Category);
    const xAxis: AgCartesianAxisOptions = {
      gridLine: {
        enabled: false,
      },
      label: {
        color: 'gray.500',
        fontSize: 10,
        fontFamily: fonts.body,
      },
      line: { enabled: false },
      paddingInner: 0.2,
      paddingOuter: 0.1,
      position: categoryAxis?.position === ChartElementPosition.Top ? 'top' : 'bottom',
      title: {
        enabled: Boolean(categoryAxis?.showLabel),
        text: categoryAxis?.name ?? '',
        fontFamily: fonts.body,
        fontSize: 10,
        color: 'gray.500',
        spacing: 12,
      },
      type: 'category',
    };

    return [xAxis, yAxis];
  }, [chartDisplay.axes, totalDisplayConfiguration.format]);

  const options: AgCartesianChartOptions = useMemo(() => {
    return {
      axes,
      data,
      tooltip: {
        class: CHART_TOOLTIP_CLASSNAME,
        position: {
          type: 'pointer',
        },
      },
      legend: {
        enabled:
          chartDisplay.legend?.showLegend != null
            ? chartDisplay.legend.showLegend
            : driverIds.length > 1,
        reverseOrder: false,
        position:
          (chartDisplay.legend?.position?.toLowerCase() as AgChartLegendPosition) ?? 'bottom',
        preventHidingAll: true,
        maxWidth: chartDisplay.legend?.container?.maxWidth ?? undefined,
        maxHeight: chartDisplay.legend?.container?.maxHeight ?? undefined,
        item: {
          showSeriesStroke: false,
          marker: {
            size: 8,
            shape: 'square',
          },
          maxWidth: chartDisplay.legend?.item?.maxWidth ?? undefined,
          paddingX: chartDisplay.legend?.item?.paddingX ?? undefined,
          paddingY: chartDisplay.legend?.item?.paddingY ?? undefined,
          label: {
            fontFamily: fonts.body,
            fontSize: 10,
            color: 'gray.600',
            formatter: ({ itemId, value }) => {
              let base = value;

              const found = chartDisplay?.series.find((s) => s.id === itemId);
              if (!found) {
                return base;
              }

              const attributes = safeObjGet(attributesBySubDriverId[found.driverId ?? found.id]);
              if (attributes != null && attributes.length > 0) {
                for (const attr of attributes) {
                  base += ` [${attr.value}]`;
                }
              }

              const group = chartDisplay?.groups.find((g) =>
                g.seriesIds.includes(itemId as string),
              );
              if (group?.layerId != null) {
                base += ` [${layersById[group.layerId].name}]`;
                return base;
              }

              return base;
            },
          },
        },
      },
      series,
      width,
      height,
    };
  }, [
    attributesBySubDriverId,
    axes,
    chartDisplay,
    data,
    driverIds.length,
    height,
    layersById,
    series,
    width,
  ]);

  return <AgChart isLoading={isAnyDriverLoading} options={options} />;
};

export default React.memo(AgWaterfallChart);
