package org.deeplearning4j.nn.updater;

import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/updater/MultiLayerUpdater.class */
public class MultiLayerUpdater implements Updater {
    private final Updater[] layerUpdaters;
    private INDArray viewArray;

    public MultiLayerUpdater(MultiLayerNetwork multiLayerNetwork) {
        Layer[] layers = multiLayerNetwork.getLayers();
        this.layerUpdaters = new Updater[layers.length];
        int i = 0;
        for (int i2 = 0; i2 < layers.length; i2++) {
            this.layerUpdaters[i2] = UpdaterCreator.getUpdater(layers[i2]);
            i += this.layerUpdaters[i2].stateSizeForLayer(layers[i2]);
        }
        if (i > 0) {
            this.viewArray = Nd4j.createUninitialized(new int[]{1, i}, Nd4j.order().charValue());
        }
        int i3 = 0;
        for (int i4 = 0; i4 < layers.length; i4++) {
            int stateSizeForLayer = this.layerUpdaters[i4].stateSizeForLayer(layers[i4]);
            if (stateSizeForLayer != 0) {
                this.layerUpdaters[i4].setStateViewArray(layers[i4], this.viewArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i3, i3 + stateSizeForLayer)}), true);
                i3 += stateSizeForLayer;
            }
        }
    }

    public MultiLayerUpdater(MultiLayerNetwork multiLayerNetwork, INDArray iNDArray) {
        Layer[] layers = multiLayerNetwork.getLayers();
        this.layerUpdaters = new Updater[layers.length];
        int i = 0;
        for (int i2 = 0; i2 < layers.length; i2++) {
            this.layerUpdaters[i2] = UpdaterCreator.getUpdater(layers[i2]);
            i += this.layerUpdaters[i2].stateSizeForLayer(layers[i2]);
        }
        if (iNDArray == null) {
            if (i != 0) {
                throw new IllegalStateException("Expected updater state with size " + i + ", got null input");
            }
            return;
        }
        if (iNDArray.length() != i) {
            throw new IllegalStateException("Expected updater state with size " + i + ", got size " + iNDArray.length());
        }
        this.viewArray = iNDArray;
        int i3 = 0;
        for (int i4 = 0; i4 < layers.length; i4++) {
            int stateSizeForLayer = this.layerUpdaters[i4].stateSizeForLayer(layers[i4]);
            if (stateSizeForLayer != 0) {
                this.layerUpdaters[i4].setStateViewArray(layers[i4], this.viewArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i3, i3 + stateSizeForLayer)}), false);
                i3 += stateSizeForLayer;
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Updater
    public void setStateViewArray(Layer layer, INDArray iNDArray, boolean z) {
        if (this.viewArray.length() != iNDArray.length()) {
            throw new IllegalStateException("Invalid input: view arrays differ in length. Expected length " + this.viewArray.length() + ", got length " + iNDArray.length());
        }
        this.viewArray.assign(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Updater
    public INDArray getStateViewArray() {
        return this.viewArray;
    }

    @Override // org.deeplearning4j.nn.api.Updater
    public int stateSizeForLayer(Layer layer) {
        if (layer instanceof MultiLayerNetwork) {
            return this.viewArray.length();
        }
        throw new IllegalArgumentException("Expected MultiLayerNetwork");
    }

    @Override // org.deeplearning4j.nn.api.Updater
    public void update(Layer layer, Gradient gradient, int i, int i2) {
        MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) layer;
        Gradient[] gradientArr = new Gradient[this.layerUpdaters.length];
        for (int i3 = 0; i3 < gradientArr.length; i3++) {
            gradientArr[i3] = new DefaultGradient();
        }
        for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
            String key = entry.getKey();
            int indexOf = key.indexOf(95);
            if (indexOf == -1) {
                throw new IllegalStateException("Invalid key: MuliLayerNetwork Gradient key does not have layer separator: \"" + key + "\"");
            }
            gradientArr[Integer.parseInt(key.substring(0, indexOf))].gradientForVariable().put(key.substring(indexOf + 1), entry.getValue());
        }
        for (int i4 = 0; i4 < this.layerUpdaters.length; i4++) {
            this.layerUpdaters[i4].update(multiLayerNetwork.getLayer(i4), gradientArr[i4], i, i2);
        }
    }

    @Override // org.deeplearning4j.nn.api.Updater
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Updater m90clone() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    public boolean equals(Object obj) {
        if (!(obj instanceof MultiLayerUpdater)) {
            return false;
        }
        MultiLayerUpdater multiLayerUpdater = (MultiLayerUpdater) obj;
        if (this.layerUpdaters.length != multiLayerUpdater.layerUpdaters.length) {
            return false;
        }
        for (int i = 0; i < this.layerUpdaters.length; i++) {
            if (!this.layerUpdaters[i].equals(multiLayerUpdater.layerUpdaters[i])) {
                return false;
            }
        }
        return true;
    }

    public Updater[] getLayerUpdaters() {
        return this.layerUpdaters;
    }

    public INDArray getViewArray() {
        return this.viewArray;
    }
}
