package org.deeplearning4j.nn.conf.layers.variational;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/variational/CompositeReconstructionDistribution.class */
public class CompositeReconstructionDistribution implements ReconstructionDistribution {
    private final int[] distributionSizes;
    private final ReconstructionDistribution[] reconstructionDistributions;
    private final int totalSize;

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/variational/CompositeReconstructionDistribution$Builder.class */
    public static class Builder {
        private List<Integer> distributionSizes = new ArrayList();
        private List<ReconstructionDistribution> reconstructionDistributions = new ArrayList();

        public Builder addDistribution(int i, ReconstructionDistribution reconstructionDistribution) {
            this.distributionSizes.add(Integer.valueOf(i));
            this.reconstructionDistributions.add(reconstructionDistribution);
            return this;
        }

        public CompositeReconstructionDistribution build() {
            return new CompositeReconstructionDistribution(this);
        }
    }

    public CompositeReconstructionDistribution(@JsonProperty("distributionSizes") int[] iArr, @JsonProperty("reconstructionDistributions") ReconstructionDistribution[] reconstructionDistributionArr, @JsonProperty("totalSize") int i) {
        this.distributionSizes = iArr;
        this.reconstructionDistributions = reconstructionDistributionArr;
        this.totalSize = i;
    }

    private CompositeReconstructionDistribution(Builder builder) {
        this.distributionSizes = new int[builder.distributionSizes.size()];
        this.reconstructionDistributions = new ReconstructionDistribution[this.distributionSizes.length];
        int i = 0;
        for (int i2 = 0; i2 < this.distributionSizes.length; i2++) {
            this.distributionSizes[i2] = ((Integer) builder.distributionSizes.get(i2)).intValue();
            this.reconstructionDistributions[i2] = (ReconstructionDistribution) builder.reconstructionDistributions.get(i2);
            i += this.distributionSizes[i2];
        }
        this.totalSize = i;
    }

