package org.deeplearning4j.nn.conf.layers.recurrent;

import java.util.Collection;
import java.util.List;
import lombok.NonNull;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer;
import org.deeplearning4j.nn.params.BidirectionalParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;

@JsonIgnoreProperties({"initializer"})
/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.class */
public class Bidirectional extends Layer {
    private Layer fwd;
    private Layer bwd;
    private Mode mode;
    private transient BidirectionalParamInitializer initializer;

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional$Builder.class */
    public static class Builder extends Layer.Builder<Builder> {
        private Mode mode;
        private Layer layer;

        public void setLayer(Layer layer) {
            rnnLayer(layer);
        }

        public Builder mode(Mode mode) {
            setMode(mode);
            return this;
        }

        public Builder rnnLayer(Layer layer) {
            if (!(layer instanceof BaseRecurrentLayer) && !(layer instanceof LastTimeStep) && !(layer instanceof BaseWrapperLayer)) {
                throw new IllegalArgumentException("Cannot wrap a non-recurrent layer: config must extend BaseRecurrentLayer or LastTimeStep Got class: " + layer.getClass());
            }
            setLayer(layer);
            return this;
        }

        @Override // org.deeplearning4j.nn.conf.layers.Layer.Builder
        public Bidirectional build() {
            return new Bidirectional(this);
        }

        public Builder(Mode mode, Layer layer) {
            this.mode = mode;
            this.layer = layer;
        }

        public Mode getMode() {
            return this.mode;
        }

        public Layer getLayer() {
            return this.layer;
        }

