package com.kotlinnlp.neuralparser.parsers.lhrparser.neuralmodels.labeler;

import com.kotlinnlp.dependencytree.DependencyTree;
import com.kotlinnlp.dependencytree.Deprel;
import com.kotlinnlp.neuralparser.parsers.lhrparser.LatentSyntacticStructure;
import com.kotlinnlp.neuralparser.parsers.lhrparser.neuralmodels.labeler.utils.ScoredDeprel;
import com.kotlinnlp.neuralparser.parsers.lhrparser.neuralmodels.labeler.utils.ScoredDeprelKt;
import com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor;
import com.kotlinnlp.simplednn.core.neuralprocessor.batchfeedforward.BatchFeedforwardProcessor;
import com.kotlinnlp.simplednn.core.optimizer.Optimizer;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.Shape;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import java.util.ArrayList;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;

/* compiled from: DeprelLabeler.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��^\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0010\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018��2,\u0012\u0004\u0012\u00020\u0002\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00040\u0003\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00040\u0003\u0012\u0004\u0012\u00020\u0005\u0012\u0004\u0012\u00020\u00060\u0001:\u0002'(B\u001f\u0012\u0006\u0010\u0007\u001a\u00020\b\u0012\u0006\u0010\t\u001a\u00020\n\u0012\b\b\u0002\u0010\u000b\u001a\u00020\f¢\u0006\u0002\u0010\rJ\u0016\u0010\u0018\u001a\u00020\u00192\f\u0010\u001a\u001a\b\u0012\u0004\u0012\u00020\u00040\u0003H\u0016J\u001c\u0010\u001b\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00040\u00030\u00032\u0006\u0010\u001c\u001a\u00020\u0002H\u0002J\u0016\u0010\u001d\u001a\b\u0012\u0004\u0012\u00020\u00040\u00032\u0006\u0010\u001c\u001a\u00020\u0002H\u0016J\u000e\u0010\u001e\u001a\u00020\u001f2\u0006\u0010 \u001a\u00020\fJ\u0010\u0010!\u001a\u00020\u00052\u0006\u0010\"\u001a\u00020\nH\u0016J\u0010\u0010#\u001a\u00020\u00062\u0006\u0010\"\u001a\u00020\nH\u0016J\u001e\u0010$\u001a\u0012\u0012\u000e\u0012\f\u0012\u0004\u0012\u00020%0\u0003j\u0002`&0\u00032\u0006\u0010\u001c\u001a\u00020\u0002R\u000e\u0010\u000e\u001a\u00020\u000fX\u0082.¢\u0006\u0002\n��R\u0014\u0010\u000b\u001a\u00020\fX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0010\u0010\u0011R\u000e\u0010\u0007\u001a\u00020\bX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0012\u001a\b\u0012\u0004\u0012\u00020\u00040\u0013X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0014\u001a\u00020\nX\u0096D¢\u0006\b\n��\u001a\u0004\b\u0015\u0010\u0016R\u0014\u0010\t\u001a\u00020\nX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0017\u0010\u0016¨\u0006)"}, d2 = {"Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodels/labeler/DeprelLabeler;", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor;", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodels/labeler/DeprelLabeler$Input;", "", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodels/labeler/DeprelLabeler$InputErrors;", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodels/labeler/DeprelLabelerParams;", "model", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodels/labeler/DeprelLabelerModel;", "useDropout", "", "id", "", "(Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodels/labeler/DeprelLabelerModel;ZI)V", "dependencyTree", "Lcom/kotlinnlp/dependencytree/DependencyTree;", "getId", "()I", "processor", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/batchfeedforward/BatchFeedforwardProcessor;", "propagateToInput", "getPropagateToInput", "()Z", "getUseDropout", "backward", "", "outputErrors", "extractFeatures", "input", "forward", "getDeprel", "Lcom/kotlinnlp/dependencytree/Deprel;", "index", "getInputErrors", "copy", "getParamsErrors", "predict", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodels/labeler/utils/ScoredDeprel;", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodels/labeler/utils/ScoredDeprelList;", "Input", "InputErrors", "neuralparser"})
/* loaded from: input_file:com/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodels/labeler/DeprelLabeler.class */
public final class DeprelLabeler implements NeuralProcessor<Input, List<? extends DenseNDArray>, List<? extends DenseNDArray>, InputErrors, DeprelLabelerParams> {
    private final boolean propagateToInput = true;
    private final BatchFeedforwardProcessor<DenseNDArray> processor;
    private DependencyTree dependencyTree;
    private final DeprelLabelerModel model;
    private final boolean useDropout;
    private final int id;

