package com.kotlinnlp.simplednn.deeplearning.birnn;

import com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor;
import com.kotlinnlp.simplednn.core.neuralprocessor.batchfeedforward.BatchFeedforwardProcessor;
import com.kotlinnlp.simplednn.core.neuralprocessor.recurrent.RecurrentNeuralProcessor;
import com.kotlinnlp.simplednn.core.optimizer.Optimizer;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.TuplesKt;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;

/* compiled from: BiRNNEncoder.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��R\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0010\b\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0010\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\n\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u000228\u0012\n\u0012\b\u0012\u0004\u0012\u0002H\u00010\u0004\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00050\u0004\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00050\u0004\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00050\u0004\u0012\u0004\u0012\u00020\u00060\u0003B'\u0012\u0006\u0010\u0007\u001a\u00020\b\u0012\u0006\u0010\t\u001a\u00020\n\u0012\u0006\u0010\u000b\u001a\u00020\n\u0012\b\b\u0002\u0010\f\u001a\u00020\r¢\u0006\u0002\u0010\u000eJ\u0016\u0010\u001c\u001a\u00020\u001d2\f\u0010\u001e\u001a\b\u0012\u0004\u0012\u00020\u00050\u0004H\u0016J\u0016\u0010\u001f\u001a\u00020\u001d2\u0006\u0010 \u001a\u00020\u00052\u0006\u0010!\u001a\u00020\u0005J\u001a\u0010\"\u001a\u0014\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00020\u0005\u0012\u0004\u0012\u00020\u00050#0\u0004H\u0002J\u001c\u0010$\u001a\b\u0012\u0004\u0012\u00020\u00050\u00042\f\u0010%\u001a\b\u0012\u0004\u0012\u00028��0\u0004H\u0016J\u0016\u0010&\u001a\b\u0012\u0004\u0012\u00020\u00050\u00042\u0006\u0010'\u001a\u00020\nH\u0016J\u001a\u0010(\u001a\u000e\u0012\u0004\u0012\u00020\u0005\u0012\u0004\u0012\u00020\u00050#2\u0006\u0010'\u001a\u00020\nJ\u0010\u0010)\u001a\u00020\u00062\u0006\u0010'\u001a\u00020\nH\u0016J\u001c\u0010*\u001a\u0018\u0012\u0014\u0012\u0012\u0012\u0006\u0012\u0004\u0018\u00010\u0005\u0012\u0006\u0012\u0004\u0018\u00010\u00050#0\u0004J(\u0010+\u001a\u0014\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00020\u0005\u0012\u0004\u0012\u00020\u00050#0\u00042\f\u0010,\u001a\b\u0012\u0004\u0012\u00020\u00050\u0004H\u0002R\u0014\u0010\f\u001a\u00020\rX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u000f\u0010\u0010R\u0014\u0010\u0011\u001a\b\u0012\u0004\u0012\u00028��0\u0012X\u0082\u0004¢\u0006\u0002\n��R\u0011\u0010\u0007\u001a\u00020\b¢\u0006\b\n��\u001a\u0004\b\u0013\u0010\u0014R\u0014\u0010\u0015\u001a\b\u0012\u0004\u0012\u00020\u00050\u0016X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u000b\u001a\u00020\nX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0017\u0010\u0018R\u0014\u0010\u0019\u001a\b\u0012\u0004\u0012\u00028��0\u0012X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u001a\u001a\b\u0012\u0004\u0012\u00028��0\u0004X\u0082.¢\u0006\u0002\n��R\u0014\u0010\t\u001a\u00020\nX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u001b\u0010\u0018¨\u0006-"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/birnn/BiRNNEncoder;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor;", "", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "Lcom/kotlinnlp/simplednn/deeplearning/birnn/BiRNNParameters;", "network", "Lcom/kotlinnlp/simplednn/deeplearning/birnn/BiRNN;", "useDropout", "", "propagateToInput", "id", "", "(Lcom/kotlinnlp/simplednn/deeplearning/birnn/BiRNN;ZZI)V", "getId", "()I", "leftToRightProcessor", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/recurrent/RecurrentNeuralProcessor;", "getNetwork", "()Lcom/kotlinnlp/simplednn/deeplearning/birnn/BiRNN;", "outputMergeProcessors", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/batchfeedforward/BatchFeedforwardProcessor;", "getPropagateToInput", "()Z", "rightToLeftProcessor", "sequence", "getUseDropout", "backward", "", "outputErrors", "backwardLastOutput", "leftToRightErrors", "rightToLeftErrors", "biEncoding", "Lkotlin/Pair;", "forward", "input", "getInputErrors", "copy", "getLastOutput", "getParamsErrors", "getRANImportanceScores", "outputMergeBackward", "outputErrorsSequence", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/birnn/BiRNNEncoder.class */
public final class BiRNNEncoder<InputNDArrayType extends NDArray<InputNDArrayType>> implements NeuralProcessor<List<? extends InputNDArrayType>, List<? extends DenseNDArray>, List<? extends DenseNDArray>, List<? extends DenseNDArray>, BiRNNParameters> {
    private final RecurrentNeuralProcessor<InputNDArrayType> leftToRightProcessor;
    private final RecurrentNeuralProcessor<InputNDArrayType> rightToLeftProcessor;
    private final BatchFeedforwardProcessor<DenseNDArray> outputMergeProcessors;
    private List<? extends InputNDArrayType> sequence;

