package org.deeplearning4j.nn.modelimport.keras.preprocessors;

import java.util.Arrays;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize"})
/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.class */
public class ReshapePreprocessor extends BaseInputPreProcessor {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) ReshapePreprocessor.class);
    private long[] inputShape;
    private long[] targetShape;
    private boolean hasMiniBatchDimension = false;
    private int miniBatchSize;

    public ReshapePreprocessor(@JsonProperty("inputShape") long[] jArr, @JsonProperty("targetShape") long[] jArr2) {
        this.inputShape = jArr;
        this.targetShape = jArr2;
    }

    private static int prod(int[] iArr) {
        int i = 1;
        for (int i2 : iArr) {
            i *= i2;
        }
        return i;
    }

    private static long[] prependMiniBatchSize(long[] jArr, long j) {
        long[] jArr2 = new long[jArr.length + 1];
        for (int i = 0; i < jArr2.length; i++) {
            if (i == 0) {
                jArr2[i] = j;
            } else {
                jArr2[i] = jArr[i - 1];
            }
        }
        return jArr2;
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public INDArray preProcess(INDArray iNDArray, int i, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (!this.hasMiniBatchDimension) {
            this.targetShape = prependMiniBatchSize(this.targetShape, i);
            this.inputShape = prependMiniBatchSize(this.inputShape, i);
            this.hasMiniBatchDimension = true;
            this.miniBatchSize = i;
        }
        if (this.miniBatchSize != i) {
            this.targetShape = prependMiniBatchSize(ArrayUtils.subarray(this.targetShape, 1, this.targetShape.length), i);
            this.inputShape = prependMiniBatchSize(ArrayUtils.subarray(this.inputShape, 1, this.targetShape.length), i);
            this.miniBatchSize = i;
        }
        if (ArrayUtil.prodLong(iNDArray.shape()) != ArrayUtil.prodLong(this.targetShape)) {
            throw new IllegalStateException("Input shape " + Arrays.toString(iNDArray.shape()) + " and output shape" + Arrays.toString(this.inputShape) + " do not match");
        }
        if (iNDArray.ordering() != 'c' || !Shape.hasDefaultStridesForShape(iNDArray)) {
            iNDArray = layerWorkspaceMgr.dup(ArrayType.ACTIVATIONS, iNDArray, 'c');
        }
        return layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATIONS, iNDArray.reshape(this.targetShape));
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public INDArray backprop(INDArray iNDArray, int i, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (!Arrays.equals(this.targetShape, iNDArray.shape())) {
            throw new IllegalStateException("Unexpected output shape" + Arrays.toString(iNDArray.shape()) + " (expected to be " + Arrays.toString(this.targetShape) + ")");
        }
        if (ArrayUtil.prodLong(iNDArray.shape()) != ArrayUtil.prodLong(this.targetShape)) {
            throw new IllegalStateException("Output shape" + Arrays.toString(iNDArray.shape()) + " and input shape" + Arrays.toString(this.targetShape) + " do not match");
        }
        if (iNDArray.ordering() != 'c' || !Shape.hasDefaultStridesForShape(iNDArray)) {
            iNDArray = layerWorkspaceMgr.dup(ArrayType.ACTIVATIONS, iNDArray, 'c');
        }
        return layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, iNDArray.reshape(this.inputShape));
    }

    @Override // org.deeplearning4j.nn.conf.InputPreProcessor
    public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
        long[] prependMiniBatchSize = this.hasMiniBatchDimension ? this.targetShape : prependMiniBatchSize(this.targetShape, 0L);
        switch (prependMiniBatchSize.length) {
            case 2:
                return InputType.feedForward(prependMiniBatchSize[1]);
            case 3:
                return InputType.recurrent(prependMiniBatchSize[2], prependMiniBatchSize[1]);
            case 4:
                return (this.inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) ? InputType.convolutional(prependMiniBatchSize[1], prependMiniBatchSize[2], prependMiniBatchSize[3]) : InputType.convolutional(prependMiniBatchSize[2], prependMiniBatchSize[3], prependMiniBatchSize[1]);
            default:
                throw new UnsupportedOperationException("Cannot infer input type for reshape array " + Arrays.toString(prependMiniBatchSize));
        }
    }

    public long[] getInputShape() {
        return this.inputShape;
    }

    public long[] getTargetShape() {
        return this.targetShape;
    }

    public boolean isHasMiniBatchDimension() {
        return this.hasMiniBatchDimension;
    }

    public int getMiniBatchSize() {
        return this.miniBatchSize;
    }

    public void setInputShape(long[] jArr) {
        this.inputShape = jArr;
    }

    public void setTargetShape(long[] jArr) {
        this.targetShape = jArr;
    }

    public void setHasMiniBatchDimension(boolean z) {
        this.hasMiniBatchDimension = z;
    }

    public void setMiniBatchSize(int i) {
        this.miniBatchSize = i;
    }

    public String toString() {
        return "ReshapePreprocessor(inputShape=" + Arrays.toString(getInputShape()) + ", targetShape=" + Arrays.toString(getTargetShape()) + ", hasMiniBatchDimension=" + isHasMiniBatchDimension() + ", miniBatchSize=" + getMiniBatchSize() + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ReshapePreprocessor)) {
            return false;
        }
        ReshapePreprocessor reshapePreprocessor = (ReshapePreprocessor) obj;
        return reshapePreprocessor.canEqual(this) && Arrays.equals(getInputShape(), reshapePreprocessor.getInputShape()) && Arrays.equals(getTargetShape(), reshapePreprocessor.getTargetShape()) && isHasMiniBatchDimension() == reshapePreprocessor.isHasMiniBatchDimension() && getMiniBatchSize() == reshapePreprocessor.getMiniBatchSize();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ReshapePreprocessor;
    }

    public int hashCode() {
        return (((((((1 * 59) + Arrays.hashCode(getInputShape())) * 59) + Arrays.hashCode(getTargetShape())) * 59) + (isHasMiniBatchDimension() ? 79 : 97)) * 59) + getMiniBatchSize();
    }
}
