package com.kotlinnlp.neuralparser.parsers.lhrparser.neuralmodules.labeler;

import com.kotlinnlp.dependencytree.DependencyTree;
import com.kotlinnlp.linguisticdescription.GrammaticalConfiguration;
import com.kotlinnlp.linguisticdescription.sentence.token.TokenIdentificable;
import com.kotlinnlp.lssencoder.LatentSyntacticStructure;
import com.kotlinnlp.neuralparser.parsers.lhrparser.neuralmodules.labeler.utils.ScoredGrammar;
import com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor;
import com.kotlinnlp.simplednn.core.neuralprocessor.batchfeedforward.BatchFeedforwardProcessor;
import com.kotlinnlp.simplednn.core.optimizer.Optimizer;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.Shape;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.comparisons.ComparisonsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.IntRange;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;

/* compiled from: Labeler.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��b\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0010\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018��2,\u0012\u0004\u0012\u00020\u0002\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00040\u0003\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00040\u0003\u0012\u0004\u0012\u00020\u0005\u0012\u0004\u0012\u00020\u00060\u0001:\u0002)*B\u001f\u0012\u0006\u0010\u0007\u001a\u00020\b\u0012\u0006\u0010\t\u001a\u00020\n\u0012\b\b\u0002\u0010\u000b\u001a\u00020\f¢\u0006\u0002\u0010\rJ\u0016\u0010\u0018\u001a\u00020\u00192\f\u0010\u001a\u001a\b\u0012\u0004\u0012\u00020\u00040\u0003H\u0016J&\u0010\u001b\u001a\b\u0012\u0004\u0012\u00020\u00040\u00032\u0006\u0010\u001c\u001a\u00020\f2\u000e\u0010\u001d\u001a\n\u0012\u0002\b\u0003\u0012\u0002\b\u00030\u001eH\u0002J\u0016\u0010\u001f\u001a\b\u0012\u0004\u0012\u00020\u00040\u00032\u0006\u0010 \u001a\u00020\u0002H\u0016J\u0010\u0010!\u001a\u00020\"2\u0006\u0010#\u001a\u00020\fH\u0002J\u0010\u0010$\u001a\u00020\u00052\u0006\u0010%\u001a\u00020\nH\u0016J\u0010\u0010&\u001a\u00020\u00062\u0006\u0010%\u001a\u00020\nH\u0016J\u001a\u0010'\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020(0\u00030\u00032\u0006\u0010 \u001a\u00020\u0002R\u000e\u0010\u000e\u001a\u00020\u000fX\u0082.¢\u0006\u0002\n��R\u0014\u0010\u000b\u001a\u00020\fX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0010\u0010\u0011R\u000e\u0010\u0007\u001a\u00020\bX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0012\u001a\b\u0012\u0004\u0012\u00020\u00040\u0013X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0014\u001a\u00020\nX\u0096D¢\u0006\b\n��\u001a\u0004\b\u0015\u0010\u0016R\u0014\u0010\t\u001a\u00020\nX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0017\u0010\u0016¨\u0006+"}, d2 = {"Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/Labeler;", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor;", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/Labeler$Input;", "", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/Labeler$InputErrors;", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/LabelerParams;", "model", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/LabelerModel;", "useDropout", "", "id", "", "(Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/LabelerModel;ZI)V", "dependencyTree", "Lcom/kotlinnlp/dependencytree/DependencyTree;", "getId", "()I", "processor", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/batchfeedforward/BatchFeedforwardProcessor;", "propagateToInput", "getPropagateToInput", "()Z", "getUseDropout", "backward", "", "outputErrors", "extractFeatures", "tokenId", "lss", "Lcom/kotlinnlp/lssencoder/LatentSyntacticStructure;", "forward", "input", "getGrammaticalConfiguration", "Lcom/kotlinnlp/linguisticdescription/GrammaticalConfiguration;", "index", "getInputErrors", "copy", "getParamsErrors", "predict", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/utils/ScoredGrammar;", "Input", "InputErrors", "neuralparser"})
/* loaded from: input_file:com/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/Labeler.class */
public final class Labeler implements NeuralProcessor<Input, List<? extends DenseNDArray>, List<? extends DenseNDArray>, InputErrors, LabelerParams> {
    private final boolean propagateToInput = true;
    private final BatchFeedforwardProcessor<DenseNDArray> processor;
    private DependencyTree dependencyTree;
    private final LabelerModel model;
    private final boolean useDropout;
    private final int id;