    @NotNull
    private final BiRNN network;
    private final boolean useDropout;
    private final boolean propagateToInput;
    private final int id;

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    public List<DenseNDArray> forward(@NotNull List<? extends InputNDArrayType> list) {
        Intrinsics.checkParameterIsNotNull(list, "input");
        this.sequence = list;
        BatchFeedforwardProcessor<DenseNDArray> batchFeedforwardProcessor = this.outputMergeProcessors;
        List<Pair<DenseNDArray, DenseNDArray>> biEncoding = biEncoding();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(biEncoding, 10));
        Iterator<T> it = biEncoding.iterator();
        while (it.hasNext()) {
            arrayList.add(TuplesKt.toList((Pair) it.next()));
        }
        return BatchFeedforwardProcessor.forward$default((BatchFeedforwardProcessor) batchFeedforwardProcessor, new ArrayList(arrayList), false, 2, (Object) null);
    }

    @NotNull
    public final List<Pair<DenseNDArray, DenseNDArray>> getRANImportanceScores() {
        List<? extends InputNDArrayType> list = this.sequence;
        if (list == null) {
            Intrinsics.throwUninitializedPropertyAccessException("sequence");
        }
        int size = list.size();
        Iterable until = RangesKt.until(0, size);
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(until, 10));
        IntIterator it = until.iterator();
        while (it.hasNext()) {
            int nextInt = it.nextInt();
            int i = (size - nextInt) - 1;
            arrayList.add(new Pair(nextInt > 0 ? this.leftToRightProcessor.getRANImportanceScores(nextInt) : null, i > 0 ? this.rightToLeftProcessor.getRANImportanceScores(i) : null));
        }
        return arrayList;
    }

    @NotNull
    public final Pair<DenseNDArray, DenseNDArray> getLastOutput(boolean z) {
        return new Pair<>(this.leftToRightProcessor.getOutput(z), this.rightToLeftProcessor.getOutput(z));
    }

    public final void backwardLastOutput(@NotNull DenseNDArray denseNDArray, @NotNull DenseNDArray denseNDArray2) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "leftToRightErrors");
        Intrinsics.checkParameterIsNotNull(denseNDArray2, "rightToLeftErrors");
        this.leftToRightProcessor.backward(denseNDArray);
        this.rightToLeftProcessor.backward(denseNDArray2);
    }

    /* renamed from: backward, reason: avoid collision after fix types in other method */
    public void backward2(@NotNull List<DenseNDArray> list) {
        Intrinsics.checkParameterIsNotNull(list, "outputErrors");
        Pair unzip = CollectionsKt.unzip(outputMergeBackward(list));
        List<DenseNDArray> list2 = (List) unzip.component1();
        List list3 = (List) unzip.component2();
        this.leftToRightProcessor.backward2(list2);
        this.rightToLeftProcessor.backward2(CollectionsKt.reversed(list3));
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public /* bridge */ /* synthetic */ void backward(List<? extends DenseNDArray> list) {
        backward2((List<DenseNDArray>) list);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    /* renamed from: getInputErrors */
    public List<? extends DenseNDArray> getInputErrors2(boolean z) {
        return BiRNNUtils.INSTANCE.sumBidirectionalErrors(this.leftToRightProcessor.getInputErrors2(z), this.rightToLeftProcessor.getInputErrors2(z));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    public BiRNNParameters getParamsErrors(boolean z) {
        return new BiRNNParameters(this.leftToRightProcessor.getParamsErrors(z), this.rightToLeftProcessor.getParamsErrors(z), this.outputMergeProcessors.getParamsErrors(z));
    }

    private final List<Pair<DenseNDArray, DenseNDArray>> biEncoding() {
        boolean z = true;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        List<? extends InputNDArrayType> list = this.sequence;
        if (list == null) {
            Intrinsics.throwUninitializedPropertyAccessException("sequence");
        }
        Iterable indices = CollectionsKt.getIndices(list);
        List<? extends InputNDArrayType> list2 = this.sequence;
        if (list2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("sequence");
        }
        for (Pair pair : CollectionsKt.zip(indices, RangesKt.reversed(CollectionsKt.getIndices(list2)))) {
            int intValue = ((Number) pair.component1()).intValue();
            int intValue2 = ((Number) pair.component2()).intValue();
            RecurrentNeuralProcessor<InputNDArrayType> recurrentNeuralProcessor = this.leftToRightProcessor;
            List<? extends InputNDArrayType> list3 = this.sequence;
            if (list3 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("sequence");
            }
            arrayList.add(RecurrentNeuralProcessor.forward$default((RecurrentNeuralProcessor) recurrentNeuralProcessor, (NDArray) list3.get(intValue), z, (List) null, false, 12, (Object) null));
            RecurrentNeuralProcessor<InputNDArrayType> recurrentNeuralProcessor2 = this.rightToLeftProcessor;
            List<? extends InputNDArrayType> list4 = this.sequence;
            if (list4 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("sequence");
            }
            arrayList2.add(0, RecurrentNeuralProcessor.forward$default((RecurrentNeuralProcessor) recurrentNeuralProcessor2, (NDArray) list4.get(intValue2), z, (List) null, false, 12, (Object) null));
            z = false;
        }
        return CollectionsKt.zip(arrayList, arrayList2);
    }

    private final List<Pair<DenseNDArray, DenseNDArray>> outputMergeBackward(List<DenseNDArray> list) {
        this.outputMergeProcessors.backward2(list);
        List<List<DenseNDArray>> inputsErrors = this.outputMergeProcessors.getInputsErrors(false);
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(inputsErrors, 10));
        Iterator<T> it = inputsErrors.iterator();
        while (it.hasNext()) {
            List list2 = (List) it.next();
            arrayList.add(new Pair(list2.get(0), list2.get(1)));
        }
        return arrayList;
    }

    @NotNull
    public final BiRNN getNetwork() {
        return this.network;
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public boolean getUseDropout() {
        return this.useDropout;
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public boolean getPropagateToInput() {
        return this.propagateToInput;
    }

    public int getId() {
        return this.id;
    }

    public BiRNNEncoder(@NotNull BiRNN biRNN, boolean z, boolean z2, int i) {
        Intrinsics.checkParameterIsNotNull(biRNN, "network");
        this.network = biRNN;
        this.useDropout = z;
        this.propagateToInput = z2;
        this.id = i;
        this.leftToRightProcessor = new RecurrentNeuralProcessor<>(this.network.getLeftToRightNetwork(), getUseDropout(), getPropagateToInput(), null, 0, 24, null);
        this.rightToLeftProcessor = new RecurrentNeuralProcessor<>(this.network.getRightToLeftNetwork(), getUseDropout(), getPropagateToInput(), null, 0, 24, null);
        this.outputMergeProcessors = new BatchFeedforwardProcessor<>(this.network.getOutputMergeNetwork(), getUseDropout(), true, null, 0, 24, null);
    }

    public /* synthetic */ BiRNNEncoder(BiRNN biRNN, boolean z, boolean z2, int i, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(biRNN, z, z2, (i2 & 8) != 0 ? 0 : i);
    }

    @NotNull
    /* renamed from: propagateErrors, reason: avoid collision after fix types in other method */
    public List<DenseNDArray> propagateErrors2(@NotNull List<DenseNDArray> list, @NotNull Optimizer<? super BiRNNParameters> optimizer, boolean z) {
        Intrinsics.checkParameterIsNotNull(list, "errors");
        Intrinsics.checkParameterIsNotNull(optimizer, "optimizer");
        return (List) NeuralProcessor.DefaultImpls.propagateErrors(this, list, optimizer, z);
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public /* bridge */ /* synthetic */ List<? extends DenseNDArray> propagateErrors(List<? extends DenseNDArray> list, Optimizer<? super BiRNNParameters> optimizer, boolean z) {
        return propagateErrors2((List<DenseNDArray>) list, optimizer, z);
    }
}
