package org.nd4j.imports.graphmapper.tf;

import com.beust.jcommander.Parameters;
import com.github.jaiimageio.plugins.tiff.EXIFGPSTagSet;
import com.github.os72.protobuf351.Message;
import com.google.common.primitives.Floats;
import com.google.common.primitives.Ints;
import java.io.BufferedInputStream;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.BaseGraphMapper;
import org.nd4j.imports.graphmapper.ImportState;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.controlflow.IfImportState;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.PropertyAccessor;
import org.springframework.util.AntPathMatcher;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;

/* loaded from: input_file:org/nd4j/imports/graphmapper/tf/TFGraphMapper.class */
public class TFGraphMapper extends BaseGraphMapper<GraphDef, NodeDef, AttrValue, NodeDef> {
    public static final String VALUE_ATTR_KEY = "value";
    public static final String SHAPE_KEY = "shape";
    private static final Logger log = LoggerFactory.getLogger((Class<?>) TFGraphMapper.class);
    private static TFGraphMapper MAPPER_INSTANCE = new TFGraphMapper();
    private Set<String> seenNodes = new LinkedHashSet();
    private Set<String> graphMapper = new HashSet<String>() { // from class: org.nd4j.imports.graphmapper.tf.TFGraphMapper.1
        {
            add("LoopCond");
            add("Merge");
            add("Exit");
            add("NextIteration");
            add("NoOp");
            add("Switch");
        }
    };

    private TFGraphMapper() {
    }

