package org.deeplearning4j.nn.conf.preprocessor;

import java.util.Arrays;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;

/* loaded from: input_file:org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.class */
public class RnnToFeedForwardPreProcessor implements InputPreProcessor {
    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public INDArray preProcess(INDArray iNDArray, int i) {
        if (iNDArray.rank() != 3) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 3 (i.e., activations for RNN layer)");
        }
        if (iNDArray.ordering() != 'f') {
            iNDArray = iNDArray.dup('f');
        }
        int[] shape = iNDArray.shape();
        return shape[0] == 1 ? iNDArray.tensorAlongDimension(0, 1, 2).permutei(1, 0) : shape[2] == 1 ? iNDArray.tensorAlongDimension(0, 1, 0) : iNDArray.permute(0, 2, 1).reshape('f', shape[0] * shape[2], shape[1]);
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public INDArray backprop(INDArray iNDArray, int i) {
        if (iNDArray == null) {
            return null;
        }
        if (iNDArray.rank() != 2) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2 (i.e., epsilons from feed forward layer)");
        }
        if (iNDArray.ordering() == 'c') {
            iNDArray = Shape.toOffsetZeroCopy(iNDArray, 'f');
        }
        int[] shape = iNDArray.shape();
        return iNDArray.reshape('f', i, shape[0] / i, shape[1]).permute(0, 2, 1);
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public RnnToFeedForwardPreProcessor m3432clone() {
        try {
            return (RnnToFeedForwardPreProcessor) super.clone();
        } catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public InputType getOutputType(InputType inputType) {
        if (inputType == null || inputType.getType() != InputType.Type.RNN) {
            throw new IllegalStateException("Invalid input: expected input of type RNN, got " + inputType);
        }
        return InputType.feedForward(((InputType.InputTypeRecurrent) inputType).getSize());
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray iNDArray, MaskState maskState, int i) {
        if (iNDArray == null) {
            return new Pair<>(iNDArray, maskState);
        }
        if (iNDArray.rank() == 2) {
            return new Pair<>(TimeSeriesUtils.reshapeTimeSeriesMaskToVector(iNDArray), maskState);
        }
        throw new IllegalArgumentException("Received mask array of rank " + iNDArray.rank() + "; expected rank 2 mask array. Mask array shape: " + Arrays.toString(iNDArray.shape()));
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        return (obj instanceof RnnToFeedForwardPreProcessor) && ((RnnToFeedForwardPreProcessor) obj).canEqual(this);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof RnnToFeedForwardPreProcessor;
    }

    public int hashCode() {
        return 1;
    }

    public String toString() {
        return "RnnToFeedForwardPreProcessor()";
    }
}
