package org.deeplearning4j.nn.layers.recurrent;

import java.util.Arrays;
import java.util.Map;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.util.Dropout;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.transforms.TimesOneMinus;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.MulOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/LSTMHelpers.class */
public class LSTMHelpers {
    private LSTMHelpers() {
    }

    public static FwdPassReturn activateHelper(Layer layer, NeuralNetConfiguration neuralNetConfiguration, IActivation iActivation, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, boolean z, INDArray iNDArray5, INDArray iNDArray6, boolean z2, boolean z3, String str, INDArray iNDArray7) {
        INDArray muli;
        INDArray muli2;
        if (iNDArray == null || iNDArray.length() == 0) {
            throw new IllegalArgumentException("Invalid input: not set or 0 length");
        }
        INDArray iNDArray8 = iNDArray3;
        INDArray iNDArray9 = iNDArray5;
        boolean z4 = iNDArray.rank() < 3;
        int size = z4 ? 1 : iNDArray.size(2);
        int size2 = iNDArray2.size(0);
        int size3 = iNDArray.size(0);
        INDArray create = iNDArray6 == null ? Nd4j.create(new int[]{size3, size2}, 'f') : iNDArray6.dup('f');
        INDArray dup = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * size2)).dup('f');
        if (neuralNetConfiguration.isUseDropConnect() && z && neuralNetConfiguration.getLayer().getDropOut() > CMAESOptimizer.DEFAULT_STOPFITNESS) {
            iNDArray8 = Dropout.applyDropConnect(layer, str);
        }
        INDArray transpose = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.interval(4 * size2, (4 * size2) + 1)).transpose();
        INDArray transpose2 = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.interval((4 * size2) + 1, (4 * size2) + 2)).transpose();
        INDArray transpose3 = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.interval((4 * size2) + 2, (4 * size2) + 3)).transpose();
        if (size > 1 || z2) {
            transpose = Shape.toMmulCompatible(transpose);
            transpose2 = Shape.toMmulCompatible(transpose2);
            transpose3 = Shape.toMmulCompatible(transpose3);
        }
        boolean z5 = iActivation instanceof ActivationSigmoid;
        IActivation activationFn = neuralNetConfiguration.getLayer().getActivationFn();
        INDArray iNDArray10 = null;
        FwdPassReturn fwdPassReturn = new FwdPassReturn();
        if (z2) {
            fwdPassReturn.fwdPassOutputAsArrays = new INDArray[size];
            fwdPassReturn.memCellState = new INDArray[size];
            fwdPassReturn.memCellActivations = new INDArray[size];
            fwdPassReturn.iz = new INDArray[size];
            fwdPassReturn.ia = new INDArray[size];
            fwdPassReturn.fa = new INDArray[size];
            fwdPassReturn.oa = new INDArray[size];
            fwdPassReturn.ga = new INDArray[size];
            if (!z5) {
                fwdPassReturn.fz = new INDArray[size];
                fwdPassReturn.oz = new INDArray[size];
                fwdPassReturn.gz = new INDArray[size];
            }
        } else {
            iNDArray10 = Nd4j.create(new int[]{size3, size2, size}, 'f');
            fwdPassReturn.fwdPassOutput = iNDArray10;
        }
        Level1 level1 = Nd4j.getBlasWrapper().level1();
        if (iNDArray.size(1) != iNDArray8.size(0)) {
            throw new DL4JInvalidInputException("Received input with size(1) = " + iNDArray.size(1) + " (input array shape = " + Arrays.toString(iNDArray.shape()) + "); input.size(1) must match layer nIn size (nIn = " + iNDArray8.size(0) + ")");
        }
        if (iNDArray9 != null && iNDArray9.size(0) != iNDArray.size(0)) {
            throw new DL4JInvalidInputException("Previous activations (stored state) number of examples = " + iNDArray9.size(0) + " but input array number of examples = " + iNDArray.size(0) + ". Possible cause: using rnnTimeStep() without calling rnnClearPreviousState() between different sequences?");
        }
        if (iNDArray9 == null) {
            iNDArray9 = Nd4j.zeros(size3, size2);
        }
        for (int i = 0; i < size; i++) {
            int i2 = i;
            if (!z3) {
                i2 = (size - i) - 1;
            }
            INDArray mmul = Shape.toMmulCompatible(z4 ? iNDArray : iNDArray.tensorAlongDimension(i2, 1, 0)).mmul(iNDArray8);
            Nd4j.gemm(iNDArray9, dup, mmul, false, false, 1.0d, 1.0d);
            mmul.addiRowVector(iNDArray4);
            INDArray iNDArray11 = mmul.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size2));
            if (z2) {
                fwdPassReturn.iz[i2] = iNDArray11.dup('f');
            }
            neuralNetConfiguration.getLayer().getActivationFn().getActivation(iNDArray11, z);
            if (z2) {
                fwdPassReturn.ia[i2] = iNDArray11;
            }
            INDArray iNDArray12 = mmul.get(NDArrayIndex.all(), NDArrayIndex.interval(size2, 2 * size2));
            INDArray muliRowVector = create.dup('f').muliRowVector(transpose);
            level1.axpy(muliRowVector.length(), 1.0d, muliRowVector, iNDArray12);
            if (z2 && !z5) {
                fwdPassReturn.fz[i2] = iNDArray12.dup('f');
            }
            iActivation.getActivation(iNDArray12, z);
            if (z2) {
                fwdPassReturn.fa[i2] = iNDArray12;
            }
            INDArray iNDArray13 = mmul.get(NDArrayIndex.all(), NDArrayIndex.interval(3 * size2, 4 * size2));
            INDArray muliRowVector2 = create.dup('f').muliRowVector(transpose3);
            level1.axpy(muliRowVector2.length(), 1.0d, muliRowVector2, iNDArray13);
            if (z2 && !z5) {
                fwdPassReturn.gz[i2] = iNDArray13.dup('f');
            }
            iActivation.getActivation(iNDArray13, z);
            if (z2) {
                fwdPassReturn.ga[i2] = iNDArray13;
            }
            if (z2) {
                muli = create.dup('f').muli(iNDArray12);
                muli2 = iNDArray13.dup('f').muli(iNDArray11);
            } else {
                muli = iNDArray12.muli(create);
                muli2 = iNDArray13.muli(iNDArray11);
            }
            level1.axpy(muli.length(), 1.0d, muli2, muli);
            INDArray iNDArray14 = mmul.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * size2, 3 * size2));
            INDArray muliRowVector3 = muli.dup('f').muliRowVector(transpose2);
            level1.axpy(muliRowVector3.length(), 1.0d, muliRowVector3, iNDArray14);
            if (z2 && !z5) {
                fwdPassReturn.oz[i2] = iNDArray14.dup('f');
            }
            iActivation.getActivation(iNDArray14, z);
            if (z2) {
                fwdPassReturn.oa[i2] = iNDArray14;
            }
            INDArray activation = activationFn.getActivation(muli.dup('f'), z);
            INDArray muli3 = z2 ? activation.dup('f').muli(iNDArray14) : activation.muli(iNDArray14);
            if (iNDArray7 != null) {
                INDArray column = iNDArray7.getColumn(i2);
                muli3.muliColumnVector(column);
                muli.muliColumnVector(column);
            }
            if (z2) {
                fwdPassReturn.fwdPassOutputAsArrays[i2] = muli3;
                fwdPassReturn.memCellState[i2] = muli;
                fwdPassReturn.memCellActivations[i2] = activation;
            } else {
                iNDArray10.tensorAlongDimension(i2, 1, 0).assign(muli3);
            }
            iNDArray9 = muli3;
            create = muli;
            fwdPassReturn.lastAct = muli3;
            fwdPassReturn.lastMemCell = muli;
        }
        return fwdPassReturn;
    }

    public static Pair<Gradient, INDArray> backpropGradientHelper(NeuralNetConfiguration neuralNetConfiguration, IActivation iActivation, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, boolean z, int i, FwdPassReturn fwdPassReturn, boolean z2, String str, String str2, String str3, Map<String, INDArray> map, INDArray iNDArray5) {
        INDArray create;
        int size = iNDArray2.size(0);
        int size2 = iNDArray3.size(0);
        int size3 = iNDArray4.size(0);
        boolean z3 = iNDArray4.rank() < 3;
        int size4 = z3 ? 1 : iNDArray4.size(2);
        INDArray transpose = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.point(4 * size)).transpose();
        INDArray transpose2 = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.point((4 * size) + 1)).transpose();
        INDArray transpose3 = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.point((4 * size) + 2)).transpose();
        INDArray iNDArray6 = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * size));
        INDArray create2 = Nd4j.create(new int[]{size3, size2, size4}, 'f');
        INDArray iNDArray7 = null;
        INDArray create3 = Nd4j.create(new int[]{size3, 4 * size}, 'f');
        INDArray iNDArray8 = create3.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size));
        INDArray iNDArray9 = create3.get(NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size));
        INDArray iNDArray10 = create3.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size));
        INDArray iNDArray11 = create3.get(NDArrayIndex.all(), NDArrayIndex.interval(3 * size, 4 * size));
        Level1 level1 = Nd4j.getBlasWrapper().level1();
        int i2 = 0;
        if (z) {
            i2 = Math.max(0, size4 - i);
        }
        INDArray iNDArray12 = map.get(str);
        INDArray iNDArray13 = map.get(str2);
        INDArray iNDArray14 = map.get(str3);
        iNDArray12.assign((Number) 0);
        iNDArray13.assign((Number) 0);
        iNDArray14.assign((Number) 0);
        INDArray iNDArray15 = iNDArray13.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * size));
        INDArray iNDArray16 = iNDArray13.get(NDArrayIndex.all(), NDArrayIndex.point(4 * size));
        INDArray iNDArray17 = iNDArray13.get(NDArrayIndex.all(), NDArrayIndex.point((4 * size) + 1));
        INDArray iNDArray18 = iNDArray13.get(NDArrayIndex.all(), NDArrayIndex.point((4 * size) + 2));
        boolean z4 = iActivation instanceof ActivationSigmoid;
        IActivation activationFn = neuralNetConfiguration.getLayer().getActivationFn();
        INDArray iNDArray19 = null;
        int i3 = size4 - 1;
        while (i3 >= i2) {
            int i4 = i3;
            int i5 = 1;
            if (!z2) {
                i4 = (size4 - i3) - 1;
                i5 = -1;
            }
            if (i3 != size4 - 1) {
                create = iNDArray9.dup('f').muliRowVector(transpose);
                level1.axpy(create.length(), 1.0d, iNDArray11.dup('f').muliRowVector(transpose3), create);
            } else {
                create = Nd4j.create(new int[]{size3, size}, 'f');
            }
            INDArray iNDArray20 = i3 == 0 ? null : fwdPassReturn.memCellState[i4 - i5];
            INDArray iNDArray21 = i3 == 0 ? null : fwdPassReturn.fwdPassOutputAsArrays[i4 - i5];
            INDArray iNDArray22 = fwdPassReturn.memCellState[i4];
            INDArray offsetZeroCopy = Shape.toOffsetZeroCopy(z3 ? iNDArray4 : iNDArray4.tensorAlongDimension(i4, 1, 0), 'f');
            if (i3 != size4 - 1) {
                Nd4j.gemm(create3, iNDArray6, offsetZeroCopy, false, true, 1.0d, 1.0d);
            }
            INDArray iNDArray23 = fwdPassReturn.memCellActivations[i4];
            INDArray iNDArray24 = fwdPassReturn.oa[i4];
            Nd4j.getExecutioner().exec(new MulOp(offsetZeroCopy, iNDArray23, iNDArray10));
            if (z4) {
                iNDArray10.muli(Nd4j.getExecutioner().execAndReturn((TransformOp) new TimesOneMinus(iNDArray24.dup('f'))));
            } else {
                iNDArray10.assign(iActivation.backprop(fwdPassReturn.oz[i4], iNDArray10).getFirst());
            }
            level1.axpy(create.length(), 1.0d, activationFn.backprop(iNDArray22.dup('f'), iNDArray24.muli(offsetZeroCopy)).getFirst(), create);
            level1.axpy(create.length(), 1.0d, iNDArray10.dup('f').muliRowVector(transpose2), create);
            if (i3 != size4 - 1) {
                level1.axpy(create.length(), 1.0d, fwdPassReturn.fa[i4 + i5].muli(iNDArray7), create);
            }
            iNDArray7 = create;
            INDArray iNDArray25 = fwdPassReturn.fa[i4];
            INDArray iNDArray26 = null;
            if (i3 > 0) {
                iNDArray26 = iNDArray9;
                if (z4) {
                    Nd4j.getExecutioner().exec(new TimesOneMinus(iNDArray25, iNDArray26));
                    iNDArray26.muli(create);
                    iNDArray26.muli(iNDArray20);
                } else {
                    iNDArray26.assign(iActivation.backprop(fwdPassReturn.fz[i4].dup('f'), create.mul(iNDArray20)).getFirst());
                }
            }
            INDArray iNDArray27 = fwdPassReturn.ga[i4];
            INDArray iNDArray28 = fwdPassReturn.ia[i4];
            if (z4) {
                Nd4j.getExecutioner().exec(new TimesOneMinus(iNDArray27, iNDArray11));
                iNDArray11.muli(iNDArray28);
                iNDArray11.muli(create);
            } else {
                iNDArray11.assign(iActivation.backprop(fwdPassReturn.gz[i4], Nd4j.getExecutioner().execAndReturn((TransformOp) new MulOp(iNDArray28, create, Nd4j.createUninitialized(iNDArray28.shape(), 'f')))).getFirst());
            }
            iNDArray8.assign(activationFn.backprop(fwdPassReturn.iz[i4], Nd4j.getExecutioner().execAndReturn((TransformOp) new MulOp(iNDArray27, create, Nd4j.createUninitialized(iNDArray8.shape(), 'f')))).getFirst());
            if (iNDArray5 != null) {
                iNDArray19 = iNDArray5.getColumn(i4);
                create3.muliColumnVector(iNDArray19);
            }
            INDArray mmulCompatible = Shape.toMmulCompatible(z3 ? iNDArray : iNDArray.tensorAlongDimension(i4, 1, 0));
            if (i3 > 0) {
                Nd4j.gemm(mmulCompatible, create3, iNDArray12, true, false, 1.0d, 1.0d);
            } else {
                Nd4j.gemm(mmulCompatible, iNDArray8, iNDArray12.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size)), true, false, 1.0d, 1.0d);
                Nd4j.gemm(mmulCompatible, create3.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 4 * size)), iNDArray12.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 4 * size)), true, false, 1.0d, 1.0d);
            }
            if (i3 > 0) {
                Nd4j.gemm(iNDArray21, create3, iNDArray15, true, false, 1.0d, 1.0d);
                level1.axpy(size, 1.0d, iNDArray26.dup('f').muli(iNDArray20).sum(0), iNDArray16);
                level1.axpy(size, 1.0d, iNDArray11.dup('f').muli(iNDArray20).sum(0), iNDArray18);
            }
            level1.axpy(size, 1.0d, iNDArray10.dup('f').muli(iNDArray22).sum(0), iNDArray17);
            if (i3 > 0) {
                level1.axpy(4 * size, 1.0d, create3.sum(0), iNDArray14);
            } else {
                level1.axpy(size, 1.0d, iNDArray8.sum(0), iNDArray14);
                level1.axpy(2 * size, 1.0d, create3.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 4 * size)).sum(0), iNDArray14.get(NDArrayIndex.point(0), NDArrayIndex.interval(2 * size, 4 * size)));
            }
            INDArray tensorAlongDimension = create2.tensorAlongDimension(i4, 1, 0);
            if (i3 > 0) {
                Nd4j.gemm(create3, iNDArray3, tensorAlongDimension, false, true, 1.0d, 1.0d);
            } else {
                Nd4j.gemm(iNDArray8, iNDArray3.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size)), tensorAlongDimension, false, true, 1.0d, 1.0d);
                Nd4j.gemm(create3.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 4 * size)), iNDArray3.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 4 * size)), tensorAlongDimension, false, true, 1.0d, 1.0d);
            }
            if (iNDArray5 != null) {
                tensorAlongDimension.muliColumnVector(iNDArray19);
            }
            i3--;
        }
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.gradientForVariable().put(str, iNDArray12);
        defaultGradient.gradientForVariable().put(str2, iNDArray13);
        defaultGradient.gradientForVariable().put(str3, iNDArray14);
        return new Pair<>(defaultGradient, create2);
    }
}
