package com.kotlinnlp.hanclassifier;

import com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor;
import com.kotlinnlp.simplednn.core.optimizer.Optimizer;
import com.kotlinnlp.simplednn.deeplearning.attention.han.HANEncoder;
import com.kotlinnlp.simplednn.deeplearning.attention.han.HANParameters;
import com.kotlinnlp.simplednn.deeplearning.attention.han.HierarchyGroup;
import com.kotlinnlp.simplednn.deeplearning.attention.han.HierarchyItem;
import com.kotlinnlp.simplednn.deeplearning.attention.han.HierarchySequence;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: HANClassifier.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��L\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\b\n\u0002\u0010\u0002\n\u0002\b\u0007\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\u0018��2,\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\u0002\u0012\u0004\u0012\u00020\u0004\u0012\u0004\u0012\u00020\u0004\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\u0002\u0012\u0004\u0012\u00020\u00050\u0001B+\u0012\u0006\u0010\u0006\u001a\u00020\u0007\u0012\b\b\u0002\u0010\b\u001a\u00020\t\u0012\b\b\u0002\u0010\n\u001a\u00020\t\u0012\b\b\u0002\u0010\u000b\u001a\u00020\f¢\u0006\u0002\u0010\rJ\u0010\u0010\u0017\u001a\u00020\u00182\u0006\u0010\u0019\u001a\u00020\u0004H\u0016J\u0016\u0010\u001a\u001a\u00020\u00042\f\u0010\u001b\u001a\b\u0012\u0004\u0012\u00020\u00030\u0002H\u0016J\u0016\u0010\u001c\u001a\b\u0012\u0004\u0012\u00020\u00030\u00022\u0006\u0010\u001d\u001a\u00020\tH\u0016J\u0010\u0010\u001e\u001a\u00020\u00052\u0006\u0010\u001d\u001a\u00020\tH\u0016J\u0012\u0010\u001f\u001a\u00020 *\b\u0012\u0004\u0012\u00020\u00030\u0002H\u0002J\u0012\u0010!\u001a\b\u0012\u0004\u0012\u00020\u00040\"*\u00020\u0003H\u0002R\u0014\u0010\u000e\u001a\b\u0012\u0004\u0012\u00020\u00040\u000fX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u000b\u001a\u00020\fX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0010\u0010\u0011R\u0011\u0010\u0006\u001a\u00020\u0007¢\u0006\b\n��\u001a\u0004\b\u0012\u0010\u0013R\u0014\u0010\n\u001a\u00020\tX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0014\u0010\u0015R\u0014\u0010\b\u001a\u00020\tX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0016\u0010\u0015¨\u0006#"}, d2 = {"Lcom/kotlinnlp/hanclassifier/HANClassifier;", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor;", "", "Lcom/kotlinnlp/hanclassifier/EncodedSentence;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "Lcom/kotlinnlp/simplednn/deeplearning/attention/han/HANParameters;", "model", "Lcom/kotlinnlp/hanclassifier/HANClassifierModel;", "useDropout", "", "propagateToInput", "id", "", "(Lcom/kotlinnlp/hanclassifier/HANClassifierModel;ZZI)V", "encoder", "Lcom/kotlinnlp/simplednn/deeplearning/attention/han/HANEncoder;", "getId", "()I", "getModel", "()Lcom/kotlinnlp/hanclassifier/HANClassifierModel;", "getPropagateToInput", "()Z", "getUseDropout", "backward", "", "outputErrors", "forward", "input", "getInputErrors", "copy", "getParamsErrors", "toHierarchyGroup", "Lcom/kotlinnlp/simplednn/deeplearning/attention/han/HierarchyGroup;", "toHierarchySequence", "Lcom/kotlinnlp/simplednn/deeplearning/attention/han/HierarchySequence;", "hanclassifier"})
/* loaded from: input_file:com/kotlinnlp/hanclassifier/HANClassifier.class */
public final class HANClassifier implements NeuralProcessor<List<? extends EncodedSentence>, DenseNDArray, DenseNDArray, List<? extends EncodedSentence>, HANParameters> {
    private final HANEncoder<DenseNDArray> encoder;

    @NotNull
    private final HANClassifierModel model;
    private final boolean useDropout;
    private final boolean propagateToInput;
    private final int id;