        public void setMode(Mode mode) {
            this.mode = mode;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional$Mode.class */
    public enum Mode {
        ADD,
        MUL,
        AVERAGE,
        CONCAT
    }

    private Bidirectional(Builder builder) {
        super(builder);
    }

    public Bidirectional(@NonNull Layer layer) {
        this(Mode.CONCAT, layer);
        if (layer == null) {
            throw new NullPointerException("layer is marked @NonNull but is null");
        }
    }

    public Bidirectional(@NonNull Mode mode, @NonNull Layer layer) {
        if (mode == null) {
            throw new NullPointerException("mode is marked @NonNull but is null");
        }
        if (layer == null) {
            throw new NullPointerException("layer is marked @NonNull but is null");
        }
        if (!(layer instanceof BaseRecurrentLayer) && !(layer instanceof LastTimeStep) && !(layer instanceof BaseWrapperLayer)) {
            throw new IllegalArgumentException("Cannot wrap a non-recurrent layer: config must extend BaseRecurrentLayer or LastTimeStep Got class: " + layer.getClass());
        }
        this.fwd = layer;
        this.bwd = layer.mo6800clone();
        this.mode = mode;
    }

    public long getNOut() {
        return this.fwd instanceof LastTimeStep ? ((FeedForwardLayer) ((LastTimeStep) this.fwd).getUnderlying()).getNOut() : ((FeedForwardLayer) this.fwd).getNOut();
    }

    public long getNIn() {
        return this.fwd instanceof LastTimeStep ? ((FeedForwardLayer) ((LastTimeStep) this.fwd).getUnderlying()).getNIn() : ((FeedForwardLayer) this.fwd).getNIn();
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration neuralNetConfiguration, Collection<TrainingListener> collection, int i, INDArray iNDArray, boolean z, DataType dataType) {
        NeuralNetConfiguration m6779clone = neuralNetConfiguration.m6779clone();
        NeuralNetConfiguration m6779clone2 = neuralNetConfiguration.m6779clone();
        m6779clone.setLayer(this.fwd);
        m6779clone2.setLayer(this.bwd);
        long length = iNDArray.length() / 2;
        BidirectionalLayer bidirectionalLayer = new BidirectionalLayer(neuralNetConfiguration, this.fwd.instantiate(m6779clone, collection, i, iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(0L, length)), z, dataType), this.bwd.instantiate(m6779clone2, collection, i, iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(length, 2 * length)), z, dataType), iNDArray);
        bidirectionalLayer.setParamTable(initializer().init(neuralNetConfiguration, iNDArray, z));
        bidirectionalLayer.setConf(neuralNetConfiguration);
        return bidirectionalLayer;
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public ParamInitializer initializer() {
        if (this.initializer == null) {
            this.initializer = new BidirectionalParamInitializer(this);
        }
        return this.initializer;
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public InputType getOutputType(int i, InputType inputType) {
        InputType outputType = this.fwd.getOutputType(i, inputType);
        if (this.fwd instanceof LastTimeStep) {
            InputType.InputTypeFeedForward inputTypeFeedForward = (InputType.InputTypeFeedForward) outputType;
            return this.mode == Mode.CONCAT ? InputType.feedForward(2 * inputTypeFeedForward.getSize()) : inputTypeFeedForward;
        }
        InputType.InputTypeRecurrent inputTypeRecurrent = (InputType.InputTypeRecurrent) outputType;
        return this.mode == Mode.CONCAT ? InputType.recurrent(2 * inputTypeRecurrent.getSize()) : inputTypeRecurrent;
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public void setNIn(InputType inputType, boolean z) {
        this.fwd.setNIn(inputType, z);
        this.bwd.setNIn(inputType, z);
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        return this.fwd.getPreProcessorForInputType(inputType);
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer, org.deeplearning4j.nn.api.TrainingConfig
    public List<Regularization> getRegularizationByParam(String str) {
        return this.fwd.getRegularizationByParam(str.substring(1));
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer, org.deeplearning4j.nn.api.TrainingConfig
    public boolean isPretrainParam(String str) {
        return this.fwd.isPretrainParam(str.substring(1));
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer, org.deeplearning4j.nn.api.TrainingConfig
    public IUpdater getUpdaterByParam(String str) {
        return this.fwd.getUpdaterByParam(str.substring(1));
    }

    @Override // org.deeplearning4j.nn.api.TrainingConfig
    public GradientNormalization getGradientNormalization() {
        return this.fwd.getGradientNormalization();
    }

    @Override // org.deeplearning4j.nn.api.TrainingConfig
    public double getGradientNormalizationThreshold() {
        return this.fwd.getGradientNormalizationThreshold();
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public void setLayerName(String str) {
        this.layerName = str;
        this.fwd.setLayerName(str);
        this.bwd.setLayerName(str);
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public LayerMemoryReport getMemoryReport(InputType inputType) {
        LayerMemoryReport memoryReport = this.fwd.getMemoryReport(inputType);
        memoryReport.scale(2);
        return memoryReport;
    }

    public Bidirectional() {
    }

    public Layer getFwd() {
        return this.fwd;
    }

    public Layer getBwd() {
        return this.bwd;
    }

    public Mode getMode() {
        return this.mode;
    }

    public BidirectionalParamInitializer getInitializer() {
        return this.initializer;
    }

    public void setFwd(Layer layer) {
        this.fwd = layer;
    }

    public void setBwd(Layer layer) {
        this.bwd = layer;
    }

    public void setMode(Mode mode) {
        this.mode = mode;
    }

    public void setInitializer(BidirectionalParamInitializer bidirectionalParamInitializer) {
        this.initializer = bidirectionalParamInitializer;
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public String toString() {
        return "Bidirectional(fwd=" + getFwd() + ", bwd=" + getBwd() + ", mode=" + getMode() + ", initializer=" + getInitializer() + ")";
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof Bidirectional)) {
            return false;
        }
        Bidirectional bidirectional = (Bidirectional) obj;
        if (!bidirectional.canEqual(this) || !super.equals(obj)) {
            return false;
        }
        Layer fwd = getFwd();
        Layer fwd2 = bidirectional.getFwd();
        if (fwd == null) {
            if (fwd2 != null) {
                return false;
            }
        } else if (!fwd.equals(fwd2)) {
            return false;
        }
        Layer bwd = getBwd();
        Layer bwd2 = bidirectional.getBwd();
        if (bwd == null) {
            if (bwd2 != null) {
                return false;
            }
        } else if (!bwd.equals(bwd2)) {
            return false;
        }
        Mode mode = getMode();
        Mode mode2 = bidirectional.getMode();
        return mode == null ? mode2 == null : mode.equals(mode2);
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    protected boolean canEqual(Object obj) {
        return obj instanceof Bidirectional;
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public int hashCode() {
        int hashCode = super.hashCode();
        Layer fwd = getFwd();
        int hashCode2 = (hashCode * 59) + (fwd == null ? 43 : fwd.hashCode());
        Layer bwd = getBwd();
        int hashCode3 = (hashCode2 * 59) + (bwd == null ? 43 : bwd.hashCode());
        Mode mode = getMode();
        return (hashCode3 * 59) + (mode == null ? 43 : mode.hashCode());
    }
}
