package net.finmath.marketdata.model.curves.locallinearregression;

import java.time.LocalDate;
import net.finmath.marketdata.model.curves.Curve;
import net.finmath.marketdata.model.curves.CurveInterface;
import net.finmath.marketdata.model.curves.DiscountCurve;
import org.apache.commons.math3.distribution.AbstractRealDistribution;
import org.apache.commons.math3.distribution.CauchyDistribution;
import org.apache.commons.math3.distribution.LaplaceDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.jblas.DoubleMatrix;
import org.jblas.Solve;

/* loaded from: input_file:net/finmath/marketdata/model/curves/locallinearregression/CurveEstimation.class */
public class CurveEstimation {
    private LocalDate referenceDate;
    private double bandwidth;
    private double[] X;
    private double[] Y;
    private Partition partition;
    private DiscountCurve regressionCurve;
    private AbstractRealDistribution kernel;

    /* loaded from: input_file:net/finmath/marketdata/model/curves/locallinearregression/CurveEstimation$Distribution.class */
    public enum Distribution {
        NORMAL,
        LAPLACE,
        CAUCHY
    }

    public CurveEstimation(LocalDate localDate, double d, double[] dArr, double[] dArr2, double[] dArr3, double d2, Distribution distribution) {
        this.regressionCurve = null;
        this.referenceDate = localDate;
        this.bandwidth = d;
        this.X = dArr;
        this.Y = dArr2;
        this.partition = new Partition((double[]) dArr3.clone(), d2);
        switch (distribution) {
            case LAPLACE:
                this.kernel = new LaplaceDistribution(0.0d, 1.0d);
                return;
            case CAUCHY:
                this.kernel = new CauchyDistribution();
                return;
            case NORMAL:
            default:
                this.kernel = new NormalDistribution();
                return;
        }
    }

    public CurveEstimation(LocalDate localDate, double d, double[] dArr, double[] dArr2, double[] dArr3, double d2) {
        this(localDate, d, dArr, dArr2, dArr3, d2, Distribution.NORMAL);
    }

    public CurveInterface getRegressionCurve() {
        if (this.regressionCurve != null) {
            return this.regressionCurve;
        }
        DoubleMatrix solveEquationSystem = solveEquationSystem();
        double[] dArr = new double[this.partition.getLength()];
        dArr[0] = solveEquationSystem.get(0);
        for (int i = 1; i < dArr.length; i++) {
            dArr[i] = dArr[i - 1] + (solveEquationSystem.get(i) * this.partition.getIntervalLength(i - 1));
        }
        return new Curve("RegressionCurve", this.referenceDate, Curve.InterpolationMethod.LINEAR, Curve.ExtrapolationMethod.CONSTANT, Curve.InterpolationEntity.VALUE, this.partition.getPoints(), dArr);
    }

    private DoubleMatrix solveEquationSystem() {
        DoubleMatrix doubleMatrix = new DoubleMatrix(this.partition.getLength());
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(this.partition.getLength(), this.partition.getLength());
        DoubleMatrix doubleMatrix3 = new DoubleMatrix(this.partition.getPoints());
        DoubleMatrix doubleMatrix4 = new DoubleMatrix(this.partition.getLength());
        for (int i = 1; i < doubleMatrix4.length; i++) {
            doubleMatrix4.put(i, this.partition.getPoint(i - 1));
        }
        DoubleMatrix put = doubleMatrix3.sub(doubleMatrix4).put(0, 1.0d);
        DoubleMatrix doubleMatrix5 = new DoubleMatrix(this.partition.getLength() - 1);
        DoubleMatrix doubleMatrix6 = new DoubleMatrix(1);
        DoubleMatrix doubleMatrix7 = new DoubleMatrix(this.partition.getLength() - 1);
        DoubleMatrix doubleMatrix8 = new DoubleMatrix(this.partition.getLength() - 1);
        new DoubleMatrix(this.partition.getLength() - 1, this.partition.getLength() - 1);
        DoubleMatrix doubleMatrix9 = new DoubleMatrix(this.partition.getLength() - 1);
        for (int i2 = 0; i2 < this.X.length; i2++) {
            DoubleMatrix doubleMatrix10 = new DoubleMatrix(this.partition.getLength());
            DoubleMatrix doubleMatrix11 = new DoubleMatrix(this.partition.getLength());
            DoubleMatrix doubleMatrix12 = new DoubleMatrix(this.partition.getLength());
            for (int i3 = 0; i3 < this.partition.getLength() - 1; i3++) {
                doubleMatrix10.put(i3, 1.0d);
                doubleMatrix5.put(i3, this.kernel.density((this.partition.getIntervalReferencePoint(i3) - this.X[i2]) / this.bandwidth));
                doubleMatrix12.put(i3 + 1, doubleMatrix5.get(i3));
                doubleMatrix11 = doubleMatrix11.add(doubleMatrix10.mmul(doubleMatrix5.get(i3)));
            }
            doubleMatrix = doubleMatrix.add(doubleMatrix4.neg().add(this.X[i2]).mul(doubleMatrix12).add(put.mul(doubleMatrix11)).mul(this.Y[i2]));
            doubleMatrix6 = doubleMatrix6.add(doubleMatrix11.get(0));
            doubleMatrix7 = doubleMatrix7.add(doubleMatrix3.getRange(0, doubleMatrix3.length - 1).neg().add(this.X[i2]).mul(doubleMatrix5).add(put.getRange(1, doubleMatrix3.length).mul(doubleMatrix11.getRange(1, doubleMatrix11.length))));
            doubleMatrix8 = doubleMatrix8.add(doubleMatrix3.getRange(0, doubleMatrix3.length - 1).neg().add(this.X[i2]).mul(doubleMatrix3.getRange(0, doubleMatrix3.length - 1).neg().add(this.X[i2])).mul(doubleMatrix5).add(put.getRange(1, doubleMatrix3.length).mul(put.getRange(1, doubleMatrix3.length).mul(doubleMatrix11.getRange(1, doubleMatrix11.length)))));
            doubleMatrix9 = doubleMatrix9.add(doubleMatrix3.getRange(0, doubleMatrix3.length - 1).neg().add(this.X[i2]).mul(doubleMatrix5).add(put.getRange(1, put.length).mul(doubleMatrix11.getRange(1, doubleMatrix11.length))));
        }
        DoubleMatrix doubleMatrix13 = new DoubleMatrix(this.partition.getLength() - 1, this.partition.getLength() - 1);
        DoubleMatrix ones = DoubleMatrix.ones(this.partition.getLength() - 1);
        for (int i4 = 0; i4 < ones.length - 1; i4++) {
            ones.put(i4, 0.0d);
            doubleMatrix13.putColumn(i4, ones.mul(put.get(i4 + 1)));
        }
        DoubleMatrix mulColumnVector = doubleMatrix13.mulColumnVector(doubleMatrix9);
        DoubleMatrix add = mulColumnVector.add(mulColumnVector.transpose()).add(DoubleMatrix.diag(doubleMatrix8));
        int[] iArr = new int[this.partition.getLength() - 1];
        for (int i5 = 0; i5 < iArr.length; i5++) {
            iArr[i5] = i5 + 1;
        }
        doubleMatrix2.put(0, 0, doubleMatrix6.get(0));
        doubleMatrix2.put(iArr, 0, doubleMatrix7);
        doubleMatrix2.put(0, iArr, doubleMatrix7.transpose());
        doubleMatrix2.put(iArr, iArr, add);
        return Solve.solve(doubleMatrix2, doubleMatrix);
    }
}
