package net.finmath.optimizer;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.logging.Level;
import java.util.logging.Logger;
import net.finmath.functions.LinearAlgebra;
import net.finmath.montecarlo.RandomVariableFromDoubleArray;
import net.finmath.stochastic.RandomVariable;

/* loaded from: input_file:net/finmath/optimizer/StochasticLevenbergMarquardt.class */
public abstract class StochasticLevenbergMarquardt implements Serializable, Cloneable, StochasticOptimizer {
    private static final long serialVersionUID = 4560864869394838155L;
    private final RegularizationMethod regularizationMethod;
    private RandomVariable[] initialParameters;
    private RandomVariable[] parameterSteps;
    private RandomVariable[] targetValues;
    private final int maxIteration;
    private double lambda;
    private final double lambdaInitialValue = 0.001d;
    private double lambdaDivisor;
    private double lambdaMultiplicator;
    private final double errorTolerance;
    private int iteration;
    private RandomVariable[] parameterTest;
    private RandomVariable[] valueTest;
    private RandomVariable[] parameterCurrent;
    private RandomVariable[] valueCurrent;
    private RandomVariable[][] derivativeCurrent;
    private double errorMeanSquaredCurrent;
    private double errorRootMeanSquaredChange;
    private boolean isParameterCurrentDerivativeValid;
    private int numberOfThreads;
    private ExecutorService executor;
    private boolean executorShutdownWhenDone;
    private final Logger logger;

    /* loaded from: input_file:net/finmath/optimizer/StochasticLevenbergMarquardt$RegularizationMethod.class */
    public enum RegularizationMethod {
        LEVENBERG,
        LEVENBERG_MARQUARDT
    }

    public static void main(String[] strArr) throws SolverException {
        StochasticLevenbergMarquardt stochasticLevenbergMarquardt = new StochasticLevenbergMarquardt(new RandomVariable[]{new RandomVariableFromDoubleArray(2.0d), new RandomVariableFromDoubleArray(2.0d)}, new RandomVariable[]{new RandomVariableFromDoubleArray(25.0d), new RandomVariableFromDoubleArray(100.0d)}, new RandomVariable[]{new RandomVariableFromDoubleArray(1.0d), new RandomVariableFromDoubleArray(1.0d)}, 100, 1.0E-12d, null) { // from class: net.finmath.optimizer.StochasticLevenbergMarquardt.1
            private static final long serialVersionUID = -282626938650139518L;

            @Override // net.finmath.optimizer.StochasticLevenbergMarquardt
            public void setValues(RandomVariable[] randomVariableArr, RandomVariable[] randomVariableArr2) {
                randomVariableArr2[0] = randomVariableArr[0].mult(0.0d).add(randomVariableArr[1]).squared();
                randomVariableArr2[1] = randomVariableArr[0].mult(2.0d).add(randomVariableArr[1]).squared();
            }

            @Override // net.finmath.optimizer.StochasticLevenbergMarquardt
            /* renamed from: clone */
            public /* bridge */ /* synthetic */ Object mo107clone() throws CloneNotSupportedException {
                return super.mo107clone();
            }
        };
        stochasticLevenbergMarquardt.run();
        RandomVariable[] bestFitParameters = stochasticLevenbergMarquardt.getBestFitParameters();
        System.out.println("The solver for problem 1 required " + stochasticLevenbergMarquardt.getIterations() + " iterations. The best fit parameters are:");
        for (int i = 0; i < bestFitParameters.length; i++) {
            System.out.println("\tparameter[" + i + "]: " + bestFitParameters[i]);
        }
        System.out.println("The solver accuracy is " + stochasticLevenbergMarquardt.getRootMeanSquaredError());
    }

