package org.deeplearning4j.nn;

import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.activation.ActivationFunction;
import org.deeplearning4j.nn.activation.Sigmoid;
import org.deeplearning4j.optimize.MultiLayerNetworkOptimizer;
import org.deeplearning4j.transformation.MatrixTransform;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/BaseMultiLayerNetwork.class */
public abstract class BaseMultiLayerNetwork implements Serializable, Persistable {
    private static Logger log = LoggerFactory.getLogger(BaseMultiLayerNetwork.class);
    private static final long serialVersionUID = -5029161847383716484L;
    private int nIns;
    private int[] hiddenLayerSizes;
    private int nOuts;
    private int nLayers;
    private HiddenLayer[] sigmoidLayers;
    private LogisticRegression logLayer;
    private RandomGenerator rng;
    private RealDistribution dist;
    private double momentum;
    private DoubleMatrix input;
    private DoubleMatrix labels;
    private MultiLayerNetworkOptimizer optimizer;
    private ActivationFunction activation;
    private boolean toDecode;
    private double l2;
    private boolean shouldInit;
    private double fanIn;
    private int renderWeightsEveryNEpochs;
    private boolean useRegularization;
    private Map<Integer, MatrixTransform> weightTransforms;
    private boolean shouldBackProp;
    private boolean forceNumEpochs;
    private double sparsity;
    private DoubleMatrix columnSums;
    private DoubleMatrix columnMeans;
    private DoubleMatrix columnStds;
    public double learningRateUpdate;
    public NeuralNetwork[] layers;
    public double errorTolerance;

    /* loaded from: input_file:org/deeplearning4j/nn/BaseMultiLayerNetwork$Builder.class */
    public static class Builder<E extends BaseMultiLayerNetwork> {
        protected Class<? extends BaseMultiLayerNetwork> clazz;
        private E ret;
        private int nIns;
        private int[] hiddenLayerSizes;
        private int nOuts;
        private int nLayers;
        private DoubleMatrix input;
        private DoubleMatrix labels;
        private ActivationFunction activation;
        private double momentum;
        private RealDistribution dist;
        private RandomGenerator rng = new MersenneTwister(1234);
        private boolean decode = false;
        private double fanIn = -1.0d;
        private int renderWeithsEveryNEpochs = -1;
        private double l2 = 0.01d;
        private boolean useRegularization = true;
        protected Map<Integer, MatrixTransform> weightTransforms = new HashMap();
        protected boolean backProp = true;
        protected boolean shouldForceEpochs = false;
        private double sparsity = 0.0d;

        public Builder<E> withSparsity(double d) {
            this.sparsity = d;
            return this;
        }

        public Builder<E> forceEpochs() {
            this.shouldForceEpochs = true;
            return this;
        }

        public Builder<E> disableBackProp() {
            this.backProp = false;
            return this;
        }

        public Builder<E> transformWeightsAt(int i, MatrixTransform matrixTransform) {
            this.weightTransforms.put(Integer.valueOf(i), matrixTransform);
            return this;
        }

        public Builder<E> transformWeightsAt(Map<Integer, MatrixTransform> map) {
            this.weightTransforms.putAll(map);
            return this;
        }

        public Builder<E> withDist(RealDistribution realDistribution) {
            this.dist = realDistribution;
            return this;
        }

        public Builder<E> withMomentum(double d) {
            this.momentum = d;
            return this;
        }

        public Builder<E> useRegularization(boolean z) {
            this.useRegularization = z;
            return this;
        }

        public Builder<E> withL2(double d) {
            this.l2 = d;
            return this;
        }

        public Builder<E> renderWeights(int i) {
            this.renderWeithsEveryNEpochs = i;
            return this;
        }

        public Builder<E> withFanIn(Double d) {
            this.fanIn = d.doubleValue();
            return this;
        }

        public Builder<E> withActivation(ActivationFunction activationFunction) {
            this.activation = activationFunction;
            return this;
        }

        public Builder<E> numberOfInputs(int i) {
            this.nIns = i;
            return this;
        }

        public Builder<E> decodeNetwork(boolean z) {
            this.decode = z;
            return this;
        }

        public Builder<E> hiddenLayerSizes(int[] iArr) {
            this.hiddenLayerSizes = iArr;
            this.nLayers = iArr.length;
            return this;
        }

