package org.nd4j.autodiff.samediff.internal;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.AbstractSession;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.BaseReduceOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.If;
import org.nd4j.linalg.api.ops.impl.controlflow.While;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.BaseTensorOp;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayConcat;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayGather;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayRead;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayScatter;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArraySize;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArraySplit;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayWrite;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/autodiff/samediff/internal/InferenceSession.class */
public class InferenceSession extends AbstractSession<INDArray, DifferentialFunction> {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) InferenceSession.class);
    private static final String SCOPE_PANIC_MSG = "If required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\nAlternatively, arrays defined in a workspace must be replaced after the workspace has been closed.";

    public InferenceSession(@NonNull SameDiff sameDiff) {
        super(sameDiff);
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked @NonNull but is null");
        }
    }

    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    protected Map<String, INDArray> preprocessPlaceholders(Map<String, INDArray> map) {
        if (map == null || map.isEmpty()) {
            return map;
        }
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, INDArray> entry : map.entrySet()) {
            Preconditions.checkState(this.sameDiff.hasVariable(entry.getKey()), "Invalid placeholder passed for execution: No variable/placeholder with name %s exists", entry.getKey());
            INDArray value = entry.getValue();
            if (value.isAttached()) {
                MemoryWorkspace parentWorkspace = value.data() == null ? null : value.data().getParentWorkspace();
                if (parentWorkspace != null && parentWorkspace.getWorkspaceType() != MemoryWorkspace.Type.CIRCULAR) {
                    if (!parentWorkspace.isScopeActive()) {
                        throw new ND4JIllegalStateException("Placeholder \"" + entry.getKey() + "\" array uses leaked workspace pointer from workspace [" + parentWorkspace.getId() + "]: Workspace the array was defined in is no longer open.\nAll open workspaces: " + DefaultOpExecutioner.allOpenWorkspaces() + "\n" + SCOPE_PANIC_MSG);
                    }
                    if (parentWorkspace.getGenerationId() != value.data().getGenerationId()) {
                        throw new ND4JIllegalStateException("Placeholder \"" + entry.getKey() + "\" array uses outdated workspace pointer from workspace [" + parentWorkspace.getId() + "]: Workspace array was defined in has been closed and reopened at least once since array creation. Array WS iteration: " + value.data().getGenerationId() + ". Workspace current iteration: " + parentWorkspace.getGenerationId() + "\nAll open workspaces: " + DefaultOpExecutioner.allOpenWorkspaces() + "\n" + SCOPE_PANIC_MSG);
                    }
                }
            }
            DataType dataType = this.sameDiff.getVariable(entry.getKey()).dataType();
            if (value.dataType() != dataType) {
                value = value.castTo(dataType);
            }
            hashMap.put(entry.getKey(), value);
        }
        return hashMap;
    }

    /* renamed from: getOutputs, reason: avoid collision after fix types in other method */
    public INDArray[] getOutputs2(DifferentialFunction differentialFunction, AbstractSession.FrameIter frameIter, Set<AbstractSession.VarId> set, Set<AbstractSession.VarId> set2, Set<String> set3, List<Listener> list, At at, MultiDataSet multiDataSet) {
        if (list != null && list.size() > 0) {
            SameDiffOp sameDiffOp = this.sameDiff.getOps().get(differentialFunction.getOwnName());
            for (Listener listener : list) {
                if (listener.isActive(at.operation())) {
                    listener.preOpExecution(this.sameDiff, at, sameDiffOp);
                }
            }
        }
        INDArray[] outputsHelper = getOutputsHelper(differentialFunction, frameIter, set, set2, set3);
        if (list != null && list.size() > 0) {
            SameDiffOp sameDiffOp2 = this.sameDiff.getOps().get(differentialFunction.getOwnName());
            HashMap hashMap = new HashMap();
            for (int i = 0; i < outputsHelper.length; i++) {
                hashMap.put(sameDiffOp2.outputsOfOp.get(i), outputsHelper[i]);
            }
            Map unmodifiableMap = Collections.unmodifiableMap(hashMap);
            for (Listener listener2 : list) {
                if (listener2.isActive(at.operation())) {
                    listener2.opExecution(this.sameDiff, at, multiDataSet, sameDiffOp2, outputsHelper);
                    for (String str : unmodifiableMap.keySet()) {
                        listener2.activationAvailable(this.sameDiff, at, multiDataSet, sameDiffOp2, str, (INDArray) unmodifiableMap.get(str));
                    }
                }
            }
        }
        return outputsHelper;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public INDArray[] getOutputsHelper(DifferentialFunction differentialFunction, AbstractSession.FrameIter frameIter, Set<AbstractSession.VarId> set, Set<AbstractSession.VarId> set2, Set<String> set3) {
        MemoryWorkspace scopeOutOfWorkspaces;
        int size = (set == null ? 0 : set.size()) + (set3 == null ? 0 : set3.size()) + (set2 == null ? 0 : set2.size());
        boolean z = (set == null || set.size() == 0) && (set2 == null || set2.size() == 0);
        if (differentialFunction instanceof Identity) {
            String[] argNames = ((Identity) differentialFunction).argNames();
            Preconditions.checkState(argNames.length == 1, "Expected only 1 arg name in identity op, got %s", (Object[]) argNames);
            return new INDArray[]{(INDArray) this.nodeOutputs.get(newVarId(argNames[0], frameIter))};
        }
        if (differentialFunction instanceof Switch) {
            String[] argNames2 = ((Switch) differentialFunction).argNames();
            INDArray iNDArray = (INDArray) this.nodeOutputs.get(newVarId(argNames2[1], frameIter));
            Preconditions.checkState(iNDArray.isScalar() && iNDArray.dataType() == DataType.BOOL, "Expected boolean predicate: got %ndSInfo", iNDArray);
            AbstractSession.VarId newVarId = newVarId(argNames2[0], frameIter);
            return iNDArray.getDouble(0L) == 0.0d ? new INDArray[]{(INDArray) this.nodeOutputs.get(newVarId), null} : new INDArray[]{null, (INDArray) this.nodeOutputs.get(newVarId)};
        }
        if (differentialFunction instanceof Enter) {
            Enter enter = (Enter) differentialFunction;
            String[] argNames3 = enter.argNames();
            Preconditions.checkState(argNames3.length == 1, "Expected only 1 arg name for enter op: got %s", (Object[]) argNames3);
            Preconditions.checkState(size == 1, "Expected exactly 1 op input for Enter op \"%s\", got %s+%s", enter.getOwnName(), set, set3);
            INDArray iNDArray2 = (INDArray) this.nodeOutputs.get(z ? new AbstractSession.VarId(set3.iterator().next(), AbstractSession.OUTER_FRAME, 0, null) : (set2 == null || set2.size() <= 0) ? set.iterator().next() : set2.iterator().next());
            Preconditions.checkNotNull(iNDArray2, "Could not get enter op \"%s\" input: output variable %s - %s", enter.getOwnName(), enter.outputVariablesNames(), frameIter);
            return new INDArray[]{iNDArray2};
        }
        if (differentialFunction instanceof Exit) {
            return new INDArray[]{(INDArray) this.nodeOutputs.get(z ? new AbstractSession.VarId(set3.iterator().next(), AbstractSession.OUTER_FRAME, 0, null) : (set2 == null || set2.size() <= 0) ? set.iterator().next() : set2.iterator().next())};
        }
        if (differentialFunction instanceof NextIteration) {
            Preconditions.checkState(size == 1, "Expected exactly 1 op input for NextIteration: got %s+%s", set, set3);
            AbstractSession.VarId next = (set2 == null || set2.isEmpty()) ? set.iterator().next() : set2.iterator().next();
            Preconditions.checkState(frameIter.getFrame().equals(next.getFrame()), "Expected same frame for NextIteration input vs. output: got input %s, output %s", next, frameIter);
            Preconditions.checkState(frameIter.getIteration() == next.getIteration() + 1, "Expected output iteration for NextIteration output to be 1 larger than the input iteration. Input: %s, output %s", next, frameIter);
            return new INDArray[]{(INDArray) this.nodeOutputs.get(next)};
        }
        if (differentialFunction instanceof If) {
            ((If) differentialFunction).argNames();
            throw new UnsupportedOperationException("Execution not yet implemented for: " + differentialFunction.getClass().getName());
        }
        if (differentialFunction instanceof Merge) {
            Merge merge = (Merge) differentialFunction;
            String[] inputsForOp = this.sameDiff.getInputsForOp(differentialFunction);
            for (String str : inputsForOp) {
                AbstractSession.VarId newVarId2 = newVarId(str, frameIter);
                if (this.nodeOutputs.containsKey(newVarId2)) {
                    log.trace("Returning input \"{}\" for merge node \"{}\"", merge.getOwnName(), str);
                    return new INDArray[]{(INDArray) this.nodeOutputs.get(newVarId2)};
                }
            }
            throw new IllegalStateException("Merge node " + merge.getOwnName() + " has no available inputs (all inputs: " + Arrays.toString(inputsForOp) + ") - should not be executed at this point");
        }
        if (differentialFunction instanceof LoopCond) {
            String[] argNames4 = ((LoopCond) differentialFunction).argNames();
            Preconditions.checkState(argNames4.length == 1, "Expected only 1 arg name in LoopCond op, got %s", (Object[]) argNames4);
            INDArray iNDArray3 = (INDArray) this.nodeOutputs.get(newVarId(argNames4[0], frameIter));
            Preconditions.checkNotNull(iNDArray3, "Input to LoopCond op must not be null");
            Preconditions.checkState(iNDArray3.isScalar() && iNDArray3.dataType() == DataType.BOOL, "LoopCond input must be a scalar boolean, got %ndShape");
            return new INDArray[]{iNDArray3};
        }
        if (!(differentialFunction instanceof BaseTensorOp)) {
            if (differentialFunction instanceof GradientBackwardsMarker) {
                return new INDArray[]{Nd4j.scalar(1.0f)};
            }
            if (differentialFunction instanceof CustomOp) {
                CustomOp customOp = (CustomOp) differentialFunction;
                Nd4j.getExecutioner().exec(customOp);
                return customOp.outputArguments();
            }
            if (!(differentialFunction instanceof Op)) {
                throw new UnsupportedOperationException("Execution not yet implemented for: " + differentialFunction.getClass().getName());
            }
            Op op = (Op) differentialFunction;
            Nd4j.getExecutioner().exec(op);
            return new INDArray[]{op.z()};
        }
        if (differentialFunction instanceof TensorArray) {
            AbstractSession.VarId newVarId3 = newVarId(differentialFunction.outputVariable().getVarName(), frameIter);
            Preconditions.checkState(!this.tensorArrays.containsKey(newVarId3), "TensorArray already exists for %s when executing TensorArrayV3", newVarId3);
            this.tensorArrays.put(newVarId3, new ArrayList());
            MemoryWorkspace scopeOutOfWorkspaces2 = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            Throwable th = null;
            try {
                INDArray[] iNDArrayArr = {Nd4j.scalar(true), Nd4j.scalar(0.0f)};
                if (scopeOutOfWorkspaces2 != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces2.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces2.close();
                    }
                }
                return iNDArrayArr;
            } catch (Throwable th3) {
                if (scopeOutOfWorkspaces2 != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces2.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        scopeOutOfWorkspaces2.close();
                    }
                }
                throw th3;
            }
        }
        if (differentialFunction instanceof TensorArrayRead) {
            INDArray array = getArray(differentialFunction.arg(1), set, set2);
            Preconditions.checkState(array.isScalar(), "TensorArrayRead input argument 1 should be scalar - has shape %ndShape", array);
            int i = array.getInt(0);
            SDVariable arg = differentialFunction.arg(0);
            AbstractSession.VarId lookup = set == null ? null : lookup(arg.getVarName(), set, false);
            if (lookup == null && set2 != null) {
                lookup = lookup(arg.getVarName(), set2, false);
            }
            Preconditions.checkState(lookup != null, "Could not find input %s", arg.getVarName());
            while (this.sameDiff.getVariableOutputOp(arg.getVarName()) instanceof Enter) {
                arg = this.sameDiff.getVariableOutputOp(arg.getVarName()).arg();
                lookup = newVarId(arg.getVarName(), lookup.getParentFrame());
            }
            List<INDArray> list = getTensorArrays().get(lookup);
            Preconditions.checkState(list != null, "Could not find TensorList for %s", lookup);
            Preconditions.checkState(list.size() > i, "Cannot get index %s from TensorList of size %s (array not present?) - VarId=%s", Integer.valueOf(i), Integer.valueOf(list.size()), lookup);
            return new INDArray[]{list.get(i)};
        }
        if (differentialFunction instanceof TensorArrayWrite) {
            SDVariable arg2 = differentialFunction.arg(0);
            AbstractSession.VarId lookup2 = set == null ? null : lookup(arg2.getVarName(), set, false);
            if (lookup2 == null && set2 != null) {
                lookup2 = lookup(arg2.getVarName(), set2, false);
            }
            Preconditions.checkState(lookup2 != null, "Could not find input %s", arg2.getVarName());
            while (this.sameDiff.getVariableOutputOp(arg2.getVarName()) instanceof Enter) {
                arg2 = this.sameDiff.getVariableOutputOp(arg2.getVarName()).arg();
                lookup2 = newVarId(arg2.getVarName(), lookup2.getParentFrame());
            }
            INDArray array2 = getArray(this.sameDiff.getVariable(differentialFunction.arg(1).getVarName()), set, set2);
            Preconditions.checkState(array2.isScalar(), "Index variable ID for TensorArrayWrite should be a scalar, got %ndShape", array2);
            int i2 = array2.getInt(0);
            String varName = differentialFunction.arg(2).getVarName();
            INDArray array3 = getArray(this.sameDiff.getVariable(varName), set, set2);
            Preconditions.checkState(array3 != null, "Could not find array for %s", varName);
            Preconditions.checkState(this.tensorArrays.containsKey(lookup2), "Tensor array does not exist for %s", lookup2);
            List list2 = (List) this.tensorArrays.get(lookup2);
            while (list2.size() <= i2) {
                list2.add(null);
            }
            list2.set(i2, array3);
            MemoryWorkspace scopeOutOfWorkspaces3 = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            Throwable th5 = null;
            try {
                INDArray[] iNDArrayArr2 = {Nd4j.scalar(0.0f)};
                if (scopeOutOfWorkspaces3 != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces3.close();
                        } catch (Throwable th6) {
                            th5.addSuppressed(th6);
                        }
                    } else {
                        scopeOutOfWorkspaces3.close();
                    }
                }
                return iNDArrayArr2;
            } catch (Throwable th7) {
                if (scopeOutOfWorkspaces3 != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces3.close();
                        } catch (Throwable th8) {
                            th5.addSuppressed(th8);
                        }
                    } else {
                        scopeOutOfWorkspaces3.close();
                    }
                }
                throw th7;
            }
        }
        if (differentialFunction instanceof TensorArraySize) {
            SDVariable arg3 = differentialFunction.arg(0);
            AbstractSession.VarId lookup3 = set == null ? null : lookup(arg3.getVarName(), set, false);
            if (lookup3 == null && set2 != null) {
                lookup3 = lookup(arg3.getVarName(), set2, false);
            }
            List list3 = (List) this.tensorArrays.get(lookup3);
            Preconditions.checkState(list3 != null, "Could not find TensorArray: %s", lookup3);
            scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            Throwable th9 = null;
            try {
                try {
                    INDArray[] iNDArrayArr3 = {Nd4j.scalar(DataType.INT, Integer.valueOf(list3.size()))};
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th10) {
                                th9.addSuppressed(th10);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                    return iNDArrayArr3;
                } finally {
                }
            } finally {
            }
        }
        if (differentialFunction instanceof TensorArrayConcat) {
            SDVariable arg4 = differentialFunction.arg(0);
            AbstractSession.VarId lookup4 = set == null ? null : lookup(arg4.getVarName(), set, false);
            if (lookup4 == null && set2 != null) {
                lookup4 = lookup(arg4.getVarName(), set2, false);
            }
            List list4 = (List) this.tensorArrays.get(lookup4);
            scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            Throwable th11 = null;
            try {
                try {
                    INDArray[] iNDArrayArr4 = {Nd4j.concat(0, (INDArray[]) list4.toArray(new INDArray[list4.size()]))};
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th12) {
                                th11.addSuppressed(th12);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                    return iNDArrayArr4;
                } finally {
                }
            } finally {
            }
        }
        if (differentialFunction instanceof TensorArrayGather) {
            SDVariable arg5 = differentialFunction.arg(0);
            AbstractSession.VarId lookup5 = set == null ? null : lookup(arg5.getVarName(), set, false);
            if (lookup5 == null && set2 != null) {
                lookup5 = lookup(arg5.getVarName(), set2, false);
            }
            List list5 = (List) this.tensorArrays.get(lookup5);
            Preconditions.checkState(list5 != null, "Could not find TensorArray: %s", lookup5);
            String varName2 = differentialFunction.arg(1).getVarName();
            INDArray array4 = getArray(this.sameDiff.getVariable(varName2), set, set2);
            Preconditions.checkState(array4.isVector(), "Indices variable for TensorArrayGather should be a vector, got %ndShape for %s", array4, varName2);
            Preconditions.checkState(array4.dataType().isIntType(), "Indices variable for TensorArrayGather should be an integer type, got %s for array %s", array4.dataType(), varName2);
            int[] intVector = array4.toIntVector();
            ArrayList arrayList = new ArrayList();
            if (intVector.length == 1 && intVector[0] == -1) {
                arrayList.addAll(list5);
            } else {
                int length = intVector.length;
                for (int i3 = 0; i3 < length; i3++) {
                    int i4 = intVector[i3];
                    Preconditions.checkState(i4 >= 0, "Index for TensorArrayGather must be >= 0, got %s", i4);
                    arrayList.add(list5.get(i4));
                }
            }
            MemoryWorkspace scopeOutOfWorkspaces4 = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            Throwable th13 = null;
            try {
                INDArray[] iNDArrayArr5 = {Nd4j.pile(arrayList)};
                if (scopeOutOfWorkspaces4 != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces4.close();
                        } catch (Throwable th14) {
                            th13.addSuppressed(th14);
                        }
                    } else {
                        scopeOutOfWorkspaces4.close();
                    }
                }
                return iNDArrayArr5;
            } catch (Throwable th15) {
                if (scopeOutOfWorkspaces4 != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces4.close();
                        } catch (Throwable th16) {
                            th13.addSuppressed(th16);
                        }
                    } else {
                        scopeOutOfWorkspaces4.close();
                    }
                }
                throw th15;
            }
        }
        if (!(differentialFunction instanceof TensorArrayScatter)) {
            if (!(differentialFunction instanceof TensorArraySplit)) {
                throw new IllegalStateException("Execution support not yet implemented for: " + differentialFunction.getClass().getName());
            }
            SDVariable arg6 = differentialFunction.arg(0);
            AbstractSession.VarId lookup6 = set == null ? null : lookup(arg6.getVarName(), set, false);
            if (lookup6 == null && set2 != null) {
                lookup6 = lookup(arg6.getVarName(), set2, false);
            }
            List list6 = (List) this.tensorArrays.get(lookup6);
            Preconditions.checkState(list6 != null, "Could not find TensorArray: %s", lookup6);
            INDArray array5 = getArray(this.sameDiff.getVariable(differentialFunction.arg(1).getVarName()), set, set2);
            String varName3 = differentialFunction.arg(2).getVarName();
            INDArray array6 = getArray(this.sameDiff.getVariable(varName3), set, set2);
            Preconditions.checkState(array6.isVector(), "Indices variable for TensorArraySplit should be a vector, got %ndShape for %s", array6, varName3);
            Preconditions.checkState(array6.dataType().isIntType(), "Indices variable for TensorArraySplit should be an integer type, got %s for array %s", array6.dataType(), varName3);
            int[] intVector2 = array6.toIntVector();
            while (list6.size() <= intVector2.length) {
                list6.add(null);
            }
            INDArrayIndex[] iNDArrayIndexArr = (INDArrayIndex[]) ArrayUtil.nTimes(array5.rank(), NDArrayIndex.all(), INDArrayIndex.class);
            int i5 = 0;
            for (int i6 = 0; i6 < intVector2.length; i6++) {
                iNDArrayIndexArr[0] = NDArrayIndex.interval(i5, i5 + intVector2[i6]);
                list6.set(i6, array5.get(iNDArrayIndexArr).dup());
                i5 += intVector2[i6];
            }
            MemoryWorkspace scopeOutOfWorkspaces5 = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            Throwable th17 = null;
            try {
                INDArray[] iNDArrayArr6 = {Nd4j.scalar(0.0f)};
                if (scopeOutOfWorkspaces5 != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces5.close();
                        } catch (Throwable th18) {
                            th17.addSuppressed(th18);
                        }
                    } else {
                        scopeOutOfWorkspaces5.close();
                    }
                }
                return iNDArrayArr6;
            } catch (Throwable th19) {
                if (scopeOutOfWorkspaces5 != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces5.close();
                        } catch (Throwable th20) {
                            th17.addSuppressed(th20);
                        }
                    } else {
                        scopeOutOfWorkspaces5.close();
                    }
                }
                throw th19;
            }
        }
        SDVariable arg7 = differentialFunction.arg(0);
        AbstractSession.VarId lookup7 = set == null ? null : lookup(arg7.getVarName(), set, false);
        if (lookup7 == null && set2 != null) {
            lookup7 = lookup(arg7.getVarName(), set2, false);
        }
        List list7 = (List) this.tensorArrays.get(lookup7);
        Preconditions.checkState(list7 != null, "Could not find TensorArray: %s", lookup7);
        String varName4 = differentialFunction.arg(1).getVarName();
        INDArray array7 = getArray(this.sameDiff.getVariable(varName4), set, set2);
        Preconditions.checkState(array7.isVector(), "Indices variable for TensorArrayScatter should be a vector, got %ndShape for %s", array7, varName4);
        Preconditions.checkState(array7.dataType().isIntType(), "Indices variable for TensorArrayScatter should be an integer type, got %s for array %s", array7.dataType(), varName4);
        int[] intVector3 = array7.toIntVector();
        INDArray array8 = getArray(this.sameDiff.getVariable(differentialFunction.arg(2).getVarName()), set, set2);
        while (list7.size() <= intVector3.length) {
            list7.add(null);
        }
        if (intVector3.length == 1 && intVector3[0] == -1) {
            intVector3 = ArrayUtil.range(0, (int) array8.size(0));
        }
        INDArrayIndex[] iNDArrayIndexArr2 = (INDArrayIndex[]) ArrayUtil.nTimes(array8.rank(), NDArrayIndex.all(), INDArrayIndex.class);
        for (int i7 = 0; i7 < intVector3.length; i7++) {
            iNDArrayIndexArr2[0] = NDArrayIndex.point(i7);
            INDArray dup = array8.get(iNDArrayIndexArr2).dup();
            int i8 = intVector3[i7];
            if (array8.rank() == 2 && dup.rank() == 2) {
                dup = dup.reshape(dup.length());
            }
            if (array8.rank() == 1 && dup.rank() > 0) {
                dup = dup.reshape(new long[0]);
            }
            list7.set(i8, dup);
        }
        MemoryWorkspace scopeOutOfWorkspaces6 = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th21 = null;
        try {
            try {
                INDArray[] iNDArrayArr7 = {Nd4j.scalar(0.0f)};
                if (scopeOutOfWorkspaces6 != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces6.close();
                        } catch (Throwable th22) {
                            th21.addSuppressed(th22);
                        }
                    } else {
                        scopeOutOfWorkspaces6.close();
                    }
                }
                return iNDArrayArr7;
            } finally {
            }
        } finally {
            if (scopeOutOfWorkspaces6 != null) {
                if (th21 != null) {
                    try {
                        scopeOutOfWorkspaces6.close();
                    } catch (Throwable th23) {
                        th21.addSuppressed(th23);
                    }
                } else {
                    scopeOutOfWorkspaces6.close();
                }
            }
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    public INDArray getConstantOrVariable(String str) {
        Preconditions.checkState(this.sameDiff.getVariable(str).isConstant() || this.sameDiff.getVariable(str).getVariableType() == VariableType.VARIABLE, "Variable %s is not a constant", str);
        return this.sameDiff.getArrForVarName(str);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    public DifferentialFunction getAndParameterizeOp(String str, AbstractSession.FrameIter frameIter, Set<AbstractSession.VarId> set, Set<AbstractSession.VarId> set2, Set<String> set3, Map<String, INDArray> map) {
        DifferentialFunction opById = this.sameDiff.getOpById(str);
        Preconditions.checkNotNull(opById, "No differential function fond with name %s", str);
        if ((opById instanceof LoopCond) || (opById instanceof Enter) || (opById instanceof Exit) || (opById instanceof NextIteration) || (opById instanceof Merge) || (opById instanceof Switch) || (opById instanceof If) || (opById instanceof While) || (opById instanceof BaseTensorOp)) {
            return opById;
        }
        String[] argNames = opById.argNames();
        int length = argNames == null ? 0 : argNames.length;
        int size = set == null ? 0 : set.size();
        int size2 = set2 == null ? 0 : set2.size();
        int size3 = set3 == null ? 0 : set3.size();
        HashSet<String> hashSet = null;
        if (length != size + size3 + size2) {
            boolean z = false;
            for (SDVariable sDVariable : opById.args()) {
                Variable variable = this.sameDiff.getVariables().get(sDVariable.getVarName());
                DifferentialFunction op = variable.getOutputOfOp() == null ? null : this.sameDiff.getOps().get(variable.getOutputOfOp()).getOp();
                if ((op instanceof Enter) && ((Enter) op).isConstant()) {
                    z = true;
                    if (hashSet == null) {
                        hashSet = new HashSet();
                    }
                    hashSet.add(sDVariable.getVarName());
                }
            }
            int i = 0;
            if (z) {
                for (String str2 : hashSet) {
                    if (set3 == null || !set3.contains(str2)) {
                        boolean z2 = false;
                        if (set2 != null) {
                            Iterator<AbstractSession.VarId> it2 = set2.iterator();
                            while (true) {
                                if (!it2.hasNext()) {
                                    break;
                                }
                                if (str2.equals(it2.next().getVariable())) {
                                    z2 = true;
                                    break;
                                }
                            }
                        }
                        if (!z2) {
                            i++;
                        }
                    }
                }
            }
            if (length > 1) {
                HashSet hashSet2 = new HashSet();
                Collections.addAll(hashSet2, argNames);
                Preconditions.checkState(hashSet2.size() == ((size + size3) + size2) + i, "Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", opById.getClass().getSimpleName(), str, hashSet2, set, set3);
            } else {
                Preconditions.checkState(length == (size + size3) + i, "Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", opById.getClass().getSimpleName(), str, argNames, set, set3);
            }
        }
        INDArray[] iNDArrayArr = null;
        if (argNames != null && argNames.length > 0) {
            iNDArrayArr = new INDArray[argNames.length];
            int i2 = 0;
            for (String str3 : argNames) {
                SDVariable variable2 = this.sameDiff.getVariable(str3);
                if (variable2.isConstant()) {
                    iNDArrayArr[i2] = variable2.getArr();
                } else if (variable2.isPlaceHolder()) {
                    Preconditions.checkState(map != null && map.containsKey(str3), "No array provided for placeholder %s", str3);
                    iNDArrayArr[i2] = map.get(str3);
                } else if (hashSet == null || !hashSet.contains(str3)) {
                    if (set != null) {
                        Iterator<AbstractSession.VarId> it3 = set.iterator();
                        while (true) {
                            if (!it3.hasNext()) {
                                break;
                            }
                            AbstractSession.VarId next = it3.next();
                            if (next.getVariable().equals(str3)) {
                                iNDArrayArr[i2] = (INDArray) this.nodeOutputs.get(next);
                                break;
                            }
                        }
                    }
                    if (iNDArrayArr[i2] == null && set2 != null) {
                        Iterator<AbstractSession.VarId> it4 = set2.iterator();
                        while (true) {
                            if (it4.hasNext()) {
                                AbstractSession.VarId next2 = it4.next();
                                if (next2.getVariable().equals(str3)) {
                                    iNDArrayArr[i2] = (INDArray) this.nodeOutputs.get(next2);
                                    break;
                                }
                            }
                        }
                    }
                } else {
                    AbstractSession.VarId newVarId = newVarId(str3, frameIter.m7787clone());
                    newVarId.setIteration(0);
                    AbstractSession.FrameIter parentFrame = newVarId.getParentFrame();
                    while (true) {
                        AbstractSession.FrameIter frameIter2 = parentFrame;
                        if (frameIter2 == null) {
                            break;
                        }
                        frameIter2.setIteration(0);
                        parentFrame = frameIter2.getParentFrame();
                    }
                    iNDArrayArr[i2] = (INDArray) this.nodeOutputs.get(newVarId);
                }
                Preconditions.checkNotNull(iNDArrayArr[i2], "Could not parameterize op %s: array %s (variable %s) is null", str, Integer.valueOf(i2), variable2.getVarName());
                i2++;
            }
        }
        boolean z3 = !frameIter.getFrame().equals(AbstractSession.OUTER_FRAME) && frameIter.getIteration() > 0;
        if (opById instanceof CustomOp) {
            DynamicCustomOp dynamicCustomOp = (DynamicCustomOp) opById;
            if (iNDArrayArr != null) {
                dynamicCustomOp.setInputArguments(iNDArrayArr);
            }
            opById.resolvePropertiesFromSameDiffBeforeExecution();
            List<LongShapeDescriptor> calculateOutputShape = dynamicCustomOp.calculateOutputShape();
            Preconditions.checkState(calculateOutputShape != null && calculateOutputShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", dynamicCustomOp.opName(), dynamicCustomOp.getOwnName());
            String[] outputVariablesNames = opById.outputVariablesNames();
            Preconditions.checkState(outputVariablesNames.length == calculateOutputShape.size(), "Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation with %s outputs (number of shapes and outputs must be equal)", opById.opName(), Integer.valueOf(calculateOutputShape.size()), Integer.valueOf(outputVariablesNames.length));
            int i3 = 0;
            while (i3 < calculateOutputShape.size()) {
                INDArray outputArgument = dynamicCustomOp.numOutputArguments() <= i3 ? null : dynamicCustomOp.getOutputArgument(i3);
                LongShapeDescriptor longShapeDescriptor = calculateOutputShape.get(i3);
                DataType dataType = this.sameDiff.getVariable(outputVariablesNames[i3]).dataType();
                if (dataType != longShapeDescriptor.dataType()) {
                    longShapeDescriptor = longShapeDescriptor.asDataType(dataType);
                }
                if (outputArgument == null || !outputArgument.shapeDescriptor().equals(longShapeDescriptor) || outputArgument.isEmpty() != longShapeDescriptor.isEmpty() || z3) {
                    MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                    Throwable th = null;
                    try {
                        try {
                            INDArray create = Nd4j.create(longShapeDescriptor, false);
                            if (scopeOutOfWorkspaces != null) {
                                if (0 != 0) {
                                    try {
                                        scopeOutOfWorkspaces.close();
                                    } catch (Throwable th2) {
                                        th.addSuppressed(th2);
                                    }
                                } else {
                                    scopeOutOfWorkspaces.close();
                                }
                            }
                            dynamicCustomOp.setOutputArgument(i3, create);
                        } catch (Throwable th3) {
                            if (scopeOutOfWorkspaces != null) {
                                if (th != null) {
                                    try {
                                        scopeOutOfWorkspaces.close();
                                    } catch (Throwable th4) {
                                        th.addSuppressed(th4);
                                    }
                                } else {
                                    scopeOutOfWorkspaces.close();
                                }
                            }
                            throw th3;
                        }
                    } finally {
                    }
                }
                i3++;
            }
        } else if (opById instanceof Op) {
            Op op2 = (Op) opById;
            boolean z4 = false;
            boolean z5 = false;
            if ((op2 instanceof ReduceOp) && ((ReduceOp) op2).getOpType() != Op.Type.REDUCE3 && opById.argNames().length == 2) {
                SDVariable arg = opById.arg(1);
                Preconditions.checkState(arg.dataType().isIntType(), "Legacy op %s input 1 (axis) was expected to be an integer type, is %s", opById.getClass(), arg.dataType());
                INDArray array = getArray(arg, set, set2);
                Preconditions.checkState(array != null, "Could not get axis argument for op %s: %s", opById.getOwnName(), opById.getClass());
                if (array.isEmpty()) {
                    opById.setDimensions(null);
                    z5 = true;
                    ((BaseReduceOp) op2).setEmptyReduce(true);
                } else {
                    opById.setDimensions(Shape.normalizeAxis(iNDArrayArr[0].rank(), array.toIntVector()));
                    ((BaseReduceOp) op2).setEmptyReduce(false);
                }
                z4 = true;
            } else if ((op2 instanceof ScalarOp) && opById.argNames().length == 2) {
                INDArray array2 = getArray(opById.arg(1), set, set2);
                Preconditions.checkState(array2 != null, "Could not get scalar argument for op %s: %s", opById.getOwnName(), opById.getClass());
                Preconditions.checkState(array2.isScalar(), "Scalar argument for op %s (%s) is not a scalar: has shape %ndShape", opById.getOwnName(), opById.getClass(), array2);
                ((ScalarOp) op2).setScalar(array2);
            }
            if (iNDArrayArr != null && iNDArrayArr.length > 0) {
                op2.setX(iNDArrayArr[0]);
                if (iNDArrayArr.length == 2 && !z4) {
                    op2.setY(iNDArrayArr[1]);
                }
            }
            if (z5) {
                INDArray z6 = op2.z();
                if (z6 == null || !op2.x().equalShapes(z6) || z3) {
                    op2.setZ(op2.x().ulike());
                }
            } else {
                List<LongShapeDescriptor> calculateOutputShape2 = ((BaseOp) op2).calculateOutputShape();
                Preconditions.checkState(calculateOutputShape2 != null && calculateOutputShape2.size() == 1, "Could not calculate output shape for op: %s", op2.getClass());
                INDArray z7 = op2.z();
                if (z7 == null || !calculateOutputShape2.get(0).equals(z7.shapeDescriptor()) || z3) {
                    if (log.isTraceEnabled()) {
                        Logger logger = log;
                        Object[] objArr = new Object[3];
                        objArr[0] = op2.getClass().getSimpleName();
                        objArr[1] = z7 == null ? null : Arrays.toString(z7.shape());
                        objArr[2] = calculateOutputShape2.get(0).toString();
                        logger.trace("Existing op result (z) array shape for op {} was {}, allocating new array of shape {}", objArr);
                    }
                    LongShapeDescriptor longShapeDescriptor2 = calculateOutputShape2.get(0);
                    MemoryWorkspace scopeOutOfWorkspaces2 = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                    Throwable th5 = null;
                    try {
                        INDArray create2 = Nd4j.create(longShapeDescriptor2, false);
                        if (scopeOutOfWorkspaces2 != null) {
                            if (0 != 0) {
                                try {
                                    scopeOutOfWorkspaces2.close();
                                } catch (Throwable th6) {
                                    th5.addSuppressed(th6);
                                }
                            } else {
                                scopeOutOfWorkspaces2.close();
                            }
                        }
                        op2.setZ(create2);
                    } catch (Throwable th7) {
                        if (scopeOutOfWorkspaces2 != null) {
                            if (0 != 0) {
                                try {
                                    scopeOutOfWorkspaces2.close();
                                } catch (Throwable th8) {
                                    th5.addSuppressed(th8);
                                }
                            } else {
                                scopeOutOfWorkspaces2.close();
                            }
                        }
                        throw th7;
                    }
                }
            }
            opById.resolvePropertiesFromSameDiffBeforeExecution();
        }
        return opById;
    }

    protected INDArray getArray(SDVariable sDVariable, Collection<AbstractSession.VarId> collection, Collection<AbstractSession.VarId> collection2) {
        String varName = sDVariable.getVarName();
        if (sDVariable.getVariableType() == VariableType.CONSTANT || sDVariable.getVariableType() == VariableType.VARIABLE) {
            return getConstantOrVariable(varName);
        }
        AbstractSession.VarId varId = null;
        if (collection != null) {
            varId = lookup(varName, collection, false);
        }
        if (varId == null && collection2 != null && !collection2.isEmpty()) {
            varId = lookup(varName, collection2, false);
        }
        Preconditions.checkState(varId != null, "Could not find array for variable %s", sDVariable.getVarName());
        return (INDArray) this.nodeOutputs.get(varId);
    }

    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    public /* bridge */ /* synthetic */ INDArray[] getOutputs(DifferentialFunction differentialFunction, AbstractSession.FrameIter frameIter, Set set, Set set2, Set set3, List list, At at, MultiDataSet multiDataSet) {
        return getOutputs2(differentialFunction, frameIter, (Set<AbstractSession.VarId>) set, (Set<AbstractSession.VarId>) set2, (Set<String>) set3, (List<Listener>) list, at, multiDataSet);
    }

    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    public /* bridge */ /* synthetic */ DifferentialFunction getAndParameterizeOp(String str, AbstractSession.FrameIter frameIter, Set set, Set set2, Set set3, Map<String, INDArray> map) {
        return getAndParameterizeOp(str, frameIter, (Set<AbstractSession.VarId>) set, (Set<AbstractSession.VarId>) set2, (Set<String>) set3, map);
    }
}