    @NotNull
    /* renamed from: forward, reason: avoid collision after fix types in other method */
    public DenseNDArray forward2(@NotNull List<EncodedSentence> input) {
        Intrinsics.checkParameterIsNotNull(input, "input");
        return this.encoder.forward((HierarchyItem) toHierarchyGroup(input));
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public /* bridge */ /* synthetic */ DenseNDArray forward(List<? extends EncodedSentence> list) {
        return forward2((List<EncodedSentence>) list);
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public void backward(@NotNull DenseNDArray outputErrors) {
        Intrinsics.checkParameterIsNotNull(outputErrors, "outputErrors");
        this.encoder.backward(outputErrors);
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    /* renamed from: getInputErrors */
    public List<? extends EncodedSentence> getInputErrors2(boolean z) {
        HierarchyItem inputErrors2 = this.encoder.getInputErrors2(false);
        if (inputErrors2 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.deeplearning.attention.han.HierarchyGroup");
        }
        HierarchyGroup hierarchyGroup = (HierarchyGroup) inputErrors2;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(hierarchyGroup, 10));
        for (HierarchyItem hierarchyItem : hierarchyGroup) {
            if (hierarchyItem == null) {
                throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.deeplearning.attention.han.HierarchySequence<*>");
            }
            HierarchySequence<NDArray> hierarchySequence = (HierarchySequence) hierarchyItem;
            ArrayList arrayList2 = new ArrayList(CollectionsKt.collectionSizeOrDefault(hierarchySequence, 10));
            for (NDArray nDArray : hierarchySequence) {
                if (nDArray == null) {
                    throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
                }
                arrayList2.add((DenseNDArray) nDArray);
            }
            arrayList.add(new EncodedSentence(arrayList2));
        }
        return arrayList;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    /* renamed from: getParamsErrors */
    public HANParameters getParamsErrors2(boolean z) {
        return this.encoder.getParamsErrors2(z);
    }

    private final HierarchyGroup toHierarchyGroup(@NotNull List<EncodedSentence> list) {
        List<EncodedSentence> list2 = list;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list2, 10));
        Iterator<T> it = list2.iterator();
        while (it.hasNext()) {
            arrayList.add(toHierarchySequence((EncodedSentence) it.next()));
        }
        Object[] array = arrayList.toArray(new HierarchySequence[0]);
        if (array == null) {
            throw new TypeCastException("null cannot be cast to non-null type kotlin.Array<T>");
        }
        HierarchyItem[] hierarchyItemArr = (HierarchyItem[]) array;
        return new HierarchyGroup((HierarchyItem[]) Arrays.copyOf(hierarchyItemArr, hierarchyItemArr.length));
    }

    private final HierarchySequence<DenseNDArray> toHierarchySequence(@NotNull EncodedSentence encodedSentence) {
        List<DenseNDArray> tokens = encodedSentence.getTokens();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(tokens, 10));
        Iterator<T> it = tokens.iterator();
        while (it.hasNext()) {
            arrayList.add((DenseNDArray) it.next());
        }
        Object[] array = arrayList.toArray(new DenseNDArray[0]);
        if (array == null) {
            throw new TypeCastException("null cannot be cast to non-null type kotlin.Array<T>");
        }
        DenseNDArray[] denseNDArrayArr = (DenseNDArray[]) array;
        return new HierarchySequence<>((DenseNDArray[]) Arrays.copyOf(denseNDArrayArr, denseNDArrayArr.length));
    }

    @NotNull
    public final HANClassifierModel getModel() {
        return this.model;
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public boolean getUseDropout() {
        return this.useDropout;
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public boolean getPropagateToInput() {
        return this.propagateToInput;
    }

    @Override // com.kotlinnlp.utils.ItemsPool.IDItem
    public int getId() {
        return this.id;
    }

    public HANClassifier(@NotNull HANClassifierModel model, boolean z, boolean z2, int i) {
        Intrinsics.checkParameterIsNotNull(model, "model");
        this.model = model;
        this.useDropout = z;
        this.propagateToInput = z2;
        this.id = i;
        this.encoder = new HANEncoder<>(this.model.getHan(), getUseDropout(), getPropagateToInput(), null, 0, 24, null);
    }

    public /* synthetic */ HANClassifier(HANClassifierModel hANClassifierModel, boolean z, boolean z2, int i, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(hANClassifierModel, (i2 & 2) != 0 ? false : z, (i2 & 4) != 0 ? false : z2, (i2 & 8) != 0 ? 0 : i);
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    public List<EncodedSentence> propagateErrors(@NotNull DenseNDArray errors, @NotNull Optimizer<? super HANParameters> optimizer, boolean z) {
        Intrinsics.checkParameterIsNotNull(errors, "errors");
        Intrinsics.checkParameterIsNotNull(optimizer, "optimizer");
        return (List) NeuralProcessor.DefaultImpls.propagateErrors(this, errors, optimizer, z);
    }
}