    public StochasticLevenbergMarquardt(RegularizationMethod regularizationMethod, RandomVariable[] randomVariableArr, RandomVariable[] randomVariableArr2, RandomVariable[] randomVariableArr3, int i, double d, ExecutorService executorService) {
        this.initialParameters = null;
        this.parameterSteps = null;
        this.targetValues = null;
        this.lambdaInitialValue = 0.001d;
        this.lambdaDivisor = 3.0d;
        this.lambdaMultiplicator = 2.0d;
        this.iteration = 0;
        this.parameterTest = null;
        this.valueTest = null;
        this.parameterCurrent = null;
        this.valueCurrent = null;
        this.derivativeCurrent = null;
        this.errorMeanSquaredCurrent = Double.POSITIVE_INFINITY;
        this.errorRootMeanSquaredChange = Double.POSITIVE_INFINITY;
        this.numberOfThreads = 1;
        this.executor = null;
        this.executorShutdownWhenDone = true;
        this.logger = Logger.getLogger("net.finmath");
        this.regularizationMethod = regularizationMethod;
        this.initialParameters = randomVariableArr;
        this.targetValues = randomVariableArr2;
        this.parameterSteps = randomVariableArr3;
        this.maxIteration = i;
        this.errorTolerance = d;
        this.executor = executorService;
        this.executorShutdownWhenDone = executorService == null;
    }

    public StochasticLevenbergMarquardt(RandomVariable[] randomVariableArr, RandomVariable[] randomVariableArr2, RandomVariable[] randomVariableArr3, int i, double d, ExecutorService executorService) {
        this(RegularizationMethod.LEVENBERG_MARQUARDT, randomVariableArr, randomVariableArr2, randomVariableArr3, i, d, executorService);
    }