    /* compiled from: Labeler.kt */
    @Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��\u0018\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0006\u0018��2\u00020\u0001B\u001d\u0012\u000e\u0010\u0002\u001a\n\u0012\u0002\b\u0003\u0012\u0002\b\u00030\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005¢\u0006\u0002\u0010\u0006R\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\bR\u0019\u0010\u0002\u001a\n\u0012\u0002\b\u0003\u0012\u0002\b\u00030\u0003¢\u0006\b\n��\u001a\u0004\b\t\u0010\n¨\u0006\u000b"}, d2 = {"Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/Labeler$Input;", "", "lss", "Lcom/kotlinnlp/lssencoder/LatentSyntacticStructure;", "dependencyTree", "Lcom/kotlinnlp/dependencytree/DependencyTree;", "(Lcom/kotlinnlp/lssencoder/LatentSyntacticStructure;Lcom/kotlinnlp/dependencytree/DependencyTree;)V", "getDependencyTree", "()Lcom/kotlinnlp/dependencytree/DependencyTree;", "getLss", "()Lcom/kotlinnlp/lssencoder/LatentSyntacticStructure;", "neuralparser"})
    /* loaded from: input_file:com/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/Labeler$Input.class */
    public static final class Input {

        @NotNull
        private final LatentSyntacticStructure<?, ?> lss;

        @NotNull
        private final DependencyTree dependencyTree;

        @NotNull
        public final LatentSyntacticStructure<?, ?> getLss() {
            return this.lss;
        }

        @NotNull
        public final DependencyTree getDependencyTree() {
            return this.dependencyTree;
        }

        public Input(@NotNull LatentSyntacticStructure<?, ?> lss, @NotNull DependencyTree dependencyTree) {
            Intrinsics.checkParameterIsNotNull(lss, "lss");
            Intrinsics.checkParameterIsNotNull(dependencyTree, "dependencyTree");
            this.lss = lss;
            this.dependencyTree = dependencyTree;
        }
    }

    /* compiled from: Labeler.kt */
    @Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��\u0018\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\b\u0006\u0018��2\u00020\u0001B\u001b\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00030\u0005¢\u0006\u0002\u0010\u0006R\u0017\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00030\u0005¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\bR\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\t\u0010\n¨\u0006\u000b"}, d2 = {"Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/Labeler$InputErrors;", "", "rootErrors", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "contextErrors", "", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;Ljava/util/List;)V", "getContextErrors", "()Ljava/util/List;", "getRootErrors", "()Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "neuralparser"})
    /* loaded from: input_file:com/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/Labeler$InputErrors.class */
    public static final class InputErrors {

        @NotNull
        private final DenseNDArray rootErrors;

        @NotNull
        private final List<DenseNDArray> contextErrors;

        @NotNull
        public final DenseNDArray getRootErrors() {
            return this.rootErrors;
        }

        @NotNull
        public final List<DenseNDArray> getContextErrors() {
            return this.contextErrors;
        }

        public InputErrors(@NotNull DenseNDArray rootErrors, @NotNull List<DenseNDArray> contextErrors) {
            Intrinsics.checkParameterIsNotNull(rootErrors, "rootErrors");
            Intrinsics.checkParameterIsNotNull(contextErrors, "contextErrors");
            this.rootErrors = rootErrors;
            this.contextErrors = contextErrors;
        }
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public boolean getPropagateToInput() {
        return this.propagateToInput;
    }

