package com.kotlinnlp.simplednn.deeplearning.attentionnetwork.han;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.layers.LayerType;
import com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor;
import com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor;
import com.kotlinnlp.simplednn.core.optimizer.ParamsErrorsAccumulator;
import com.kotlinnlp.simplednn.deeplearning.attentionnetwork.AttentionNetwork;
import com.kotlinnlp.simplednn.deeplearning.attentionnetwork.AttentionNetworkParameters;
import com.kotlinnlp.simplednn.deeplearning.attentionnetwork.AttentionNetworksPool;
import com.kotlinnlp.simplednn.deeplearning.birnn.BiRNNEncoder;
import com.kotlinnlp.simplednn.deeplearning.birnn.BiRNNEncodersPool;
import com.kotlinnlp.simplednn.deeplearning.birnn.BiRNNParameters;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.utils.ItemsPool;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.spek.engine.SpekTestEngine;

/* compiled from: HANEncoder.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��\u0098\u0001\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0010\u0011\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n\u0002\b\u0003\n\u0002\u0010\u000b\n��\n\u0002\u0010\u0006\n\u0002\b\t\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\t\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u00022\u00020\u0003B\u0017\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\b\b\u0002\u0010\u0006\u001a\u00020\u0007¢\u0006\u0002\u0010\bJ\b\u0010&\u001a\u00020'H\u0002J'\u0010(\u001a\u00020'2\u0006\u0010)\u001a\u00020\u00122\u0006\u0010*\u001a\u00020+2\n\b\u0002\u0010,\u001a\u0004\u0018\u00010-¢\u0006\u0002\u0010.J/\u0010/\u001a\u00020'2\u0006\u0010)\u001a\u00020\u00122\u0006\u00100\u001a\u00020\u00072\u0006\u00101\u001a\u00020\u00072\b\u0010,\u001a\u0004\u0018\u00010-H\u0002¢\u0006\u0002\u00102J \u00103\u001a\u00020'2\u0006\u00100\u001a\u00020\u00072\u0006\u00101\u001a\u00020\u00072\u0006\u0010*\u001a\u00020+H\u0002J7\u00104\u001a\u00020'2\u0006\u0010)\u001a\u00020\u00122\u0006\u00100\u001a\u00020\u00072\u0006\u00101\u001a\u00020\u00072\u0006\u0010*\u001a\u00020+2\b\u0010,\u001a\u0004\u0018\u00010-H\u0002¢\u0006\u0002\u00105J \u00106\u001a\u0002072\u0006\u00100\u001a\u00020\u00072\u0006\u00101\u001a\u00020\u00072\u0006\u00108\u001a\u00020-H\u0002J \u00109\u001a\u0002072\u0006\u00100\u001a\u00020\u00072\u0006\u00101\u001a\u00020\u00072\u0006\u0010:\u001a\u00020+H\u0002J+\u0010;\u001a\b\u0012\u0004\u0012\u00020\u00120\n2\u0006\u0010<\u001a\u00020=2\u0006\u00100\u001a\u00020\u00072\u0006\u0010>\u001a\u00020+H\u0002¢\u0006\u0002\u0010?J\u0018\u0010@\u001a\u00020\u00122\u0006\u0010A\u001a\u0002072\b\b\u0002\u0010>\u001a\u00020+J \u0010B\u001a\u00020\u00122\u0006\u0010C\u001a\u0002072\u0006\u00100\u001a\u00020\u00072\u0006\u0010>\u001a\u00020+H\u0003J\u0006\u0010D\u001a\u000207J\u0010\u0010E\u001a\u0002072\b\b\u0002\u0010:\u001a\u00020+J\u0010\u0010F\u001a\u00020G2\b\b\u0002\u0010:\u001a\u00020+J\b\u0010H\u001a\u00020'H\u0002J\b\u0010I\u001a\u00020'H\u0002R\u0016\u0010\t\u001a\b\u0012\u0004\u0012\u00020\u000b0\nX\u0082\u0004¢\u0006\u0004\n\u0002\u0010\fR\u001c\u0010\r\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u000b0\u000e0\nX\u0082\u0004¢\u0006\u0004\n\u0002\u0010\u000fR\u001c\u0010\u0010\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\u00110\nX\u0082\u0004¢\u0006\u0004\n\u0002\u0010\u0013R\u001c\u0010\u0014\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00150\u000e0\nX\u0082\u0004¢\u0006\u0004\n\u0002\u0010\u000fR\u001a\u0010\u0016\u001a\f\u0012\b\u0012\u0006\u0012\u0002\b\u00030\u00170\nX\u0082\u0004¢\u0006\u0004\n\u0002\u0010\u0018R\u0014\u0010\u0006\u001a\u00020\u0007X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0019\u0010\u001aR\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\u001b\u0010\u001cR\u0014\u0010\u001d\u001a\b\u0012\u0004\u0012\u00020\u00120\u001eX\u0082\u0004¢\u0006\u0002\n��R0\u0010\u001f\u001a$\u0012 \u0012\u001e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\"0!j\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00120\"`#0 X\u0082\u0004¢\u0006\u0002\n��R,\u0010$\u001a \u0012\u001c\u0012\u001a\u0012\b\u0012\u0006\u0012\u0002\b\u00030%0!j\f\u0012\b\u0012\u0006\u0012\u0002\b\u00030%`#0 X\u0082\u0004¢\u0006\u0002\n��¨\u0006J"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/han/HANEncoder;", "InputNDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "Lcom/kotlinnlp/simplednn/utils/ItemsPool$IDItem;", "model", "Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/han/HAN;", "id", "", "(Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/han/HAN;I)V", "attentionNetworksParamsErrors", "", "Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/AttentionNetworkParameters;", "[Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/AttentionNetworkParameters;", "attentionNetworksParamsErrorsAccumulators", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsAccumulator;", "[Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsAccumulator;", "attentionNetworksPools", "Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/AttentionNetworksPool;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "[Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/AttentionNetworksPool;", "encodersParamsErrorsAccumulators", "Lcom/kotlinnlp/simplednn/deeplearning/birnn/BiRNNParameters;", "encodersPools", "Lcom/kotlinnlp/simplednn/deeplearning/birnn/BiRNNEncodersPool;", "[Lcom/kotlinnlp/simplednn/deeplearning/birnn/BiRNNEncodersPool;", "getId", "()I", "getModel", "()Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/han/HAN;", "outputProcessor", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/feedforward/FeedforwardNeuralProcessor;", "usedAttentionNetworksPerLevel", "", "Ljava/util/ArrayList;", "Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/AttentionNetwork;", "Lkotlin/collections/ArrayList;", "usedEncodersPerLevel", "Lcom/kotlinnlp/simplednn/deeplearning/birnn/BiRNNEncoder;", "averageAccumulatedErrors", "", "backward", "outputErrors", "propagateToInput", "", "mePropK", "", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;ZLjava/lang/Double;)V", "backwardAttentionNetwork", "levelIndex", "groupIndex", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;IILjava/lang/Double;)V", "backwardEncoder", "backwardHierarchicalGroup", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;IIZLjava/lang/Double;)V", "buildImportanceScoreHierarchyItem", "Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/han/HierarchyItem;", "refScore", "buildInputErrorsHierarchyItem", "copy", "buildInputSequence", SpekTestEngine.GROUP_SEGMENT_TYPE, "Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/han/HierarchyGroup;", "useDropout", "(Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/han/HierarchyGroup;IZ)[Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "forward", "sequencesHierarchy", "forwardItem", "item", "getInputImportanceScores", "getInputSequenceErrors", "getParamsErrors", "Lcom/kotlinnlp/simplednn/deeplearning/attentionnetwork/han/HANParameters;", "resetAccumulators", "resetUsedNetworks", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/attentionnetwork/han/HANEncoder.class */
public final class HANEncoder<InputNDArrayType extends NDArray<InputNDArrayType>> implements ItemsPool.IDItem {
    private final BiRNNEncodersPool<?>[] encodersPools;
    private final AttentionNetworksPool<DenseNDArray>[] attentionNetworksPools;
    private final List<ArrayList<BiRNNEncoder<?>>> usedEncodersPerLevel;
    private final List<ArrayList<AttentionNetwork<DenseNDArray>>> usedAttentionNetworksPerLevel;
    private final ParamsErrorsAccumulator<BiRNNParameters>[] encodersParamsErrorsAccumulators;
    private final ParamsErrorsAccumulator<AttentionNetworkParameters>[] attentionNetworksParamsErrorsAccumulators;
    private final AttentionNetworkParameters[] attentionNetworksParamsErrors;
    private final FeedforwardNeuralProcessor<DenseNDArray> outputProcessor;