    public INDArray computeLossFunctionScoreArray(INDArray iNDArray, INDArray iNDArray2) {
        if (!hasLossFunction()) {
            throw new IllegalStateException("Cannot compute score array unless hasLossFunction() == true");
        }
        int i = 0;
        int i2 = 0;
        INDArray iNDArray3 = null;
        for (int i3 = 0; i3 < this.distributionSizes.length; i3++) {
            int i4 = this.distributionSizes[i3];
            int distributionInputSize = this.reconstructionDistributions[i3].distributionInputSize(i4);
            INDArray iNDArray4 = iNDArray.get(NDArrayIndex.all(), NDArrayIndex.interval(i, i + i4));
            INDArray iNDArray5 = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.interval(i2, i2 + distributionInputSize));
            if (i3 == 0) {
                iNDArray3 = getScoreArray(this.reconstructionDistributions[i3], iNDArray4, iNDArray5);
            } else {
                iNDArray3.addi(getScoreArray(this.reconstructionDistributions[i3], iNDArray4, iNDArray5));
            }
            i += i4;
            i2 += distributionInputSize;
        }
        return iNDArray3;
    }

    private INDArray getScoreArray(ReconstructionDistribution reconstructionDistribution, INDArray iNDArray, INDArray iNDArray2) {
        if (reconstructionDistribution instanceof LossFunctionWrapper) {
            return ((LossFunctionWrapper) reconstructionDistribution).getLossFunction().computeScoreArray(iNDArray, iNDArray2, new ActivationIdentity(), null);
        }
        if (reconstructionDistribution instanceof CompositeReconstructionDistribution) {
            return ((CompositeReconstructionDistribution) reconstructionDistribution).computeLossFunctionScoreArray(iNDArray, iNDArray2);
        }
        throw new UnsupportedOperationException("Cannot calculate composite reconstruction distribution");
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public boolean hasLossFunction() {
        for (ReconstructionDistribution reconstructionDistribution : this.reconstructionDistributions) {
            if (!reconstructionDistribution.hasLossFunction()) {
                return false;
            }
        }
        return true;
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public int distributionInputSize(int i) {
        if (i != this.totalSize) {
            throw new IllegalStateException("Invalid input size: Got input size " + i + " for data, but expected input size for all distributions is " + this.totalSize + ". Distribution sizes: " + Arrays.toString(this.distributionSizes));
        }
        int i2 = 0;
        for (int i3 = 0; i3 < this.distributionSizes.length; i3++) {
            i2 += this.reconstructionDistributions[i3].distributionInputSize(this.distributionSizes[i3]);
        }
        return i2;
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public double negLogProbability(INDArray iNDArray, INDArray iNDArray2, boolean z) {
        int i = 0;
        int i2 = 0;
        double d = 0.0d;
        for (int i3 = 0; i3 < this.distributionSizes.length; i3++) {
            int i4 = this.distributionSizes[i3];
            int distributionInputSize = this.reconstructionDistributions[i3].distributionInputSize(i4);
            d += this.reconstructionDistributions[i3].negLogProbability(iNDArray.get(NDArrayIndex.all(), NDArrayIndex.interval(i, i + i4)), iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.interval(i2, i2 + distributionInputSize)), z);
            i += i4;
            i2 += distributionInputSize;
        }
        return d;
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray exampleNegLogProbability(INDArray iNDArray, INDArray iNDArray2) {
        int i = 0;
        int i2 = 0;
        INDArray iNDArray3 = null;
        for (int i3 = 0; i3 < this.distributionSizes.length; i3++) {
            int i4 = this.distributionSizes[i3];
            int distributionInputSize = this.reconstructionDistributions[i3].distributionInputSize(i4);
            INDArray iNDArray4 = iNDArray.get(NDArrayIndex.all(), NDArrayIndex.interval(i, i + i4));
            INDArray iNDArray5 = iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.interval(i2, i2 + distributionInputSize));
            if (i3 == 0) {
                iNDArray3 = this.reconstructionDistributions[i3].exampleNegLogProbability(iNDArray4, iNDArray5);
            } else {
                iNDArray3.addi(this.reconstructionDistributions[i3].exampleNegLogProbability(iNDArray4, iNDArray5));
            }
            i += i4;
            i2 += distributionInputSize;
        }
        return iNDArray3;
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray gradient(INDArray iNDArray, INDArray iNDArray2) {
        int i = 0;
        int i2 = 0;
        INDArray createUninitialized = Nd4j.createUninitialized(iNDArray2.shape());
        for (int i3 = 0; i3 < this.distributionSizes.length; i3++) {
            int i4 = this.distributionSizes[i3];
            int distributionInputSize = this.reconstructionDistributions[i3].distributionInputSize(i4);
            createUninitialized.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i2, i2 + distributionInputSize)}, this.reconstructionDistributions[i3].gradient(iNDArray.get(NDArrayIndex.all(), NDArrayIndex.interval(i, i + i4)), iNDArray2.get(NDArrayIndex.all(), NDArrayIndex.interval(i2, i2 + distributionInputSize))));
            i += i4;
            i2 += distributionInputSize;
        }
        return createUninitialized;
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray generateRandom(INDArray iNDArray) {
        return randomSample(iNDArray, false);
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray generateAtMean(INDArray iNDArray) {
        return randomSample(iNDArray, true);
    }

    private INDArray randomSample(INDArray iNDArray, boolean z) {
        int i = 0;
        int i2 = 0;
        INDArray createUninitialized = Nd4j.createUninitialized(new int[]{iNDArray.size(0), this.totalSize});
        for (int i3 = 0; i3 < this.distributionSizes.length; i3++) {
            int i4 = this.distributionSizes[i3];
            int distributionInputSize = this.reconstructionDistributions[i3].distributionInputSize(i4);
            INDArray iNDArray2 = iNDArray.get(NDArrayIndex.all(), NDArrayIndex.interval(i2, i2 + distributionInputSize));
            createUninitialized.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i, i + i4)}, z ? this.reconstructionDistributions[i3].generateAtMean(iNDArray2) : this.reconstructionDistributions[i3].generateRandom(iNDArray2));
            i += i4;
            i2 += distributionInputSize;
        }
        return createUninitialized;
    }

    public int[] getDistributionSizes() {
        return this.distributionSizes;
    }

    public ReconstructionDistribution[] getReconstructionDistributions() {
        return this.reconstructionDistributions;
    }

    public int getTotalSize() {
        return this.totalSize;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof CompositeReconstructionDistribution)) {
            return false;
        }
        CompositeReconstructionDistribution compositeReconstructionDistribution = (CompositeReconstructionDistribution) obj;
        return compositeReconstructionDistribution.canEqual(this) && Arrays.equals(getDistributionSizes(), compositeReconstructionDistribution.getDistributionSizes()) && Arrays.deepEquals(getReconstructionDistributions(), compositeReconstructionDistribution.getReconstructionDistributions()) && getTotalSize() == compositeReconstructionDistribution.getTotalSize();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof CompositeReconstructionDistribution;
    }

    public int hashCode() {
        return (((((1 * 59) + Arrays.hashCode(getDistributionSizes())) * 59) + Arrays.deepHashCode(getReconstructionDistributions())) * 59) + getTotalSize();
    }

    public String toString() {
        return "CompositeReconstructionDistribution(distributionSizes=" + Arrays.toString(getDistributionSizes()) + ", reconstructionDistributions=" + Arrays.deepToString(getReconstructionDistributions()) + ", totalSize=" + getTotalSize() + ")";
    }
}