    @NotNull
    public final List<List<ScoredGrammar>> predict(@NotNull Input input) {
        Intrinsics.checkParameterIsNotNull(input, "input");
        List<DenseNDArray> forward = forward(input);
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(forward, 10));
        for (DenseNDArray denseNDArray : forward) {
            IntRange until = RangesKt.until(0, denseNDArray.getLength());
            ArrayList arrayList2 = new ArrayList(CollectionsKt.collectionSizeOrDefault(until, 10));
            Iterator<Integer> it = until.iterator();
            while (it.hasNext()) {
                int nextInt = ((IntIterator) it).nextInt();
                arrayList2.add(new ScoredGrammar(getGrammaticalConfiguration(nextInt), denseNDArray.get(nextInt).doubleValue()));
            }
            arrayList.add(CollectionsKt.sortedWith(arrayList2, new Comparator<T>() { // from class: com.kotlinnlp.neuralparser.parsers.lhrparser.neuralmodules.labeler.Labeler$$special$$inlined$compareByDescending$1
                /* JADX WARN: Multi-variable type inference failed */
                @Override // java.util.Comparator
                public final int compare(T t, T t2) {
                    return ComparisonsKt.compareValues(Double.valueOf(((ScoredGrammar) t2).getScore()), Double.valueOf(((ScoredGrammar) t).getScore()));
                }
            }));
        }
        return arrayList;
    }

