package org.deeplearning4j.nn.graph.vertex.impl;

import java.util.Arrays;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.Or;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.class */
public class MergeVertex extends BaseGraphVertex {
    private int[][] forwardPassShapes;
    private int fwdPassRank;

    public MergeVertex(ComputationGraph computationGraph, String str, int i) {
        this(computationGraph, str, i, null, null);
    }

    public MergeVertex(ComputationGraph computationGraph, String str, int i, VertexIndices[] vertexIndicesArr, VertexIndices[] vertexIndicesArr2) {
        super(computationGraph, str, i, vertexIndicesArr, vertexIndicesArr2);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
    public String toString() {
        return "MergeVertex(id=" + getVertexIndex() + ",name=\"" + getVertexName() + "\")";
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean hasLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean isOutputVertex() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Layer getLayer() {
        return null;
    }

    /* JADX WARN: Type inference failed for: r1v59, types: [int[], int[][]] */
    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public INDArray doForward(boolean z) {
        INDArray create;
        if (!canDoForward()) {
            throw new IllegalStateException("Cannot do forward pass: inputs not set");
        }
        if (this.inputs.length == 1) {
            int[] shape = this.inputs[0].shape();
            this.forwardPassShapes = new int[]{Arrays.copyOf(shape, shape.length)};
            return this.inputs[0];
        }
        this.forwardPassShapes = new int[this.inputs.length][0];
        int size = this.inputs[0].size(0);
        int i = 0;
        this.fwdPassRank = this.inputs[0].rank();
        for (int i2 = 0; i2 < this.inputs.length; i2++) {
            int[] shape2 = this.inputs[i2].shape();
            if (this.fwdPassRank != shape2.length) {
                throw new IllegalStateException("Cannot merge activations with different ranks: first activations have rank " + this.fwdPassRank + ", activations[" + i2 + "] have rank " + shape2.length + " (shape=" + Arrays.toString(shape2) + ")");
            }
            this.forwardPassShapes[i2] = Arrays.copyOf(shape2, shape2.length);
            if (shape2[0] != size) {
                throw new IllegalStateException("Cannot merge activations with different number of examples (activations[0] shape: " + Arrays.toString(this.inputs[0].shape()) + ", activations[" + i2 + "] shape: " + Arrays.toString(this.inputs[i2].shape()));
            }
            i += shape2[1];
        }
        int i3 = 0;
        switch (this.inputs[0].rank()) {
            case 2:
                create = Nd4j.create(size, i);
                for (INDArray iNDArray : this.inputs) {
                    int[] shape3 = iNDArray.shape();
                    create.get(NDArrayIndex.all(), NDArrayIndex.interval(i3, i3 + shape3[1])).assign(iNDArray);
                    i3 += shape3[1];
                }
                break;
            case 3:
                create = Nd4j.create(size, i, this.inputs[0].size(2));
                for (INDArray iNDArray2 : this.inputs) {
                    int[] shape4 = iNDArray2.shape();
                    create.get(NDArrayIndex.all(), NDArrayIndex.interval(i3, i3 + shape4[1]), NDArrayIndex.all()).assign(iNDArray2);
                    i3 += shape4[1];
                }
                break;
            case 4:
                this.fwdPassRank = 4;
                int[] copyOf = Arrays.copyOf(this.inputs[0].shape(), 4);
                copyOf[1] = i;
                create = Nd4j.create(copyOf);
                for (INDArray iNDArray3 : this.inputs) {
                    create.get(NDArrayIndex.all(), NDArrayIndex.interval(i3, i3 + iNDArray3.size(1)), NDArrayIndex.all(), NDArrayIndex.all()).assign(iNDArray3);
                    i3 += iNDArray3.size(1);
                }
                break;
            default:
                throw new UnsupportedOperationException("Cannot merge activations with rank 4 or more");
        }
        return create;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<Gradient, INDArray[]> doBackward(boolean z) {
        if (!canDoBackward()) {
            throw new IllegalStateException("Cannot do backward pass: errors not set");
        }
        if (this.forwardPassShapes.length == 1) {
            return new Pair<>(null, new INDArray[]{this.epsilon});
        }
        INDArray[] iNDArrayArr = new INDArray[this.forwardPassShapes.length];
        for (int i = 0; i < iNDArrayArr.length; i++) {
            iNDArrayArr[i] = Nd4j.create(this.forwardPassShapes[i]);
        }
        int i2 = 0;
        switch (this.fwdPassRank) {
            case 2:
                for (int i3 = 0; i3 < this.forwardPassShapes.length; i3++) {
                    iNDArrayArr[i3].assign(this.epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(i2, i2 + this.forwardPassShapes[i3][1])));
                    i2 += this.forwardPassShapes[i3][1];
                }
                break;
            case 3:
                for (int i4 = 0; i4 < this.forwardPassShapes.length; i4++) {
                    iNDArrayArr[i4].assign(this.epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(i2, i2 + this.forwardPassShapes[i4][1]), NDArrayIndex.all()));
                    i2 += this.forwardPassShapes[i4][1];
                }
                break;
            case 4:
                for (int i5 = 0; i5 < this.forwardPassShapes.length; i5++) {
                    iNDArrayArr[i5].assign(this.epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(i2, i2 + this.forwardPassShapes[i5][1]), NDArrayIndex.all(), NDArrayIndex.all()));
                    i2 += this.forwardPassShapes[i5][1];
                }
                break;
            default:
                throw new RuntimeException("Invalid rank during forward pass (not 2, 3, 4)");
        }
        return new Pair<>(null, iNDArrayArr);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        if (iNDArray != null) {
            throw new RuntimeException("Vertex does not have gradients; gradients view array cannot be set here");
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] iNDArrayArr, MaskState maskState, int i) {
        if (iNDArrayArr == null) {
            return new Pair<>(null, maskState);
        }
        for (INDArray iNDArray : iNDArrayArr) {
            if (iNDArray == null) {
                return new Pair<>(null, maskState);
            }
        }
        if (iNDArrayArr.length == 1) {
            return new Pair<>(iNDArrayArr[0], maskState);
        }
        INDArray dup = iNDArrayArr[0].dup(iNDArrayArr[0].ordering());
        Nd4j.getExecutioner().exec(new Or(iNDArrayArr[0], iNDArrayArr[1], dup));
        for (int i2 = 2; i2 < iNDArrayArr.length; i2++) {
            Nd4j.getExecutioner().exec(new Or(iNDArrayArr[i2], dup, dup));
        }
        return new Pair<>(dup, maskState);
    }
}
