package org.deeplearning4j.nn.modelimport.keras;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasLoss;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasLstm;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.shade.jackson.core.type.TypeReference;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/KerasModel.class */
public class KerasModel {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) KerasModel.class);
    public static final String MODEL_FIELD_CLASS_NAME = "class_name";
    public static final String MODEL_CLASS_NAME_SEQUENTIAL = "Sequential";
    public static final String MODEL_CLASS_NAME_MODEL = "Model";
    public static final String MODEL_FIELD_CONFIG = "config";
    public static final String MODEL_CONFIG_FIELD_LAYERS = "layers";
    public static final String MODEL_CONFIG_FIELD_INPUT_LAYERS = "input_layers";
    public static final String MODEL_CONFIG_FIELD_OUTPUT_LAYERS = "output_layers";
    public static final String TRAINING_CONFIG_FIELD_LOSS = "loss";
    public static final String HDF5_MODEL_WEIGHTS_ROOT = "model_weights";
    public static final String HDF5_MODEL_CONFIG_ATTRIBUTE = "model_config";
    public static final String HDF5_TRAINING_CONFIG_ATTRIBUTE = "training_config";
    protected String className;
    protected boolean enforceTrainingConfig;
    protected List<KerasLayer> layersOrdered;
    protected Map<String, KerasLayer> layers;
    protected Map<String, InputType> outputTypes;
    protected ArrayList<String> inputLayerNames;
    protected ArrayList<String> outputLayerNames;
    protected boolean useTruncatedBPTT;
    protected int truncatedBPTT;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/KerasModel$ModelBuilder.class */
    public static class ModelBuilder implements Cloneable {
        protected String modelJson = null;
        protected String modelYaml = null;
        protected String trainingJson = null;
        protected Hdf5Archive weightsArchive = null;
        protected String weightsRoot = null;
        protected Hdf5Archive trainingArchive = null;
        protected boolean enforceTrainingConfig = false;

        public ModelBuilder modelJson(String str) {
            this.modelJson = str;
            return this;
        }

        public ModelBuilder modelJsonFilename(String str) throws IOException {
            this.modelJson = new String(Files.readAllBytes(Paths.get(str, new String[0])));
            return this;
        }

        public ModelBuilder modelJsonInputStream(InputStream inputStream) throws IOException {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            IOUtils.copy(inputStream, byteArrayOutputStream);
            this.modelJson = new String(byteArrayOutputStream.toByteArray());
            return this;
        }

        public ModelBuilder modelYaml(String str) {
            this.modelYaml = str;
            return this;
        }

        public ModelBuilder modelYamlFilename(String str) throws IOException {
            this.modelJson = new String(Files.readAllBytes(Paths.get(str, new String[0])));
            return this;
        }

        public ModelBuilder modelYamlInputStream(InputStream inputStream) throws IOException {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            IOUtils.copy(inputStream, byteArrayOutputStream);
            this.modelJson = new String(byteArrayOutputStream.toByteArray());
            return this;
        }

        public ModelBuilder trainingJson(String str) {
            this.trainingJson = str;
            return this;
        }

        public ModelBuilder trainingJsonInputStream(InputStream inputStream) throws IOException {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            IOUtils.copy(inputStream, byteArrayOutputStream);
            this.trainingJson = new String(byteArrayOutputStream.toByteArray());
            return this;
        }

        public ModelBuilder modelHdf5Filename(String str) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
            Hdf5Archive hdf5Archive = new Hdf5Archive(str);
            this.trainingArchive = hdf5Archive;
            this.weightsArchive = hdf5Archive;
            this.weightsRoot = KerasModel.HDF5_MODEL_WEIGHTS_ROOT;
            if (!this.weightsArchive.hasAttribute(KerasModel.HDF5_MODEL_CONFIG_ATTRIBUTE, new String[0])) {
                throw new InvalidKerasConfigurationException("Model configuration attribute missing from " + str + " archive.");
            }
            this.modelJson = this.weightsArchive.readAttributeAsJson(KerasModel.HDF5_MODEL_CONFIG_ATTRIBUTE, new String[0]);
            if (this.trainingArchive.hasAttribute(KerasModel.HDF5_TRAINING_CONFIG_ATTRIBUTE, new String[0])) {
                this.trainingJson = this.trainingArchive.readAttributeAsJson(KerasModel.HDF5_TRAINING_CONFIG_ATTRIBUTE, new String[0]);
            }
            return this;
        }

        public ModelBuilder weightsHdf5Filename(String str) {
            this.weightsArchive = new Hdf5Archive(str);
            return this;
        }

        public ModelBuilder enforceTrainingConfig(boolean z) {
            this.enforceTrainingConfig = z;
            return this;
        }

        public static ModelBuilder builder() {
            return new ModelBuilder();
        }

        public KerasModel buildModel() throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
            return new KerasModel(this);
        }

        public KerasSequentialModel buildSequential() throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
            return new KerasSequentialModel(this);
        }

        public String getModelJson() {
            return this.modelJson;
        }

        public String getModelYaml() {
            return this.modelYaml;
        }

        public String getTrainingJson() {
            return this.trainingJson;
        }

        public Hdf5Archive getWeightsArchive() {
            return this.weightsArchive;
        }

        public String getWeightsRoot() {
            return this.weightsRoot;
        }

        public Hdf5Archive getTrainingArchive() {
            return this.trainingArchive;
        }

        public boolean isEnforceTrainingConfig() {
            return this.enforceTrainingConfig;
        }

        public void setModelJson(String str) {
            this.modelJson = str;
        }

        public void setModelYaml(String str) {
            this.modelYaml = str;
        }

        public void setTrainingJson(String str) {
            this.trainingJson = str;
        }

        public void setWeightsArchive(Hdf5Archive hdf5Archive) {
            this.weightsArchive = hdf5Archive;
        }

        public void setWeightsRoot(String str) {
            this.weightsRoot = str;
        }

        public void setTrainingArchive(Hdf5Archive hdf5Archive) {
            this.trainingArchive = hdf5Archive;
        }

        public void setEnforceTrainingConfig(boolean z) {
            this.enforceTrainingConfig = z;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof ModelBuilder)) {
                return false;
            }
            ModelBuilder modelBuilder = (ModelBuilder) obj;
            if (!modelBuilder.canEqual(this)) {
                return false;
            }
            String modelJson = getModelJson();
            String modelJson2 = modelBuilder.getModelJson();
            if (modelJson == null) {
                if (modelJson2 != null) {
                    return false;
                }
            } else if (!modelJson.equals(modelJson2)) {
                return false;
            }
            String modelYaml = getModelYaml();
            String modelYaml2 = modelBuilder.getModelYaml();
            if (modelYaml == null) {
                if (modelYaml2 != null) {
                    return false;
                }
            } else if (!modelYaml.equals(modelYaml2)) {
                return false;
            }
            String trainingJson = getTrainingJson();
            String trainingJson2 = modelBuilder.getTrainingJson();
            if (trainingJson == null) {
                if (trainingJson2 != null) {
                    return false;
                }
            } else if (!trainingJson.equals(trainingJson2)) {
                return false;
            }
            Hdf5Archive weightsArchive = getWeightsArchive();
            Hdf5Archive weightsArchive2 = modelBuilder.getWeightsArchive();
            if (weightsArchive == null) {
                if (weightsArchive2 != null) {
                    return false;
                }
            } else if (!weightsArchive.equals(weightsArchive2)) {
                return false;
            }
            String weightsRoot = getWeightsRoot();
            String weightsRoot2 = modelBuilder.getWeightsRoot();
            if (weightsRoot == null) {
                if (weightsRoot2 != null) {
                    return false;
                }
            } else if (!weightsRoot.equals(weightsRoot2)) {
                return false;
            }
            Hdf5Archive trainingArchive = getTrainingArchive();
            Hdf5Archive trainingArchive2 = modelBuilder.getTrainingArchive();
            if (trainingArchive == null) {
                if (trainingArchive2 != null) {
                    return false;
                }
            } else if (!trainingArchive.equals(trainingArchive2)) {
                return false;
            }
            return isEnforceTrainingConfig() == modelBuilder.isEnforceTrainingConfig();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof ModelBuilder;
        }

        public int hashCode() {
            String modelJson = getModelJson();
            int hashCode = (1 * 59) + (modelJson == null ? 43 : modelJson.hashCode());
            String modelYaml = getModelYaml();
            int hashCode2 = (hashCode * 59) + (modelYaml == null ? 43 : modelYaml.hashCode());
            String trainingJson = getTrainingJson();
            int hashCode3 = (hashCode2 * 59) + (trainingJson == null ? 43 : trainingJson.hashCode());
            Hdf5Archive weightsArchive = getWeightsArchive();
            int hashCode4 = (hashCode3 * 59) + (weightsArchive == null ? 43 : weightsArchive.hashCode());
            String weightsRoot = getWeightsRoot();
            int hashCode5 = (hashCode4 * 59) + (weightsRoot == null ? 43 : weightsRoot.hashCode());
            Hdf5Archive trainingArchive = getTrainingArchive();
            return (((hashCode5 * 59) + (trainingArchive == null ? 43 : trainingArchive.hashCode())) * 59) + (isEnforceTrainingConfig() ? 79 : 97);
        }

        public String toString() {
            return "KerasModel.ModelBuilder(modelJson=" + getModelJson() + ", modelYaml=" + getModelYaml() + ", trainingJson=" + getTrainingJson() + ", weightsArchive=" + getWeightsArchive() + ", weightsRoot=" + getWeightsRoot() + ", trainingArchive=" + getTrainingArchive() + ", enforceTrainingConfig=" + isEnforceTrainingConfig() + ")";
        }
    }

    public KerasModel(ModelBuilder modelBuilder) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
        this(modelBuilder.modelJson, modelBuilder.modelYaml, modelBuilder.weightsArchive, modelBuilder.weightsRoot, modelBuilder.trainingJson, modelBuilder.trainingArchive, modelBuilder.enforceTrainingConfig);
    }

    protected KerasModel(String str, String str2, Hdf5Archive hdf5Archive, String str3, String str4, Hdf5Archive hdf5Archive2, boolean z) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> parseYamlString;
        this.useTruncatedBPTT = false;
        this.truncatedBPTT = 0;
        if (str != null) {
            parseYamlString = parseJsonString(str);
        } else {
            if (str2 == null) {
                throw new InvalidKerasConfigurationException("Requires model configuration not found.");
            }
            parseYamlString = parseYamlString(str2);
        }
        this.enforceTrainingConfig = z;
        if (!parseYamlString.containsKey("class_name")) {
            throw new InvalidKerasConfigurationException("Could not determine Keras model class (no class_name field found)");
        }
        this.className = (String) parseYamlString.get("class_name");
        if (!this.className.equals(MODEL_CLASS_NAME_MODEL)) {
            throw new InvalidKerasConfigurationException("Expected model class name Model (found " + this.className + ")");
        }
        if (!parseYamlString.containsKey("config")) {
            throw new InvalidKerasConfigurationException("Could not find model configuration details (no config in model config)");
        }
        Map map = (Map) parseYamlString.get("config");
        if (!map.containsKey(MODEL_CONFIG_FIELD_INPUT_LAYERS)) {
            throw new InvalidKerasConfigurationException("Could not find list of input layers (no input_layers field found)");
        }
        this.inputLayerNames = new ArrayList<>();
        Iterator it2 = ((List) map.get(MODEL_CONFIG_FIELD_INPUT_LAYERS)).iterator();
        while (it2.hasNext()) {
            this.inputLayerNames.add((String) ((List) it2.next()).get(0));
        }
        if (!map.containsKey(MODEL_CONFIG_FIELD_OUTPUT_LAYERS)) {
            throw new InvalidKerasConfigurationException("Could not find list of output layers (no output_layers field found)");
        }
        this.outputLayerNames = new ArrayList<>();
        Iterator it3 = ((List) map.get(MODEL_CONFIG_FIELD_OUTPUT_LAYERS)).iterator();
        while (it3.hasNext()) {
            this.outputLayerNames.add((String) ((List) it3.next()).get(0));
        }
        if (!map.containsKey(MODEL_CONFIG_FIELD_LAYERS)) {
            throw new InvalidKerasConfigurationException("Could not find layer configurations (no layers field found)");
        }
        helperPrepareLayers((List) map.get(MODEL_CONFIG_FIELD_LAYERS));
        if (str4 != null) {
            helperImportTrainingConfiguration(str4);
        }
        helperInferOutputTypes();
        if (hdf5Archive != null) {
            helperImportWeights(hdf5Archive, str3);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void helperPrepareLayers(List<Object> list) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this.layersOrdered = new ArrayList();
        this.layers = new HashMap();
        KerasLayer.DimOrder dimOrder = KerasLayer.DimOrder.NONE;
        Iterator<Object> it2 = list.iterator();
        while (it2.hasNext()) {
            KerasLayer kerasLayerFromConfig = KerasLayer.getKerasLayerFromConfig((Map) it2.next(), this.enforceTrainingConfig);
            if (dimOrder == KerasLayer.DimOrder.NONE && kerasLayerFromConfig.getDimOrder() != KerasLayer.DimOrder.NONE) {
                dimOrder = kerasLayerFromConfig.getDimOrder();
            }
            this.layersOrdered.add(kerasLayerFromConfig);
            this.layers.put(kerasLayerFromConfig.getLayerName(), kerasLayerFromConfig);
            if (kerasLayerFromConfig instanceof KerasLstm) {
                this.useTruncatedBPTT = this.useTruncatedBPTT || ((KerasLstm) kerasLayerFromConfig).getUnroll();
            }
        }
        for (KerasLayer kerasLayer : this.layersOrdered) {
            if (kerasLayer.getDimOrder() == KerasLayer.DimOrder.NONE) {
                kerasLayer.setDimOrder(dimOrder);
            } else if (kerasLayer.getDimOrder() != dimOrder) {
                throw new UnsupportedKerasConfigurationException("Keras layer " + kerasLayer.getLayerName() + " has conflicting dim_ordering " + kerasLayer.getDimOrder() + " (vs. dimOrder)");
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void helperImportTrainingConfiguration(String str) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> parseJsonString = parseJsonString(str);
        ArrayList<KerasLayer> arrayList = new ArrayList();
        if (!parseJsonString.containsKey(TRAINING_CONFIG_FIELD_LOSS)) {
            throw new InvalidKerasConfigurationException("Could not determine training loss function (no loss field found in training config)");
        }
        Object obj = parseJsonString.get(TRAINING_CONFIG_FIELD_LOSS);
        if (obj instanceof String) {
            String str2 = (String) obj;
            Iterator<String> it2 = this.outputLayerNames.iterator();
            while (it2.hasNext()) {
                String next = it2.next();
                arrayList.add(new KerasLoss(next + "_loss", next, str2));
            }
        } else if (obj instanceof Map) {
            Map map = (Map) obj;
            for (String str3 : map.keySet()) {
                Object obj2 = map.get(str3);
                if (!(obj2 instanceof String)) {
                    throw new InvalidKerasConfigurationException("Unknown Keras loss " + obj2.toString());
                }
                arrayList.add(new KerasLoss(str3 + "_loss", str3, (String) obj2));
            }
        }
        this.outputLayerNames.clear();
        for (KerasLayer kerasLayer : arrayList) {
            this.layersOrdered.add(kerasLayer);
            this.layers.put(kerasLayer.getLayerName(), kerasLayer);
            this.outputLayerNames.add(kerasLayer.getLayerName());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void helperInferOutputTypes() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        InputType outputType;
        this.outputTypes = new HashMap();
        for (KerasLayer kerasLayer : this.layersOrdered) {
            if (kerasLayer instanceof KerasInput) {
                outputType = kerasLayer.getOutputType(new InputType[0]);
                this.truncatedBPTT = ((KerasInput) kerasLayer).getTruncatedBptt();
            } else {
                InputType[] inputTypeArr = new InputType[kerasLayer.getInboundLayerNames().size()];
                int i = 0;
                Iterator<String> it2 = kerasLayer.getInboundLayerNames().iterator();
                while (it2.hasNext()) {
                    int i2 = i;
                    i++;
                    inputTypeArr[i2] = this.outputTypes.get(it2.next());
                }
                outputType = kerasLayer.getOutputType(inputTypeArr);
            }
            this.outputTypes.put(kerasLayer.getLayerName(), outputType);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void helperImportWeights(Hdf5Archive hdf5Archive, String str) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        List<String> groups = str != null ? hdf5Archive.getGroups(str) : hdf5Archive.getGroups(new String[0]);
        for (String str2 : groups) {
            List<String> dataSets = str != null ? hdf5Archive.getDataSets(str, str2) : hdf5Archive.getDataSets(str2);
            if (!dataSets.isEmpty()) {
                if (!dataSets.isEmpty() && !this.layers.containsKey(str2)) {
                    throw new InvalidKerasConfigurationException("Found weights for layer not in model (named " + str2 + ")");
                }
                KerasLayer kerasLayer = this.layers.get(str2);
                if (dataSets.size() != kerasLayer.getNumParams()) {
                    throw new InvalidKerasConfigurationException("Found " + dataSets.size() + " weights for layer with " + kerasLayer.getNumParams() + " trainable params (named " + str2 + ")");
                }
                HashMap hashMap = new HashMap();
                for (String str3 : dataSets) {
                    Matcher matcher = Pattern.compile(str2).matcher(str3);
                    if (!matcher.find()) {
                        throw new InvalidKerasConfigurationException("Unable to parse layer/parameter name " + str3 + " for stored weights.");
                    }
                    String replaceFirst = matcher.replaceFirst("");
                    Matcher matcher2 = Pattern.compile("^_(.+)$").matcher(replaceFirst);
                    if (matcher2.find()) {
                        replaceFirst = matcher2.group(1);
                    }
                    Matcher matcher3 = Pattern.compile(":\\d+?$").matcher(replaceFirst);
                    if (matcher3.find()) {
                        replaceFirst = matcher3.replaceFirst("");
                    }
                    Matcher matcher4 = Pattern.compile("_\\d+$").matcher(replaceFirst);
                    if (matcher4.find()) {
                        replaceFirst = matcher4.replaceFirst("");
                    }
                    hashMap.put(replaceFirst, str != null ? hdf5Archive.readDataSet(str3, str, str2) : hdf5Archive.readDataSet(str3, str2));
                }
                kerasLayer.setWeights(hashMap);
            }
        }
        HashSet<String> hashSet = new HashSet(this.layers.keySet());
        hashSet.removeAll(groups);
        for (String str4 : hashSet) {
            if (this.layers.get(str4).getNumParams() > 0) {
                throw new InvalidKerasConfigurationException("Could not find weights required for layer " + str4);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public KerasModel() {
        this.useTruncatedBPTT = false;
        this.truncatedBPTT = 0;
    }

    public ComputationGraphConfiguration getComputationGraphConfiguration() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        if (!this.className.equals(MODEL_CLASS_NAME_MODEL) && !this.className.equals(MODEL_CLASS_NAME_SEQUENTIAL)) {
            throw new InvalidKerasConfigurationException("Keras model class name " + this.className + " incompatible with ComputationGraph");
        }
        NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
        ComputationGraphConfiguration.GraphBuilder graphBuilder = builder.graphBuilder();
        String[] strArr = new String[this.inputLayerNames.size()];
        this.inputLayerNames.toArray(strArr);
        graphBuilder.addInputs(strArr);
        ArrayList arrayList = new ArrayList();
        Iterator<String> it2 = this.inputLayerNames.iterator();
        while (it2.hasNext()) {
            arrayList.add(this.layers.get(it2.next()).getOutputType(new InputType[0]));
        }
        InputType[] inputTypeArr = new InputType[arrayList.size()];
        arrayList.toArray(inputTypeArr);
        graphBuilder.setInputTypes(inputTypeArr);
        String[] strArr2 = new String[this.outputLayerNames.size()];
        this.outputLayerNames.toArray(strArr2);
        graphBuilder.setOutputs(strArr2);
        HashMap hashMap = new HashMap();
        for (KerasLayer kerasLayer : this.layersOrdered) {
            List<String> inboundLayerNames = kerasLayer.getInboundLayerNames();
            String[] strArr3 = new String[inboundLayerNames.size()];
            inboundLayerNames.toArray(strArr3);
            ArrayList arrayList2 = new ArrayList();
            Iterator<String> it3 = inboundLayerNames.iterator();
            while (it3.hasNext()) {
                arrayList2.add(this.outputTypes.get(it3.next()));
            }
            InputType[] inputTypeArr2 = new InputType[arrayList2.size()];
            arrayList2.toArray(inputTypeArr2);
            InputPreProcessor inputPreprocessor = kerasLayer.getInputPreprocessor(inputTypeArr2);
            if (kerasLayer.usesRegularization()) {
                builder.setUseRegularization(true);
            }
            if (kerasLayer.isLayer()) {
                if (inputPreprocessor != null) {
                    hashMap.put(kerasLayer.getLayerName(), inputPreprocessor);
                }
                graphBuilder.addLayer(kerasLayer.getLayerName(), kerasLayer.getLayer(), strArr3);
                if (this.outputLayerNames.contains(kerasLayer.getLayerName()) && !(kerasLayer.getLayer() instanceof IOutputLayer)) {
                    log.warn("Model cannot be trained: output layer " + kerasLayer.getLayerName() + " is not an IOutputLayer (no loss function specified)");
                }
            } else if (kerasLayer.isVertex()) {
                if (inputPreprocessor != null) {
                    hashMap.put(kerasLayer.getLayerName(), inputPreprocessor);
                }
                graphBuilder.addVertex(kerasLayer.getLayerName(), kerasLayer.getVertex(), strArr3);
                if (this.outputLayerNames.contains(kerasLayer.getLayerName()) && !(kerasLayer.getVertex() instanceof IOutputLayer)) {
                    log.warn("Model cannot be trained: output vertex " + kerasLayer.getLayerName() + " is not an IOutputLayer (no loss function specified)");
                }
            } else if (kerasLayer.isInputPreProcessor()) {
                if (inputPreprocessor == null) {
                    throw new UnsupportedKerasConfigurationException("Layer " + kerasLayer.getLayerName() + " could not be mapped to Layer, Vertex, or InputPreProcessor");
                }
                graphBuilder.addVertex(kerasLayer.getLayerName(), new PreprocessorVertex(inputPreprocessor), strArr3);
            }
            if (this.outputLayerNames.contains(kerasLayer.getLayerName())) {
                log.warn("Model cannot be trained: output " + kerasLayer.getLayerName() + " is not an IOutputLayer (no loss function specified)");
            }
        }
        graphBuilder.setInputPreProcessors(hashMap);
        if (!this.useTruncatedBPTT || this.truncatedBPTT <= 0) {
            graphBuilder.backpropType(BackpropType.Standard);
        } else {
            graphBuilder.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(this.truncatedBPTT).tBPTTBackwardLength(this.truncatedBPTT);
        }
        return graphBuilder.build();
    }

    public ComputationGraph getComputationGraph() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return getComputationGraph(true);
    }

    public ComputationGraph getComputationGraph(boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        ComputationGraph computationGraph = new ComputationGraph(getComputationGraphConfiguration());
        computationGraph.init();
        if (z) {
            computationGraph = (ComputationGraph) helperCopyWeightsToModel(computationGraph);
        }
        return computationGraph;
    }

    public static Map<String, Object> parseJsonString(String str) throws IOException {
        return (Map) new ObjectMapper().readValue(str, new TypeReference<HashMap<String, Object>>() { // from class: org.deeplearning4j.nn.modelimport.keras.KerasModel.1
        });
    }

    public static Map<String, Object> parseYamlString(String str) throws IOException {
        return (Map) new ObjectMapper(new YAMLFactory()).readValue(str, new TypeReference<HashMap<String, Object>>() { // from class: org.deeplearning4j.nn.modelimport.keras.KerasModel.2
        });
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public org.deeplearning4j.nn.api.Model helperCopyWeightsToModel(org.deeplearning4j.nn.api.Model model) throws InvalidKerasConfigurationException {
        Layer[] layers = model instanceof MultiLayerNetwork ? ((MultiLayerNetwork) model).getLayers() : ((ComputationGraph) model).getLayers();
        HashSet<String> hashSet = new HashSet(this.layers.keySet());
        for (Layer layer : layers) {
            String layerName = layer.conf().getLayer().getLayerName();
            if (!this.layers.containsKey(layerName)) {
                throw new InvalidKerasConfigurationException("No weights found for layer in model (named " + layerName + ")");
            }
            this.layers.get(layerName).copyWeightsToLayer(layer);
            hashSet.remove(layerName);
        }
        for (String str : hashSet) {
            if (this.layers.get(str).getNumParams() > 0) {
                throw new InvalidKerasConfigurationException("Attemping to copy weights for layer not in model (named " + str + ")");
            }
        }
        return model;
    }
}
