package org.deeplearning4j.nn.layers.normalization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/layers/normalization/BatchNormalization.class */
public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.layers.BatchNormalization> {
    protected static final Logger log = LoggerFactory.getLogger(BatchNormalization.class);
    BatchNormalizationHelper helper;
    protected int index;
    protected List<IterationListener> listeners;
    protected int[] shape;
    protected INDArray mean;
    protected INDArray var;
    protected INDArray std;
    protected INDArray xMu;
    protected INDArray xHat;
    protected Layer.TrainingMode trainingMode;
    protected boolean setMeanVar;

    public BatchNormalization(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
        this.helper = null;
        this.index = 0;
        this.listeners = new ArrayList();
        this.setMeanVar = true;
        initializeHelper();
    }

    void initializeHelper() {
        try {
            this.helper = (BatchNormalizationHelper) Class.forName("org.deeplearning4j.nn.layers.normalization.CudnnBatchNormalizationHelper").asSubclass(BatchNormalizationHelper.class).newInstance();
        } catch (Throwable th) {
            if (th instanceof ClassNotFoundException) {
                return;
            }
            log.warn("Could not load CudnnBatchNormalizationHelper", th);
        }
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL2() {
        return 0.0d;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL1() {
        return 0.0d;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.NORMALIZATION;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Gradient error(INDArray iNDArray) {
        return null;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Gradient calcGradient(Gradient gradient, INDArray iNDArray) {
        return null;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        INDArray execAndReturn;
        Pair<Gradient, INDArray> backpropGradient;
        this.shape = getShape(iNDArray);
        int size = iNDArray.size(0);
        org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = layerConf();
        INDArray ones = layerConf.isLockGammaBeta() ? Nd4j.ones(this.shape) : getParam(BatchNormalizationParamInitializer.GAMMA);
        DefaultGradient defaultGradient = new DefaultGradient();
        INDArray iNDArray2 = this.gradientViews.get(BatchNormalizationParamInitializer.GAMMA);
        INDArray iNDArray3 = this.gradientViews.get(BatchNormalizationParamInitializer.BETA);
        if (this.helper != null && (backpropGradient = this.helper.backpropGradient(this.input, iNDArray, this.shape, ones, iNDArray2, iNDArray3, layerConf.getEps())) != null) {
            return backpropGradient;
        }
        if (iNDArray.rank() == 2) {
            INDArray sum = iNDArray.mul(this.xHat).sum(new int[]{0});
            INDArray sum2 = iNDArray.sum(new int[]{0});
            INDArray mulRowVector = iNDArray.mulRowVector(ones);
            INDArray add = mulRowVector.divRowVector(this.std).add(this.xMu.mul(2).mulRowVector(mulRowVector.mul(this.xMu).sum(new int[]{0}).mul(Double.valueOf(0.5d)).div(Transforms.pow(this.std, 3)).neg().div(Integer.valueOf(size))));
            execAndReturn = Nd4j.getExecutioner().execAndReturn(new BroadcastAddOp(add, add.sum(new int[]{0}).neg().div(Integer.valueOf(size)), add.dup(), new int[]{-1}));
            iNDArray.mul(2).subRowVector(sum2).mulRowVector(ones.div(this.std.mul(2))).sub(this.xMu.divRowVector(Transforms.pow(this.std, 2)).mulRowVector(iNDArray.mul(this.xMu).sum(new int[]{0})));
            iNDArray2.assign(sum);
            iNDArray3.assign(sum2);
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, iNDArray2);
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, iNDArray3);
        } else {
            if (iNDArray.rank() != 4) {
                throw new IllegalStateException("The layer prior to BatchNorm in the configuration is not currently supported.");
            }
            INDArray sum3 = iNDArray.mul(this.xHat).sum(new int[]{0, 2, 3});
            INDArray sum4 = iNDArray.sum(new int[]{0, 2, 3});
            INDArray execAndReturn2 = Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(iNDArray, ones, iNDArray.dup(), new int[]{1}));
            INDArray add2 = Nd4j.getExecutioner().execAndReturn(new BroadcastDivOp(execAndReturn2, this.std, execAndReturn2, new int[]{1, 2, 3})).add(Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(this.xMu.mul(2), execAndReturn2.mul(this.xMu).sum(new int[]{0}).mul(Double.valueOf(0.5d)).div(Transforms.pow(this.std, 3)).neg().div(Integer.valueOf(size)), this.xMu.mul(2), new int[]{1, 2, 3})));
            execAndReturn = Nd4j.getExecutioner().execAndReturn(new BroadcastAddOp(add2, add2.sum(new int[]{0}).neg().div(Integer.valueOf(size)), add2.dup(), new int[]{1, 2, 3}));
            iNDArray2.assign(sum3);
            iNDArray3.assign(sum4);
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, iNDArray2);
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, iNDArray3);
        }
        return new Pair<>(defaultGradient, execAndReturn);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public void merge(Layer layer, int i) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray) {
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        return preOutput(this.input, z ? Layer.TrainingMode.TRAIN : Layer.TrainingMode.TEST);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        return this.gradient;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray) {
        return preOutput(iNDArray, Layer.TrainingMode.TRAIN);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, Layer.TrainingMode trainingMode) {
        INDArray iNDArray2;
        INDArray iNDArray3;
        INDArray param;
        INDArray param2;
        INDArray execAndReturn;
        this.trainingMode = trainingMode;
        org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = layerConf();
        int size = iNDArray.size(0);
        this.shape = getShape(iNDArray);
        if (this.trainingMode == Layer.TrainingMode.TRAIN && layerConf.isUseBatchMean()) {
            iNDArray2 = iNDArray.mean(new int[]{0});
            iNDArray3 = iNDArray.var(false, new int[]{0});
            iNDArray3.addi(Double.valueOf(layerConf.getEps()));
        } else {
            iNDArray2 = this.mean;
            iNDArray3 = this.var;
        }
        this.std = Transforms.sqrt(iNDArray3);
        if (layerConf.isLockGammaBeta()) {
            param = Nd4j.ones(this.shape);
            param2 = Nd4j.zeros(this.shape);
        } else {
            param = getParam(BatchNormalizationParamInitializer.GAMMA);
            param2 = getParam(BatchNormalizationParamInitializer.BETA);
        }
        if (this.helper != null) {
            double decay = this.setMeanVar ? 1.0d : layerConf.getDecay();
            if (this.setMeanVar) {
                this.mean = this.mean == null ? Nd4j.zeros(iNDArray2.shape()) : this.mean;
                this.var = this.var == null ? Nd4j.valueArrayOf(iNDArray3.shape(), layerConf.getEps()) : this.var;
                this.setMeanVar = false;
            }
            INDArray preOutput = this.helper.preOutput(iNDArray, trainingMode == Layer.TrainingMode.TRAIN && layerConf.isUseBatchMean(), this.shape, param, param2, this.mean, this.var, decay, layerConf.getEps());
            if (preOutput != null) {
                return preOutput;
            }
        }
        if (iNDArray.rank() == 2) {
            this.xMu = Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(iNDArray, iNDArray2, iNDArray.dup(), new int[]{-1}));
            this.xHat = Nd4j.getExecutioner().execAndReturn(new BroadcastDivOp(this.xMu, this.std, this.xMu.dup(), new int[]{-1}));
            execAndReturn = this.xHat.dup().mulRowVector(param).addRowVector(param2);
        } else {
            if (iNDArray.rank() != 4) {
                throw new IllegalStateException("The layer prior to BatchNorm in the configuration is not currently supported.");
            }
            this.xMu = Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(iNDArray, iNDArray2, iNDArray.dup(), new int[]{1, 2, 3}));
            this.xHat = Nd4j.getExecutioner().execAndReturn(new BroadcastDivOp(this.xMu, this.std, this.xMu.dup(), new int[]{1, 2, 3}));
            INDArray execAndReturn2 = Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(this.xHat, param, this.xHat.dup(), new int[]{1}));
            execAndReturn = Nd4j.getExecutioner().execAndReturn(new BroadcastAddOp(execAndReturn2, param2, execAndReturn2, new int[]{1}));
        }
        if (trainingMode == Layer.TrainingMode.TRAIN && layerConf.isUseBatchMean()) {
            if (this.setMeanVar) {
                this.mean = this.mean == null ? Nd4j.zeros(iNDArray2.shape()) : this.mean;
                this.var = this.var == null ? Nd4j.valueArrayOf(iNDArray3.shape(), layerConf.getEps()) : this.var;
                this.setMeanVar = false;
            }
            double decay2 = layerConf.getDecay();
            double max = size / Math.max(size - 1.0d, 1.0d);
            this.mean = iNDArray2.mul(Double.valueOf(decay2)).add(this.mean.mul(Double.valueOf(1.0d - decay2)));
            this.var = iNDArray3.mul(Double.valueOf(decay2)).add(this.var.mul(Double.valueOf((1.0d - decay2) * max)));
        }
        return execAndReturn;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(Layer.TrainingMode trainingMode) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, Layer.TrainingMode trainingMode) {
        return preOutput(iNDArray, trainingMode);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, boolean z) {
        return preOutput(iNDArray, z ? Layer.TrainingMode.TRAIN : Layer.TrainingMode.TEST);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer transpose() {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Layer mo79clone() {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Collection<IterationListener> getListeners() {
        return this.listeners;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public void setListeners(IterationListener... iterationListenerArr) {
        this.listeners = new ArrayList(Arrays.asList(iterationListenerArr));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public void setIndex(int i) {
        this.index = i;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public int getIndex() {
        return this.index;
    }

    public int[] getShape(INDArray iNDArray) {
        if (iNDArray.rank() == 2 || iNDArray.rank() == 4) {
            return new int[]{1, iNDArray.size(1)};
        }
        if (iNDArray.rank() != 3) {
            throw new IllegalStateException("Unable to process input of rank " + iNDArray.rank());
        }
        int size = iNDArray.size(1);
        int size2 = iNDArray.size(2);
        if (iNDArray.size(0) <= 1 || size * size2 != iNDArray.length()) {
            return new int[]{1, size * size2};
        }
        throw new IllegalArgumentException("Illegal input for batch size");
    }
}
