import PPNode from '../../../classes/NodeClass';
import Socket from '../../../classes/SocketClass';
import { NODE_TYPE_COLOR, SOCKET_TYPE } from '../../../utils/constants';
import { TRgba } from '../../../utils/interfaces';
import { GraphInputPointXY, GraphInputXYType } from '../../datatypes/graphInputType';
import { NumberType } from '../../datatypes/numberType';
import { inputDataName } from './scatterGraph';

export const outputTrendFunctionM = 'Trend Function M';
export const outputTrendFunctionB = 'Trend Function B';

export function calculateTrendLine(
  inputData: GraphInputPointXY[],
): [number, number] {
  const n = inputData.length;

  // Calculate sums
  const sumX = inputData.reduce((sum, point) => sum + point.Value1, 0);
  const sumY = inputData.reduce((sum, point) => sum + point.Value2, 0);
  const sumXY = inputData.reduce(
    (sum, point) => sum + point.Value1 * point.Value2,
    0,
  );
  const sumXSquared = inputData.reduce(
    (sum, point) => sum + point.Value1 * point.Value1,
    0,
  );

  // Calculate slope (m)
  const calculatedM =
    (n * sumXY - sumX * sumY) / (n * sumXSquared - sumX * sumX);

  // Calculate y-intercept (b)
  const calculatedB = (sumY - calculatedM * sumX) / n;

  return [calculatedM, calculatedB];
}


export class TrendLine extends PPNode {
  getColor(): TRgba {
    return TRgba.fromString(NODE_TYPE_COLOR.DRAW);
  }

  public getName(): string {
    return 'Calculate Trend Line';
  }

  public getDescription(): string {
    return 'Uses graph input and calculates the minimal square distance linear trend line';
  }

  public getTags(): string[] {
    return ['Input'].concat(super.getTags());
  }

  protected getDefaultIO(): Socket[] {
    return [
      new Socket(SOCKET_TYPE.IN, inputDataName, new GraphInputXYType()),
      new Socket(SOCKET_TYPE.OUT, outputTrendFunctionM, new NumberType()),
      new Socket(SOCKET_TYPE.OUT, outputTrendFunctionB, new NumberType())
    ].concat(super.getDefaultIO());
  }

  public async onExecute(input,output): Promise<void> {
    const points = input[inputDataName];
    const [m, b] = calculateTrendLine(points);
    output[outputTrendFunctionM] = m;
    output[outputTrendFunctionB] = b;
  };
}
