package com.kotlinnlp.simplednn.deeplearning.sequenceencoder;

import com.kotlinnlp.simplednn.core.neuralnetwork.NetworkParameters;
import com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor;
import com.kotlinnlp.simplednn.core.optimizer.ParamsErrorsAccumulator;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import java.util.ArrayList;
import java.util.Arrays;
import kotlin.Metadata;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: SequenceFeedforwardEncoder.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��P\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n��\n\u0002\u0010\u0002\n��\n\u0002\u0010\u0011\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n\u0002\b\u0010\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\u00020\u0003B\r\u0012\u0006\u0010\u0004\u001a\u00020\u0005¢\u0006\u0002\u0010\u0006J!\u0010\u0012\u001a\u00020\u00132\f\u0010\u0014\u001a\b\u0012\u0004\u0012\u00020\u00160\u00152\u0006\u0010\u0017\u001a\u00020\u0018¢\u0006\u0002\u0010\u0019J\u001f\u0010\u001a\u001a\b\u0012\u0004\u0012\u00020\u00160\u00152\f\u0010\u001b\u001a\b\u0012\u0004\u0012\u00028��0\u0015¢\u0006\u0002\u0010\u001cJ!\u0010\u001d\u001a\b\u0012\u0004\u0012\u00020\u00160\u00152\f\u0010\u001b\u001a\b\u0012\u0004\u0012\u00028��0\u0015H\u0002¢\u0006\u0002\u0010\u001cJ\u001b\u0010\u001e\u001a\b\u0012\u0004\u0012\u00020\u00160\u00152\b\b\u0002\u0010\u001f\u001a\u00020\u0018¢\u0006\u0002\u0010 J\u0010\u0010!\u001a\u00020\t2\b\b\u0002\u0010\u001f\u001a\u00020\u0018J\u0016\u0010\"\u001a\b\u0012\u0004\u0012\u00028��0\u000e2\u0006\u0010#\u001a\u00020\u0011H\u0002J&\u0010$\u001a\u00020\u00132\f\u0010%\u001a\b\u0012\u0004\u0012\u00028��0\u000e2\u0006\u0010&\u001a\u00020\u00162\u0006\u0010\u0017\u001a\u00020\u0018H\u0002J\b\u0010'\u001a\u00020\u0013H\u0002R\u0014\u0010\u0007\u001a\b\u0012\u0004\u0012\u00020\t0\bX\u0082\u0004¢\u0006\u0002\n��R\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\n\u0010\u000bR*\u0010\f\u001a\u001e\u0012\n\u0012\b\u0012\u0004\u0012\u00028��0\u000e0\rj\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00028��0\u000e`\u000fX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0010\u001a\u00020\u0011X\u0082\u000e¢\u0006\u0002\n��¨\u0006("}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/sequenceencoder/SequenceFeedforwardEncoder;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "", "network", "Lcom/kotlinnlp/simplednn/deeplearning/sequenceencoder/SequenceFeedforwardNetwork;", "(Lcom/kotlinnlp/simplednn/deeplearning/sequenceencoder/SequenceFeedforwardNetwork;)V", "errorsAccumulator", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsAccumulator;", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;", "getNetwork", "()Lcom/kotlinnlp/simplednn/deeplearning/sequenceencoder/SequenceFeedforwardNetwork;", "processorsList", "Ljava/util/ArrayList;", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/feedforward/FeedforwardNeuralProcessor;", "Lkotlin/collections/ArrayList;", "usedProcessors", "", "backward", "", "outputErrorsSequence", "", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "propagateToInput", "", "([Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;Z)V", "encode", "sequence", "([Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;)[Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "forward", "getInputSequenceErrors", "copy", "(Z)[Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "getParamsErrors", "getProcessor", "index", "processorBackward", "processor", "errors", "reset", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/sequenceencoder/SequenceFeedforwardEncoder.class */
public final class SequenceFeedforwardEncoder<InputNDArrayType extends NDArray<InputNDArrayType>> {
    private final ArrayList<FeedforwardNeuralProcessor<InputNDArrayType>> processorsList;
    private final ParamsErrorsAccumulator<NetworkParameters> errorsAccumulator;
    private int usedProcessors;

    @NotNull
    private final SequenceFeedforwardNetwork network;

    @NotNull
    public final DenseNDArray[] getInputSequenceErrors(boolean z) {
        DenseNDArray[] denseNDArrayArr = new DenseNDArray[this.usedProcessors];
        int length = denseNDArrayArr.length;
        for (int i = 0; i < length; i++) {
            denseNDArrayArr[i] = this.processorsList.get(i).getInputErrors(z);
        }
        return denseNDArrayArr;
    }

    @NotNull
    public static /* bridge */ /* synthetic */ DenseNDArray[] getInputSequenceErrors$default(SequenceFeedforwardEncoder sequenceFeedforwardEncoder, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = true;
        }
        return sequenceFeedforwardEncoder.getInputSequenceErrors(z);
    }

    @NotNull
    public final NetworkParameters getParamsErrors(boolean z) {
        return this.errorsAccumulator.getParamsErrors(z);
    }

    @NotNull
    public static /* bridge */ /* synthetic */ NetworkParameters getParamsErrors$default(SequenceFeedforwardEncoder sequenceFeedforwardEncoder, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = true;
        }
        return sequenceFeedforwardEncoder.getParamsErrors(z);
    }

    @NotNull
    public final DenseNDArray[] encode(@NotNull InputNDArrayType[] sequence) {
        Intrinsics.checkParameterIsNotNull(sequence, "sequence");
        reset();
        return forward(sequence);
    }

    public final void backward(@NotNull DenseNDArray[] outputErrorsSequence, boolean z) {
        Intrinsics.checkParameterIsNotNull(outputErrorsSequence, "outputErrorsSequence");
        if (!(outputErrorsSequence.length == this.usedProcessors)) {
            Object[] objArr = {Integer.valueOf(outputErrorsSequence.length), Integer.valueOf(this.usedProcessors)};
            String format = String.format("Number of errors (%d) does not reflect the number of used processors (%d)", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            throw new IllegalArgumentException(format.toString());
        }
        int i = this.usedProcessors;
        for (int i2 = 0; i2 < i; i2++) {
            FeedforwardNeuralProcessor<InputNDArrayType> feedforwardNeuralProcessor = this.processorsList.get(i2);
            Intrinsics.checkExpressionValueIsNotNull(feedforwardNeuralProcessor, "this.processorsList[i]");
            processorBackward(feedforwardNeuralProcessor, outputErrorsSequence[i2], z);
        }
        this.errorsAccumulator.averageErrors();
    }

    private final void reset() {
        this.usedProcessors = 0;
        this.errorsAccumulator.reset();
    }

    private final DenseNDArray[] forward(InputNDArrayType[] inputndarraytypeArr) {
        DenseNDArray[] denseNDArrayArr = new DenseNDArray[inputndarraytypeArr.length];
        int length = denseNDArrayArr.length;
        for (int i = 0; i < length; i++) {
            int i2 = i;
            FeedforwardNeuralProcessor<InputNDArrayType> processor = getProcessor(i2);
            this.usedProcessors++;
            denseNDArrayArr[i] = FeedforwardNeuralProcessor.forward$default(processor, inputndarraytypeArr[i2], false, 2, null);
        }
        return denseNDArrayArr;
    }

    private final FeedforwardNeuralProcessor<InputNDArrayType> getProcessor(int i) {
        if (!(i <= this.processorsList.size())) {
            Object[] objArr = {Integer.valueOf(i), Integer.valueOf(this.processorsList.size())};
            String format = String.format("Invalid output processor index: %d (size = %d)", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            throw new IllegalArgumentException(format.toString());
        }
        if (i == this.processorsList.size()) {
            this.processorsList.add(new FeedforwardNeuralProcessor<>(this.network.getNetwork(), 0, 2, null));
        }
        FeedforwardNeuralProcessor<InputNDArrayType> feedforwardNeuralProcessor = this.processorsList.get(i);
        Intrinsics.checkExpressionValueIsNotNull(feedforwardNeuralProcessor, "this.processorsList[index]");
        return feedforwardNeuralProcessor;
    }

    private final void processorBackward(FeedforwardNeuralProcessor<InputNDArrayType> feedforwardNeuralProcessor, DenseNDArray denseNDArray, boolean z) {
        FeedforwardNeuralProcessor.backward$default(feedforwardNeuralProcessor, denseNDArray, z, null, 4, null);
        ParamsErrorsAccumulator.accumulate$default(this.errorsAccumulator, feedforwardNeuralProcessor.getParamsErrors(false), false, 2, null);
    }

    @NotNull
    public final SequenceFeedforwardNetwork getNetwork() {
        return this.network;
    }

    public SequenceFeedforwardEncoder(@NotNull SequenceFeedforwardNetwork network) {
        Intrinsics.checkParameterIsNotNull(network, "network");
        this.network = network;
        this.processorsList = new ArrayList<>();
        this.errorsAccumulator = new ParamsErrorsAccumulator<>();
    }
}
