package org.openimaj.ml.linear.learner.perceptron;

import ch.akuhn.matrix.DenseMatrix;
import ch.akuhn.matrix.Matrix;
import ch.akuhn.matrix.Vector;
import org.openimaj.math.matrix.MatlibMatrixUtils;
import org.openimaj.ml.linear.kernel.VectorKernel;
import org.openimaj.util.pair.IndependentPair;

/* loaded from: input_file:org/openimaj/ml/linear/learner/perceptron/Projectron.class */
public class Projectron extends MatrixKernelPerceptron {
    private static final double DEFAULT_ETA = 0.009999999776482582d;
    private Matrix Kinv;
    private double eta;

    public Projectron(VectorKernel vectorKernel, double d) {
        super(vectorKernel);
        this.eta = d;
        this.Kinv = DenseMatrix.dense(0, 0);
    }

    public Projectron(VectorKernel vectorKernel) {
        this(vectorKernel, DEFAULT_ETA);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.openimaj.ml.linear.learner.perceptron.MatrixKernelPerceptron, org.openimaj.ml.linear.learner.perceptron.KernelPerceptron
    public void update(double[] dArr, PerceptronClass perceptronClass, PerceptronClass perceptronClass2) {
        double doubleValue = ((Double) this.kernel.apply(IndependentPair.pair(dArr, dArr))).doubleValue();
        Vector calculatekt = calculatekt(dArr);
        Vector mult = this.Kinv.mult(calculatekt);
        double max = Math.max(doubleValue - mult.dot(calculatekt), 0.0d);
        if (max <= this.eta) {
            updateWeights(perceptronClass, mult);
        } else {
            super.update(dArr, perceptronClass, perceptronClass2);
            updateKinv(mult, max);
        }
    }

    private void updateWeights(PerceptronClass perceptronClass, Vector vector) {
        for (int i = 0; i < vector.size(); i++) {
            this.weights.set(i, Double.valueOf(this.weights.get(i).doubleValue() + (perceptronClass.v() * vector.get(i))));
        }
    }

    @Override // org.openimaj.ml.linear.learner.perceptron.MatrixKernelPerceptron, org.openimaj.ml.linear.learner.perceptron.KernelPerceptron
    public double getBias() {
        return 0.0d;
    }

    private void updateKinv(Vector vector, double d) {
        Matrix dense;
        if (this.supports.size() > 1) {
            Matrix dense2 = DenseMatrix.dense(Vector.dense(vector.size() + 1).size(), 1);
            MatlibMatrixUtils.setSubVector(dense2.column(0), 0, vector);
            dense2.column(0).put(vector.size(), -1.0d);
            dense = new DenseMatrix(this.Kinv.rowCount() + 1, this.Kinv.columnCount() + 1);
            MatlibMatrixUtils.setSubMatrix(dense, 0, 0, this.Kinv);
            Matrix newInstance = dense.newInstance();
            MatlibMatrixUtils.dotProductTranspose(dense2, dense2, newInstance);
            MatlibMatrixUtils.scaleInplace(newInstance, 1.0d / d);
            MatlibMatrixUtils.plusInplace(dense, newInstance);
        } else {
            double[] dArr = this.supports.get(0);
            dense = DenseMatrix.dense(1, 1);
            dense.put(0, 0, 1.0d / ((Double) this.kernel.apply(IndependentPair.pair(dArr, dArr))).doubleValue());
        }
        this.Kinv = dense;
    }

    private Vector calculatekt(double[] dArr) {
        Vector dense = Vector.dense(this.supports.size());
        for (int i = 0; i < this.supports.size(); i++) {
            dense.put(i, ((Double) this.kernel.apply(IndependentPair.pair(dArr, this.supports.get(i)))).doubleValue());
        }
        return dense;
    }
}