    /* JADX WARN: Type inference failed for: r3v2, types: [com.kotlinnlp.linguisticdescription.sentence.SentenceIdentificable] */
    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    public List<DenseNDArray> forward(@NotNull Input input) {
        Intrinsics.checkParameterIsNotNull(input, "input");
        this.dependencyTree = input.getDependencyTree();
        BatchFeedforwardProcessor<DenseNDArray> batchFeedforwardProcessor = this.processor;
        List tokens = input.getLss().getSentence().getTokens();
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(tokens, 10));
        Iterator it = tokens.iterator();
        while (it.hasNext()) {
            arrayList.add(extractFeatures(((TokenIdentificable) it.next()).getId(), input.getLss()));
        }
        return BatchFeedforwardProcessor.forward$default((BatchFeedforwardProcessor) batchFeedforwardProcessor, new ArrayList(arrayList), false, 2, (Object) null);
    }

    /* renamed from: backward, reason: avoid collision after fix types in other method */
    public void backward2(@NotNull List<DenseNDArray> outputErrors) {
        Intrinsics.checkParameterIsNotNull(outputErrors, "outputErrors");
        this.processor.backward2(outputErrors);
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public /* bridge */ /* synthetic */ void backward(List<? extends DenseNDArray> list) {
        backward2((List<DenseNDArray>) list);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    /* renamed from: getInputErrors */
    public InputErrors getInputErrors2(boolean z) {
        DenseNDArray denseNDArray;
        List<List<DenseNDArray>> inputsErrors = this.processor.getInputsErrors(false);
        int size = inputsErrors.size();
        ArrayList arrayList = new ArrayList(size);
        for (int i = 0; i < size; i++) {
            arrayList.add(DenseNDArrayFactory.INSTANCE.zeros(new Shape(this.model.getContextEncodingSize(), 0, 2, null)));
        }
        ArrayList arrayList2 = arrayList;
        DenseNDArray zeros = DenseNDArrayFactory.INSTANCE.zeros(new Shape(this.model.getContextEncodingSize(), 0, 2, null));
        int i2 = 0;
        for (Object obj : inputsErrors) {
            int i3 = i2;
            i2++;
            List list = (List) obj;
            DenseNDArray denseNDArray2 = (DenseNDArray) list.get(0);
            DenseNDArray denseNDArray3 = (DenseNDArray) list.get(1);
            DependencyTree dependencyTree = this.dependencyTree;
            if (dependencyTree == null) {
                Intrinsics.throwUninitializedPropertyAccessException("dependencyTree");
            }
            int intValue = dependencyTree.getElements().get(i3).intValue();
            DenseNDArray denseNDArray4 = (DenseNDArray) arrayList2.get(i3);
            DependencyTree dependencyTree2 = this.dependencyTree;
            if (dependencyTree2 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("dependencyTree");
            }
            Integer head = dependencyTree2.getHead(intValue);
            if (head != null) {
                int intValue2 = head.intValue();
                DependencyTree dependencyTree3 = this.dependencyTree;
                if (dependencyTree3 == null) {
                    Intrinsics.throwUninitializedPropertyAccessException("dependencyTree");
                }
                denseNDArray = (DenseNDArray) arrayList2.get(dependencyTree3.getPosition(intValue2));
                if (denseNDArray != null) {
                    denseNDArray4.assignSum((NDArray<?>) denseNDArray2);
                    denseNDArray.assignSum((NDArray<?>) denseNDArray3);
                }
            }
            denseNDArray = zeros;
            denseNDArray4.assignSum((NDArray<?>) denseNDArray2);
            denseNDArray.assignSum((NDArray<?>) denseNDArray3);
        }
        return new InputErrors(zeros, arrayList2);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    @NotNull
    /* renamed from: getParamsErrors */
    public LabelerParams getParamsErrors2(boolean z) {
        return new LabelerParams(this.processor.getParamsErrors2(z));
    }

    private final GrammaticalConfiguration getGrammaticalConfiguration(int i) {
        GrammaticalConfiguration element = this.model.getGrammaticalConfigurations().getElement(i);
        if (element == null) {
            Intrinsics.throwNpe();
        }
        return element;
    }

    /* JADX WARN: Code restructure failed: missing block: B:7:0x0044, code lost:
    
        if (r3 != null) goto L11;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private final java.util.List<com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray> extractFeatures(int r7, com.kotlinnlp.lssencoder.LatentSyntacticStructure<?, ?> r8) {
        /*
            r6 = this;
            r0 = 2
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray[] r0 = new com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray[r0]
            r1 = r0
            r2 = 0
            r3 = r8
            r4 = r7
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r3 = r3.getContextVectorById(r4)
            r1[r2] = r3
            r1 = r0
            r2 = 1
            r3 = r6
            com.kotlinnlp.dependencytree.DependencyTree r3 = r3.dependencyTree
            r4 = r3
            if (r4 != 0) goto L1b
            java.lang.String r4 = "dependencyTree"
            kotlin.jvm.internal.Intrinsics.throwUninitializedPropertyAccessException(r4)
        L1b:
            r4 = r7
            java.lang.Integer r3 = r3.getHead(r4)
            r4 = r3
            if (r4 == 0) goto L4a
            r9 = r3
            r14 = r2
            r13 = r1
            r12 = r0
            r0 = r9
            java.lang.Number r0 = (java.lang.Number) r0
            int r0 = r0.intValue()
            r10 = r0
            r0 = r8
            r1 = r10
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r0 = r0.getContextVectorById(r1)
            r15 = r0
            r0 = r12
            r1 = r13
            r2 = r14
            r3 = r15
            r4 = r3
            if (r4 == 0) goto L4a
            goto L4f
        L4a:
            r3 = r8
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r3 = r3.getVirtualRoot()
        L4f:
            r1[r2] = r3
            java.util.List r0 = kotlin.collections.CollectionsKt.listOf(r0)
            return r0
        */
        throw new UnsupportedOperationException("Method not decompiled: com.kotlinnlp.neuralparser.parsers.lhrparser.neuralmodules.labeler.Labeler.extractFeatures(int, com.kotlinnlp.lssencoder.LatentSyntacticStructure):java.util.List");
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public boolean getUseDropout() {
        return this.useDropout;
    }

    @Override // com.kotlinnlp.utils.ItemsPool.IDItem
    public int getId() {
        return this.id;
    }

    public Labeler(@NotNull LabelerModel model, boolean z, int i) {
        Intrinsics.checkParameterIsNotNull(model, "model");
        this.model = model;
        this.useDropout = z;
        this.id = i;
        this.propagateToInput = true;
        this.processor = new BatchFeedforwardProcessor<>(this.model.getNetworkModel(), getUseDropout(), true, null, 0, 24, null);
    }

    public /* synthetic */ Labeler(LabelerModel labelerModel, boolean z, int i, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(labelerModel, z, (i2 & 4) != 0 ? 0 : i);
    }

    @NotNull
    /* renamed from: propagateErrors, reason: avoid collision after fix types in other method */
    public InputErrors propagateErrors2(@NotNull List<DenseNDArray> errors, @NotNull Optimizer<? super LabelerParams> optimizer, boolean z) {
        Intrinsics.checkParameterIsNotNull(errors, "errors");
        Intrinsics.checkParameterIsNotNull(optimizer, "optimizer");
        return (InputErrors) NeuralProcessor.DefaultImpls.propagateErrors(this, errors, optimizer, z);
    }

    @Override // com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor
    public /* bridge */ /* synthetic */ InputErrors propagateErrors(List<? extends DenseNDArray> list, Optimizer<? super LabelerParams> optimizer, boolean z) {
        return propagateErrors2((List<DenseNDArray>) list, optimizer, z);
    }
}
