package org.deeplearning4j.nn.weights;

import java.util.Arrays;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/weights/WeightInitIdentity.class */
public class WeightInitIdentity implements IWeightInit {
    @Override // org.deeplearning4j.nn.weights.IWeightInit
    public INDArray init(double d, double d2, long[] jArr, char c, INDArray iNDArray) {
        if (jArr[0] != jArr[1]) {
            throw new IllegalStateException("Cannot use IDENTITY init with parameters of shape " + Arrays.toString(jArr) + ": weights must be a square matrix for identity");
        }
        switch (jArr.length) {
            case 2:
                return setIdentity2D(jArr, c, iNDArray);
            case 3:
            case 4:
            case 5:
                return setIdentityConv(jArr, c, iNDArray);
            default:
                throw new IllegalStateException("Identity mapping for " + jArr.length + " dimensions not defined!");
        }
    }

    private INDArray setIdentity2D(long[] jArr, char c, INDArray iNDArray) {
        iNDArray.assign(Nd4j.toFlattened(c, c == Nd4j.order().charValue() ? Nd4j.eye(jArr[0]) : Nd4j.createUninitialized(jArr, c).assign(Nd4j.eye(jArr[0]))));
        return iNDArray.reshape(c, jArr);
    }

    private INDArray setIdentityConv(long[] jArr, char c, INDArray iNDArray) {
        INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[jArr.length];
        for (int i = 2; i < jArr.length; i++) {
            if (jArr[i] % 2 == 0) {
                throw new IllegalStateException("Cannot use IDENTITY init with parameters of shape " + Arrays.toString(jArr) + "! Must have odd sized kernels!");
            }
            iNDArrayIndexArr[i] = NDArrayIndex.point(jArr[i] / 2);
        }
        iNDArray.assign(Nd4j.zeros(iNDArray.shape()));
        INDArray reshape = iNDArray.reshape(c, jArr);
        for (int i2 = 0; i2 < jArr[0]; i2++) {
            iNDArrayIndexArr[0] = NDArrayIndex.point(i2);
            iNDArrayIndexArr[1] = NDArrayIndex.point(i2);
            reshape.put(iNDArrayIndexArr, Nd4j.ones(1));
        }
        return reshape;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        return (obj instanceof WeightInitIdentity) && ((WeightInitIdentity) obj).canEqual(this);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof WeightInitIdentity;
    }

    public int hashCode() {
        return 1;
    }
}