        public Builder<E> numberOfOutPuts(int i) {
            this.nOuts = i;
            return this;
        }

        public Builder<E> withRng(RandomGenerator randomGenerator) {
            this.rng = randomGenerator;
            return this;
        }

        public Builder<E> withInput(DoubleMatrix doubleMatrix) {
            this.input = doubleMatrix;
            return this;
        }

        public Builder<E> withLabels(DoubleMatrix doubleMatrix) {
            this.labels = doubleMatrix;
            return this;
        }

        public Builder<E> withClazz(Class<? extends BaseMultiLayerNetwork> cls) {
            this.clazz = cls;
            return this;
        }

        public E buildEmpty() {
            try {
                return (E) this.clazz.newInstance();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        public E build() {
            try {
                this.ret = (E) this.clazz.newInstance();
                this.ret.setInput(this.input);
                this.ret.setnOuts(this.nOuts);
                this.ret.setnIns(this.nIns);
                this.ret.setLabels(this.labels);
                this.ret.setHiddenLayerSizes(this.hiddenLayerSizes);
                this.ret.setnLayers(this.nLayers);
                this.ret.setRng(this.rng);
                this.ret.setShouldBackProp(this.backProp);
                this.ret.setSigmoidLayers(new HiddenLayer[this.ret.getnLayers()]);
                this.ret.setToDecode(this.decode);
                this.ret.setInput(this.input);
                this.ret.setMomentum(this.momentum);
                this.ret.setLabels(this.labels);
                this.ret.setFanIn(this.fanIn);
                this.ret.setSparsity(this.sparsity);
                this.ret.setRenderWeightsEveryNEpochs(this.renderWeithsEveryNEpochs);
                this.ret.setL2(this.l2);
                this.ret.setForceNumEpochs(this.shouldForceEpochs);
                this.ret.setUseRegularization(this.useRegularization);
                if (this.activation != null) {
                    this.ret.setActivation(this.activation);
                }
                if (this.dist != null) {
                    this.ret.setDist(this.dist);
                }
                this.ret.getWeightTransforms().putAll(this.weightTransforms);
                return this.ret;
            } catch (IllegalAccessException | InstantiationException e) {
                throw new RuntimeException(e);
            }
        }
    }

    public BaseMultiLayerNetwork() {
        this.momentum = 0.1d;
        this.activation = new Sigmoid();
        this.l2 = 0.01d;
        this.shouldInit = true;
        this.fanIn = -1.0d;
        this.renderWeightsEveryNEpochs = -1;
        this.useRegularization = true;
        this.weightTransforms = new HashMap();
        this.shouldBackProp = true;
        this.forceNumEpochs = false;
        this.sparsity = 0.0d;
        this.learningRateUpdate = 0.95d;
        this.errorTolerance = 1.0E-4d;
    }

    public BaseMultiLayerNetwork(int i, int[] iArr, int i2, int i3, RandomGenerator randomGenerator) {
        this(i, iArr, i2, i3, randomGenerator, null, null);
    }

    public BaseMultiLayerNetwork(int i, int[] iArr, int i2, int i3, RandomGenerator randomGenerator, DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        this.momentum = 0.1d;
        this.activation = new Sigmoid();
        this.l2 = 0.01d;
        this.shouldInit = true;
        this.fanIn = -1.0d;
        this.renderWeightsEveryNEpochs = -1;
        this.useRegularization = true;
        this.weightTransforms = new HashMap();
        this.shouldBackProp = true;
        this.forceNumEpochs = false;
        this.sparsity = 0.0d;
        this.learningRateUpdate = 0.95d;
        this.errorTolerance = 1.0E-4d;
        this.nIns = i;
        this.hiddenLayerSizes = iArr;
        this.input = doubleMatrix.dup();
        this.labels = doubleMatrix2.dup();
        if (iArr.length != i3) {
            throw new IllegalArgumentException("The number of hidden layer sizes must be equivalent to the nLayers argument which is a value of " + i3);
        }
        this.nOuts = i2;
        this.nLayers = i3;
        this.sigmoidLayers = new HiddenLayer[i3];
        this.layers = createNetworkLayers(i3);
        if (randomGenerator == null) {
            this.rng = new MersenneTwister(123);
        } else {
            this.rng = randomGenerator;
        }
        if (doubleMatrix != null) {
            initializeLayers(doubleMatrix);
        }
    }

    public double fanIn() {
        return this.fanIn < 0.0d ? 1.0d / this.nIns : this.fanIn;
    }

    private void dimensionCheck() {
        for (int i = 0; i < this.nLayers; i++) {
            HiddenLayer hiddenLayer = this.sigmoidLayers[i];
            NeuralNetwork neuralNetwork = this.layers[i];
            hiddenLayer.getW().assertSameSize(neuralNetwork.getW());
            hiddenLayer.getB().assertSameSize(neuralNetwork.gethBias());
            if (i < this.nLayers - 1) {
                HiddenLayer hiddenLayer2 = this.sigmoidLayers[i + 1];
                NeuralNetwork neuralNetwork2 = this.layers[i + 1];
                if (hiddenLayer2.getnIn() != hiddenLayer.getnOut()) {
                    throw new IllegalStateException("Invalid structure: hidden layer in for " + (i + 1) + " not equal to number of ins " + i);
                }
                if (neuralNetwork.getnHidden() != neuralNetwork2.getnVisible()) {
                    throw new IllegalStateException("Invalid structure: network hidden for " + (i + 1) + " not equal to number of visible " + i);
                }
            }
        }
        if (this.sigmoidLayers[this.sigmoidLayers.length - 1].getnOut() != this.logLayer.getnIn()) {
            throw new IllegalStateException("Number of outputs for final hidden layer not equal to the number of logistic input units for output layer");
        }
    }

    public void asDecoder(BaseMultiLayerNetwork baseMultiLayerNetwork) {
        createNetworkLayers(baseMultiLayerNetwork.nLayers + 1);
        this.layers = new NeuralNetwork[baseMultiLayerNetwork.nLayers];
        this.sigmoidLayers = new HiddenLayer[baseMultiLayerNetwork.nLayers];
        this.hiddenLayerSizes = new int[baseMultiLayerNetwork.nLayers];
        this.nIns = baseMultiLayerNetwork.nOuts;
        this.nOuts = baseMultiLayerNetwork.nIns;
        this.nLayers = baseMultiLayerNetwork.nLayers;
        this.dist = baseMultiLayerNetwork.dist;
        int i = 0;
        for (int i2 = baseMultiLayerNetwork.nLayers - 1; i2 >= 0; i2--) {
            this.layers[i] = baseMultiLayerNetwork.layers[i2].m13clone();
            this.layers[i].setRng(baseMultiLayerNetwork.layers[i2].getRng());
            this.hiddenLayerSizes[i] = baseMultiLayerNetwork.hiddenLayerSizes[i2];
            i++;
        }
        this.rng = baseMultiLayerNetwork.rng;
        this.shouldInit = false;
    }

    public void initializeLayers(DoubleMatrix doubleMatrix) {
        if (doubleMatrix == null) {
            throw new IllegalArgumentException("Unable to initialize layers with empty input");
        }
        if (doubleMatrix.columns != this.nIns) {
            throw new IllegalArgumentException(String.format("Unable to train on number of inputs; columns should be equal to number of inputs. Number of inputs was %d while number of columns was %d", Integer.valueOf(this.nIns), Integer.valueOf(doubleMatrix.columns)));
        }
        if (this.layers == null) {
            this.layers = new NeuralNetwork[this.nLayers];
        }
        for (int i = 0; i < this.hiddenLayerSizes.length; i++) {
            if (this.hiddenLayerSizes[i] < 1) {
                throw new IllegalArgumentException("All hidden layer sizes must be >= 1");
            }
        }
        this.input = doubleMatrix.dup();
        DoubleMatrix doubleMatrix2 = doubleMatrix;
        int i2 = 0;
        while (i2 < this.nLayers) {
            int i3 = i2 == 0 ? this.nIns : this.hiddenLayerSizes[i2 - 1];
            if (i2 == 0) {
                this.sigmoidLayers[i2] = new HiddenLayer(i3, this.hiddenLayerSizes[i2], null, null, this.rng, doubleMatrix2);
                this.sigmoidLayers[i2].setActivationFunction(this.activation);
            } else {
                doubleMatrix2 = this.sigmoidLayers[i2 - 1].sample_h_given_v();
                this.sigmoidLayers[i2] = new HiddenLayer(i3, this.hiddenLayerSizes[i2], null, null, this.rng, doubleMatrix2);
                this.sigmoidLayers[i2].setActivationFunction(this.activation);
            }
            this.layers[i2] = createLayer(doubleMatrix2, i3, this.hiddenLayerSizes[i2], this.sigmoidLayers[i2].getW(), this.sigmoidLayers[i2].getB(), null, this.rng, i2);
            i2++;
        }
        this.logLayer = new LogisticRegression(doubleMatrix2, this.hiddenLayerSizes[this.nLayers - 1], this.nOuts);
        this.logLayer.setUseRegularization(isUseRegularization());
        this.logLayer.setL2(getL2());
        dimensionCheck();
        applyTransforms();
    }

    public synchronized int getnIns() {
        return this.nIns;
    }

    public synchronized void setnIns(int i) {
        this.nIns = i;
    }

    public synchronized int getnOuts() {
        return this.nOuts;
    }

    public synchronized void setnOuts(int i) {
        this.nOuts = i;
    }

    public synchronized int getnLayers() {
        return this.nLayers;
    }

    public synchronized void setnLayers(int i) {
        this.nLayers = i;
    }

    public synchronized double getMomentum() {
        return this.momentum;
    }

    public synchronized void setMomentum(double d) {
        this.momentum = d;
    }

    public synchronized double getL2() {
        return this.l2;
    }

    public synchronized void setL2(double d) {
        this.l2 = d;
    }

    public synchronized boolean isUseRegularization() {
        return this.useRegularization;
    }

    public synchronized void setUseRegularization(boolean z) {
        this.useRegularization = z;
    }

    public synchronized void setSigmoidLayers(HiddenLayer[] hiddenLayerArr) {
        this.sigmoidLayers = hiddenLayerArr;
    }

    public synchronized void setLogLayer(LogisticRegression logisticRegression) {
        this.logLayer = logisticRegression;
    }

    public synchronized void setShouldBackProp(boolean z) {
        this.shouldBackProp = z;
    }

    public synchronized void setLayers(NeuralNetwork[] neuralNetworkArr) {
        this.layers = neuralNetworkArr;
    }

    protected void initializeNetwork(NeuralNetwork neuralNetwork) {
        neuralNetwork.setFanIn(this.fanIn);
        neuralNetwork.setRenderEpochs(this.renderWeightsEveryNEpochs);
    }

    public void finetune(double d, int i) {
        finetune(this.labels, d, i);
    }

    public synchronized DoubleMatrix getLabels() {
        return this.labels;
    }

    public synchronized LogisticRegression getLogLayer() {
        return this.logLayer;
    }

    public synchronized void setInput(DoubleMatrix doubleMatrix) {
        this.input = doubleMatrix;
    }

    public synchronized DoubleMatrix getInput() {
        return this.input;
    }

    public synchronized HiddenLayer[] getSigmoidLayers() {
        return this.sigmoidLayers;
    }

    public synchronized NeuralNetwork[] getLayers() {
        return this.layers;
    }

    public synchronized List<DoubleMatrix> feedForward(DoubleMatrix doubleMatrix) {
        if (this.input == null) {
            throw new IllegalStateException("Unable to perform feed forward; no input found");
        }
        DoubleMatrix doubleMatrix2 = doubleMatrix;
        ArrayList arrayList = new ArrayList();
        arrayList.add(doubleMatrix2);
        for (int i = 0; i < this.nLayers; i++) {
            getLayers()[i].setInput(doubleMatrix2);
            doubleMatrix2 = getSigmoidLayers()[i].activate(doubleMatrix2);
            arrayList.add(doubleMatrix2);
        }
        arrayList.add(getLogLayer().predict(doubleMatrix2));
        return arrayList;
    }

    private synchronized void computeDeltas(List<Pair<DoubleMatrix, DoubleMatrix>> list) {
        DoubleMatrix[] doubleMatrixArr = new DoubleMatrix[this.nLayers + 2];
        DoubleMatrix[] doubleMatrixArr2 = new DoubleMatrix[this.nLayers + 2];
        ActivationFunction activationFunction = this.sigmoidLayers[0].getActivationFunction();
        List<DoubleMatrix> feedForward = feedForward(getInput());
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < getLayers().length; i++) {
            arrayList.add(getLayers()[i].getW());
        }
        arrayList.add(getLogLayer().getW());
        DoubleMatrix predict = predict(getInput());
        for (int i2 = this.nLayers + 1; i2 >= 0; i2--) {
            if (i2 >= this.nLayers + 1) {
                DoubleMatrix doubleMatrix = feedForward.get(i2);
                doubleMatrixArr2[i2] = predict.sub(doubleMatrix).neg().mul(activationFunction.applyDerivative(doubleMatrix));
            } else {
                DoubleMatrix doubleMatrix2 = doubleMatrixArr2[i2 + 1];
                DoubleMatrix transpose = ((DoubleMatrix) arrayList.get(i2)).transpose();
                DoubleMatrix doubleMatrix3 = feedForward.get(i2);
                DoubleMatrix doubleMatrix4 = feedForward.get(i2);
                DoubleMatrix mmul = doubleMatrix2.mmul(transpose);
                doubleMatrixArr2[i2] = mmul;
                doubleMatrixArr2[i2] = mmul.mul(activationFunction.applyDerivative(doubleMatrix3));
                doubleMatrixArr[i2] = doubleMatrixArr2[i2 + 1].transpose().mmul(doubleMatrix4).div(getInput().rows);
            }
        }
        for (int i3 = 0; i3 < doubleMatrixArr.length; i3++) {
            list.add(new Pair<>(doubleMatrixArr[i3], doubleMatrixArr2[i3]));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public BaseMultiLayerNetwork m12clone() {
        BaseMultiLayerNetwork buildEmpty = new Builder().withClazz(getClass()).buildEmpty();
        buildEmpty.update(this);
        return buildEmpty;
    }

    public void backProp(double d, int i) {
        Double d2 = null;
        BaseMultiLayerNetwork m12clone = m12clone();
        if (!this.forceNumEpochs) {
            int i2 = 0;
            while (backPropStep(d2, m12clone, d, i2)) {
                i2++;
                d2 = Double.valueOf(negativeLogLikelihood());
            }
            return;
        }
        for (int i3 = 0; i3 < i; i3++) {
            backPropStep(d2, m12clone, d, i3);
            d2 = Double.valueOf(negativeLogLikelihood());
        }
    }

    protected boolean backPropStep(Double d, BaseMultiLayerNetwork baseMultiLayerNetwork, double d2, int i) {
        double negativeLogLikelihood = negativeLogLikelihood();
        if (d == null) {
            Double.valueOf(negativeLogLikelihood);
        } else {
            if (negativeLogLikelihood == d.doubleValue()) {
                log.info("Converged; no more stepping appears to do anything");
                return false;
            }
            if (negativeLogLikelihood > d.doubleValue() || Double.isNaN(negativeLogLikelihood) || Double.isInfinite(negativeLogLikelihood)) {
                log.info("Error greater than previous; found global minima; converging");
                update(baseMultiLayerNetwork);
                return false;
            }
            if (negativeLogLikelihood < d.doubleValue()) {
                Double valueOf = Double.valueOf(negativeLogLikelihood);
                m12clone();
                log.info("Found better error on epoch " + i + " " + valueOf);
            }
        }
        ArrayList arrayList = new ArrayList();
        computeDeltas(arrayList);
        for (int i2 = 0; i2 < this.nLayers; i2++) {
            DoubleMatrix mul = arrayList.get(i2).getFirst().div(this.input.rows).mul(d2);
            mul.divi(this.input.rows);
            if (this.useRegularization) {
                mul.muli(this.layers[i2].getW().mul(this.l2));
            }
            this.layers[i2].setW(this.layers[i2].getW().add(mul.mul(d2)));
            this.sigmoidLayers[i2].setW(this.layers[i2].getW());
            DoubleMatrix columnSums = arrayList.get(i2 + 1).getSecond().columnSums();
            columnSums.divi(this.input.rows);
            this.layers[i2].gethBias().addi(columnSums.mul(d2));
            this.sigmoidLayers[i2].setB(getLayers()[i2].gethBias());
        }
        this.logLayer.getW().addi(arrayList.get(this.nLayers).getFirst());
        return true;
    }

    public void finetune(DoubleMatrix doubleMatrix, double d, int i) {
        if (doubleMatrix != null) {
            this.labels = doubleMatrix;
        }
        this.optimizer = new MultiLayerNetworkOptimizer(this, d);
        this.optimizer.optimize(this.labels, d, i);
    }

    public DoubleMatrix predict(DoubleMatrix doubleMatrix) {
        if (this.columnSums != null) {
            for (int i = 0; i < doubleMatrix.columns; i++) {
                doubleMatrix.putColumn(i, doubleMatrix.getColumn(i).div(this.columnSums.get(0, i)));
            }
        }
        if (this.columnMeans != null) {
            for (int i2 = 0; i2 < doubleMatrix.columns; i2++) {
                doubleMatrix.putColumn(i2, doubleMatrix.getColumn(i2).sub(this.columnMeans.get(0, i2)));
            }
        }
        if (this.columnStds != null) {
            for (int i3 = 0; i3 < doubleMatrix.columns; i3++) {
                doubleMatrix.putColumn(i3, doubleMatrix.getColumn(i3).div(this.columnStds.get(0, i3)));
            }
        }
        if (this.input == null) {
            initializeLayers(doubleMatrix);
        }
        DoubleMatrix doubleMatrix2 = doubleMatrix;
        for (int i4 = 0; i4 < this.nLayers; i4++) {
            doubleMatrix2 = this.sigmoidLayers[i4].activate(doubleMatrix2);
        }
        return this.logLayer.predict(doubleMatrix2);
    }

    public DoubleMatrix reconstruct(DoubleMatrix doubleMatrix, int i) {
        if (i > this.nLayers || i < 0) {
            throw new IllegalArgumentException("Layer number " + i + " does not exist");
        }
        if (this.columnSums != null) {
            for (int i2 = 0; i2 < doubleMatrix.columns; i2++) {
                doubleMatrix.putColumn(i2, doubleMatrix.getColumn(i2).div(this.columnSums.get(0, i2)));
            }
        }
        if (this.columnMeans != null) {
            for (int i3 = 0; i3 < doubleMatrix.columns; i3++) {
                doubleMatrix.putColumn(i3, doubleMatrix.getColumn(i3).sub(this.columnMeans.get(0, i3)));
            }
        }
        if (this.columnStds != null) {
            for (int i4 = 0; i4 < doubleMatrix.columns; i4++) {
                doubleMatrix.putColumn(i4, doubleMatrix.getColumn(i4).div(this.columnStds.get(0, i4)));
            }
        }
        DoubleMatrix doubleMatrix2 = doubleMatrix;
        for (int i5 = 0; i5 < i; i5++) {
            doubleMatrix2 = this.sigmoidLayers[i5].activate(doubleMatrix2);
        }
        return doubleMatrix2;
    }

    public DoubleMatrix reconstruct(DoubleMatrix doubleMatrix) {
        return reconstruct(doubleMatrix, this.sigmoidLayers.length);
    }

    @Override // org.deeplearning4j.nn.Persistable
    public void write(OutputStream outputStream) {
        try {
            new ObjectOutputStream(outputStream).writeObject(this);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.nn.Persistable
    public void load(InputStream inputStream) {
        try {
            update((BaseMultiLayerNetwork) new ObjectInputStream(inputStream).readObject());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static BaseMultiLayerNetwork loadFromFile(InputStream inputStream) {
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(inputStream);
            log.info("Loading network model...");
            return (BaseMultiLayerNetwork) objectInputStream.readObject();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    protected void update(BaseMultiLayerNetwork baseMultiLayerNetwork) {
        this.layers = new NeuralNetwork[baseMultiLayerNetwork.layers.length];
        for (int i = 0; i < this.layers.length; i++) {
            this.layers[i] = baseMultiLayerNetwork.layers[i].m13clone();
        }
        this.hiddenLayerSizes = baseMultiLayerNetwork.hiddenLayerSizes;
        this.logLayer = baseMultiLayerNetwork.logLayer.m15clone();
        this.nIns = baseMultiLayerNetwork.nIns;
        this.nLayers = baseMultiLayerNetwork.nLayers;
        this.nOuts = baseMultiLayerNetwork.nOuts;
        this.rng = baseMultiLayerNetwork.rng;
        this.dist = baseMultiLayerNetwork.dist;
        this.activation = baseMultiLayerNetwork.activation;
        this.useRegularization = baseMultiLayerNetwork.useRegularization;
        this.columnMeans = baseMultiLayerNetwork.columnMeans;
        this.columnStds = baseMultiLayerNetwork.columnStds;
        this.columnSums = baseMultiLayerNetwork.columnSums;
        this.errorTolerance = baseMultiLayerNetwork.errorTolerance;
        this.forceNumEpochs = baseMultiLayerNetwork.forceNumEpochs;
        this.input = baseMultiLayerNetwork.input;
        this.l2 = baseMultiLayerNetwork.l2;
        this.fanIn = baseMultiLayerNetwork.fanIn;
        this.labels = baseMultiLayerNetwork.labels;
        this.momentum = baseMultiLayerNetwork.momentum;
        this.learningRateUpdate = baseMultiLayerNetwork.learningRateUpdate;
        this.shouldBackProp = baseMultiLayerNetwork.shouldBackProp;
        this.weightTransforms = baseMultiLayerNetwork.weightTransforms;
        this.sparsity = baseMultiLayerNetwork.sparsity;
        this.toDecode = baseMultiLayerNetwork.toDecode;
        this.sigmoidLayers = new HiddenLayer[baseMultiLayerNetwork.sigmoidLayers.length];
        for (int i2 = 0; i2 < this.sigmoidLayers.length; i2++) {
            this.sigmoidLayers[i2] = baseMultiLayerNetwork.sigmoidLayers[i2].m14clone();
        }
    }

    public synchronized double negativeLogLikelihood() {
        double d;
        double d2;
        double d3 = 0.0d;
        for (int i = 0; i < this.nLayers; i++) {
            double sum = MatrixFunctions.pow(this.layers[i].getW(), 2.0d).sum() / 2.0d;
            if (this.useRegularization) {
                d = d3;
                d2 = sum * this.l2;
            } else {
                d = d3;
                d2 = sum;
            }
            d3 = d + d2;
        }
        double sum2 = MatrixFunctions.pow(this.logLayer.getW(), 2.0d).sum() / 2.0d;
        return this.useRegularization ? d3 + (sum2 * this.l2) : d3 + sum2;
    }

    public abstract void trainNetwork(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, Object[] objArr);

    protected void applyTransforms() {
        if (this.layers == null || this.layers.length < 1) {
            throw new IllegalStateException("Layers not initialized");
        }
        for (int i = 0; i < this.layers.length; i++) {
            if (this.weightTransforms.containsKey(Integer.valueOf(i))) {
                this.layers[i].setW((DoubleMatrix) this.weightTransforms.get(Integer.valueOf(i)).apply(this.layers[i].getW()));
            }
        }
    }

    public boolean isShouldBackProp() {
        return this.shouldBackProp;
    }

    public abstract NeuralNetwork createLayer(DoubleMatrix doubleMatrix, int i, int i2, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3, DoubleMatrix doubleMatrix4, RandomGenerator randomGenerator, int i3);

    public abstract NeuralNetwork[] createNetworkLayers(int i);

    public void merge(BaseMultiLayerNetwork baseMultiLayerNetwork, int i) {
        if (baseMultiLayerNetwork.nLayers != this.nLayers) {
            throw new IllegalArgumentException("Unable to merge networks that are not of equal length");
        }
        for (int i2 = 0; i2 < this.nLayers; i2++) {
            NeuralNetwork neuralNetwork = this.layers[i2];
            neuralNetwork.merge(baseMultiLayerNetwork.layers[i2], i);
            getSigmoidLayers()[i2].setB(neuralNetwork.gethBias());
            getSigmoidLayers()[i2].setW(neuralNetwork.getW());
        }
        getLogLayer().merge(baseMultiLayerNetwork.logLayer, i);
    }

    public void encode(BaseMultiLayerNetwork baseMultiLayerNetwork) {
        createNetworkLayers(baseMultiLayerNetwork.nLayers);
        this.layers = new NeuralNetwork[baseMultiLayerNetwork.nLayers];
        this.hiddenLayerSizes = new int[this.nLayers];
        int i = 0;
        for (int i2 = this.nLayers - 1; i2 > 0; i2--) {
            NeuralNetwork m13clone = baseMultiLayerNetwork.layers[i2].m13clone();
            HiddenLayer m14clone = baseMultiLayerNetwork.sigmoidLayers[i2].m14clone();
            this.layers[i] = m13clone;
            this.sigmoidLayers[i] = m14clone;
            this.hiddenLayerSizes[i] = baseMultiLayerNetwork.hiddenLayerSizes[i2];
            i++;
        }
        this.logLayer = new LogisticRegression(this.hiddenLayerSizes[this.nLayers - 1], baseMultiLayerNetwork.input.columns);
    }

    public boolean isForceNumEpochs() {
        return this.forceNumEpochs;
    }

    public DoubleMatrix getColumnSums() {
        return this.columnSums;
    }

    public void setColumnSums(DoubleMatrix doubleMatrix) {
        this.columnSums = doubleMatrix;
    }

    public synchronized int[] getHiddenLayerSizes() {
        return this.hiddenLayerSizes;
    }

    public synchronized void setHiddenLayerSizes(int[] iArr) {
        this.hiddenLayerSizes = iArr;
    }

    public synchronized RandomGenerator getRng() {
        return this.rng;
    }

    public synchronized void setRng(RandomGenerator randomGenerator) {
        this.rng = randomGenerator;
    }

    public synchronized RealDistribution getDist() {
        return this.dist;
    }

    public synchronized void setDist(RealDistribution realDistribution) {
        this.dist = realDistribution;
    }

    public synchronized MultiLayerNetworkOptimizer getOptimizer() {
        return this.optimizer;
    }

    public synchronized void setOptimizer(MultiLayerNetworkOptimizer multiLayerNetworkOptimizer) {
        this.optimizer = multiLayerNetworkOptimizer;
    }

    public synchronized ActivationFunction getActivation() {
        return this.activation;
    }

    public synchronized void setActivation(ActivationFunction activationFunction) {
        this.activation = activationFunction;
    }

    public synchronized boolean isToDecode() {
        return this.toDecode;
    }

    public synchronized void setToDecode(boolean z) {
        this.toDecode = z;
    }

    public synchronized boolean isShouldInit() {
        return this.shouldInit;
    }

    public synchronized void setShouldInit(boolean z) {
        this.shouldInit = z;
    }

    public synchronized double getFanIn() {
        return this.fanIn;
    }

    public synchronized void setFanIn(double d) {
        this.fanIn = d;
    }

    public synchronized int getRenderWeightsEveryNEpochs() {
        return this.renderWeightsEveryNEpochs;
    }

    public synchronized void setRenderWeightsEveryNEpochs(int i) {
        this.renderWeightsEveryNEpochs = i;
    }

    public synchronized Map<Integer, MatrixTransform> getWeightTransforms() {
        return this.weightTransforms;
    }

    public synchronized void setWeightTransforms(Map<Integer, MatrixTransform> map) {
        this.weightTransforms = map;
    }

    public synchronized double getSparsity() {
        return this.sparsity;
    }

    public synchronized void setSparsity(double d) {
        this.sparsity = d;
    }

    public synchronized double getLearningRateUpdate() {
        return this.learningRateUpdate;
    }

    public synchronized void setLearningRateUpdate(double d) {
        this.learningRateUpdate = d;
    }

    public synchronized double getErrorTolerance() {
        return this.errorTolerance;
    }

    public synchronized void setErrorTolerance(double d) {
        this.errorTolerance = d;
    }

    public synchronized void setLabels(DoubleMatrix doubleMatrix) {
        this.labels = doubleMatrix;
    }

    public synchronized void setForceNumEpochs(boolean z) {
        this.forceNumEpochs = z;
    }

    public DoubleMatrix getColumnMeans() {
        return this.columnMeans;
    }

    public void setColumnMeans(DoubleMatrix doubleMatrix) {
        this.columnMeans = doubleMatrix;
    }

    public DoubleMatrix getColumnStds() {
        return this.columnStds;
    }

    public void setColumnStds(DoubleMatrix doubleMatrix) {
        this.columnStds = doubleMatrix;
    }
}