    @NotNull
    private final HAN model;
    private final int id;

    @NotNull
    public final DenseNDArray forward(@NotNull HierarchyItem sequencesHierarchy, boolean z) {
        Intrinsics.checkParameterIsNotNull(sequencesHierarchy, "sequencesHierarchy");
        resetUsedNetworks();
        FeedforwardNeuralProcessor.forward$default(this.outputProcessor, forwardItem(sequencesHierarchy, 0, z), false, 2, null);
        return NeuralProcessor.getOutput$default(this.outputProcessor, false, 1, null);
    }

    @NotNull
    public static /* bridge */ /* synthetic */ DenseNDArray forward$default(HANEncoder hANEncoder, HierarchyItem hierarchyItem, boolean z, int i, Object obj) {
        if ((i & 2) != 0) {
            z = false;
        }
        return hANEncoder.forward(hierarchyItem, z);
    }

    public final void backward(@NotNull DenseNDArray outputErrors, boolean z, @Nullable Double d) {
        Intrinsics.checkParameterIsNotNull(outputErrors, "outputErrors");
        resetAccumulators();
        FeedforwardNeuralProcessor.backward$default(this.outputProcessor, outputErrors, true, null, 4, null);
        backwardHierarchicalGroup(this.outputProcessor.getInputErrors(false), 0, 0, z, d);
        averageAccumulatedErrors();
    }

