package org.nd4j.linalg.api.ops.impl.controlflow;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.SameDiffConditional;
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.util.HashUtil;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/controlflow/If.class */
public class If extends DifferentialFunction implements CustomOp {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) If.class);
    protected SameDiff loopBodyExecution;
    protected SameDiff predicateExecution;
    protected SameDiff falseBodyExecution;
    protected SameDiffConditional predicate;
    protected SameDiffFunctionDefinition trueBody;
    protected SameDiffFunctionDefinition falseBody;
    protected String blockName;
    protected String trueBodyName;
    protected String falseBodyName;
    protected SDVariable[] inputVars;
    protected Boolean trueBodyExecuted;
    protected SDVariable targetBoolean;
    protected SDVariable dummyResult;
    protected SDVariable[] outputVars;

    /* loaded from: input_file:org/nd4j/linalg/api/ops/impl/controlflow/If$IfBuilder.class */
    public static class IfBuilder {
        private String blockName;
        private SameDiff parent;
        private SDVariable[] inputVars;
        private SameDiffFunctionDefinition conditionBody;
        private SameDiffConditional predicate;
        private SameDiffFunctionDefinition trueBody;
        private SameDiffFunctionDefinition falseBody;

        IfBuilder() {
        }

        public IfBuilder blockName(String str) {
            this.blockName = str;
            return this;
        }

        public IfBuilder parent(SameDiff sameDiff) {
            this.parent = sameDiff;
            return this;
        }

        public IfBuilder inputVars(SDVariable[] sDVariableArr) {
            this.inputVars = sDVariableArr;
            return this;
        }

        public IfBuilder conditionBody(SameDiffFunctionDefinition sameDiffFunctionDefinition) {
            this.conditionBody = sameDiffFunctionDefinition;
            return this;
        }

        public IfBuilder predicate(SameDiffConditional sameDiffConditional) {
            this.predicate = sameDiffConditional;
            return this;
        }

        public IfBuilder trueBody(SameDiffFunctionDefinition sameDiffFunctionDefinition) {
            this.trueBody = sameDiffFunctionDefinition;
            return this;
        }

        public IfBuilder falseBody(SameDiffFunctionDefinition sameDiffFunctionDefinition) {
            this.falseBody = sameDiffFunctionDefinition;
            return this;
        }

        public If build() {
            return new If(this.blockName, this.parent, this.inputVars, this.conditionBody, this.predicate, this.trueBody, this.falseBody);
        }

        public String toString() {
            return "If.IfBuilder(blockName=" + this.blockName + ", parent=" + this.parent + ", inputVars=" + Arrays.deepToString(this.inputVars) + ", conditionBody=" + this.conditionBody + ", predicate=" + this.predicate + ", trueBody=" + this.trueBody + ", falseBody=" + this.falseBody + ")";
        }
    }

    public If(If r12) {
        this.trueBodyExecuted = null;
        this.sameDiff = r12.sameDiff;
        this.outputVars = r12.outputVars;
        this.falseBodyExecution = r12.falseBodyExecution;
        this.trueBodyExecuted = r12.trueBodyExecuted;
        this.falseBody = r12.falseBody;
        this.trueBodyExecuted = r12.trueBodyExecuted;
        this.dummyResult = r12.dummyResult;
        this.inputVars = r12.inputVars;
        this.dummyResult = this.sameDiff.var("dummyresult-" + UUID.randomUUID().toString(), new ZeroInitScheme(), DataType.FLOAT, 1);
        if (this.sameDiff.getShapeForVarName(this.dummyResult.getVarName()) == null) {
            this.sameDiff.putShapeForVarName(this.dummyResult.getVarName(), new long[]{1, 1});
        }
    }

    public If(String str, SameDiff sameDiff, SDVariable[] sDVariableArr, SameDiffFunctionDefinition sameDiffFunctionDefinition, SameDiffConditional sameDiffConditional, SameDiffFunctionDefinition sameDiffFunctionDefinition2, SameDiffFunctionDefinition sameDiffFunctionDefinition3) {
        this.trueBodyExecuted = null;
        this.sameDiff = sameDiff;
        sameDiff.putFunctionForId(getOwnName(), this);
        this.inputVars = sDVariableArr;
        this.predicate = sameDiffConditional;
        sameDiff.addArgsFor(sDVariableArr, this);
        this.trueBody = sameDiffFunctionDefinition2;
        this.falseBody = sameDiffFunctionDefinition3;
        this.blockName = str;
        this.dummyResult = sameDiff.var("dummyresult-" + UUID.randomUUID().toString(), new ZeroInitScheme('f'), DataType.FLOAT, 1);
        sameDiff.addOutgoingFor(new SDVariable[]{this.dummyResult}, this);
        SameDiff create = SameDiff.create();
        this.targetBoolean = sameDiffConditional.eval(create, sameDiffFunctionDefinition, sDVariableArr);
        this.predicateExecution = create;
        String str2 = "true-body-" + UUID.randomUUID().toString();
        this.trueBodyName = str2;
        String str3 = "false-body-" + UUID.randomUUID().toString();
        this.falseBodyName = str2;
        this.loopBodyExecution = sameDiff.defineFunction(str2, sameDiffFunctionDefinition2, sDVariableArr);
        this.falseBodyExecution = sameDiff.defineFunction(str3, sameDiffFunctionDefinition3, sDVariableArr);
        sameDiff.defineFunction(str, sameDiffFunctionDefinition, sDVariableArr);
        sameDiff.putSubFunction("predicate-eval-body-" + UUID.randomUUID().toString(), create);
        this.loopBodyExecution = sameDiff.getFunction(str2);
    }

    public void exectedTrueOrFalse(boolean z) {
        if (z) {
            this.trueBodyExecuted = true;
        } else {
            this.trueBodyExecuted = false;
        }
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public SDVariable[] outputVariables(String str) {
        return new SDVariable[]{this.dummyResult};
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(Arrays.asList(new IfDerivative(this).outputVariables()));
        return arrayList;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String toString() {
        return opName();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String opName() {
        return "if";
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public long opHash() {
        return HashUtil.getLongHash(opName());
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public boolean isInplaceCall() {
        return false;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public INDArray[] outputArguments() {
        return new INDArray[0];
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public INDArray[] inputArguments() {
        return new INDArray[0];
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public long[] iArgs() {
        return new long[0];
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public double[] tArgs() {
        return new double[0];
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public boolean[] bArgs() {
        return new boolean[0];
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void addIArgument(int... iArr) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void addIArgument(long... jArr) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void addBArgument(boolean... zArr) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void removeIArgument(Integer num) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public Boolean getBArgument(int i) {
        return null;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public Long getIArgument(int i) {
        return null;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public int numIArguments() {
        return 0;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void addTArgument(double... dArr) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void removeTArgument(Double d) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public Double getTArgument(int i) {
        return null;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public int numTArguments() {
        return 0;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public int numBArguments() {
        return 0;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void addInputArgument(INDArray... iNDArrayArr) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void removeInputArgument(INDArray iNDArray) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public INDArray getInputArgument(int i) {
        return null;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public int numInputArguments() {
        return 0;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void addOutputArgument(INDArray... iNDArrayArr) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void removeOutputArgument(INDArray iNDArray) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public INDArray getOutputArgument(int i) {
        return null;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public int numOutputArguments() {
        return 0;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Op.Type opType() {
        return Op.Type.CONDITIONAL;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        if (nodeDef.getName().contains("/cond/")) {
            return;
        }
        IfImportState nodesForIf = TFGraphMapper.getInstance().nodesForIf(nodeDef, graphDef);
        GraphDef.Builder newBuilder = GraphDef.newBuilder();
        Iterator<NodeDef> it2 = nodesForIf.getTrueNodes().iterator();
        while (it2.hasNext()) {
            newBuilder.addNode(it2.next());
        }
        SameDiff importGraph = TFGraphMapper.getInstance().importGraph((TFGraphMapper) newBuilder.build());
        GraphDef.Builder newBuilder2 = GraphDef.newBuilder();
        Iterator<NodeDef> it3 = nodesForIf.getFalseNodes().iterator();
        while (it3.hasNext()) {
            newBuilder2.addNode(it3.next());
        }
        SameDiff importGraph2 = TFGraphMapper.getInstance().importGraph((TFGraphMapper) newBuilder2.build());
        GraphDef.Builder newBuilder3 = GraphDef.newBuilder();
        Iterator<NodeDef> it4 = nodesForIf.getCondNodes().iterator();
        while (it4.hasNext()) {
            newBuilder3.addNode(it4.next());
        }
        SameDiff importGraph3 = TFGraphMapper.getInstance().importGraph((TFGraphMapper) newBuilder3.build());
        sameDiff.putSubFunction(nodesForIf.getTrueBodyScopeName(), importGraph);
        sameDiff.putSubFunction(nodesForIf.getFalseBodyScopeName(), importGraph2);
        sameDiff.putSubFunction(nodesForIf.getConditionBodyScopeName(), importGraph3);
        this.loopBodyExecution = importGraph;
        this.falseBodyExecution = importGraph2;
        this.predicateExecution = importGraph3;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromOnnx(OnnxProto3.NodeProto nodeProto, SameDiff sameDiff, Map<String, OnnxProto3.AttributeProto> map, OnnxProto3.GraphProto graphProto) {
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<LongShapeDescriptor> calculateOutputShape() {
        return Arrays.asList(LongShapeDescriptor.fromShape(new long[0], DataType.BOOL));
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public CustomOpDescriptor getDescriptor() {
        return null;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void assertValidForExecution() {
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op opName found for " + opName());
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        throw new NoOpNameFoundException("This operation has no TF counterpart");
    }

    public static IfBuilder builder() {
        return new IfBuilder();
    }

    public If() {
        this.trueBodyExecuted = null;
    }

    public SameDiff getLoopBodyExecution() {
        return this.loopBodyExecution;
    }

    public SameDiff getPredicateExecution() {
        return this.predicateExecution;
    }

    public SameDiff getFalseBodyExecution() {
        return this.falseBodyExecution;
    }

    public SameDiffConditional getPredicate() {
        return this.predicate;
    }

    public SameDiffFunctionDefinition getTrueBody() {
        return this.trueBody;
    }

    public SameDiffFunctionDefinition getFalseBody() {
        return this.falseBody;
    }

    public String getBlockName() {
        return this.blockName;
    }

    public String getTrueBodyName() {
        return this.trueBodyName;
    }

    public String getFalseBodyName() {
        return this.falseBodyName;
    }

    public SDVariable[] getInputVars() {
        return this.inputVars;
    }

    public Boolean getTrueBodyExecuted() {
        return this.trueBodyExecuted;
    }

    public SDVariable getTargetBoolean() {
        return this.targetBoolean;
    }

    public SDVariable[] getOutputVars() {
        return this.outputVars;
    }

    public void setOutputVars(SDVariable[] sDVariableArr) {
        this.outputVars = sDVariableArr;
    }
}
