package org.deeplearning4j.nn.layers.recurrent;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.AbstractLSTM;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.impl.transforms.same.TimesOneMinus;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/LSTMHelpers.class */
public class LSTMHelpers {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) LSTMHelpers.class);

    private LSTMHelpers() {
    }

    /* JADX WARN: Type inference failed for: r0v38, types: [org.deeplearning4j.nn.conf.layers.BaseLayer] */
    /* JADX WARN: Type inference failed for: r0v93, types: [org.deeplearning4j.nn.conf.layers.BaseLayer] */
    public static FwdPassReturn activateHelper(BaseRecurrentLayer baseRecurrentLayer, NeuralNetConfiguration neuralNetConfiguration, IActivation iActivation, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, boolean z, INDArray iNDArray5, INDArray iNDArray6, boolean z2, boolean z3, String str, INDArray iNDArray7, boolean z4, LSTMHelper lSTMHelper, CacheMode cacheMode, LayerWorkspaceMgr layerWorkspaceMgr, boolean z5) {
        INDArray create;
        INDArray leverageTo;
        INDArray muli;
        INDArray muli2;
        if (iNDArray == null || iNDArray.length() == 0) {
            throw new IllegalArgumentException("Invalid input: not set or 0 length");
        }
        INDArray iNDArray8 = iNDArray5;
        if (iNDArray7 != null) {
            iNDArray7 = iNDArray7.castTo(iNDArray2.dataType());
        }
        boolean z6 = iNDArray.rank() < 3;
        INDArray castTo = iNDArray.castTo(iNDArray3.dataType());
        int size = (int) (z6 ? 1L : castTo.size(2));
        int size2 = (int) iNDArray2.size(0);
        int size3 = (int) castTo.size(0);
        INDArray create2 = iNDArray6 == null ? Nd4j.create(iNDArray3.dataType(), new long[]{size3, size2}, 'f') : iNDArray6.dup('f');
        INDArray dup = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * size2)).dup('f');
        INDArray iNDArray9 = null;
        INDArray iNDArray10 = null;
        INDArray iNDArray11 = null;
        if (z4) {
            iNDArray9 = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.interval(4 * size2, (4 * size2) + 1)).reshape(1L, iNDArray2.size(0));
            iNDArray10 = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.interval((4 * size2) + 1, (4 * size2) + 2)).reshape(1L, iNDArray2.size(0));
            iNDArray11 = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.interval((4 * size2) + 2, (4 * size2) + 3)).reshape(1L, iNDArray2.size(0));
            if (size > 1 || z2) {
                iNDArray9 = Shape.toMmulCompatible(iNDArray9);
                iNDArray10 = Shape.toMmulCompatible(iNDArray10);
                iNDArray11 = Shape.toMmulCompatible(iNDArray11);
            }
        }
        boolean z7 = iActivation instanceof ActivationSigmoid;
        IActivation activationFn = baseRecurrentLayer.layerConf().getActivationFn();
        FwdPassReturn fwdPassReturn = new FwdPassReturn();
        if (z2) {
            fwdPassReturn.fwdPassOutputAsArrays = new INDArray[size];
            fwdPassReturn.memCellState = new INDArray[size];
            fwdPassReturn.memCellActivations = new INDArray[size];
            fwdPassReturn.iz = new INDArray[size];
            fwdPassReturn.ia = new INDArray[size];
            fwdPassReturn.fa = new INDArray[size];
            fwdPassReturn.oa = new INDArray[size];
            fwdPassReturn.ga = new INDArray[size];
            if (!z7) {
                fwdPassReturn.fz = new INDArray[size];
                fwdPassReturn.oz = new INDArray[size];
                fwdPassReturn.gz = new INDArray[size];
            }
            if (z && cacheMode != CacheMode.NONE && layerWorkspaceMgr.hasConfiguration(ArrayType.FF_CACHE) && layerWorkspaceMgr.isWorkspaceOpen(ArrayType.FF_CACHE)) {
                MemoryWorkspace notifyScopeBorrowed = layerWorkspaceMgr.notifyScopeBorrowed(ArrayType.FF_CACHE);
                Throwable th = null;
                try {
                    try {
                        create = Nd4j.create(iNDArray3.dataType(), new long[]{size3, size2, size}, 'f');
                        fwdPassReturn.fwdPassOutput = create;
                        if (notifyScopeBorrowed != null) {
                            if (0 != 0) {
                                try {
                                    notifyScopeBorrowed.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                notifyScopeBorrowed.close();
                            }
                        }
                    } finally {
                    }
                } catch (Throwable th3) {
                    if (notifyScopeBorrowed != null) {
                        if (th != null) {
                            try {
                                notifyScopeBorrowed.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            notifyScopeBorrowed.close();
                        }
                    }
                    throw th3;
                }
            } else {
                create = layerWorkspaceMgr.create(ArrayType.ACTIVATIONS, castTo.dataType(), new long[]{size3, size2, size}, 'f');
                fwdPassReturn.fwdPassOutput = create;
            }
        } else {
            create = layerWorkspaceMgr.create(ArrayType.ACTIVATIONS, castTo.dataType(), new long[]{size3, size2, size}, 'f');
            fwdPassReturn.fwdPassOutput = create;
        }
        if (castTo.size(1) != iNDArray3.size(0)) {
            throw new DL4JInvalidInputException("Received input with size(1) = " + castTo.size(1) + " (input array shape = " + Arrays.toString(castTo.shape()) + "); input.size(1) must match layer nIn size (nIn = " + iNDArray3.size(0) + ")");
        }
        Preconditions.checkState(iNDArray8 == null || iNDArray8.size(0) == castTo.size(0), "Invalid RNN previous state (last time step activations/initialization): rnnTimeStep with different minibatch size, or forgot to call rnnClearPreviousState between batches? Previous step output = [batch, nIn] = %ndShape, current input = [batch, nIn, seqLength] = %ndShape", iNDArray8, castTo);
        if (iNDArray8 == null) {
            iNDArray8 = Nd4j.zeros(castTo.dataType(), size3, size2);
        }
        if (lSTMHelper != null && (baseRecurrentLayer.helperCountFail == 0 || !z5)) {
            FwdPassReturn fwdPassReturn2 = null;
            try {
                fwdPassReturn2 = lSTMHelper.activate(baseRecurrentLayer, neuralNetConfiguration, iActivation, castTo, iNDArray2, iNDArray3, iNDArray4, z, iNDArray8, create2, z2, z3, str, iNDArray7, z4, layerWorkspaceMgr);
            } catch (ND4JOpProfilerException e) {
                throw e;
            } catch (Exception e2) {
                if (e2.getMessage().contains("Failed to allocate")) {
                    throw e2;
                }
                if (!z5) {
                    throw new RuntimeException("Error during LSTM MKL/CuDNN helper forward pass - helperAllowFallback() is set to false", e2);
                }
                baseRecurrentLayer.helperCountFail++;
                log.warn("MKL/CuDNN execution failed - falling back on built-in implementation", (Throwable) e2);
            }
            if (fwdPassReturn2 != null) {
                return fwdPassReturn2;
            }
        }
        for (int i = 0; i < size; i++) {
            MemoryWorkspace notifyScopeEntered = layerWorkspaceMgr.notifyScopeEntered((LayerWorkspaceMgr) ArrayType.RNN_FF_LOOP_WORKING_MEM);
            Throwable th5 = null;
            int i2 = i;
            if (!z3) {
                try {
                    try {
                        i2 = (size - i) - 1;
                    } finally {
                    }
                } catch (Throwable th6) {
                    if (notifyScopeEntered != null) {
                        if (th5 != null) {
                            try {
                                notifyScopeEntered.close();
                            } catch (Throwable th7) {
                                th5.addSuppressed(th7);
                            }
                        } else {
                            notifyScopeEntered.close();
                        }
                    }
                    throw th6;
                }
            }
            INDArray mmulCompatible = Shape.toMmulCompatible(z6 ? castTo : castTo.tensorAlongDimension(i2, 1, 0));
            cacheEnter(z, cacheMode, layerWorkspaceMgr);
            INDArray mmul = mmulCompatible.mmul(iNDArray3);
            cacheExit(z, cacheMode, layerWorkspaceMgr);
            Nd4j.gemm(iNDArray8, dup, mmul, false, false, 1.0d, 1.0d);
            mmul.addiRowVector(iNDArray4);
            INDArray iNDArray12 = mmul.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size2));
            if (z2) {
                if (shouldCache(z, cacheMode, layerWorkspaceMgr)) {
                    cacheEnter(z, cacheMode, layerWorkspaceMgr);
                    fwdPassReturn.iz[i2] = iNDArray12.dup('f');
                    cacheExit(z, cacheMode, layerWorkspaceMgr);
                } else {
                    fwdPassReturn.iz[i2] = layerWorkspaceMgr.dup(ArrayType.BP_WORKING_MEM, iNDArray12, 'f');
                }
            }
            baseRecurrentLayer.layerConf().getActivationFn().getActivation(iNDArray12, z);
            if (z2) {
                if (shouldCache(z, cacheMode, layerWorkspaceMgr)) {
                    cacheEnter(z, cacheMode, layerWorkspaceMgr);
                    fwdPassReturn.ia[i2] = iNDArray12.dup('f');
                    cacheExit(z, cacheMode, layerWorkspaceMgr);
                } else {
                    fwdPassReturn.ia[i2] = layerWorkspaceMgr.leverageTo(ArrayType.BP_WORKING_MEM, iNDArray12);
                }
            }
            INDArray iNDArray13 = mmul.get(NDArrayIndex.all(), NDArrayIndex.interval(size2, 2 * size2));
            if (z4) {
                iNDArray13.addi(create2.dup('f').muliRowVector(iNDArray9));
            }
            if (z2 && !z7) {
                if (shouldCache(z, cacheMode, layerWorkspaceMgr)) {
                    cacheEnter(z, cacheMode, layerWorkspaceMgr);
                    fwdPassReturn.fz[i2] = iNDArray13.dup('f');
                    cacheExit(z, cacheMode, layerWorkspaceMgr);
                } else {
                    fwdPassReturn.fz[i2] = layerWorkspaceMgr.dup(ArrayType.BP_WORKING_MEM, iNDArray13, 'f');
                }
            }
            iActivation.getActivation(iNDArray13, z);
            if (z2) {
                if (shouldCache(z, cacheMode, layerWorkspaceMgr)) {
                    cacheEnter(z, cacheMode, layerWorkspaceMgr);
                    fwdPassReturn.fa[i2] = iNDArray13.dup('f');
                    cacheExit(z, cacheMode, layerWorkspaceMgr);
                } else {
                    fwdPassReturn.fa[i2] = layerWorkspaceMgr.leverageTo(ArrayType.BP_WORKING_MEM, iNDArray13);
                }
            }
            INDArray iNDArray14 = mmul.get(NDArrayIndex.all(), NDArrayIndex.interval(3 * size2, 4 * size2));
            if (z4) {
                iNDArray14.addi(create2.dup('f').muliRowVector(iNDArray11));
            }
            if (z2 && !z7) {
                cacheEnter(z, cacheMode, layerWorkspaceMgr);
                fwdPassReturn.gz[i2] = layerWorkspaceMgr.dup(ArrayType.BP_WORKING_MEM, iNDArray14, 'f');
                cacheExit(z, cacheMode, layerWorkspaceMgr);
            }
            iActivation.getActivation(iNDArray14, z);
            if (z2) {
                if (shouldCache(z, cacheMode, layerWorkspaceMgr)) {
                    cacheEnter(z, cacheMode, layerWorkspaceMgr);
                    fwdPassReturn.ga[i2] = iNDArray14.dup('f');
                    cacheExit(z, cacheMode, layerWorkspaceMgr);
                } else {
                    fwdPassReturn.ga[i2] = layerWorkspaceMgr.leverageTo(ArrayType.BP_WORKING_MEM, iNDArray14);
                }
            }
            if (z2) {
                cacheEnter(z, cacheMode, layerWorkspaceMgr);
                leverageTo = layerWorkspaceMgr.dup(ArrayType.BP_WORKING_MEM, create2, 'f').muli(iNDArray13);
                cacheExit(z, cacheMode, layerWorkspaceMgr);
                muli = iNDArray14.dup('f').muli(iNDArray12);
            } else {
                leverageTo = layerWorkspaceMgr.leverageTo(ArrayType.FF_WORKING_MEM, iNDArray13.muli(create2));
                muli = iNDArray14.muli(iNDArray12);
            }
            leverageTo.addi(muli);
            INDArray iNDArray15 = mmul.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * size2, 3 * size2));
            if (z4) {
                iNDArray15.addi(leverageTo.dup('f').muliRowVector(iNDArray10));
            }
            if (z2 && !z7) {
                cacheEnter(z, cacheMode, layerWorkspaceMgr);
                fwdPassReturn.oz[i2] = layerWorkspaceMgr.dup(ArrayType.BP_WORKING_MEM, iNDArray15, 'f');
                cacheExit(z, cacheMode, layerWorkspaceMgr);
            }
            iActivation.getActivation(iNDArray15, z);
            if (z2) {
                if (shouldCache(z, cacheMode, layerWorkspaceMgr)) {
                    cacheEnter(z, cacheMode, layerWorkspaceMgr);
                    fwdPassReturn.oa[i2] = iNDArray15.dup('f');
                    cacheExit(z, cacheMode, layerWorkspaceMgr);
                } else {
                    fwdPassReturn.oa[i2] = layerWorkspaceMgr.leverageTo(ArrayType.BP_WORKING_MEM, iNDArray15);
                }
            }
            cacheEnter(z, cacheMode, layerWorkspaceMgr);
            INDArray activation = activationFn.getActivation(layerWorkspaceMgr.dup(ArrayType.FF_WORKING_MEM, leverageTo, 'f'), z);
            cacheExit(z, cacheMode, layerWorkspaceMgr);
            if (z2) {
                cacheEnter(z, cacheMode, layerWorkspaceMgr);
                muli2 = layerWorkspaceMgr.dup(ArrayType.BP_WORKING_MEM, activation, 'f').muli(iNDArray15);
                cacheExit(z, cacheMode, layerWorkspaceMgr);
            } else {
                muli2 = activation.muli(iNDArray15);
            }
            if (iNDArray7 != null) {
                INDArray column = iNDArray7.getColumn(i2, true);
                muli2.muliColumnVector(column);
                leverageTo.muliColumnVector(column);
            }
            INDArray leverageTo2 = layerWorkspaceMgr.leverageTo(ArrayType.FF_WORKING_MEM, leverageTo);
            if (z2) {
                fwdPassReturn.fwdPassOutputAsArrays[i2] = muli2;
                fwdPassReturn.memCellState[i2] = leverageTo2;
                fwdPassReturn.memCellActivations[i2] = activation;
                if (z && cacheMode != CacheMode.NONE && layerWorkspaceMgr.hasConfiguration(ArrayType.FF_CACHE) && layerWorkspaceMgr.isWorkspaceOpen(ArrayType.FF_CACHE)) {
                    fwdPassReturn.memCellActivations[i2] = layerWorkspaceMgr.leverageTo(ArrayType.FF_CACHE, fwdPassReturn.memCellActivations[i2]);
                    fwdPassReturn.memCellState[i2] = layerWorkspaceMgr.leverageTo(ArrayType.FF_CACHE, fwdPassReturn.memCellState[i2]);
                }
                if (cacheMode != CacheMode.NONE) {
                    create.tensorAlongDimension(i2, 1, 0).assign(muli2);
                }
            } else {
                create.tensorAlongDimension(i2, 1, 0).assign(muli2);
            }
            iNDArray8 = muli2;
            create2 = leverageTo2;
            fwdPassReturn.lastAct = muli2;
            fwdPassReturn.lastMemCell = leverageTo2;
            if (notifyScopeEntered != null) {
                if (0 != 0) {
                    try {
                        notifyScopeEntered.close();
                    } catch (Throwable th8) {
                        th5.addSuppressed(th8);
                    }
                } else {
                    notifyScopeEntered.close();
                }
            }
        }
        fwdPassReturn.prevAct = iNDArray5;
        fwdPassReturn.prevMemCell = iNDArray6;
        return fwdPassReturn;
    }

    private static boolean shouldCache(boolean z, CacheMode cacheMode, LayerWorkspaceMgr layerWorkspaceMgr) {
        return z && cacheMode != CacheMode.NONE && layerWorkspaceMgr.hasConfiguration(ArrayType.FF_CACHE) && layerWorkspaceMgr.isWorkspaceOpen(ArrayType.FF_CACHE);
    }

    private static void cacheEnter(boolean z, CacheMode cacheMode, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (shouldCache(z, cacheMode, layerWorkspaceMgr)) {
            layerWorkspaceMgr.notifyScopeBorrowed(ArrayType.FF_CACHE);
        }
    }

    private static void cacheExit(boolean z, CacheMode cacheMode, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (shouldCache(z, cacheMode, layerWorkspaceMgr)) {
            Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(layerWorkspaceMgr.getWorkspaceName(ArrayType.FF_CACHE)).notifyScopeLeft();
        }
    }

    public static Pair<Gradient, INDArray> backpropGradientHelper(BaseRecurrentLayer baseRecurrentLayer, NeuralNetConfiguration neuralNetConfiguration, IActivation iActivation, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, boolean z, int i, FwdPassReturn fwdPassReturn, boolean z2, String str, String str2, String str3, Map<String, INDArray> map, INDArray iNDArray5, boolean z3, LSTMHelper lSTMHelper, LayerWorkspaceMgr layerWorkspaceMgr, boolean z4) {
        INDArray create;
        INDArray castTo = iNDArray.castTo(iNDArray3.dataType());
        long size = iNDArray2.size(0);
        long size2 = iNDArray3.size(0);
        long size3 = iNDArray4.size(0);
        boolean z5 = iNDArray4.rank() < 3;
        long size4 = z5 ? 1L : iNDArray4.size(2);
        INDArray iNDArray6 = null;
        INDArray iNDArray7 = null;
        INDArray iNDArray8 = null;
        if (z3) {
            iNDArray6 = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.point(4 * size)).reshape(1L, iNDArray2.size(0));
            iNDArray7 = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.point((4 * size) + 1)).reshape(1L, iNDArray2.size(0));
            iNDArray8 = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.point((4 * size) + 2)).reshape(1L, iNDArray2.size(0));
        }
        INDArray iNDArray9 = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.interval(0L, 4 * size));
        INDArray create2 = layerWorkspaceMgr.create(ArrayType.ACTIVATION_GRAD, castTo.dataType(), new long[]{size3, size2, size4}, 'f');
        INDArray iNDArray10 = null;
        INDArray create3 = Nd4j.create(iNDArray3.dataType(), new long[]{size3, 4 * size}, 'f');
        INDArray iNDArray11 = create3.get(NDArrayIndex.all(), NDArrayIndex.interval(0L, size));
        INDArray iNDArray12 = create3.get(NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size));
        INDArray iNDArray13 = create3.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size));
        INDArray iNDArray14 = create3.get(NDArrayIndex.all(), NDArrayIndex.interval(3 * size, 4 * size));
        long j = 0;
        if (z) {
            j = Math.max(0L, size4 - i);
        }
        INDArray iNDArray15 = map.get(str);
        INDArray iNDArray16 = map.get(str2);
        INDArray iNDArray17 = map.get(str3);
        iNDArray15.assign((Number) 0);
        iNDArray16.assign((Number) 0);
        iNDArray17.assign((Number) 0);
        INDArray iNDArray18 = iNDArray16.get(NDArrayIndex.all(), NDArrayIndex.interval(0L, 4 * size));
        INDArray iNDArray19 = null;
        INDArray iNDArray20 = null;
        INDArray iNDArray21 = null;
        if (z3) {
            iNDArray19 = iNDArray16.get(NDArrayIndex.all(), NDArrayIndex.point(4 * size)).reshape(1L, iNDArray2.size(0));
            iNDArray20 = iNDArray16.get(NDArrayIndex.all(), NDArrayIndex.point((4 * size) + 1)).reshape(1L, iNDArray2.size(0));
            iNDArray21 = iNDArray16.get(NDArrayIndex.all(), NDArrayIndex.point((4 * size) + 2)).reshape(1L, iNDArray2.size(0));
        }
        if (lSTMHelper != null && (baseRecurrentLayer.helperCountFail == 0 || !z4)) {
            Pair<Gradient, INDArray> pair = null;
            try {
                pair = lSTMHelper.backpropGradient(neuralNetConfiguration, iActivation, castTo, iNDArray2, iNDArray3, iNDArray4, z, i, fwdPassReturn, z2, str, str2, str3, map, iNDArray5, z3, layerWorkspaceMgr);
            } catch (ND4JOpProfilerException e) {
                throw e;
            } catch (Exception e2) {
                if (e2.getMessage().contains("Failed to allocate")) {
                    throw e2;
                }
                if (!z4) {
                    throw new RuntimeException("Error during LSTM MKL/CuDNN helper backprop - helperAllowFallback() is set to false", e2);
                }
                baseRecurrentLayer.helperCountFail++;
                log.warn("MKL/CuDNN execution failed - falling back on built-in implementation", (Throwable) e2);
            }
            if (pair != null) {
                return pair;
            }
        }
        boolean z6 = iActivation instanceof ActivationSigmoid;
        IActivation activationFn = ((BaseLayer) neuralNetConfiguration.getLayer()).getActivationFn();
        INDArray iNDArray22 = null;
        long j2 = size4;
        while (true) {
            long j3 = j2 - 1;
            if (j3 < j) {
                DefaultGradient defaultGradient = new DefaultGradient();
                defaultGradient.gradientForVariable().put(str, iNDArray15);
                defaultGradient.gradientForVariable().put(str2, iNDArray16);
                defaultGradient.gradientForVariable().put(str3, iNDArray17);
                return new Pair<>(defaultGradient, create2);
            }
            MemoryWorkspace notifyScopeEntered = layerWorkspaceMgr.notifyScopeEntered((LayerWorkspaceMgr) ArrayType.RNN_BP_LOOP_WORKING_MEM);
            Throwable th = null;
            try {
                try {
                    int i2 = (int) j3;
                    int i3 = 1;
                    if (!z2) {
                        i2 = (int) ((size4 - j3) - 1);
                        i3 = -1;
                    }
                    if (j3 == size4 - 1 || !z3) {
                        create = Nd4j.create(iNDArray3.dataType(), new long[]{size3, size}, 'f');
                    } else {
                        create = iNDArray12.dup('f').muliRowVector(iNDArray6);
                        create.addi(iNDArray14.dup('f').muliRowVector(iNDArray8));
                    }
                    INDArray iNDArray23 = j3 == 0 ? fwdPassReturn.prevMemCell : fwdPassReturn.memCellState[i2 - i3];
                    INDArray iNDArray24 = j3 == 0 ? fwdPassReturn.prevAct : fwdPassReturn.fwdPassOutputAsArrays[i2 - i3];
                    INDArray iNDArray25 = fwdPassReturn.memCellState[i2];
                    INDArray offsetZeroCopy = Shape.toOffsetZeroCopy(z5 ? iNDArray4 : iNDArray4.tensorAlongDimension(i2, 1, 0), 'f');
                    if (j3 != size4 - 1) {
                        Nd4j.gemm(create3, iNDArray9, offsetZeroCopy, false, true, 1.0d, 1.0d);
                    }
                    INDArray iNDArray26 = fwdPassReturn.memCellActivations[i2];
                    INDArray iNDArray27 = fwdPassReturn.oa[i2];
                    Nd4j.getExecutioner().exec(new MulOp(offsetZeroCopy, iNDArray26, iNDArray13));
                    if (z6) {
                        iNDArray13.muli(Nd4j.getExecutioner().exec(new TimesOneMinus(iNDArray27.dup('f'))));
                    } else {
                        iNDArray13.assign(iActivation.backprop(fwdPassReturn.oz[i2], iNDArray13).getFirst());
                    }
                    create.addi(activationFn.backprop(iNDArray25.dup('f'), iNDArray27.muli(offsetZeroCopy)).getFirst());
                    if (z3) {
                        create.addi(iNDArray13.dup('f').muliRowVector(iNDArray7));
                    }
                    if (j3 != size4 - 1) {
                        create.addi(fwdPassReturn.fa[i2 + i3].muli(iNDArray10));
                    }
                    iNDArray10 = layerWorkspaceMgr.leverageTo(ArrayType.BP_WORKING_MEM, create);
                    INDArray iNDArray28 = fwdPassReturn.fa[i2];
                    INDArray iNDArray29 = null;
                    if (j3 > 0 || iNDArray23 != null) {
                        iNDArray29 = iNDArray12;
                        if (z6) {
                            Nd4j.getExecutioner().exec(new TimesOneMinus(iNDArray28, iNDArray29));
                            iNDArray29.muli(create);
                            iNDArray29.muli(iNDArray23);
                        } else {
                            iNDArray29.assign(iActivation.backprop(fwdPassReturn.fz[i2].dup('f'), create.mul(iNDArray23)).getFirst());
                        }
                    }
                    INDArray iNDArray30 = fwdPassReturn.ga[i2];
                    INDArray iNDArray31 = fwdPassReturn.ia[i2];
                    if (z6) {
                        Nd4j.getExecutioner().exec(new TimesOneMinus(iNDArray30, iNDArray14));
                        iNDArray14.muli(iNDArray31);
                        iNDArray14.muli(create);
                    } else {
                        iNDArray14.assign(iActivation.backprop(fwdPassReturn.gz[i2], Nd4j.getExecutioner().exec(new MulOp(iNDArray31, create, Nd4j.createUninitialized(iNDArray3.dataType(), iNDArray31.shape(), 'f')))[0]).getFirst());
                    }
                    iNDArray11.assign(activationFn.backprop(fwdPassReturn.iz[i2], Nd4j.getExecutioner().exec(new MulOp(iNDArray30, create, Nd4j.createUninitialized(iNDArray3.dataType(), iNDArray11.shape(), 'f')))[0]).getFirst());
                    if (iNDArray5 != null) {
                        iNDArray22 = iNDArray5.getColumn(i2, true);
                        create3.muliColumnVector(iNDArray22);
                    }
                    INDArray mmulCompatible = Shape.toMmulCompatible(z5 ? castTo : castTo.tensorAlongDimension(i2, 1, 0));
                    if (j3 > 0 || iNDArray24 != null) {
                        Nd4j.gemm(mmulCompatible, create3, iNDArray15, true, false, 1.0d, 1.0d);
                    } else {
                        Nd4j.gemm(mmulCompatible, iNDArray11, iNDArray15.get(NDArrayIndex.all(), NDArrayIndex.interval(0L, size)), true, false, 1.0d, 1.0d);
                        Nd4j.gemm(mmulCompatible, create3.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 4 * size)), iNDArray15.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 4 * size)), true, false, 1.0d, 1.0d);
                    }
                    if (j3 > 0 || iNDArray24 != null) {
                        Nd4j.gemm(iNDArray24, create3, iNDArray18, true, false, 1.0d, 1.0d);
                        if (z3) {
                            iNDArray19.addi(iNDArray29.dup('f').muli(iNDArray23).sum(true, 0));
                            iNDArray21.addi(iNDArray14.dup('f').muli(iNDArray23).sum(true, 0));
                        }
                    }
                    if (z3) {
                        iNDArray20.addi(iNDArray13.dup('f').muli(iNDArray25).sum(true, 0));
                    }
                    if (j3 > 0 || iNDArray24 != null) {
                        iNDArray17.addi(create3.sum(true, 0));
                    } else {
                        iNDArray17.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(0L, size)).addi(iNDArray11.sum(true, 0));
                        iNDArray17.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(2 * size, 4 * size)).addi(create3.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 4 * size)).sum(true, 0));
                    }
                    INDArray tensorAlongDimension = create2.tensorAlongDimension(i2, 1, 0);
                    if (j3 > 0 || iNDArray24 != null) {
                        Nd4j.gemm(create3, iNDArray3, tensorAlongDimension, false, true, 1.0d, 1.0d);
                    } else {
                        Nd4j.gemm(iNDArray11, iNDArray3.get(NDArrayIndex.all(), NDArrayIndex.interval(0L, size)), tensorAlongDimension, false, true, 1.0d, 1.0d);
                        Nd4j.gemm(create3.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 4 * size)), iNDArray3.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 4 * size)), tensorAlongDimension, false, true, 1.0d, 1.0d);
                    }
                    if (iNDArray5 != null) {
                        tensorAlongDimension.muliColumnVector(iNDArray22);
                    }
                    if (notifyScopeEntered != null) {
                        if (0 != 0) {
                            try {
                                notifyScopeEntered.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            notifyScopeEntered.close();
                        }
                    }
                    j2 = j3;
                } finally {
                }
            } catch (Throwable th3) {
                if (notifyScopeEntered != null) {
                    if (th != null) {
                        try {
                            notifyScopeEntered.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        notifyScopeEntered.close();
                    }
                }
                throw th3;
            }
        }
    }

    public static LayerMemoryReport getMemoryReport(AbstractLSTM abstractLSTM, InputType inputType) {
        return getMemoryReport(abstractLSTM instanceof org.deeplearning4j.nn.conf.layers.GravesLSTM, abstractLSTM, inputType);
    }

    public static LayerMemoryReport getMemoryReport(org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM gravesBidirectionalLSTM, InputType inputType) {
        LayerMemoryReport memoryReport = getMemoryReport(true, gravesBidirectionalLSTM, inputType);
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        HashMap hashMap4 = new HashMap();
        for (CacheMode cacheMode : CacheMode.values()) {
            hashMap.put(cacheMode, Long.valueOf(2 * memoryReport.getWorkingMemoryFixedTrain().get(cacheMode).longValue()));
            hashMap2.put(cacheMode, Long.valueOf(2 * memoryReport.getWorkingMemoryVariableTrain().get(cacheMode).longValue()));
            hashMap3.put(cacheMode, Long.valueOf(2 * memoryReport.getCacheModeMemFixed().get(cacheMode).longValue()));
            hashMap4.put(cacheMode, Long.valueOf(2 * memoryReport.getCacheModeMemVariablePerEx().get(cacheMode).longValue()));
        }
        return new LayerMemoryReport.Builder(memoryReport.getLayerName(), memoryReport.getClass(), memoryReport.getInputType(), memoryReport.getOutputType()).standardMemory(2 * memoryReport.getParameterSize(), 2 * memoryReport.getUpdaterStateSize()).workingMemory(2 * memoryReport.getWorkingMemoryFixedInference(), 2 * memoryReport.getWorkingMemoryVariableInference(), hashMap, hashMap2).cacheMemory(hashMap3, hashMap4).build();
    }

    public static LayerMemoryReport getMemoryReport(boolean z, FeedForwardLayer feedForwardLayer, InputType inputType) {
        long j;
        long j2;
        long timeSeriesLength = ((InputType.InputTypeRecurrent) inputType).getTimeSeriesLength();
        InputType outputType = feedForwardLayer.getOutputType(-1, inputType);
        long numParams = feedForwardLayer.initializer().numParams(feedForwardLayer);
        int stateSize = (int) feedForwardLayer.getIUpdater().stateSize(numParams);
        long nOut = timeSeriesLength * 4 * feedForwardLayer.getNOut();
        long nOut2 = timeSeriesLength * 6 * feedForwardLayer.getNOut();
        long nOut3 = (z ? 9 : 6) * timeSeriesLength * feedForwardLayer.getNOut();
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (CacheMode cacheMode : CacheMode.values()) {
            if (cacheMode == CacheMode.NONE) {
                j = nOut + nOut2 + nOut3;
                j2 = 0;
            } else {
                j = nOut + nOut3;
                j2 = nOut2;
            }
            hashMap.put(cacheMode, Long.valueOf(j));
            hashMap2.put(cacheMode, Long.valueOf(j2));
        }
        return new LayerMemoryReport.Builder(null, feedForwardLayer.getClass(), inputType, outputType).standardMemory(numParams, stateSize).workingMemory(0L, nOut, MemoryReport.CACHE_MODE_ALL_ZEROS, hashMap).cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, hashMap2).build();
    }
}
