package org.deeplearning4j.nn.layers.convolution;

import java.util.Arrays;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/layers/convolution/Convolution1DLayer.class */
public class Convolution1DLayer extends ConvolutionLayer {
    public Convolution1DLayer(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public Convolution1DLayer(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    @Override // org.deeplearning4j.nn.layers.convolution.ConvolutionLayer, org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        if (iNDArray.rank() != 3) {
            throw new DL4JInvalidInputException("Got rank " + iNDArray.rank() + " array as epsilon for Convolution1DLayer backprop with shape " + Arrays.toString(iNDArray.shape()) + ". Expected rank 3 array with shape [minibatchSize, features, length].");
        }
        INDArray reshape = iNDArray.reshape(iNDArray.size(0), iNDArray.size(1), iNDArray.size(2), 1);
        INDArray iNDArray2 = this.input;
        this.input = this.input.reshape(this.input.size(0), this.input.size(1), this.input.size(2), 1);
        Pair<Gradient, INDArray> backpropGradient = super.backpropGradient(reshape);
        INDArray second = backpropGradient.getSecond();
        INDArray reshape2 = second.reshape(second.size(0), second.size(1), second.size(2));
        this.input = iNDArray2;
        return new Pair<>(backpropGradient.getFirst(), reshape2);
    }

    @Override // org.deeplearning4j.nn.layers.convolution.ConvolutionLayer
    protected INDArray preOutput4d(boolean z) {
        return super.preOutput(true);
    }

    @Override // org.deeplearning4j.nn.layers.convolution.ConvolutionLayer, org.deeplearning4j.nn.layers.BaseLayer
    public INDArray preOutput(boolean z) {
        INDArray iNDArray = this.input;
        this.input = this.input.reshape(this.input.size(0), this.input.size(1), this.input.size(2), 1);
        INDArray preOutput = super.preOutput(z);
        this.input = iNDArray;
        return preOutput.reshape(preOutput.size(0), preOutput.size(1), preOutput.size(2));
    }
}