    public StochasticLevenbergMarquardt(RegularizationMethod regularizationMethod, RandomVariable[] randomVariableArr, RandomVariable[] randomVariableArr2, RandomVariable[] randomVariableArr3, int i, double d, int i2) {
        this(regularizationMethod, randomVariableArr, randomVariableArr2, randomVariableArr3, i, d, (ExecutorService) null);
        this.numberOfThreads = i2;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    public double getLambdaMultiplicator() {
        return this.lambdaMultiplicator;
    }

    public void setLambdaMultiplicator(double d) {
        if (d <= 1.0d) {
            throw new IllegalArgumentException("Parameter lambdaMultiplicator is required to be > 1.");
        }
        this.lambdaMultiplicator = d;
    }

    public double getLambdaDivisor() {
        return this.lambdaDivisor;
    }

    public void setLambdaDivisor(double d) {
        if (d <= 1.0d) {
            throw new IllegalArgumentException("Parameter lambdaDivisor is required to be > 1.");
        }
        this.lambdaDivisor = d;
    }

    @Override // net.finmath.optimizer.StochasticOptimizer
    public RandomVariable[] getBestFitParameters() {
        return this.parameterCurrent;
    }

    @Override // net.finmath.optimizer.StochasticOptimizer
    public double getRootMeanSquaredError() {
        return Math.sqrt(this.errorMeanSquaredCurrent);
    }

    public void setErrorMeanSquaredCurrent(double d) {
        this.errorMeanSquaredCurrent = d;
    }

    @Override // net.finmath.optimizer.StochasticOptimizer
    public int getIterations() {
        return this.iteration;
    }

    protected void prepareAndSetValues(RandomVariable[] randomVariableArr, RandomVariable[] randomVariableArr2) throws SolverException {
        setValues(randomVariableArr, randomVariableArr2);
    }

    protected void prepareAndSetDerivatives(RandomVariable[] randomVariableArr, RandomVariable[] randomVariableArr2, RandomVariable[][] randomVariableArr3) throws SolverException {
        setDerivatives(randomVariableArr, randomVariableArr3);
    }

    public abstract void setValues(RandomVariable[] randomVariableArr, RandomVariable[] randomVariableArr2) throws SolverException;

    public void setDerivatives(RandomVariable[] randomVariableArr, RandomVariable[][] randomVariableArr2) throws SolverException {
        RandomVariable[] randomVariableArr3 = this.parameterCurrent;
        Vector vector = new Vector(this.parameterCurrent.length);
        for (int i = 0; i < this.parameterCurrent.length; i++) {
            final RandomVariable[] randomVariableArr4 = (RandomVariable[]) randomVariableArr3.clone();
            final RandomVariable[] randomVariableArr5 = randomVariableArr2[i];
            final int i2 = i;
            Callable<RandomVariable[]> callable = new Callable<RandomVariable[]>() { // from class: net.finmath.optimizer.StochasticLevenbergMarquardt.2
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.concurrent.Callable
                public RandomVariable[] call() {
                    RandomVariable mult = StochasticLevenbergMarquardt.this.parameterSteps != null ? StochasticLevenbergMarquardt.this.parameterSteps[i2] : randomVariableArr4[i2].abs().add(1.0d).mult(1.0E-8d);
                    randomVariableArr4[i2] = randomVariableArr4[i2].add(mult);
                    try {
                        StochasticLevenbergMarquardt.this.prepareAndSetValues(randomVariableArr4, randomVariableArr5);
                    } catch (Exception e) {
                        Arrays.fill(randomVariableArr5, new RandomVariableFromDoubleArray(Double.NaN));
                    }
                    for (int i3 = 0; i3 < StochasticLevenbergMarquardt.this.valueCurrent.length; i3++) {
                        randomVariableArr5[i3] = randomVariableArr5[i3].sub(StochasticLevenbergMarquardt.this.valueCurrent[i3]).div(mult);
                    }
                    return randomVariableArr5;
                }
            };
            if (this.executor != null) {
                vector.add(i, this.executor.submit(callable));
            } else {
                FutureTask futureTask = new FutureTask(callable);
                futureTask.run();
                vector.add(i, futureTask);
            }
        }
        for (int i3 = 0; i3 < this.parameterCurrent.length; i3++) {
            try {
                randomVariableArr2[i3] = (RandomVariable[]) ((Future) vector.get(i3)).get();
            } catch (InterruptedException e) {
                throw new SolverException(e);
            } catch (ExecutionException e2) {
                throw new SolverException(e2);
            }
        }
    }

    boolean done() {
        return this.iteration > this.maxIteration || this.errorRootMeanSquaredChange <= this.errorTolerance;
    }

    @Override // net.finmath.optimizer.StochasticOptimizer
    public void run() throws SolverException {
        if (this.numberOfThreads > 1 && this.executor == null) {
            this.executor = Executors.newFixedThreadPool(this.numberOfThreads);
            this.executorShutdownWhenDone = true;
        }
        try {
            int length = this.initialParameters.length;
            int length2 = this.targetValues.length;
            this.parameterTest = (RandomVariable[]) this.initialParameters.clone();
            this.parameterCurrent = (RandomVariable[]) this.initialParameters.clone();
            this.valueTest = new RandomVariable[length2];
            this.valueCurrent = new RandomVariable[length2];
            Arrays.fill(this.valueCurrent, new RandomVariableFromDoubleArray(Double.NaN));
            this.derivativeCurrent = new RandomVariable[length][length2];
            this.iteration = 0;
            this.lambda = 0.001d;
            this.isParameterCurrentDerivativeValid = false;
            while (true) {
                this.iteration++;
                prepareAndSetValues(this.parameterTest, this.valueTest);
                double meanSquaredError = getMeanSquaredError(this.valueTest);
                boolean z = this.errorMeanSquaredCurrent > meanSquaredError;
                if (z) {
                    this.parameterCurrent = (RandomVariable[]) this.parameterTest.clone();
                    this.valueCurrent = (RandomVariable[]) this.valueTest.clone();
                    this.errorRootMeanSquaredChange = Math.sqrt(this.errorMeanSquaredCurrent) - Math.sqrt(meanSquaredError);
                    this.errorMeanSquaredCurrent = meanSquaredError;
                }
                if (done()) {
                    break;
                }
                this.isParameterCurrentDerivativeValid = !z;
                if (z) {
                    this.lambda /= this.lambdaDivisor;
                } else {
                    this.lambda *= this.lambdaMultiplicator;
                }
                prepareAndSetDerivatives(this.parameterTest, this.valueTest, this.derivativeCurrent);
                double[] dArr = new double[this.parameterCurrent.length];
                double[][] dArr2 = new double[this.parameterCurrent.length][this.parameterCurrent.length];
                double[] dArr3 = new double[this.parameterCurrent.length];
                boolean z2 = true;
                while (z2) {
                    for (int i = 0; i < this.parameterCurrent.length; i++) {
                        for (int i2 = i; i2 < this.parameterCurrent.length; i2++) {
                            double d = 0.0d;
                            for (int i3 = 0; i3 < this.valueCurrent.length; i3++) {
                                if (this.derivativeCurrent[i][i3] != null && this.derivativeCurrent[i2][i3] != null) {
                                    d += this.derivativeCurrent[i][i3].mult(this.derivativeCurrent[i2][i3]).getAverage();
                                }
                            }
                            if (i == i2) {
                                d = this.regularizationMethod == RegularizationMethod.LEVENBERG ? d + this.lambda : d == 0.0d ? this.lambda : d * (1.0d + this.lambda);
                            }
                            dArr2[i][i2] = d;
                            dArr2[i2][i] = d;
                        }
                    }
                    for (int i4 = 0; i4 < this.parameterCurrent.length; i4++) {
                        double d2 = 0.0d;
                        RandomVariable[] randomVariableArr = this.derivativeCurrent[i4];
                        for (int i5 = 0; i5 < this.valueCurrent.length; i5++) {
                            if (randomVariableArr[i5] != null) {
                                d2 += this.targetValues[i5].sub(this.valueCurrent[i5]).mult(randomVariableArr[i5]).getAverage();
                            }
                        }
                        dArr3[i4] = d2;
                    }
                    try {
                        dArr = LinearAlgebra.solveLinearEquationSymmetric(dArr2, dArr3);
                        z2 = false;
                    } catch (Exception e) {
                        z2 = true;
                        this.lambda *= 16.0d;
                    }
                }
                for (int i6 = 0; i6 < this.parameterCurrent.length; i6++) {
                    this.parameterTest[i6] = this.parameterCurrent[i6].add(dArr[i6]);
                }
                if (this.logger.isLoggable(Level.FINE)) {
                    String str = "Iteration: " + this.iteration + "\tLambda=" + this.lambda + "\tError Current (RMS):" + Math.sqrt(this.errorMeanSquaredCurrent) + "\tError Change:" + this.errorRootMeanSquaredChange + "\t";
                    for (int i7 = 0; i7 < this.parameterCurrent.length; i7++) {
                        str = str + "[" + i7 + "] = " + this.parameterCurrent[i7].doubleValue() + "\t";
                    }
                    this.logger.fine(str);
                }
            }
        } finally {
            if (this.executor != null && this.executorShutdownWhenDone) {
                this.executor.shutdown();
                this.executor = null;
            }
        }
    }

    public double getMeanSquaredError(RandomVariable[] randomVariableArr) {
        double d = 0.0d;
        for (int i = 0; i < randomVariableArr.length; i++) {
            d += randomVariableArr[i].sub(this.targetValues[i]).squared().getAverage();
        }
        return d / randomVariableArr.length;
    }

    @Override // 
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public StochasticLevenbergMarquardt mo107clone() throws CloneNotSupportedException {
        throw new CloneNotSupportedException();
    }

    public StochasticLevenbergMarquardt getCloneWithModifiedTargetValues(RandomVariable[] randomVariableArr, RandomVariable[] randomVariableArr2, boolean z) throws CloneNotSupportedException {
        StochasticLevenbergMarquardt mo107clone = mo107clone();
        mo107clone.targetValues = (RandomVariable[]) randomVariableArr.clone();
        if (z && done()) {
            mo107clone.initialParameters = getBestFitParameters();
        }
        return mo107clone;
    }

    public StochasticLevenbergMarquardt getCloneWithModifiedTargetValues(List<RandomVariable> list, List<RandomVariable> list2, boolean z) throws CloneNotSupportedException {
        StochasticLevenbergMarquardt mo107clone = mo107clone();
        mo107clone.targetValues = (RandomVariable[]) list.toArray(new RandomVariable[0]);
        if (z && done()) {
            mo107clone.initialParameters = getBestFitParameters();
        }
        return mo107clone;
    }
}