    /* compiled from: DeprelLabeler.kt */
    @Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��\u0018\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0006\u0018��2\u00020\u0001B\u0015\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005¢\u0006\u0002\u0010\u0006R\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\bR\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\t\u0010\n¨\u0006\u000b"}, d2 = {"Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodels/labeler/DeprelLabeler$Input;", "", "lss", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/LatentSyntacticStructure;", "dependencyTree", "Lcom/kotlinnlp/dependencytree/DependencyTree;", "(Lcom/kotlinnlp/neuralparser/parsers/lhrparser/LatentSyntacticStructure;Lcom/kotlinnlp/dependencytree/DependencyTree;)V", "getDependencyTree", "()Lcom/kotlinnlp/dependencytree/DependencyTree;", "getLss", "()Lcom/kotlinnlp/neuralparser/parsers/lhrparser/LatentSyntacticStructure;", "neuralparser"})
    /* loaded from: input_file:com/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodels/labeler/DeprelLabeler$Input.class */
    public static final class Input {

        @NotNull
        private final LatentSyntacticStructure lss;

        @NotNull
        private final DependencyTree dependencyTree;

        @NotNull
        public final LatentSyntacticStructure getLss() {
            return this.lss;
        }

        @NotNull
        public final DependencyTree getDependencyTree() {
            return this.dependencyTree;
        }

        public Input(@NotNull LatentSyntacticStructure latentSyntacticStructure, @NotNull DependencyTree dependencyTree) {
            Intrinsics.checkParameterIsNotNull(latentSyntacticStructure, "lss");
            Intrinsics.checkParameterIsNotNull(dependencyTree, "dependencyTree");
            this.lss = latentSyntacticStructure;
            this.dependencyTree = dependencyTree;
        }
    }

    /* compiled from: DeprelLabeler.kt */
    @Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��\u0018\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\b\u0006\u0018��2\u00020\u0001B\u001b\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00030\u0005¢\u0006\u0002\u0010\u0006R\u0017\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00030\u0005¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\bR\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\t\u0010\n¨\u0006\u000b"}, d2 = {"Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodels/labeler/DeprelLabeler$InputErrors;", "", "rootErrors", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "contextErrors", "", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;Ljava/util/List;)V", "getContextErrors", "()Ljava/util/List;", "getRootErrors", "()Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "neuralparser"})
    /* loaded from: input_file:com/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodels/labeler/DeprelLabeler$InputErrors.class */
    public static final class InputErrors {

        @NotNull
        private final DenseNDArray rootErrors;

        @NotNull
        private final List<DenseNDArray> contextErrors;

        @NotNull
        public final DenseNDArray getRootErrors() {
            return this.rootErrors;
        }

        @NotNull
        public final List<DenseNDArray> getContextErrors() {
            return this.contextErrors;
        }

        public InputErrors(@NotNull DenseNDArray denseNDArray, @NotNull List<DenseNDArray> list) {
            Intrinsics.checkParameterIsNotNull(denseNDArray, "rootErrors");
            Intrinsics.checkParameterIsNotNull(list, "contextErrors");
            this.rootErrors = denseNDArray;
            this.contextErrors = list;
        }
    }

    public boolean getPropagateToInput() {
        return this.propagateToInput;
    }

    @NotNull
    public final List<List<ScoredDeprel>> predict(@NotNull Input input) {
        Intrinsics.checkParameterIsNotNull(input, "input");
        List<DenseNDArray> forward = forward(input);
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(forward, 10));
        for (DenseNDArray denseNDArray : forward) {
            Iterable until = RangesKt.until(0, denseNDArray.getLength());
            ArrayList arrayList2 = new ArrayList(CollectionsKt.collectionSizeOrDefault(until, 10));
            IntIterator it = until.iterator();
            while (it.hasNext()) {
                int nextInt = it.nextInt();
                arrayList2.add(new ScoredDeprel(getDeprel(nextInt), denseNDArray.get(nextInt).doubleValue()));
            }
            arrayList.add(ScoredDeprelKt.sortByScore(arrayList2));
        }
        return arrayList;
    }

    @NotNull
    public List<DenseNDArray> forward(@NotNull Input input) {
        Intrinsics.checkParameterIsNotNull(input, "input");
        this.dependencyTree = input.getDependencyTree();
        return BatchFeedforwardProcessor.forward$default(this.processor, new ArrayList(extractFeatures(input)), false, 2, (Object) null);
    }

    public void backward(@NotNull List<DenseNDArray> list) {
        Intrinsics.checkParameterIsNotNull(list, "outputErrors");
        this.processor.backward(list);
    }

    @NotNull
    /* renamed from: getInputErrors, reason: merged with bridge method [inline-methods] */
    public InputErrors m17getInputErrors(boolean z) {
        List inputsErrors = this.processor.getInputsErrors(false);
        int size = inputsErrors.size();
        ArrayList arrayList = new ArrayList(size);
        for (int i = 0; i < size; i++) {
            arrayList.add(DenseNDArrayFactory.INSTANCE.zeros(new Shape(this.model.getContextEncodingSize(), 0, 2, (DefaultConstructorMarker) null)));
        }
        ArrayList arrayList2 = arrayList;
        DenseNDArray zeros = DenseNDArrayFactory.INSTANCE.zeros(new Shape(this.model.getContextEncodingSize(), 0, 2, (DefaultConstructorMarker) null));
        int i2 = 0;
        for (Object obj : inputsErrors) {
            int i3 = i2;
            i2++;
            List list = (List) obj;
            NDArray nDArray = (DenseNDArray) list.get(0);
            NDArray nDArray2 = (DenseNDArray) list.get(1);
            DenseNDArray denseNDArray = (DenseNDArray) arrayList2.get(i3);
            DependencyTree dependencyTree = this.dependencyTree;
            if (dependencyTree == null) {
                Intrinsics.throwUninitializedPropertyAccessException("dependencyTree");
            }
            Integer num = dependencyTree.getHeads()[i3];
            DenseNDArray denseNDArray2 = num == null ? zeros : (DenseNDArray) arrayList2.get(num.intValue());
            denseNDArray.assignSum(nDArray);
            denseNDArray2.assignSum(nDArray2);
        }
        return new InputErrors(zeros, arrayList2);
    }

    @NotNull
    /* renamed from: getParamsErrors, reason: merged with bridge method [inline-methods] */
    public DeprelLabelerParams m18getParamsErrors(boolean z) {
        return new DeprelLabelerParams(this.processor.getParamsErrors(z));
    }

    @NotNull
    public final Deprel getDeprel(int i) {
        Object element = this.model.getDeprels().getElement(i);
        if (element == null) {
            Intrinsics.throwNpe();
        }
        return (Deprel) element;
    }

    private final List<List<DenseNDArray>> extractFeatures(Input input) {
        DenseNDArray virtualRoot;
        ArrayList arrayList = new ArrayList();
        Iterable until = RangesKt.until(0, input.getLss().getSentence().getTokens().size());
        DependencyTree dependencyTree = this.dependencyTree;
        if (dependencyTree == null) {
            Intrinsics.throwUninitializedPropertyAccessException("dependencyTree");
        }
        for (Pair pair : CollectionsKt.zip(until, dependencyTree.getHeads())) {
            int intValue = ((Number) pair.component1()).intValue();
            Integer num = (Integer) pair.component2();
            ArrayList arrayList2 = arrayList;
            DenseNDArray[] denseNDArrayArr = new DenseNDArray[2];
            denseNDArrayArr[0] = input.getLss().getContextVectors().get(intValue);
            DenseNDArray[] denseNDArrayArr2 = denseNDArrayArr;
            char c = 1;
            if (num != null) {
                DenseNDArray denseNDArray = input.getLss().getContextVectors().get(num.intValue());
                arrayList2 = arrayList2;
                denseNDArrayArr = denseNDArrayArr;
                denseNDArrayArr2 = denseNDArrayArr2;
                c = 1;
                virtualRoot = denseNDArray;
                if (virtualRoot != null) {
                    denseNDArrayArr2[c] = virtualRoot;
                    arrayList2.add(CollectionsKt.listOf(denseNDArrayArr));
                }
            }
            virtualRoot = input.getLss().getVirtualRoot();
            denseNDArrayArr2[c] = virtualRoot;
            arrayList2.add(CollectionsKt.listOf(denseNDArrayArr));
        }
        return arrayList;
    }

    public boolean getUseDropout() {
        return this.useDropout;
    }

    public int getId() {
        return this.id;
    }

    public DeprelLabeler(@NotNull DeprelLabelerModel deprelLabelerModel, boolean z, int i) {
        Intrinsics.checkParameterIsNotNull(deprelLabelerModel, "model");
        this.model = deprelLabelerModel;
        this.useDropout = z;
        this.id = i;
        this.propagateToInput = true;
        this.processor = new BatchFeedforwardProcessor<>(this.model.getNetworkModel(), getUseDropout(), true, (List) null, 0, 24, (DefaultConstructorMarker) null);
    }

    public /* synthetic */ DeprelLabeler(DeprelLabelerModel deprelLabelerModel, boolean z, int i, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(deprelLabelerModel, z, (i2 & 4) != 0 ? 0 : i);
    }

    @NotNull
    public InputErrors propagateErrors(@NotNull List<DenseNDArray> list, @NotNull Optimizer<? super DeprelLabelerParams> optimizer, boolean z) {
        Intrinsics.checkParameterIsNotNull(list, "errors");
        Intrinsics.checkParameterIsNotNull(optimizer, "optimizer");
        return (InputErrors) NeuralProcessor.DefaultImpls.propagateErrors(this, list, optimizer, z);
    }

    public /* bridge */ /* synthetic */ Object propagateErrors(Object obj, Optimizer optimizer, boolean z) {
        return propagateErrors((List<DenseNDArray>) obj, (Optimizer<? super DeprelLabelerParams>) optimizer, z);
    }
}
