package com.kotlinnlp.simplednn.core.neuralprocessor.batchfeedforward;

import com.kotlinnlp.simplednn.core.neuralnetwork.NetworkParameters;
import com.kotlinnlp.simplednn.core.neuralnetwork.NeuralNetwork;
import com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor;
import com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor;
import com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessorsPool;
import com.kotlinnlp.simplednn.core.optimizer.ParamsErrorsAccumulator;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: BatchFeedforwardProcessor.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��^\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010!\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n\u0002\u0010 \n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0010\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\u00020\u0003B\u0017\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\b\b\u0002\u0010\u0006\u001a\u00020\u0007¢\u0006\u0002\u0010\bJ\u001e\u0010\u0011\u001a\u00020\u00122\u0006\u0010\u0013\u001a\u00020\u00072\u0006\u0010\u0014\u001a\u00020\u00152\u0006\u0010\u0016\u001a\u00020\u0017J\u001c\u0010\u0011\u001a\u00020\u00122\f\u0010\u0014\u001a\b\u0012\u0004\u0012\u00020\u00150\u00182\u0006\u0010\u0016\u001a\u00020\u0017JD\u0010\u0019\u001a\b\u0012\u0004\u0012\u00020\u00150\u00182\"\u0010\u001a\u001a\u001e\u0012\n\u0012\b\u0012\u0004\u0012\u00028��0\u00180\u001bj\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00028��0\u0018`\u001c2\b\b\u0002\u0010\u001d\u001a\u00020\u00172\b\b\u0002\u0010\u001e\u001a\u00020\u0017J.\u0010\u0019\u001a\b\u0012\u0004\u0012\u00020\u00150\u00182\f\u0010\u001f\u001a\b\u0012\u0004\u0012\u00028��0\u00182\b\b\u0002\u0010\u001d\u001a\u00020\u00172\b\b\u0002\u0010\u001e\u001a\u00020\u0017J\u001f\u0010 \u001a\u00020\u00152\u0006\u0010!\u001a\u00028��2\b\b\u0002\u0010\u001e\u001a\u00020\u0017H\u0002¢\u0006\u0002\u0010\"J \u0010 \u001a\u00020\u00152\f\u0010#\u001a\b\u0012\u0004\u0012\u00028��0\u00182\b\b\u0002\u0010\u001e\u001a\u00020\u0017H\u0002J\u0016\u0010$\u001a\b\u0012\u0004\u0012\u00020\u00150\u00182\b\b\u0002\u0010%\u001a\u00020\u0017J\u001c\u0010&\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00150\u00180\u00182\b\b\u0002\u0010%\u001a\u00020\u0017J\u0010\u0010'\u001a\u00020\u000b2\u0006\u0010%\u001a\u00020\u0017H\u0016J&\u0010(\u001a\u00020\u00122\f\u0010)\u001a\b\u0012\u0004\u0012\u00028��0\u00102\u0006\u0010*\u001a\u00020\u00152\u0006\u0010\u0016\u001a\u00020\u0017H\u0002J\b\u0010+\u001a\u00020\u0012H\u0002R\u0014\u0010\t\u001a\b\u0012\u0004\u0012\u00020\u000b0\nX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\f\u001a\b\u0012\u0004\u0012\u00028��0\rX\u0082\u0004¢\u0006\u0002\n��R\u001a\u0010\u000e\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00028��0\u00100\u000fX\u0082\u000e¢\u0006\u0002\n��¨\u0006,"}, d2 = {"Lcom/kotlinnlp/simplednn/core/neuralprocessor/batchfeedforward/BatchFeedforwardProcessor;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor;", "neuralNetwork", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;", "id", "", "(Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;I)V", "errorsAccumulator", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsAccumulator;", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;", "processorsPool", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/feedforward/FeedforwardNeuralProcessorsPool;", "usedProcessors", "", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/feedforward/FeedforwardNeuralProcessor;", "backward", "", "elementIndex", "outputErrors", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "propagateToInput", "", "", "forward", "featuresListBatch", "Ljava/util/ArrayList;", "Lkotlin/collections/ArrayList;", "continueBatch", "useDropout", "featuresBatch", "forwardProcessor", "features", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;Z)Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "featuresList", "getInputErrors", "copy", "getInputsErrors", "getParamsErrors", "processorBackward", "processor", "errors", "reset", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/neuralprocessor/batchfeedforward/BatchFeedforwardProcessor.class */
public final class BatchFeedforwardProcessor<InputNDArrayType extends NDArray<InputNDArrayType>> extends NeuralProcessor {
    private final FeedforwardNeuralProcessorsPool<InputNDArrayType> processorsPool;
    private final ParamsErrorsAccumulator<NetworkParameters> errorsAccumulator;
    private List<FeedforwardNeuralProcessor<InputNDArrayType>> usedProcessors;

    @NotNull
    public final List<DenseNDArray> getInputErrors(boolean z) {
        List<FeedforwardNeuralProcessor<InputNDArrayType>> list = this.usedProcessors;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list, 10));
        Iterator<T> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(((FeedforwardNeuralProcessor) it.next()).getInputErrors(z));
        }
        return arrayList;
    }

    @NotNull
    public static /* bridge */ /* synthetic */ List getInputErrors$default(BatchFeedforwardProcessor batchFeedforwardProcessor, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = true;
        }
        return batchFeedforwardProcessor.getInputErrors(z);
    }

    @NotNull
    public final List<List<DenseNDArray>> getInputsErrors(boolean z) {
        List<FeedforwardNeuralProcessor<InputNDArrayType>> list = this.usedProcessors;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list, 10));
        Iterator<T> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(((FeedforwardNeuralProcessor) it.next()).getInputsErrors(z));
        }
        return arrayList;
    }

    @NotNull
    public static /* bridge */ /* synthetic */ List getInputsErrors$default(BatchFeedforwardProcessor batchFeedforwardProcessor, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = true;
        }
        return batchFeedforwardProcessor.getInputsErrors(z);
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    public NetworkParameters getParamsErrors(boolean z) {
        return this.errorsAccumulator.getParamsErrors(z);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @NotNull
    public final List<DenseNDArray> forward(@NotNull List<? extends InputNDArrayType> list, boolean z, boolean z2) {
        Intrinsics.checkParameterIsNotNull(list, "featuresBatch");
        List<? extends InputNDArrayType> list2 = list;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list2, 10));
        int i = 0;
        Iterator<T> it = list2.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            NDArray nDArray = (NDArray) it.next();
            if (!z && i2 == 0) {
                reset();
            }
            arrayList.add(forwardProcessor((BatchFeedforwardProcessor<InputNDArrayType>) nDArray, z2));
        }
        return arrayList;
    }

    @NotNull
    public static /* bridge */ /* synthetic */ List forward$default(BatchFeedforwardProcessor batchFeedforwardProcessor, List list, boolean z, boolean z2, int i, Object obj) {
        if ((i & 2) != 0) {
            z = false;
        }
        if ((i & 4) != 0) {
            z2 = false;
        }
        return batchFeedforwardProcessor.forward(list, z, z2);
    }

    @NotNull
    public final List<DenseNDArray> forward(@NotNull ArrayList<List<InputNDArrayType>> arrayList, boolean z, boolean z2) {
        Intrinsics.checkParameterIsNotNull(arrayList, "featuresListBatch");
        ArrayList<List<InputNDArrayType>> arrayList2 = arrayList;
        ArrayList arrayList3 = new ArrayList(CollectionsKt.collectionSizeOrDefault(arrayList2, 10));
        int i = 0;
        Iterator<T> it = arrayList2.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            List<? extends InputNDArrayType> list = (List) it.next();
            if (!z && i2 == 0) {
                reset();
            }
            arrayList3.add(forwardProcessor(list, z2));
        }
        return arrayList3;
    }

    @NotNull
    public static /* bridge */ /* synthetic */ List forward$default(BatchFeedforwardProcessor batchFeedforwardProcessor, ArrayList arrayList, boolean z, boolean z2, int i, Object obj) {
        if ((i & 2) != 0) {
            z = false;
        }
        if ((i & 4) != 0) {
            z2 = false;
        }
        return batchFeedforwardProcessor.forward(arrayList, z, z2);
    }

    public final void backward(int i, @NotNull DenseNDArray denseNDArray, boolean z) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "outputErrors");
        if (!(0 <= i && this.usedProcessors.size() > i)) {
            throw new IllegalArgumentException("The processor index exceeds the last index of the used processors.".toString());
        }
        processorBackward(this.usedProcessors.get(i), denseNDArray, z);
    }

    public final void backward(@NotNull List<DenseNDArray> list, boolean z) {
        Intrinsics.checkParameterIsNotNull(list, "outputErrors");
        if (!(list.size() == this.usedProcessors.size())) {
            Object[] objArr = {Integer.valueOf(list.size()), Integer.valueOf(this.usedProcessors.size())};
            String format = String.format("Number of errors (%d) does not reflect the number of used processors (%d)", Arrays.copyOf(objArr, objArr.length));
            Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
            throw new IllegalArgumentException(format.toString());
        }
        for (Pair pair : CollectionsKt.zip(this.usedProcessors, list)) {
            processorBackward((FeedforwardNeuralProcessor) pair.component1(), (DenseNDArray) pair.component2(), z);
        }
        this.errorsAccumulator.averageErrors();
    }

    private final DenseNDArray forwardProcessor(InputNDArrayType inputndarraytype, boolean z) {
        FeedforwardNeuralProcessor<InputNDArrayType> feedforwardNeuralProcessor = (FeedforwardNeuralProcessor) this.processorsPool.getItem();
        this.usedProcessors.add(feedforwardNeuralProcessor);
        return feedforwardNeuralProcessor.forward((FeedforwardNeuralProcessor<InputNDArrayType>) inputndarraytype, z);
    }

    static /* bridge */ /* synthetic */ DenseNDArray forwardProcessor$default(BatchFeedforwardProcessor batchFeedforwardProcessor, NDArray nDArray, boolean z, int i, Object obj) {
        if ((i & 2) != 0) {
            z = false;
        }
        return batchFeedforwardProcessor.forwardProcessor((BatchFeedforwardProcessor) nDArray, z);
    }

    private final DenseNDArray forwardProcessor(List<? extends InputNDArrayType> list, boolean z) {
        FeedforwardNeuralProcessor<InputNDArrayType> feedforwardNeuralProcessor = (FeedforwardNeuralProcessor) this.processorsPool.getItem();
        this.usedProcessors.add(feedforwardNeuralProcessor);
        return feedforwardNeuralProcessor.forward(list, z);
    }

    static /* bridge */ /* synthetic */ DenseNDArray forwardProcessor$default(BatchFeedforwardProcessor batchFeedforwardProcessor, List list, boolean z, int i, Object obj) {
        if ((i & 2) != 0) {
            z = false;
        }
        return batchFeedforwardProcessor.forwardProcessor(list, z);
    }

    private final void reset() {
        this.processorsPool.releaseAll();
        this.usedProcessors.clear();
        this.errorsAccumulator.reset();
    }

    private final void processorBackward(FeedforwardNeuralProcessor<InputNDArrayType> feedforwardNeuralProcessor, DenseNDArray denseNDArray, boolean z) {
        FeedforwardNeuralProcessor.backward$default(feedforwardNeuralProcessor, denseNDArray, z, null, 4, null);
        ParamsErrorsAccumulator.accumulate$default(this.errorsAccumulator, feedforwardNeuralProcessor.getParamsErrors(false), false, 2, null);
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public BatchFeedforwardProcessor(@NotNull NeuralNetwork neuralNetwork, int i) {
        super(neuralNetwork, i);
        Intrinsics.checkParameterIsNotNull(neuralNetwork, "neuralNetwork");
        this.processorsPool = new FeedforwardNeuralProcessorsPool<>(neuralNetwork);
        this.errorsAccumulator = new ParamsErrorsAccumulator<>();
        this.usedProcessors = new ArrayList();
    }

    public /* synthetic */ BatchFeedforwardProcessor(NeuralNetwork neuralNetwork, int i, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(neuralNetwork, (i2 & 2) != 0 ? 0 : i);
    }
}
