package org.deeplearning4j.nn.conf;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
import org.deeplearning4j.util.OutputLayerUtil;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.node.ArrayNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/conf/MultiLayerConfiguration.class */
public class MultiLayerConfiguration implements Serializable, Cloneable {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) MultiLayerConfiguration.class);
    protected List<NeuralNetConfiguration> confs;
    protected Map<Integer, InputPreProcessor> inputPreProcessors;
    protected boolean pretrain;
    protected boolean backprop;
    protected BackpropType backpropType;
    protected int tbpttFwdLength;
    protected int tbpttBackLength;
    protected boolean legacyBatchScaledL2;
    protected WorkspaceMode trainingWorkspaceMode;
    protected WorkspaceMode inferenceWorkspaceMode;
    protected CacheMode cacheMode;
    protected int iterationCount;
    protected int epochCount;

    /* loaded from: input_file:org/deeplearning4j/nn/conf/MultiLayerConfiguration$Builder.class */
    public static class Builder {
        private static final int DEFAULT_TBPTT_LENGTH = 20;
        protected InputType inputType;
        protected boolean legacyBatchScaledL2;
        protected List<NeuralNetConfiguration> confs = new ArrayList();
        protected double dampingFactor = 100.0d;
        protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap();

        @Deprecated
        protected boolean pretrain = false;

        @Deprecated
        protected boolean backprop = true;
        protected BackpropType backpropType = BackpropType.Standard;
        protected int tbpttFwdLength = 20;
        protected int tbpttBackLength = 20;
        protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
        protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
        protected CacheMode cacheMode = CacheMode.NONE;
        protected boolean validateOutputConfig = true;

        public Builder inputPreProcessor(Integer num, InputPreProcessor inputPreProcessor) {
            this.inputPreProcessors.put(num, inputPreProcessor);
            return this;
        }

        public Builder inputPreProcessors(Map<Integer, InputPreProcessor> map) {
            this.inputPreProcessors = map;
            return this;
        }

        @Deprecated
        public Builder backprop(boolean z) {
            this.backprop = z;
            return this;
        }

        @Deprecated
        public Builder trainingWorkspaceMode(@NonNull WorkspaceMode workspaceMode) {
            if (workspaceMode == null) {
                throw new NullPointerException("workspaceMode is marked @NonNull but is null");
            }
            this.trainingWorkspaceMode = workspaceMode;
            return this;
        }

        @Deprecated
        public Builder inferenceWorkspaceMode(@NonNull WorkspaceMode workspaceMode) {
            if (workspaceMode == null) {
                throw new NullPointerException("workspaceMode is marked @NonNull but is null");
            }
            this.inferenceWorkspaceMode = workspaceMode;
            return this;
        }

        public Builder cacheMode(@NonNull CacheMode cacheMode) {
            if (cacheMode == null) {
                throw new NullPointerException("cacheMode is marked @NonNull but is null");
            }
            this.cacheMode = cacheMode;
            return this;
        }

        public Builder backpropType(@NonNull BackpropType backpropType) {
            if (backpropType == null) {
                throw new NullPointerException("type is marked @NonNull but is null");
            }
            this.backpropType = backpropType;
            return this;
        }

        public Builder tBPTTLength(int i) {
            tBPTTForwardLength(i);
            return tBPTTBackwardLength(i);
        }

        public Builder tBPTTForwardLength(int i) {
            this.tbpttFwdLength = i;
            return this;
        }

        public Builder tBPTTBackwardLength(int i) {
            this.tbpttBackLength = i;
            return this;
        }

        @Deprecated
        public Builder pretrain(boolean z) {
            this.pretrain = z;
            return this;
        }

        public Builder confs(List<NeuralNetConfiguration> list) {
            this.confs = list;
            return this;
        }

        public Builder setInputType(InputType inputType) {
            this.inputType = inputType;
            return this;
        }

        public Builder validateOutputLayerConfig(boolean z) {
            this.validateOutputConfig = z;
            return this;
        }

        public Builder legacyBatchScaledL2(boolean z) {
            this.legacyBatchScaledL2 = z;
            return this;
        }

        public MultiLayerConfiguration build() {
            InputPreProcessor preProcessorForInputType;
            if ((this.tbpttBackLength != 20 || this.tbpttFwdLength != 20) && this.backpropType != BackpropType.TruncatedBPTT) {
                MultiLayerConfiguration.log.warn("Truncated backpropagation through time lengths have been configured with values " + this.tbpttFwdLength + " and " + this.tbpttBackLength + " but backprop type is set to " + this.backpropType + ". TBPTT configuration settings will only take effect if backprop type is set to BackpropType.TruncatedBPTT");
            }
            if (this.inputType == null && this.inputPreProcessors.get(0) == null) {
                Layer layer = this.confs.get(0).getLayer();
                if (layer instanceof BaseRecurrentLayer) {
                    long nIn = ((BaseRecurrentLayer) layer).getNIn();
                    if (nIn > 0) {
                        this.inputType = InputType.recurrent(nIn);
                    }
                } else if ((layer instanceof DenseLayer) || (layer instanceof EmbeddingLayer) || (layer instanceof OutputLayer)) {
                    long nIn2 = ((FeedForwardLayer) layer).getNIn();
                    if (nIn2 > 0) {
                        this.inputType = InputType.feedForward(nIn2);
                    }
                }
            }
            if (this.inputType != null) {
                InputType inputType = this.inputType;
                for (int i = 0; i < this.confs.size(); i++) {
                    Layer layer2 = this.confs.get(i).getLayer();
                    if (this.inputPreProcessors.get(Integer.valueOf(i)) == null && (preProcessorForInputType = layer2.getPreProcessorForInputType(inputType)) != null) {
                        this.inputPreProcessors.put(Integer.valueOf(i), preProcessorForInputType);
                    }
                    InputPreProcessor inputPreProcessor = this.inputPreProcessors.get(Integer.valueOf(i));
                    if (inputPreProcessor != null) {
                        inputType = inputPreProcessor.getOutputType(inputType);
                    }
                    layer2.setNIn(inputType, false);
                    inputType = layer2.getOutputType(i, inputType);
                }
            }
            if (isPretrain()) {
                for (int i2 = 0; i2 < this.confs.size(); i2++) {
                    if (this.confs.get(i2).getLayer() instanceof BasePretrainNetwork) {
                        this.confs.get(i2).setPretrain(this.pretrain);
                    }
                }
            }
            MultiLayerConfiguration multiLayerConfiguration = new MultiLayerConfiguration();
            multiLayerConfiguration.confs = this.confs;
            multiLayerConfiguration.pretrain = this.pretrain;
            multiLayerConfiguration.backprop = this.backprop;
            multiLayerConfiguration.inputPreProcessors = this.inputPreProcessors;
            multiLayerConfiguration.backpropType = this.backpropType;
            multiLayerConfiguration.tbpttFwdLength = this.tbpttFwdLength;
            multiLayerConfiguration.tbpttBackLength = this.tbpttBackLength;
            multiLayerConfiguration.trainingWorkspaceMode = this.trainingWorkspaceMode;
            multiLayerConfiguration.inferenceWorkspaceMode = this.inferenceWorkspaceMode;
            multiLayerConfiguration.cacheMode = this.cacheMode;
            multiLayerConfiguration.legacyBatchScaledL2 = this.legacyBatchScaledL2;
            Nd4j.getRandom().setSeed(multiLayerConfiguration.getConf(0).getSeed());
            if (this.validateOutputConfig) {
                Iterator<NeuralNetConfiguration> it2 = multiLayerConfiguration.getConfs().iterator();
                while (it2.hasNext()) {
                    Layer layer3 = it2.next().getLayer();
                    OutputLayerUtil.validateOutputLayer(layer3.getLayerName(), layer3);
                }
            }
            return multiLayerConfiguration;
        }

        public List<NeuralNetConfiguration> getConfs() {
            return this.confs;
        }

        public double getDampingFactor() {
            return this.dampingFactor;
        }

        public Map<Integer, InputPreProcessor> getInputPreProcessors() {
            return this.inputPreProcessors;
        }

        @Deprecated
        public boolean isPretrain() {
            return this.pretrain;
        }

        @Deprecated
        public boolean isBackprop() {
            return this.backprop;
        }

        public BackpropType getBackpropType() {
            return this.backpropType;
        }

        public int getTbpttFwdLength() {
            return this.tbpttFwdLength;
        }

        public int getTbpttBackLength() {
            return this.tbpttBackLength;
        }

        public InputType getInputType() {
            return this.inputType;
        }

        public WorkspaceMode getTrainingWorkspaceMode() {
            return this.trainingWorkspaceMode;
        }

        public WorkspaceMode getInferenceWorkspaceMode() {
            return this.inferenceWorkspaceMode;
        }

        public CacheMode getCacheMode() {
            return this.cacheMode;
        }

        public boolean isValidateOutputConfig() {
            return this.validateOutputConfig;
        }

        public boolean isLegacyBatchScaledL2() {
            return this.legacyBatchScaledL2;
        }

        public void setConfs(List<NeuralNetConfiguration> list) {
            this.confs = list;
        }

        public void setDampingFactor(double d) {
            this.dampingFactor = d;
        }

        public void setInputPreProcessors(Map<Integer, InputPreProcessor> map) {
            this.inputPreProcessors = map;
        }

        @Deprecated
        public void setPretrain(boolean z) {
            this.pretrain = z;
        }

        @Deprecated
        public void setBackprop(boolean z) {
            this.backprop = z;
        }

        public void setBackpropType(BackpropType backpropType) {
            this.backpropType = backpropType;
        }

        public void setTbpttFwdLength(int i) {
            this.tbpttFwdLength = i;
        }

        public void setTbpttBackLength(int i) {
            this.tbpttBackLength = i;
        }

        public void setTrainingWorkspaceMode(WorkspaceMode workspaceMode) {
            this.trainingWorkspaceMode = workspaceMode;
        }

        public void setInferenceWorkspaceMode(WorkspaceMode workspaceMode) {
            this.inferenceWorkspaceMode = workspaceMode;
        }

        public void setCacheMode(CacheMode cacheMode) {
            this.cacheMode = cacheMode;
        }

        public void setValidateOutputConfig(boolean z) {
            this.validateOutputConfig = z;
        }

        public void setLegacyBatchScaledL2(boolean z) {
            this.legacyBatchScaledL2 = z;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Builder)) {
                return false;
            }
            Builder builder = (Builder) obj;
            if (!builder.canEqual(this)) {
                return false;
            }
            List<NeuralNetConfiguration> confs = getConfs();
            List<NeuralNetConfiguration> confs2 = builder.getConfs();
            if (confs == null) {
                if (confs2 != null) {
                    return false;
                }
            } else if (!confs.equals(confs2)) {
                return false;
            }
            if (Double.compare(getDampingFactor(), builder.getDampingFactor()) != 0) {
                return false;
            }
            Map<Integer, InputPreProcessor> inputPreProcessors = getInputPreProcessors();
            Map<Integer, InputPreProcessor> inputPreProcessors2 = builder.getInputPreProcessors();
            if (inputPreProcessors == null) {
                if (inputPreProcessors2 != null) {
                    return false;
                }
            } else if (!inputPreProcessors.equals(inputPreProcessors2)) {
                return false;
            }
            if (isPretrain() != builder.isPretrain() || isBackprop() != builder.isBackprop()) {
                return false;
            }
            BackpropType backpropType = getBackpropType();
            BackpropType backpropType2 = builder.getBackpropType();
            if (backpropType == null) {
                if (backpropType2 != null) {
                    return false;
                }
            } else if (!backpropType.equals(backpropType2)) {
                return false;
            }
            if (getTbpttFwdLength() != builder.getTbpttFwdLength() || getTbpttBackLength() != builder.getTbpttBackLength()) {
                return false;
            }
            InputType inputType = getInputType();
            InputType inputType2 = builder.getInputType();
            if (inputType == null) {
                if (inputType2 != null) {
                    return false;
                }
            } else if (!inputType.equals(inputType2)) {
                return false;
            }
            WorkspaceMode trainingWorkspaceMode = getTrainingWorkspaceMode();
            WorkspaceMode trainingWorkspaceMode2 = builder.getTrainingWorkspaceMode();
            if (trainingWorkspaceMode == null) {
                if (trainingWorkspaceMode2 != null) {
                    return false;
                }
            } else if (!trainingWorkspaceMode.equals(trainingWorkspaceMode2)) {
                return false;
            }
            WorkspaceMode inferenceWorkspaceMode = getInferenceWorkspaceMode();
            WorkspaceMode inferenceWorkspaceMode2 = builder.getInferenceWorkspaceMode();
            if (inferenceWorkspaceMode == null) {
                if (inferenceWorkspaceMode2 != null) {
                    return false;
                }
            } else if (!inferenceWorkspaceMode.equals(inferenceWorkspaceMode2)) {
                return false;
            }
            CacheMode cacheMode = getCacheMode();
            CacheMode cacheMode2 = builder.getCacheMode();
            if (cacheMode == null) {
                if (cacheMode2 != null) {
                    return false;
                }
            } else if (!cacheMode.equals(cacheMode2)) {
                return false;
            }
            return isValidateOutputConfig() == builder.isValidateOutputConfig() && isLegacyBatchScaledL2() == builder.isLegacyBatchScaledL2();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof Builder;
        }

        public int hashCode() {
            List<NeuralNetConfiguration> confs = getConfs();
            int hashCode = (1 * 59) + (confs == null ? 43 : confs.hashCode());
            long doubleToLongBits = Double.doubleToLongBits(getDampingFactor());
            int i = (hashCode * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
            Map<Integer, InputPreProcessor> inputPreProcessors = getInputPreProcessors();
            int hashCode2 = (((((i * 59) + (inputPreProcessors == null ? 43 : inputPreProcessors.hashCode())) * 59) + (isPretrain() ? 79 : 97)) * 59) + (isBackprop() ? 79 : 97);
            BackpropType backpropType = getBackpropType();
            int hashCode3 = (((((hashCode2 * 59) + (backpropType == null ? 43 : backpropType.hashCode())) * 59) + getTbpttFwdLength()) * 59) + getTbpttBackLength();
            InputType inputType = getInputType();
            int hashCode4 = (hashCode3 * 59) + (inputType == null ? 43 : inputType.hashCode());
            WorkspaceMode trainingWorkspaceMode = getTrainingWorkspaceMode();
            int hashCode5 = (hashCode4 * 59) + (trainingWorkspaceMode == null ? 43 : trainingWorkspaceMode.hashCode());
            WorkspaceMode inferenceWorkspaceMode = getInferenceWorkspaceMode();
            int hashCode6 = (hashCode5 * 59) + (inferenceWorkspaceMode == null ? 43 : inferenceWorkspaceMode.hashCode());
            CacheMode cacheMode = getCacheMode();
            return (((((hashCode6 * 59) + (cacheMode == null ? 43 : cacheMode.hashCode())) * 59) + (isValidateOutputConfig() ? 79 : 97)) * 59) + (isLegacyBatchScaledL2() ? 79 : 97);
        }

        public String toString() {
            return "MultiLayerConfiguration.Builder(confs=" + getConfs() + ", dampingFactor=" + getDampingFactor() + ", inputPreProcessors=" + getInputPreProcessors() + ", pretrain=" + isPretrain() + ", backprop=" + isBackprop() + ", backpropType=" + getBackpropType() + ", tbpttFwdLength=" + getTbpttFwdLength() + ", tbpttBackLength=" + getTbpttBackLength() + ", inputType=" + getInputType() + ", trainingWorkspaceMode=" + getTrainingWorkspaceMode() + ", inferenceWorkspaceMode=" + getInferenceWorkspaceMode() + ", cacheMode=" + getCacheMode() + ", validateOutputConfig=" + isValidateOutputConfig() + ", legacyBatchScaledL2=" + isLegacyBatchScaledL2() + ")";
        }
    }

    public int getEpochCount() {
        return this.epochCount;
    }

    public void setEpochCount(int i) {
        this.epochCount = i;
        for (int i2 = 0; i2 < this.confs.size(); i2++) {
            getConf(i2).setEpochCount(i);
        }
    }

    public String toYaml() {
        String writeValueAsString;
        ObjectMapper mapperYaml = NeuralNetConfiguration.mapperYaml();
        synchronized (mapperYaml) {
            try {
                writeValueAsString = mapperYaml.writeValueAsString(this);
            } catch (JsonProcessingException e) {
                throw new RuntimeException(e);
            }
        }
        return writeValueAsString;
    }

    public static MultiLayerConfiguration fromYaml(String str) {
        try {
            return (MultiLayerConfiguration) NeuralNetConfiguration.mapperYaml().readValue(str, MultiLayerConfiguration.class);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public String toJson() {
        String writeValueAsString;
        ObjectMapper mapper = NeuralNetConfiguration.mapper();
        synchronized (mapper) {
            try {
                writeValueAsString = mapper.writeValueAsString(this);
            } catch (JsonProcessingException e) {
                throw new RuntimeException(e);
            }
        }
        return writeValueAsString;
    }

    public static MultiLayerConfiguration fromJson(String str) {
        ObjectMapper mapper = NeuralNetConfiguration.mapper();
        try {
            MultiLayerConfiguration multiLayerConfiguration = (MultiLayerConfiguration) mapper.readValue(str, MultiLayerConfiguration.class);
            int i = 0;
            JsonNode jsonNode = null;
            Iterator<NeuralNetConfiguration> it2 = multiLayerConfiguration.getConfs().iterator();
            while (it2.hasNext()) {
                Layer layer = it2.next().getLayer();
                if ((layer instanceof BaseOutputLayer) && ((BaseOutputLayer) layer).getLossFn() == null) {
                    BaseOutputLayer baseOutputLayer = (BaseOutputLayer) layer;
                    try {
                        JsonNode readTree = mapper.readTree(str);
                        if (jsonNode == null) {
                            jsonNode = readTree.get("confs");
                        }
                        if (jsonNode instanceof ArrayNode) {
                            JsonNode jsonNode2 = ((ArrayNode) jsonNode).get(i);
                            if (jsonNode2 != null) {
                                JsonNode jsonNode3 = jsonNode2.get("layer");
                                JsonNode jsonNode4 = null;
                                if (jsonNode3.has("output")) {
                                    jsonNode4 = jsonNode3.get("output").get("lossFunction");
                                } else if (jsonNode3.has("rnnoutput")) {
                                    jsonNode4 = jsonNode3.get("rnnoutput").get("lossFunction");
                                }
                                if (jsonNode4 != null) {
                                    LossFunctions.LossFunction lossFunction = null;
                                    try {
                                        lossFunction = LossFunctions.LossFunction.valueOf(jsonNode4.asText());
                                    } catch (Exception e) {
                                        log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", (Throwable) e);
                                    }
                                    if (lossFunction != null) {
                                        switch (lossFunction) {
                                            case MSE:
                                                baseOutputLayer.setLossFn(new LossMSE());
                                                break;
                                            case XENT:
                                                baseOutputLayer.setLossFn(new LossBinaryXENT());
                                                break;
                                            case NEGATIVELOGLIKELIHOOD:
                                                baseOutputLayer.setLossFn(new LossNegativeLogLikelihood());
                                                break;
                                            case MCXENT:
                                                baseOutputLayer.setLossFn(new LossMCXENT());
                                                break;
                                            case EXPLL:
                                            case RMSE_XENT:
                                            case SQUARED_LOSS:
                                            case RECONSTRUCTION_CROSSENTROPY:
                                            case CUSTOM:
                                            default:
                                                log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}", lossFunction);
                                                break;
                                        }
                                    }
                                }
                            } else {
                                return multiLayerConfiguration;
                            }
                        } else {
                            log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})", jsonNode != null ? jsonNode.getClass() : null);
                        }
                    } catch (IOException e2) {
                        log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", (Throwable) e2);
                    }
                }
                if ((layer instanceof BaseLayer) && ((BaseLayer) layer).getActivationFn() == null) {
                    try {
                        JsonNode readTree2 = mapper.readTree(str);
                        if (jsonNode == null) {
                            jsonNode = readTree2.get("confs");
                        }
                    } catch (IOException e3) {
                        log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", (Throwable) e3);
                    }
                    if (jsonNode instanceof ArrayNode) {
                        JsonNode jsonNode5 = ((ArrayNode) jsonNode).get(i);
                        if (jsonNode5 == null) {
                            return multiLayerConfiguration;
                        }
                        JsonNode jsonNode6 = jsonNode5.get("layer");
                        if (jsonNode6 != null && jsonNode6.size() == 1) {
                            JsonNode jsonNode7 = jsonNode6.elements().next().get("activationFunction");
                            if (jsonNode7 != null) {
                                ((BaseLayer) layer).setActivationFn(Activation.fromString(jsonNode7.asText()).getActivationFunction());
                            }
                        }
                    }
                }
                i++;
            }
            return multiLayerConfiguration;
        } catch (IOException e4) {
            String message = e4.getMessage();
            if (message == null || !message.contains("legacy")) {
                throw new RuntimeException(e4);
            }
            throw new RuntimeException("Error deserializing MultiLayerConfiguration - configuration may have a custom layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)", e4);
        }
    }

    public String toString() {
        return toJson();
    }

    public NeuralNetConfiguration getConf(int i) {
        return this.confs.get(i);
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public MultiLayerConfiguration m6404clone() {
        try {
            MultiLayerConfiguration multiLayerConfiguration = (MultiLayerConfiguration) super.clone();
            if (multiLayerConfiguration.confs != null) {
                ArrayList arrayList = new ArrayList();
                Iterator<NeuralNetConfiguration> it2 = multiLayerConfiguration.confs.iterator();
                while (it2.hasNext()) {
                    arrayList.add(it2.next().m6407clone());
                }
                multiLayerConfiguration.confs = arrayList;
            }
            if (multiLayerConfiguration.inputPreProcessors != null) {
                HashMap hashMap = new HashMap();
                for (Map.Entry<Integer, InputPreProcessor> entry : multiLayerConfiguration.inputPreProcessors.entrySet()) {
                    hashMap.put(entry.getKey(), entry.getValue().m6472clone());
                }
                multiLayerConfiguration.inputPreProcessors = hashMap;
            }
            multiLayerConfiguration.inferenceWorkspaceMode = this.inferenceWorkspaceMode;
            multiLayerConfiguration.trainingWorkspaceMode = this.trainingWorkspaceMode;
            multiLayerConfiguration.cacheMode = this.cacheMode;
            multiLayerConfiguration.legacyBatchScaledL2 = this.legacyBatchScaledL2;
            return multiLayerConfiguration;
        } catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    public InputPreProcessor getInputPreProcess(int i) {
        return this.inputPreProcessors.get(Integer.valueOf(i));
    }

    public NetworkMemoryReport getMemoryReport(InputType inputType) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        int size = this.confs.size();
        for (int i = 0; i < size; i++) {
            String layerName = this.confs.get(i).getLayer().getLayerName();
            if (layerName == null) {
                layerName = String.valueOf(i);
            }
            InputPreProcessor inputPreProcess = getInputPreProcess(i);
            if (inputPreProcess != null) {
                inputType = inputPreProcess.getOutputType(inputType);
            }
            linkedHashMap.put(layerName, this.confs.get(i).getLayer().getMemoryReport(inputType));
            inputType = this.confs.get(i).getLayer().getOutputType(i, inputType);
        }
        return new NetworkMemoryReport(linkedHashMap, MultiLayerConfiguration.class, "MultiLayerNetwork", inputType);
    }

    public List<InputType> getLayerActivationTypes(@NonNull InputType inputType) {
        if (inputType == null) {
            throw new NullPointerException("inputType is marked @NonNull but is null");
        }
        ArrayList arrayList = new ArrayList();
        int size = this.confs.size();
        for (int i = 0; i < size; i++) {
            InputPreProcessor inputPreProcess = getInputPreProcess(i);
            if (inputPreProcess != null) {
                inputType = inputPreProcess.getOutputType(inputType);
            }
            inputType = this.confs.get(i).getLayer().getOutputType(i, inputType);
            arrayList.add(inputType);
        }
        return arrayList;
    }

    public List<NeuralNetConfiguration> getConfs() {
        return this.confs;
    }

    public Map<Integer, InputPreProcessor> getInputPreProcessors() {
        return this.inputPreProcessors;
    }

    public boolean isPretrain() {
        return this.pretrain;
    }

    public boolean isBackprop() {
        return this.backprop;
    }

    public BackpropType getBackpropType() {
        return this.backpropType;
    }

    public int getTbpttFwdLength() {
        return this.tbpttFwdLength;
    }

    public int getTbpttBackLength() {
        return this.tbpttBackLength;
    }

    public int getIterationCount() {
        return this.iterationCount;
    }

    public void setConfs(List<NeuralNetConfiguration> list) {
        this.confs = list;
    }

    public void setInputPreProcessors(Map<Integer, InputPreProcessor> map) {
        this.inputPreProcessors = map;
    }

    public void setPretrain(boolean z) {
        this.pretrain = z;
    }

    public void setBackprop(boolean z) {
        this.backprop = z;
    }

    public void setBackpropType(BackpropType backpropType) {
        this.backpropType = backpropType;
    }

    public void setTbpttFwdLength(int i) {
        this.tbpttFwdLength = i;
    }

    public void setTbpttBackLength(int i) {
        this.tbpttBackLength = i;
    }

    public void setIterationCount(int i) {
        this.iterationCount = i;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof MultiLayerConfiguration)) {
            return false;
        }
        MultiLayerConfiguration multiLayerConfiguration = (MultiLayerConfiguration) obj;
        if (!multiLayerConfiguration.canEqual(this)) {
            return false;
        }
        List<NeuralNetConfiguration> confs = getConfs();
        List<NeuralNetConfiguration> confs2 = multiLayerConfiguration.getConfs();
        if (confs == null) {
            if (confs2 != null) {
                return false;
            }
        } else if (!confs.equals(confs2)) {
            return false;
        }
        Map<Integer, InputPreProcessor> inputPreProcessors = getInputPreProcessors();
        Map<Integer, InputPreProcessor> inputPreProcessors2 = multiLayerConfiguration.getInputPreProcessors();
        if (inputPreProcessors == null) {
            if (inputPreProcessors2 != null) {
                return false;
            }
        } else if (!inputPreProcessors.equals(inputPreProcessors2)) {
            return false;
        }
        if (isPretrain() != multiLayerConfiguration.isPretrain() || isBackprop() != multiLayerConfiguration.isBackprop()) {
            return false;
        }
        BackpropType backpropType = getBackpropType();
        BackpropType backpropType2 = multiLayerConfiguration.getBackpropType();
        if (backpropType == null) {
            if (backpropType2 != null) {
                return false;
            }
        } else if (!backpropType.equals(backpropType2)) {
            return false;
        }
        if (getTbpttFwdLength() != multiLayerConfiguration.getTbpttFwdLength() || getTbpttBackLength() != multiLayerConfiguration.getTbpttBackLength() || isLegacyBatchScaledL2() != multiLayerConfiguration.isLegacyBatchScaledL2()) {
            return false;
        }
        WorkspaceMode trainingWorkspaceMode = getTrainingWorkspaceMode();
        WorkspaceMode trainingWorkspaceMode2 = multiLayerConfiguration.getTrainingWorkspaceMode();
        if (trainingWorkspaceMode == null) {
            if (trainingWorkspaceMode2 != null) {
                return false;
            }
        } else if (!trainingWorkspaceMode.equals(trainingWorkspaceMode2)) {
            return false;
        }
        WorkspaceMode inferenceWorkspaceMode = getInferenceWorkspaceMode();
        WorkspaceMode inferenceWorkspaceMode2 = multiLayerConfiguration.getInferenceWorkspaceMode();
        if (inferenceWorkspaceMode == null) {
            if (inferenceWorkspaceMode2 != null) {
                return false;
            }
        } else if (!inferenceWorkspaceMode.equals(inferenceWorkspaceMode2)) {
            return false;
        }
        CacheMode cacheMode = getCacheMode();
        CacheMode cacheMode2 = multiLayerConfiguration.getCacheMode();
        if (cacheMode == null) {
            if (cacheMode2 != null) {
                return false;
            }
        } else if (!cacheMode.equals(cacheMode2)) {
            return false;
        }
        return getIterationCount() == multiLayerConfiguration.getIterationCount() && getEpochCount() == multiLayerConfiguration.getEpochCount();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof MultiLayerConfiguration;
    }

    public int hashCode() {
        List<NeuralNetConfiguration> confs = getConfs();
        int hashCode = (1 * 59) + (confs == null ? 43 : confs.hashCode());
        Map<Integer, InputPreProcessor> inputPreProcessors = getInputPreProcessors();
        int hashCode2 = (((((hashCode * 59) + (inputPreProcessors == null ? 43 : inputPreProcessors.hashCode())) * 59) + (isPretrain() ? 79 : 97)) * 59) + (isBackprop() ? 79 : 97);
        BackpropType backpropType = getBackpropType();
        int hashCode3 = (((((((hashCode2 * 59) + (backpropType == null ? 43 : backpropType.hashCode())) * 59) + getTbpttFwdLength()) * 59) + getTbpttBackLength()) * 59) + (isLegacyBatchScaledL2() ? 79 : 97);
        WorkspaceMode trainingWorkspaceMode = getTrainingWorkspaceMode();
        int hashCode4 = (hashCode3 * 59) + (trainingWorkspaceMode == null ? 43 : trainingWorkspaceMode.hashCode());
        WorkspaceMode inferenceWorkspaceMode = getInferenceWorkspaceMode();
        int hashCode5 = (hashCode4 * 59) + (inferenceWorkspaceMode == null ? 43 : inferenceWorkspaceMode.hashCode());
        CacheMode cacheMode = getCacheMode();
        return (((((hashCode5 * 59) + (cacheMode == null ? 43 : cacheMode.hashCode())) * 59) + getIterationCount()) * 59) + getEpochCount();
    }

    private MultiLayerConfiguration(List<NeuralNetConfiguration> list, Map<Integer, InputPreProcessor> map, boolean z, boolean z2, BackpropType backpropType, int i, int i2, boolean z3, WorkspaceMode workspaceMode, WorkspaceMode workspaceMode2, CacheMode cacheMode, int i3, int i4) {
        this.inputPreProcessors = new HashMap();
        this.pretrain = false;
        this.backprop = true;
        this.backpropType = BackpropType.Standard;
        this.tbpttFwdLength = 20;
        this.tbpttBackLength = 20;
        this.legacyBatchScaledL2 = true;
        this.trainingWorkspaceMode = WorkspaceMode.ENABLED;
        this.inferenceWorkspaceMode = WorkspaceMode.ENABLED;
        this.iterationCount = 0;
        this.epochCount = 0;
        this.confs = list;
        this.inputPreProcessors = map;
        this.pretrain = z;
        this.backprop = z2;
        this.backpropType = backpropType;
        this.tbpttFwdLength = i;
        this.tbpttBackLength = i2;
        this.legacyBatchScaledL2 = z3;
        this.trainingWorkspaceMode = workspaceMode;
        this.inferenceWorkspaceMode = workspaceMode2;
        this.cacheMode = cacheMode;
        this.iterationCount = i3;
        this.epochCount = i4;
    }

    public MultiLayerConfiguration() {
        this.inputPreProcessors = new HashMap();
        this.pretrain = false;
        this.backprop = true;
        this.backpropType = BackpropType.Standard;
        this.tbpttFwdLength = 20;
        this.tbpttBackLength = 20;
        this.legacyBatchScaledL2 = true;
        this.trainingWorkspaceMode = WorkspaceMode.ENABLED;
        this.inferenceWorkspaceMode = WorkspaceMode.ENABLED;
        this.iterationCount = 0;
        this.epochCount = 0;
    }

    public boolean isLegacyBatchScaledL2() {
        return this.legacyBatchScaledL2;
    }

    public void setLegacyBatchScaledL2(boolean z) {
        this.legacyBatchScaledL2 = z;
    }

    public WorkspaceMode getTrainingWorkspaceMode() {
        return this.trainingWorkspaceMode;
    }

    public void setTrainingWorkspaceMode(WorkspaceMode workspaceMode) {
        this.trainingWorkspaceMode = workspaceMode;
    }

    public WorkspaceMode getInferenceWorkspaceMode() {
        return this.inferenceWorkspaceMode;
    }

    public void setInferenceWorkspaceMode(WorkspaceMode workspaceMode) {
        this.inferenceWorkspaceMode = workspaceMode;
    }

    public CacheMode getCacheMode() {
        return this.cacheMode;
    }

    public void setCacheMode(CacheMode cacheMode) {
        this.cacheMode = cacheMode;
    }
}