    public static /* bridge */ /* synthetic */ void backward$default(HANEncoder hANEncoder, DenseNDArray denseNDArray, boolean z, Double d, int i, Object obj) {
        if ((i & 4) != 0) {
            d = (Double) null;
        }
        hANEncoder.backward(denseNDArray, z, d);
    }

    @NotNull
    public final HierarchyItem getInputSequenceErrors(boolean z) {
        return buildInputErrorsHierarchyItem(0, 0, z);
    }

    @NotNull
    public static /* bridge */ /* synthetic */ HierarchyItem getInputSequenceErrors$default(HANEncoder hANEncoder, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = true;
        }
        return hANEncoder.getInputSequenceErrors(z);
    }

    @NotNull
    public final HierarchyItem getInputImportanceScores() {
        return buildImportanceScoreHierarchyItem(0, 0, 1.0d);
    }

    @NotNull
    public final HANParameters getParamsErrors(boolean z) {
        BiRNNParameters[] biRNNParametersArr = new BiRNNParameters[this.model.getHierarchySize()];
        int length = biRNNParametersArr.length;
        for (int i = 0; i < length; i++) {
            int i2 = i;
            BiRNNParameters biRNNParameters = (BiRNNParameters) ParamsErrorsAccumulator.getParamsErrors$default(this.encodersParamsErrorsAccumulators[i], false, 1, null);
            biRNNParametersArr[i2] = z ? biRNNParameters.copy() : biRNNParameters;
        }
        AttentionNetworkParameters[] attentionNetworkParametersArr = new AttentionNetworkParameters[this.model.getHierarchySize()];
        int length2 = attentionNetworkParametersArr.length;
        for (int i3 = 0; i3 < length2; i3++) {
            int i4 = i3;
            AttentionNetworkParameters attentionNetworkParameters = (AttentionNetworkParameters) ParamsErrorsAccumulator.getParamsErrors$default(this.attentionNetworksParamsErrorsAccumulators[i3], false, 1, null);
            attentionNetworkParametersArr[i4] = z ? attentionNetworkParameters.copy() : attentionNetworkParameters;
        }
        return new HANParameters(biRNNParametersArr, attentionNetworkParametersArr, this.outputProcessor.getParamsErrors(z));
    }

    @NotNull
    public static /* bridge */ /* synthetic */ HANParameters getParamsErrors$default(HANEncoder hANEncoder, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = true;
        }
        return hANEncoder.getParamsErrors(z);
    }

    private final DenseNDArray forwardItem(HierarchyItem hierarchyItem, int i, boolean z) {
        NDArray[] nDArrayArr;
        DenseNDArray[] encode;
        if (hierarchyItem instanceof HierarchyGroup) {
            nDArrayArr = buildInputSequence((HierarchyGroup) hierarchyItem, i, z);
        } else {
            if (!(hierarchyItem instanceof HierarchySequence)) {
                throw new RuntimeException("Invalid hierarchy item type");
            }
            Collection collection = (Collection) hierarchyItem;
            if (collection == null) {
                throw new TypeCastException("null cannot be cast to non-null type java.util.Collection<T>");
            }
            Object[] array = collection.toArray(new NDArray[collection.size()]);
            if (array == null) {
                throw new TypeCastException("null cannot be cast to non-null type kotlin.Array<T>");
            }
            if (array == null) {
                throw new TypeCastException("null cannot be cast to non-null type kotlin.Array<InputNDArrayType>");
            }
            nDArrayArr = (NDArray[]) array;
        }
        Object[] objArr = nDArrayArr;
        BiRNNEncoder<?> biRNNEncoder = (BiRNNEncoder) this.encodersPools[i].getItem();
        AttentionNetwork item = this.attentionNetworksPools[i].getItem();
        this.usedEncodersPerLevel.get(i).add(biRNNEncoder);
        this.usedAttentionNetworksPerLevel.get(i).add(item);
        boolean isInputLevel = this.model.isInputLevel(i);
        if (isInputLevel) {
            if (biRNNEncoder == null) {
                throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.deeplearning.birnn.BiRNNEncoder<InputNDArrayType>");
            }
            if (objArr == null) {
                throw new TypeCastException("null cannot be cast to non-null type kotlin.Array<InputNDArrayType>");
            }
            encode = biRNNEncoder.encode((NDArray[]) objArr, z && isInputLevel);
        } else {
            if (biRNNEncoder == null) {
                throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.deeplearning.birnn.BiRNNEncoder<com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray>");
            }
            if (objArr == null) {
                throw new TypeCastException("null cannot be cast to non-null type kotlin.Array<com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray>");
            }
            encode = biRNNEncoder.encode((DenseNDArray[]) objArr, z && isInputLevel);
        }
        DenseNDArray[] denseNDArrayArr = encode;
        AugmentedArray[] augmentedArrayArr = new AugmentedArray[denseNDArrayArr.length];
        int length = augmentedArrayArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            augmentedArrayArr[i2] = AugmentedArray.Companion.invoke(denseNDArrayArr[i2]);
        }
        return AttentionNetwork.forward$default(item, CollectionsKt.arrayListOf(augmentedArrayArr), false, 2, null);
    }

    private final DenseNDArray[] buildInputSequence(HierarchyGroup hierarchyGroup, int i, boolean z) {
        DenseNDArray[] denseNDArrayArr = new DenseNDArray[hierarchyGroup.size()];
        int length = denseNDArrayArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            HierarchyItem hierarchyItem = hierarchyGroup.get(i2);
            Intrinsics.checkExpressionValueIsNotNull(hierarchyItem, "group[i]");
            denseNDArrayArr[i2] = forwardItem(hierarchyItem, i + 1, z);
        }
        return denseNDArrayArr;
    }

    private final void backwardHierarchicalGroup(DenseNDArray denseNDArray, int i, int i2, boolean z, Double d) {
        boolean z2 = i < this.model.getHierarchySize() - 1;
        backwardAttentionNetwork(denseNDArray, i, i2, d);
        backwardEncoder(i, i2, z2 || z);
        if (z2) {
            int i3 = 0;
            for (DenseNDArray denseNDArray2 : BiRNNEncoder.getInputSequenceErrors$default(this.usedEncodersPerLevel.get(i).get(i2), false, 1, null)) {
                int i4 = i3;
                i3++;
                backwardHierarchicalGroup(denseNDArray2, i + 1, i4, z, d);
            }
        }
    }

    private final void backwardAttentionNetwork(DenseNDArray denseNDArray, int i, int i2, Double d) {
        ParamsErrorsAccumulator<AttentionNetworkParameters> paramsErrorsAccumulator = this.attentionNetworksParamsErrorsAccumulators[i];
        AttentionNetworkParameters attentionNetworkParameters = this.attentionNetworksParamsErrors[i];
        this.usedAttentionNetworksPerLevel.get(i).get(i2).backward(denseNDArray, attentionNetworkParameters, true, d);
        ParamsErrorsAccumulator.accumulate$default(paramsErrorsAccumulator, attentionNetworkParameters, false, 2, null);
    }

    private final void backwardEncoder(int i, int i2, boolean z) {
        ParamsErrorsAccumulator<BiRNNParameters> paramsErrorsAccumulator = this.encodersParamsErrorsAccumulators[i];
        BiRNNEncoder<?> biRNNEncoder = this.usedEncodersPerLevel.get(i).get(i2);
        Intrinsics.checkExpressionValueIsNotNull(biRNNEncoder, "this.usedEncodersPerLevel[levelIndex][groupIndex]");
        BiRNNEncoder<?> biRNNEncoder2 = biRNNEncoder;
        biRNNEncoder2.backward(this.usedAttentionNetworksPerLevel.get(i).get(i2).getInputErrors(), z);
        ParamsErrorsAccumulator.accumulate$default(paramsErrorsAccumulator, biRNNEncoder2.getParamsErrors(false), false, 2, null);
    }

    private final void averageAccumulatedErrors() {
        for (ParamsErrorsAccumulator<BiRNNParameters> paramsErrorsAccumulator : this.encodersParamsErrorsAccumulators) {
            paramsErrorsAccumulator.averageErrors();
        }
        for (ParamsErrorsAccumulator<AttentionNetworkParameters> paramsErrorsAccumulator2 : this.attentionNetworksParamsErrorsAccumulators) {
            paramsErrorsAccumulator2.averageErrors();
        }
    }

    private final void resetAccumulators() {
        for (ParamsErrorsAccumulator<BiRNNParameters> paramsErrorsAccumulator : this.encodersParamsErrorsAccumulators) {
            paramsErrorsAccumulator.reset();
        }
        for (ParamsErrorsAccumulator<AttentionNetworkParameters> paramsErrorsAccumulator2 : this.attentionNetworksParamsErrorsAccumulators) {
            paramsErrorsAccumulator2.reset();
        }
    }

    private final void resetUsedNetworks() {
        Iterator<T> it = this.usedEncodersPerLevel.iterator();
        while (it.hasNext()) {
            ((ArrayList) it.next()).clear();
        }
        Iterator<T> it2 = this.usedAttentionNetworksPerLevel.iterator();
        while (it2.hasNext()) {
            ((ArrayList) it2.next()).clear();
        }
        for (BiRNNEncodersPool<?> biRNNEncodersPool : this.encodersPools) {
            biRNNEncodersPool.releaseAll();
        }
        for (AttentionNetworksPool<DenseNDArray> attentionNetworksPool : this.attentionNetworksPools) {
            attentionNetworksPool.releaseAll();
        }
    }

    private final HierarchyItem buildInputErrorsHierarchyItem(int i, int i2, boolean z) {
        if (i == this.model.getHierarchySize() - 1) {
            DenseNDArray[] inputSequenceErrors = this.usedEncodersPerLevel.get(i).get(i2).getInputSequenceErrors(z);
            return new HierarchySequence((DenseNDArray[]) Arrays.copyOf(inputSequenceErrors, inputSequenceErrors.length));
        }
        HierarchyItem[] hierarchyItemArr = new HierarchyItem[this.usedEncodersPerLevel.get(i).get(i2).getInputSequenceErrors(false).length];
        int length = hierarchyItemArr.length;
        for (int i3 = 0; i3 < length; i3++) {
            hierarchyItemArr[i3] = buildInputErrorsHierarchyItem(i + 1, i3, z);
        }
        return new HierarchyGroup(hierarchyItemArr);
    }

    private final HierarchyItem buildImportanceScoreHierarchyItem(int i, int i2, double d) {
        DenseNDArray importanceScore = this.usedAttentionNetworksPerLevel.get(i).get(i2).getImportanceScore(false);
        if (i == this.model.getHierarchySize() - 1) {
            return new HierarchySequence(importanceScore.prod(d));
        }
        HierarchyItem[] hierarchyItemArr = new HierarchyItem[importanceScore.getLength()];
        int length = hierarchyItemArr.length;
        for (int i3 = 0; i3 < length; i3++) {
            int i4 = i3;
            hierarchyItemArr[i3] = buildImportanceScoreHierarchyItem(i + 1, i4, importanceScore.get(i4).doubleValue());
        }
        return new HierarchyGroup(hierarchyItemArr);
    }

    @NotNull
    public final HAN getModel() {
        return this.model;
    }

    @Override // com.kotlinnlp.simplednn.utils.ItemsPool.IDItem
    public int getId() {
        return this.id;
    }

    public HANEncoder(@NotNull HAN model, int i) {
        Intrinsics.checkParameterIsNotNull(model, "model");
        this.model = model;
        this.id = i;
        BiRNNEncodersPool<?>[] biRNNEncodersPoolArr = new BiRNNEncodersPool[this.model.getHierarchySize()];
        int length = biRNNEncodersPoolArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            int i3 = i2;
            biRNNEncodersPoolArr[i2] = this.model.isInputLevel(i3) ? new BiRNNEncodersPool<>(this.model.getBiRNNs()[i3]) : new BiRNNEncodersPool<>(this.model.getBiRNNs()[i3]);
        }
        this.encodersPools = biRNNEncodersPoolArr;
        AttentionNetworksPool<DenseNDArray>[] attentionNetworksPoolArr = new AttentionNetworksPool[this.model.getHierarchySize()];
        int length2 = attentionNetworksPoolArr.length;
        for (int i4 = 0; i4 < length2; i4++) {
            attentionNetworksPoolArr[i4] = new AttentionNetworksPool<>(this.model.getAttentionNetworksParams()[i4], LayerType.Input.Dense, 0.0d, 4, null);
        }
        this.attentionNetworksPools = attentionNetworksPoolArr;
        int hierarchySize = this.model.getHierarchySize();
        ArrayList arrayList = new ArrayList(hierarchySize);
        for (int i5 = 0; i5 < hierarchySize; i5++) {
            arrayList.add(new ArrayList());
        }
        this.usedEncodersPerLevel = arrayList;
        int hierarchySize2 = this.model.getHierarchySize();
        ArrayList arrayList2 = new ArrayList(hierarchySize2);
        for (int i6 = 0; i6 < hierarchySize2; i6++) {
            arrayList2.add(new ArrayList());
        }
        this.usedAttentionNetworksPerLevel = arrayList2;
        ParamsErrorsAccumulator<BiRNNParameters>[] paramsErrorsAccumulatorArr = new ParamsErrorsAccumulator[this.model.getHierarchySize()];
        int length3 = paramsErrorsAccumulatorArr.length;
        for (int i7 = 0; i7 < length3; i7++) {
            paramsErrorsAccumulatorArr[i7] = new ParamsErrorsAccumulator<>();
        }
        this.encodersParamsErrorsAccumulators = paramsErrorsAccumulatorArr;
        ParamsErrorsAccumulator<AttentionNetworkParameters>[] paramsErrorsAccumulatorArr2 = new ParamsErrorsAccumulator[this.model.getHierarchySize()];
        int length4 = paramsErrorsAccumulatorArr2.length;
        for (int i8 = 0; i8 < length4; i8++) {
            paramsErrorsAccumulatorArr2[i8] = new ParamsErrorsAccumulator<>();
        }
        this.attentionNetworksParamsErrorsAccumulators = paramsErrorsAccumulatorArr2;
        AttentionNetworkParameters[] attentionNetworkParametersArr = new AttentionNetworkParameters[this.model.getHierarchySize()];
        int length5 = attentionNetworkParametersArr.length;
        for (int i9 = 0; i9 < length5; i9++) {
            int i10 = i9;
            attentionNetworkParametersArr[i9] = new AttentionNetworkParameters(this.attentionNetworksPools[i10].getModel().getInputSize(), this.attentionNetworksPools[i10].getModel().getAttentionSize(), false, null, null, 24, null);
        }
        this.attentionNetworksParamsErrors = attentionNetworkParametersArr;
        this.outputProcessor = new FeedforwardNeuralProcessor<>(this.model.getOutputNetwork(), 0, 2, null);
    }

    public /* synthetic */ HANEncoder(HAN han, int i, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(han, (i2 & 2) != 0 ? 0 : i);
    }
}