    public static TFGraphMapper getInstance() {
        return MAPPER_INSTANCE;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public void dumpBinaryProtoAsText(InputStream inputStream, File file) {
        try {
            GraphDef parseFrom = GraphDef.parseFrom(inputStream);
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(file, true));
            Iterator<NodeDef> it2 = parseFrom.getNodeList().iterator();
            while (it2.hasNext()) {
                bufferedWriter.write(it2.next().toString());
            }
            bufferedWriter.flush();
            bufferedWriter.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public boolean isOpIgnoreException(NodeDef nodeDef) {
        return true;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public String getTargetMappingForOp(DifferentialFunction differentialFunction, NodeDef nodeDef) {
        return differentialFunction.opName();
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public NodeDef getNodeWithNameFromGraph(GraphDef graphDef, String str) {
        for (int i = 0; i < graphDef.getNodeCount(); i++) {
            NodeDef node = graphDef.getNode(i);
            if (node.getName().equals(str)) {
                return node;
            }
        }
        return null;
    }

    public void mapProperty(String str, DifferentialFunction differentialFunction, NodeDef nodeDef, GraphDef graphDef, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> map) {
        if (nodeDef == null) {
            throw new ND4JIllegalStateException("No node found for name " + str);
        }
        PropertyMapping propertyMapping = map.get(getOpType(nodeDef)).get(str);
        Map<String, Field> fieldsForFunction = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(differentialFunction);
        if (propertyMapping.getTfInputPosition() == null || propertyMapping.getTfInputPosition().intValue() >= nodeDef.getInputCount()) {
            String tfAttrName = propertyMapping.getTfAttrName();
            if (tfAttrName != null && nodeDef.containsAttr(tfAttrName)) {
                AttrValue attrOrThrow = nodeDef.getAttrOrThrow(tfAttrName);
                DataType type = attrOrThrow.getType();
                if (fieldsForFunction == null) {
                    throw new ND4JIllegalStateException("No fields found for op [" + propertyMapping + PropertyAccessor.PROPERTY_KEY_SUFFIX);
                }
                if (propertyMapping.getPropertyNames() == null) {
                    throw new ND4JIllegalStateException("no property found for [" + str + "] in op [" + differentialFunction.opName() + PropertyAccessor.PROPERTY_KEY_SUFFIX);
                }
                Field field = fieldsForFunction.get(propertyMapping.getPropertyNames()[0]);
                Object obj = null;
                switch (type) {
                    case DT_BOOL:
                        obj = Boolean.valueOf(attrOrThrow.getB());
                        break;
                    case DT_INT8:
                        obj = Long.valueOf(attrOrThrow.getI());
                        break;
                    case DT_INT16:
                        obj = Long.valueOf(attrOrThrow.getI());
                        break;
                    case DT_INT32:
                        obj = Long.valueOf(attrOrThrow.getI());
                        break;
                    case DT_FLOAT:
                        obj = Float.valueOf(attrOrThrow.getF());
                        break;
                    case DT_DOUBLE:
                        obj = Float.valueOf(attrOrThrow.getF());
                        break;
                    case DT_STRING:
                        obj = attrOrThrow.getS();
                        break;
                    case DT_INT64:
                        obj = Long.valueOf(attrOrThrow.getI());
                        break;
                }
                if (field == null || obj == null) {
                    return;
                }
                differentialFunction.setValueFor(field, obj);
                return;
            }
            return;
        }
        int intValue = propertyMapping.getTfInputPosition().intValue();
        if (intValue < 0) {
            intValue += nodeDef.getInputCount();
        }
        String input = nodeDef.getInput(intValue);
        NodeDef nodeWithNameFromGraph = getInstance().getNodeWithNameFromGraph(graphDef, input);
        INDArray arrayFrom = getArrayFrom(nodeWithNameFromGraph, graphDef);
        if (arrayFrom == null) {
            arrayFrom = sameDiff.getArrForVarName(input);
        }
        if (arrayFrom == null && nodeWithNameFromGraph != null) {
            sameDiff.addPropertyToResolve(differentialFunction, str);
            sameDiff.addVariableMappingForField(differentialFunction, str, nodeWithNameFromGraph.getName());
            return;
        }
        if (nodeWithNameFromGraph == null) {
            sameDiff.addAsPlaceHolder(input);
            return;
        }
        Field field2 = fieldsForFunction.get(str);
        Class<?> type2 = field2.getType();
        if (type2.equals(int[].class)) {
            differentialFunction.setValueFor(field2, arrayFrom.data().asInt());
            return;
        }
        if (type2.equals(Integer.TYPE) || type2.equals(Long.TYPE) || type2.equals(Long.class) || type2.equals(Integer.class)) {
            if (propertyMapping.getShapePosition() != null) {
                differentialFunction.setValueFor(field2, Long.valueOf(arrayFrom.size(propertyMapping.getShapePosition().intValue())));
                return;
            } else {
                differentialFunction.setValueFor(field2, Integer.valueOf(arrayFrom.getInt(0)));
                return;
            }
        }
        if (type2.equals(Float.TYPE) || type2.equals(Double.TYPE) || type2.equals(Float.class) || type2.equals(Double.class)) {
            differentialFunction.setValueFor(field2, Double.valueOf(arrayFrom.getDouble(0L)));
        }
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public boolean isPlaceHolderNode(NodeDef nodeDef) {
        return nodeDef.getOp().startsWith("Placeholder");
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public void dumpBinaryProtoAsText(File file, File file2) {
        try {
            GraphDef parseFrom = GraphDef.parseFrom(new BufferedInputStream(new FileInputStream(file)));
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(file2, true));
            Iterator<NodeDef> it2 = parseFrom.getNodeList().iterator();
            while (it2.hasNext()) {
                bufferedWriter.write(it2.next().toString());
            }
            bufferedWriter.flush();
            bufferedWriter.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public long[] getShapeFromAttr(AttrValue attrValue) {
        return shapeFromShapeProto(attrValue.getShape());
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public Map<String, AttrValue> getAttrMap(NodeDef nodeDef) {
        return nodeDef.getAttrMap();
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public String getName(NodeDef nodeDef) {
        return nodeDef.getName();
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public boolean alreadySeen(NodeDef nodeDef) {
        return this.seenNodes.contains(nodeDef.getName());
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public boolean isVariableNode(NodeDef nodeDef) {
        return nodeDef.getOp().startsWith("VariableV") || nodeDef.getOp().equalsIgnoreCase("const");
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public boolean shouldSkip(NodeDef nodeDef) {
        if (nodeDef == null) {
            return true;
        }
        return nodeDef.getName().endsWith("/read") || nodeDef.getOp().endsWith("/reduction_indices");
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public boolean hasShape(NodeDef nodeDef) {
        return nodeDef.containsAttr(SHAPE_KEY);
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public long[] getShape(NodeDef nodeDef) {
        return getShapeFromAttr(nodeDef.getAttrOrThrow(SHAPE_KEY));
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public INDArray getArrayFrom(NodeDef nodeDef, GraphDef graphDef) {
        if (nodeDef == null) {
            return null;
        }
        return getNDArrayFromTensor(nodeDef.getName(), nodeDef, graphDef);
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public String getOpType(NodeDef nodeDef) {
        return nodeDef.getOp();
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public List<NodeDef> getNodeList(GraphDef graphDef) {
        return graphDef.getNodeList();
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public DifferentialFunction getMappedOp(String str) {
        return DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(str);
    }

    public String getNodeName(String str) {
        String str2 = str;
        if (str2.startsWith("^")) {
            str2 = str2.substring(1);
        }
        if (str2.endsWith("/read")) {
            str2 = str2.replace("/read", "");
        }
        return str2;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public Map<String, NodeDef> variablesForGraph(GraphDef graphDef) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (NodeDef nodeDef : graphDef.getNodeList()) {
            if (!nodeDef.getName().endsWith("/read")) {
                linkedHashMap.put(translateToSameDiffName(nodeDef.getName(), nodeDef), nodeDef);
            }
        }
        return linkedHashMap;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public String translateToSameDiffName(String str, NodeDef nodeDef) {
        if (isVariableNode(nodeDef) || isPlaceHolder(nodeDef)) {
            return str;
        }
        StringBuilder sb = new StringBuilder();
        if (str.contains(":")) {
            sb.append(str.substring(0, str.lastIndexOf(58)));
        } else {
            sb.append(str);
        }
        return sb.toString();
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public Message.Builder getNewGraphBuilder() {
        return GraphDef.newBuilder();
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public GraphDef parseGraphFrom(byte[] bArr) throws IOException {
        return GraphDef.parseFrom(bArr);
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public GraphDef parseGraphFrom(InputStream inputStream) throws IOException {
        return GraphDef.parseFrom(inputStream);
    }

    protected void importCondition(String str, NodeDef nodeDef, ImportState<GraphDef, NodeDef> importState) {
    }

    public void mapNodeType(NodeDef nodeDef, ImportState<GraphDef, NodeDef> importState) {
        if (shouldSkip(nodeDef) || alreadySeen(nodeDef) || isVariableNode(nodeDef)) {
            return;
        }
        nodeDef.getName();
        SameDiff sameDiff = importState.getSameDiff();
        if (isVariableNode(nodeDef)) {
            ArrayList arrayList = new ArrayList();
            Map<String, AttrValue> attrMap = getAttrMap(nodeDef);
            if (attrMap.containsKey("value")) {
                sameDiff.var(getName(nodeDef), getArrayFrom(nodeDef, importState.getGraph()));
                return;
            }
            if (attrMap.containsKey(SHAPE_KEY)) {
                AttrValue attrValue = attrMap.get(SHAPE_KEY);
                int length = getShapeFromAttr(attrValue).length;
                if (length > 0) {
                    if (length == 1) {
                        arrayList.add(1L);
                    }
                    for (int i = 0; i < length; i++) {
                        arrayList.add(Long.valueOf(getShapeFromAttr(attrValue)[i]));
                    }
                    return;
                }
                return;
            }
            return;
        }
        if (isPlaceHolder(nodeDef)) {
            sameDiff.addAsPlaceHolder(sameDiff.getVariable(getName(nodeDef)).getVarName());
            return;
        }
        String op = nodeDef.getOp();
        DifferentialFunction opWithTensorflowName = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(op);
        if (opWithTensorflowName == null) {
            throw new ND4JIllegalStateException("No tensorflow op found for " + op + " possibly missing operation class?");
        }
        try {
            DifferentialFunction differentialFunction = (DifferentialFunction) opWithTensorflowName.getClass().newInstance();
            SDVariable[] sDVariableArr = new SDVariable[nodeDef.getInputCount()];
            differentialFunction.setOwnName(nodeDef.getName());
            for (int i2 = 0; i2 < nodeDef.getInputCount(); i2++) {
                String nodeName = getNodeName(nodeDef.getInput(i2));
                sDVariableArr[i2] = sameDiff.getVariable(nodeName);
                if (sDVariableArr[i2] == null) {
                    sDVariableArr[i2] = sameDiff.var(nodeName, (long[]) null, new ZeroInitScheme('f'));
                    sameDiff.addAsPlaceHolder(sDVariableArr[i2].getVarName());
                }
                if (sameDiff.isPlaceHolder(sDVariableArr[i2].getVarName())) {
                    sameDiff.putPlaceHolderForVariable(sDVariableArr[i2].getVarName(), nodeName);
                }
            }
            sameDiff.addArgsFor(sDVariableArr, differentialFunction);
            differentialFunction.setSameDiff(importState.getSameDiff());
            differentialFunction.initFromTensorFlow(nodeDef, sameDiff, getAttrMap(nodeDef), importState.getGraph());
            mapProperties(differentialFunction, nodeDef, importState.getGraph(), importState.getSameDiff(), differentialFunction.mappingsForFunction());
            importState.getSameDiff().putFunctionForId(differentialFunction.getOwnName(), differentialFunction);
            sameDiff.setBaseNameForFunctionInstanceId(nodeDef.getName(), differentialFunction);
            sameDiff.addVarNameForImport(nodeDef.getName());
        } catch (Exception e) {
            log.error("Failed with [{}]", op);
            throw new RuntimeException(e);
        }
    }

    public void initFunctionFromProperties(DifferentialFunction differentialFunction, Map<String, AttrValue> map, NodeDef nodeDef, GraphDef graphDef) {
        initFunctionFromProperties(differentialFunction.tensorflowName(), differentialFunction, map, nodeDef, graphDef);
    }

    public void initFunctionFromProperties(String str, DifferentialFunction differentialFunction, Map<String, AttrValue> map, NodeDef nodeDef, GraphDef graphDef) {
        Map<String, PropertyMapping> map2;
        Map<String, PropertyMapping> map3 = differentialFunction.mappingsForFunction().get(str);
        Map<String, Field> fieldsForFunction = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(differentialFunction);
        Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction = differentialFunction.attributeAdaptersForFunction();
        if (map3 == null) {
            return;
        }
        if (attributeAdaptersForFunction == null || !attributeAdaptersForFunction.containsKey(str)) {
            map2 = map3;
        } else {
            map2 = new LinkedHashMap<>();
            for (Map.Entry<String, PropertyMapping> entry : map3.entrySet()) {
                if (!attributeAdaptersForFunction.get(str).containsKey(entry.getKey())) {
                    map2.put(entry.getKey(), entry.getValue());
                }
            }
            for (Map.Entry<String, PropertyMapping> entry2 : map3.entrySet()) {
                if (!map2.containsKey(entry2.getKey())) {
                    map2.put(entry2.getKey(), entry2.getValue());
                }
            }
        }
        for (Map.Entry<String, PropertyMapping> entry3 : map2.entrySet()) {
            String tfAttrName = entry3.getValue().getTfAttrName();
            Field field = fieldsForFunction.get(entry3.getKey());
            AttributeAdapter attributeAdapter = null;
            if (attributeAdaptersForFunction != null && !attributeAdaptersForFunction.isEmpty()) {
                attributeAdapter = attributeAdaptersForFunction.get(str).get(entry3.getKey());
            }
            if (tfAttrName != null) {
                if (field != null && map.containsKey(tfAttrName)) {
                    AttrValue attrValue = map.get(tfAttrName);
                    switch (attrValue.getValueCase()) {
                        case B:
                            if (attributeAdapter != null) {
                                attributeAdapter.mapAttributeFor(Boolean.valueOf(attrValue.getB()), field, differentialFunction);
                                break;
                            } else {
                                break;
                            }
                        case S:
                            Object stringUtf8 = attrValue.getS().toStringUtf8();
                            if (attributeAdapter != null) {
                                attributeAdapter.mapAttributeFor(stringUtf8, field, differentialFunction);
                                break;
                            } else {
                                differentialFunction.setValueFor(field, stringUtf8);
                                break;
                            }
                        case I:
                            int i = (int) attrValue.getI();
                            if (attributeAdapter != null) {
                                attributeAdapter.mapAttributeFor(Integer.valueOf(i), field, differentialFunction);
                                break;
                            } else {
                                differentialFunction.setValueFor(field, Integer.valueOf(i));
                                break;
                            }
                        case SHAPE:
                            List<TensorShapeProto.Dim> dimList = attrValue.getShape().getDimList();
                            int[] iArr = new int[dimList.size()];
                            for (int i2 = 0; i2 < iArr.length; i2++) {
                                iArr[i2] = (int) dimList.get(i2).getSize();
                            }
                            if (attributeAdapter != null) {
                                attributeAdapter.mapAttributeFor(iArr, field, differentialFunction);
                                break;
                            } else {
                                differentialFunction.setValueFor(field, iArr);
                                break;
                            }
                        case LIST:
                            AttrValue.ListValue list = attrValue.getList();
                            if (!list.getIList().isEmpty()) {
                                Object array = Ints.toArray(list.getIList());
                                if (attributeAdapter != null) {
                                    attributeAdapter.mapAttributeFor(array, field, differentialFunction);
                                    break;
                                } else {
                                    differentialFunction.setValueFor(field, array);
                                    break;
                                }
                            } else if (!list.getBList().isEmpty()) {
                                break;
                            } else if (!list.getFList().isEmpty()) {
                                Object array2 = Floats.toArray(list.getFList());
                                if (attributeAdapter != null) {
                                    attributeAdapter.mapAttributeFor(array2, field, differentialFunction);
                                    break;
                                } else {
                                    differentialFunction.setValueFor(field, array2);
                                    break;
                                }
                            } else if (list.getFuncList().isEmpty() && !list.getTensorList().isEmpty()) {
                            }
                            break;
                        case TENSOR:
                            Object mapTensorProto = getInstance().mapTensorProto(attrValue.getTensor());
                            if (attributeAdapter != null) {
                                attributeAdapter.mapAttributeFor(mapTensorProto, field, differentialFunction);
                                break;
                            } else {
                                differentialFunction.setValueFor(field, mapTensorProto);
                                break;
                            }
                        case TYPE:
                            if (attributeAdapter != null) {
                                attributeAdapter.mapAttributeFor(attrValue.getType(), field, differentialFunction);
                                break;
                            } else {
                                break;
                            }
                    }
                }
            } else if (entry3.getValue().getTfInputPosition() != null) {
                int intValue = entry3.getValue().getTfInputPosition().intValue();
                if (intValue < 0) {
                    intValue += nodeDef.getInputCount();
                }
                NodeDef nodeWithNameFromGraph = getInstance().getNodeWithNameFromGraph(graphDef, nodeDef.getInput(intValue));
                INDArray nDArrayFromTensor = nodeWithNameFromGraph != null ? getInstance().getNDArrayFromTensor("value", nodeWithNameFromGraph, graphDef) : null;
                if (nDArrayFromTensor == null) {
                    nDArrayFromTensor = differentialFunction.getSameDiff().getArrForVarName(getNodeName(nodeDef.getInput(intValue)));
                }
                if (nDArrayFromTensor == null) {
                    differentialFunction.getSameDiff().addPropertyToResolve(differentialFunction, entry3.getKey());
                } else if (attributeAdapter != null) {
                    attributeAdapter.mapAttributeFor(nDArrayFromTensor, field, differentialFunction);
                } else if (field.getType().equals(int[].class)) {
                    differentialFunction.setValueFor(field, nDArrayFromTensor.data().asInt());
                } else if (field.getType().equals(double[].class)) {
                    differentialFunction.setValueFor(field, nDArrayFromTensor.data().asDouble());
                } else if (field.getType().equals(float[].class)) {
                    differentialFunction.setValueFor(field, nDArrayFromTensor.data().asFloat());
                } else if (field.getType().equals(INDArray.class)) {
                    differentialFunction.setValueFor(field, nDArrayFromTensor);
                } else if (field.getType().equals(Integer.TYPE)) {
                    differentialFunction.setValueFor(field, Integer.valueOf(nDArrayFromTensor.getInt(0)));
                } else if (field.getType().equals(Double.TYPE)) {
                    differentialFunction.setValueFor(field, Double.valueOf(nDArrayFromTensor.getDouble(0L)));
                } else if (field.getType().equals(Float.TYPE)) {
                    differentialFunction.setValueFor(field, Float.valueOf(nDArrayFromTensor.getFloat(0L)));
                }
            }
        }
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public DataBuffer.Type dataTypeForTensor(NodeDef nodeDef) {
        if (!nodeDef.containsAttr("dtype") && !nodeDef.containsAttr("Tidx") && !nodeDef.containsAttr(EXIFGPSTagSet.DIRECTION_REF_TRUE)) {
            return DataBuffer.Type.UNKNOWN;
        }
        switch (nodeDef.containsAttr("dtype") ? nodeDef.getAttrOrThrow("dtype").getType() : nodeDef.containsAttr(EXIFGPSTagSet.DIRECTION_REF_TRUE) ? nodeDef.getAttrOrThrow(EXIFGPSTagSet.DIRECTION_REF_TRUE).getType() : nodeDef.getAttrOrThrow("Tidx").getType()) {
            case DT_INT32:
            case DT_INT64:
                return DataBuffer.Type.INT;
            case DT_FLOAT:
                return DataBuffer.Type.FLOAT;
            case DT_DOUBLE:
                return DataBuffer.Type.DOUBLE;
            case DT_STRING:
            default:
                return DataBuffer.Type.UNKNOWN;
            case DT_BFLOAT16:
                return DataBuffer.Type.HALF;
        }
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public boolean unknownTypeNodeImportable(NodeDef nodeDef) {
        DataType dataType = null;
        if (nodeDef.containsAttr("dtype")) {
            dataType = nodeDef.getAttrOrThrow("dtype").getType();
        } else if (nodeDef.containsAttr(EXIFGPSTagSet.DIRECTION_REF_TRUE)) {
            dataType = nodeDef.getAttrOrThrow(EXIFGPSTagSet.DIRECTION_REF_TRUE).getType();
        } else if (nodeDef.containsAttr("Tidx")) {
            dataType = nodeDef.getAttrOrThrow("Tidx").getType();
        }
        return dataType == DataType.DT_BOOL;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public String getAttrValueFromNode(NodeDef nodeDef, String str) {
        return nodeDef.getAttrOrThrow(str).getS().toStringUtf8();
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public long[] getShapeFromAttribute(AttrValue attrValue) {
        long[] jArr = new long[attrValue.getShape().getDimCount()];
        for (int i = 0; i < jArr.length; i++) {
            jArr[i] = (int) r0.getDim(i).getSize();
        }
        return jArr;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public boolean isPlaceHolder(NodeDef nodeDef) {
        return nodeDef.getOp().startsWith("Placeholder");
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public boolean isConstant(NodeDef nodeDef) {
        return nodeDef.getOp().startsWith("Const");
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public INDArray getNDArrayFromTensor(String str, NodeDef nodeDef, GraphDef graphDef) {
        if (nodeDef.getAttrMap().containsKey("value")) {
            return mapTensorProto(nodeDef.getAttrOrThrow("value").getTensor());
        }
        return null;
    }

    public INDArray mapTensorProto(TensorProto tensorProto) {
        int dimCount = tensorProto.getTensorShape().getDimCount();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < dimCount; i++) {
            arrayList.add(Integer.valueOf((int) tensorProto.getTensorShape().getDim(i).getSize()));
        }
        long[] jArr = new long[arrayList.size()];
        for (int i2 = 0; i2 < jArr.length; i2++) {
            jArr[i2] = ((Integer) arrayList.get(i2)).intValue();
        }
        if (tensorProto.getDtype() == DataType.DT_INT32 || tensorProto.getDtype() == DataType.DT_INT16 || tensorProto.getDtype() == DataType.DT_INT8) {
            if (tensorProto.getIntValCount() == 1 || ArrayUtil.prod(jArr) == 1) {
                if (tensorProto.getIntValCount() < 1) {
                    return Nd4j.trueScalar(Double.valueOf(0.0d));
                }
                int intVal = tensorProto.getIntVal(0);
                return (jArr == null || jArr.length == 0) ? Nd4j.trueScalar(Double.valueOf(intVal)) : Nd4j.valueArrayOf(jArr, intVal);
            }
            if (tensorProto.getInt64ValCount() > 0) {
                double[] dArr = new double[tensorProto.getIntValCount()];
                for (int i3 = 0; i3 < tensorProto.getIntValCount(); i3++) {
                    dArr[i3] = tensorProto.getIntVal(i3);
                }
                return Nd4j.create(dArr, jArr, 0L, 'c');
            }
            ArrayUtil.prodLong(jArr);
            IntBuffer asIntBuffer = tensorProto.getTensorContent().asReadOnlyByteBuffer().order(ByteOrder.nativeOrder()).asIntBuffer();
            float[] fArr = new float[asIntBuffer.capacity()];
            for (int i4 = 0; i4 < asIntBuffer.capacity(); i4++) {
                fArr[i4] = asIntBuffer.get(i4);
            }
            return fArr.length == 0 ? Nd4j.empty() : fArr.length == 1 ? Nd4j.trueScalar(Float.valueOf(fArr[0])) : jArr.length == 1 ? Nd4j.trueVector(fArr) : Nd4j.create(fArr, jArr, 'c');
        }
        if (tensorProto.getDtype() == DataType.DT_FLOAT) {
            if (tensorProto.getFloatValCount() == 1 || ArrayUtil.prod(jArr) == 1) {
                if (tensorProto.getFloatValCount() < 1) {
                    return Nd4j.scalar(0.0d);
                }
                float floatVal = tensorProto.getFloatVal(0);
                if (jArr == null || jArr.length == 0) {
                    jArr = new long[0];
                }
                return Nd4j.valueArrayOf(jArr, floatVal);
            }
            if (tensorProto.getFloatValCount() > 0) {
                float[] fArr2 = new float[tensorProto.getFloatValCount()];
                for (int i5 = 0; i5 < tensorProto.getFloatValCount(); i5++) {
                    fArr2[i5] = tensorProto.getFloatVal(i5);
                }
                return Nd4j.create(Nd4j.createBuffer(fArr2), jArr);
            }
            if (tensorProto.getTensorContent().size() > 0) {
                FloatBuffer asFloatBuffer = tensorProto.getTensorContent().asReadOnlyByteBuffer().order(ByteOrder.nativeOrder()).asFloatBuffer();
                float[] fArr3 = new float[asFloatBuffer.capacity()];
                for (int i6 = 0; i6 < asFloatBuffer.capacity(); i6++) {
                    fArr3[i6] = asFloatBuffer.get(i6);
                }
                if (fArr3.length == 0) {
                    throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?");
                }
                return fArr3.length == 1 ? Nd4j.trueScalar(Float.valueOf(fArr3[0])) : jArr.length == 1 ? Nd4j.trueVector(fArr3) : Nd4j.create(fArr3, jArr, 'c');
            }
        } else if (tensorProto.getDtype() == DataType.DT_DOUBLE) {
            if (tensorProto.getDoubleValCount() == 1 || ArrayUtil.prod(jArr) == 1) {
                return tensorProto.getDoubleValCount() < 1 ? Nd4j.trueScalar(Double.valueOf(0.0d)) : Nd4j.trueScalar(Double.valueOf(tensorProto.getDoubleVal(0)));
            }
            if (tensorProto.getDoubleValCount() > 0) {
                double[] dArr2 = new double[tensorProto.getDoubleValCount()];
                for (int i7 = 0; i7 < tensorProto.getDoubleValCount(); i7++) {
                    dArr2[i7] = tensorProto.getDoubleVal(i7);
                }
                return Nd4j.create(dArr2, jArr, 0L, 'c');
            }
            if (tensorProto.getTensorContent().size() > 0) {
                DoubleBuffer asDoubleBuffer = tensorProto.getTensorContent().asReadOnlyByteBuffer().order(ByteOrder.nativeOrder()).asDoubleBuffer();
                double[] dArr3 = new double[asDoubleBuffer.capacity()];
                for (int i8 = 0; i8 < asDoubleBuffer.capacity(); i8++) {
                    dArr3[i8] = asDoubleBuffer.get(i8);
                }
                if (dArr3.length == 0) {
                    throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?");
                }
                return dArr3.length == 1 ? Nd4j.trueScalar(Double.valueOf(dArr3[0])) : jArr.length == 1 ? Nd4j.trueVector(dArr3) : Nd4j.create(dArr3, jArr, 0L, 'c');
            }
        } else if (tensorProto.getDtype() == DataType.DT_INT64) {
            if (tensorProto.getInt64ValCount() == 1 || ArrayUtil.prod(jArr) == 1) {
                return tensorProto.getDoubleValCount() < 1 ? Nd4j.trueScalar(Double.valueOf(0.0d)) : Nd4j.trueScalar(Double.valueOf(tensorProto.getInt64Val(0)));
            }
            if (tensorProto.getInt64ValCount() > 0) {
                double[] dArr4 = new double[tensorProto.getInt64ValCount()];
                for (int i9 = 0; i9 < tensorProto.getInt64ValCount(); i9++) {
                    dArr4[i9] = tensorProto.getInt64Val(i9);
                }
                return Nd4j.create(dArr4, jArr, 0L, 'c');
            }
            if (tensorProto.getTensorContent().size() > 0) {
                LongBuffer asLongBuffer = tensorProto.getTensorContent().asReadOnlyByteBuffer().order(ByteOrder.nativeOrder()).asLongBuffer();
                float[] fArr4 = new float[asLongBuffer.capacity()];
                for (int i10 = 0; i10 < asLongBuffer.capacity(); i10++) {
                    fArr4[i10] = (float) asLongBuffer.get(i10);
                }
                if (fArr4.length == 0) {
                    throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?");
                }
                return fArr4.length == 1 ? Nd4j.trueScalar(Float.valueOf(fArr4[0])) : jArr.length == 1 ? Nd4j.trueVector(fArr4) : Nd4j.create(fArr4, jArr, 'c');
            }
        } else {
            if (tensorProto.getDtype() != DataType.DT_BOOL) {
                throw new UnsupportedOperationException("Unknown dataType found: [" + tensorProto.getDtype() + PropertyAccessor.PROPERTY_KEY_SUFFIX);
            }
            if (tensorProto.getBoolValCount() == 1 || ArrayUtil.prod(jArr) == 1) {
                if (tensorProto.getBoolValCount() < 1) {
                    return Nd4j.trueScalar(Double.valueOf(0.0d));
                }
                return Nd4j.trueScalar(Double.valueOf(tensorProto.getBoolVal(0) ? 1.0d : 0.0d));
            }
            if (tensorProto.getBoolValCount() > 0) {
                float[] fArr5 = new float[tensorProto.getBoolValCount()];
                for (int i11 = 0; i11 < tensorProto.getBoolValCount(); i11++) {
                    fArr5[i11] = tensorProto.getBoolVal(i11) ? 1.0f : 0.0f;
                }
                return Nd4j.create(fArr5, jArr, 'c');
            }
            if (tensorProto.getTensorContent().size() > 0) {
                throw new UnsupportedOperationException("Not yet implemented for DataType.DT_BOOL");
            }
        }
        throw new ND4JIllegalStateException("Invalid method state");
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public long[] getShapeFromTensor(NodeDef nodeDef) {
        if (nodeDef.containsAttr(SHAPE_KEY)) {
            return shapeFromShapeProto(nodeDef.getAttrOrThrow(SHAPE_KEY).getShape());
        }
        if (nodeDef.containsAttr("value")) {
            return shapeFromShapeProto(nodeDef.getAttrOrThrow("value").getTensor().getTensorShape());
        }
        return null;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public Set<String> opsToIgnore() {
        return this.graphMapper;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public String getInputFromNode(NodeDef nodeDef, int i) {
        return nodeDef.getInput(i);
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public int numInputsFor(NodeDef nodeDef) {
        return nodeDef.getInputCount();
    }

    private long[] shapeFromShapeProto(TensorShapeProto tensorShapeProto) {
        long[] jArr = new long[tensorShapeProto.getDimList().size()];
        for (int i = 0; i < jArr.length; i++) {
            jArr[i] = tensorShapeProto.getDim(i).getSize();
        }
        return jArr;
    }

    public IfImportState nodesForIf(NodeDef nodeDef, GraphDef graphDef) {
        int indexOf = graphDef.getNodeList().indexOf(nodeDef);
        String input = nodeDef.getInput(1);
        nodeDef.getInput(0);
        String str = UUID.randomUUID().toString() + Parameters.DEFAULT_OPTION_PREFIXES + input.substring(0, input.indexOf(AntPathMatcher.DEFAULT_PATH_SEPARATOR));
        String str2 = str + "-true-scope";
        String str3 = str + "-false-scope";
        boolean z = true;
        boolean z2 = false;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (int i = indexOf; i >= 0; i--) {
            if (graphDef.getNode(i).getName().equals(input)) {
                z = false;
                z2 = true;
            }
            if (graphDef.getNode(i).getName().contains("pred_id")) {
                z2 = false;
            }
            if (z2 && !graphDef.getNode(i).equals(nodeDef)) {
                arrayList2.add(graphDef.getNode(i));
            } else if (!z || graphDef.getNode(i).equals(nodeDef)) {
                NodeDef node = graphDef.getNode(i);
                if (!node.equals(nodeDef)) {
                    if (!linkedHashSet.contains(graphDef.getNode(i).getName()) && !graphDef.getNode(i).getName().contains("pred_id")) {
                        break;
                    }
                    for (int i2 = 0; i2 < node.getInputCount(); i2++) {
                        linkedHashSet.add(node.getInput(i2));
                    }
                    linkedHashSet.add(graphDef.getNode(i).getName());
                    arrayList3.add(graphDef.getNode(i));
                } else {
                    continue;
                }
            } else {
                arrayList.add(graphDef.getNode(i));
            }
        }
        Collections.reverse(arrayList);
        Collections.reverse(arrayList2);
        Collections.reverse(arrayList3);
        return IfImportState.builder().condNodes(arrayList3).falseNodes(arrayList).trueNodes(arrayList2).conditionBodyScopeName(str3).falseBodyScopeName(str3).trueBodyScopeName(str2).conditionBodyScopeName(str).build();
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public /* bridge */ /* synthetic */ void mapNodeType(Object obj, ImportState importState) {
        mapNodeType((NodeDef) obj, (ImportState<GraphDef, NodeDef>) importState);
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public /* bridge */ /* synthetic */ void mapProperty(String str, DifferentialFunction differentialFunction, Object obj, Object obj2, SameDiff sameDiff, Map map) {
        mapProperty(str, differentialFunction, (NodeDef) obj, (GraphDef) obj2, sameDiff, (Map<String, Map<String, PropertyMapping>>) map);
    }
}
