package org.deeplearning4j.optimize;

import cc.mallet.optimize.Optimizable;
import java.io.Serializable;
import org.deeplearning4j.nn.LogisticRegression;
import org.deeplearning4j.nn.gradient.LogisticRegressionGradient;

/* loaded from: input_file:org/deeplearning4j/optimize/LogisticRegressionOptimizer.class */
public class LogisticRegressionOptimizer implements Optimizable.ByGradientValue, Serializable {
    private static final long serialVersionUID = 5229426347154854746L;
    private LogisticRegression logReg;
    private double lr;

    public LogisticRegressionOptimizer(LogisticRegression logisticRegression, double d) {
        this.logReg = logisticRegression;
        this.lr = d;
    }

    public int getNumParameters() {
        return this.logReg.getW().length + this.logReg.getB().length;
    }

    public void getParameters(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = getParameter(i);
        }
    }

    public double getParameter(int i) {
        return i >= this.logReg.getW().length ? this.logReg.getB().get(i - this.logReg.getW().length) : this.logReg.getW().get(i);
    }

    public void setParameters(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            setParameter(i, dArr[i]);
        }
    }

    public void setParameter(int i, double d) {
        if (i >= this.logReg.getW().length) {
            this.logReg.getB().put(i - this.logReg.getW().length, d);
        } else {
            this.logReg.getW().put(i, d);
        }
    }

    public void getValueGradient(double[] dArr) {
        LogisticRegressionGradient gradient = this.logReg.getGradient(this.lr);
        for (int i = 0; i < dArr.length; i++) {
            if (i < this.logReg.getW().length) {
                dArr[i] = gradient.getwGradient().get(i);
            } else {
                dArr[i] = gradient.getbGradient().get(i - this.logReg.getW().length);
            }
        }
    }

    public double getValue() {
        return -this.logReg.negativeLogLikelihood();
    }
}
