package com.kotlinnlp.tokenslabeler;

import com.kotlinnlp.linguisticdescription.sentence.RealSentence;
import com.kotlinnlp.linguisticdescription.sentence.properties.AnnotatedSegment;
import com.kotlinnlp.linguisticdescription.sentence.token.RealToken;
import com.kotlinnlp.simplednn.core.arrays.ParamsArray;
import com.kotlinnlp.simplednn.core.neuralprocessor.NeuralProcessor;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.simplednn.deeplearning.birnn.deepbirnn.DeepBiRNNEncoder;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.tokensencoder.TokensEncoder;
import com.kotlinnlp.tokensencoder.TokensEncoderModel;
import com.kotlinnlp.tokenslabeler.helpers.LabelsDecoder;
import com.kotlinnlp.tokenslabeler.language.IOBTag;
import com.kotlinnlp.tokenslabeler.language.Label;
import com.kotlinnlp.tokenslabeler.language.ScoredLabel;
import com.kotlinnlp.tokenslabeler.language.Segment;
import com.kotlinnlp.utils.BeamManager;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.comparisons.ComparisonsKt;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.Ref;
import kotlin.sequences.Sequence;
import kotlin.sequences.SequencesKt;
import org.jetbrains.annotations.NotNull;

/* compiled from: TokensLabeler.kt */
@Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"��p\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0006\n\u0002\b\u0002\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0010\u000b\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n\u0002\b\u0006\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0002\u0018�� ,2,\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\u0002\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00050\u0004\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00050\u0004\u0012\u0004\u0012\u00020\u00060\u0001:\u0001,B+\u0012\u0006\u0010\u0007\u001a\u00020\b\u0012\b\b\u0002\u0010\t\u001a\u00020\n\u0012\b\b\u0002\u0010\u000b\u001a\u00020\n\u0012\b\b\u0002\u0010\f\u001a\u00020\r¢\u0006\u0002\u0010\u000eJ\u0016\u0010\u001b\u001a\u00020\u001c2\f\u0010\u001d\u001a\b\u0012\u0004\u0012\u00020\u00050\u0004H\u0016J\u001c\u0010\u001e\u001a\b\u0012\u0004\u0012\u00020\u00050\u00042\f\u0010\u001f\u001a\b\u0012\u0004\u0012\u00020\u00030\u0002H\u0016J\u0010\u0010 \u001a\u00020\u00062\u0006\u0010!\u001a\u00020\u0016H\u0016J\"\u0010\"\u001a\u0014\u0012\f\u0012\n\u0012\u0002\b\u00030#R\u00020$0\u0004j\u0002`%2\u0006\u0010!\u001a\u00020\u0016H\u0016J\u001c\u0010&\u001a\b\u0012\u0004\u0012\u00020'0\u00042\f\u0010(\u001a\b\u0012\u0004\u0012\u00020\u00050\u0004H\u0002J\u001a\u0010)\u001a\b\u0012\u0004\u0012\u00020'0\u00042\f\u0010\u001f\u001a\b\u0012\u0004\u0012\u00020\u00030\u0002J\u001a\u0010*\u001a\b\u0012\u0004\u0012\u00020+0\u00042\f\u0010\u001f\u001a\b\u0012\u0004\u0012\u00020\u00030\u0002R\u0014\u0010\u000f\u001a\b\u0012\u0004\u0012\u00020\u00050\u0010X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\f\u001a\u00020\rX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0011\u0010\u0012R\u0011\u0010\u0007\u001a\u00020\b¢\u0006\b\n��\u001a\u0004\b\u0013\u0010\u0014R\u0014\u0010\u0015\u001a\u00020\u0016X\u0096D¢\u0006\b\n��\u001a\u0004\b\u0017\u0010\u0018R \u0010\u0019\u001a\u0014\u0012\u0004\u0012\u00020\u0003\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00030\u00020\u001aX\u0082\u0004¢\u0006\u0002\n��¨\u0006-"}, d2 = {"Lcom/kotlinnlp/tokenslabeler/TokensLabeler;", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor;", "Lcom/kotlinnlp/linguisticdescription/sentence/RealSentence;", "Lcom/kotlinnlp/linguisticdescription/sentence/token/RealToken;", Label.EMPTY_VALUE, "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/NeuralProcessor$NoInputErrors;", "model", "Lcom/kotlinnlp/tokenslabeler/TokensLabelerModel;", "encoderDropout", Label.EMPTY_VALUE, "outputMergeDropout", "id", Label.EMPTY_VALUE, "(Lcom/kotlinnlp/tokenslabeler/TokensLabelerModel;DDI)V", "biRNNProcessor", "Lcom/kotlinnlp/simplednn/deeplearning/birnn/deepbirnn/DeepBiRNNEncoder;", "getId", "()I", "getModel", "()Lcom/kotlinnlp/tokenslabeler/TokensLabelerModel;", "propagateToInput", Label.EMPTY_VALUE, "getPropagateToInput", "()Z", "tokensEncoder", "Lcom/kotlinnlp/tokensencoder/TokensEncoder;", "backward", Label.EMPTY_VALUE, "outputErrors", "forward", "input", "getInputErrors", "copy", "getParamsErrors", "Lcom/kotlinnlp/simplednn/core/arrays/ParamsArray$Errors;", "Lcom/kotlinnlp/simplednn/core/arrays/ParamsArray;", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsErrorsList;", "greedyDecode", "Lcom/kotlinnlp/tokenslabeler/language/ScoredLabel;", "predictions", "predict", "predictAsSegments", "Lcom/kotlinnlp/linguisticdescription/sentence/properties/AnnotatedSegment;", "Companion", "tokenslabeler"})
/* loaded from: input_file:com/kotlinnlp/tokenslabeler/TokensLabeler.class */
public final class TokensLabeler implements NeuralProcessor<RealSentence<RealToken>, List<? extends DenseNDArray>, List<? extends DenseNDArray>, NeuralProcessor.NoInputErrors> {
    private final boolean propagateToInput = false;
    private final TokensEncoder<RealToken, RealSentence<RealToken>> tokensEncoder;
    private final DeepBiRNNEncoder<DenseNDArray> biRNNProcessor;

    @NotNull
    private final TokensLabelerModel model;
    private final int id;
    public static final Companion Companion = new Companion(null);

    /* compiled from: TokensLabeler.kt */
    @Metadata(mv = {1, 1, 15}, bv = {1, 0, 3}, k = 1, d1 = {"��\"\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\b\u0086\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J(\u0010\u0003\u001a\b\u0012\u0004\u0012\u00020\u00050\u00042\f\u0010\u0006\u001a\b\u0012\u0004\u0012\u00020\u00070\u00042\f\u0010\b\u001a\b\u0012\u0004\u0012\u00020\t0\u0004¨\u0006\n"}, d2 = {"Lcom/kotlinnlp/tokenslabeler/TokensLabeler$Companion;", Label.EMPTY_VALUE, "()V", "buildSegments", Label.EMPTY_VALUE, "Lcom/kotlinnlp/linguisticdescription/sentence/properties/AnnotatedSegment;", "tokens", "Lcom/kotlinnlp/linguisticdescription/sentence/token/RealToken;", "labels", "Lcom/kotlinnlp/tokenslabeler/language/ScoredLabel;", "tokenslabeler"})
    /* loaded from: input_file:com/kotlinnlp/tokenslabeler/TokensLabeler$Companion.class */
    public static final class Companion {
        @NotNull
        public final List<AnnotatedSegment> buildSegments(@NotNull List<? extends RealToken> list, @NotNull List<ScoredLabel> list2) {
            Intrinsics.checkParameterIsNotNull(list, "tokens");
            Intrinsics.checkParameterIsNotNull(list2, "labels");
            ArrayList arrayList = new ArrayList();
            int i = 0;
            for (Object obj : CollectionsKt.zipWithNext(CollectionsKt.plus(list2, CollectionsKt.listOf((Object) null)))) {
                int i2 = i;
                i++;
                if (i2 < 0) {
                    CollectionsKt.throwIndexOverflow();
                }
                Pair pair = (Pair) obj;
                ScoredLabel scoredLabel = (ScoredLabel) pair.component1();
                ScoredLabel scoredLabel2 = (ScoredLabel) pair.component2();
                if (scoredLabel == null) {
                    Intrinsics.throwNpe();
                }
                if (scoredLabel.getType() == IOBTag.Beginning) {
                    arrayList.add(new Segment(i2, list.get(i2).getPosition().getStart(), scoredLabel.getValue(), scoredLabel.getScore()));
                }
                if (scoredLabel.getType() == IOBTag.Inside) {
                    ((Segment) CollectionsKt.last(arrayList)).addScore(scoredLabel.getScore());
                }
                if (scoredLabel.getType() != IOBTag.Outside && (scoredLabel2 == null || scoredLabel2.getType() != IOBTag.Inside)) {
                    ((Segment) CollectionsKt.last(arrayList)).setEndToken(i2);
                    ((Segment) CollectionsKt.last(arrayList)).setEndChar(list.get(i2).getPosition().getEnd());
                }
            }
            ArrayList arrayList2 = arrayList;
            ArrayList arrayList3 = new ArrayList(CollectionsKt.collectionSizeOrDefault(arrayList2, 10));
            Iterator it = arrayList2.iterator();
            while (it.hasNext()) {
                arrayList3.add(((Segment) it.next()).toAnnotatedSegment());
            }
            return arrayList3;
        }

        private Companion() {
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    public boolean getPropagateToInput() {
        return this.propagateToInput;
    }

    @NotNull
    public final List<ScoredLabel> predict(@NotNull RealSentence<RealToken> realSentence) {
        List elements;
        Sequence asSequence;
        Intrinsics.checkParameterIsNotNull(realSentence, "input");
        List<DenseNDArray> forward = forward(realSentence);
        LabelsDecoder.LabeledState labeledState = (LabelsDecoder.LabeledState) new LabelsDecoder(forward, this.model, 3, 5, 10).findBestConfiguration(true);
        if (labeledState != null && (elements = labeledState.getElements()) != null && (asSequence = CollectionsKt.asSequence(elements)) != null) {
            Sequence sortedWith = SequencesKt.sortedWith(asSequence, new Comparator<T>() { // from class: com.kotlinnlp.tokenslabeler.TokensLabeler$predict$$inlined$sortedBy$1
                @Override // java.util.Comparator
                public final int compare(T t, T t2) {
                    return ComparisonsKt.compareValues(Integer.valueOf(((BeamManager.StateElement) t).getId()), Integer.valueOf(((BeamManager.StateElement) t2).getId()));
                }
            });
            if (sortedWith != null) {
                Sequence map = SequencesKt.map(sortedWith, new Function1<BeamManager.StateElement<ScoredLabel>, ScoredLabel>() { // from class: com.kotlinnlp.tokenslabeler.TokensLabeler$predict$2
                    @NotNull
                    public final ScoredLabel invoke(@NotNull BeamManager.StateElement<ScoredLabel> stateElement) {
                        Intrinsics.checkParameterIsNotNull(stateElement, "it");
                        return (ScoredLabel) stateElement.getValue();
                    }
                });
                if (map != null) {
                    List<ScoredLabel> list = SequencesKt.toList(map);
                    if (list != null) {
                        return list;
                    }
                }
            }
        }
        return greedyDecode(forward);
    }

    @NotNull
    public final List<AnnotatedSegment> predictAsSegments(@NotNull RealSentence<RealToken> realSentence) {
        Intrinsics.checkParameterIsNotNull(realSentence, "input");
        return Companion.buildSegments(realSentence.getTokens(), predict(realSentence));
    }

    @NotNull
    public List<DenseNDArray> forward(@NotNull RealSentence<RealToken> realSentence) {
        Intrinsics.checkParameterIsNotNull(realSentence, "input");
        return this.biRNNProcessor.forward((List) this.tokensEncoder.forward(realSentence));
    }

    public void backward(@NotNull List<DenseNDArray> list) {
        Intrinsics.checkParameterIsNotNull(list, "outputErrors");
        this.biRNNProcessor.backward(list);
        this.tokensEncoder.backward(this.biRNNProcessor.getInputErrors(false));
    }

    @NotNull
    /* renamed from: getInputErrors, reason: merged with bridge method [inline-methods] */
    public NeuralProcessor.NoInputErrors m1getInputErrors(boolean z) {
        return NeuralProcessor.NoInputErrors.INSTANCE;
    }

    @NotNull
    public List<ParamsArray.Errors<?>> getParamsErrors(boolean z) {
        return CollectionsKt.plus(this.tokensEncoder.getParamsErrors(z), this.biRNNProcessor.getParamsErrors(z));
    }

    private final List<ScoredLabel> greedyDecode(final List<DenseNDArray> list) {
        final Ref.ObjectRef objectRef = new Ref.ObjectRef();
        objectRef.element = (ScoredLabel) null;
        List<DenseNDArray> list2 = list;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list2, 10));
        int i = 0;
        for (Object obj : list2) {
            int i2 = i;
            i++;
            if (i2 < 0) {
                CollectionsKt.throwIndexOverflow();
            }
            final DenseNDArray denseNDArray = (DenseNDArray) obj;
            for (Object obj2 : SequencesKt.map(ArraysKt.asSequence(denseNDArray.argSorted(true)), new Function1<Integer, ScoredLabel>() { // from class: com.kotlinnlp.tokenslabeler.TokensLabeler$greedyDecode$$inlined$mapIndexed$lambda$1
                /* JADX INFO: Access modifiers changed from: package-private */
                /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
                {
                    super(1);
                }

                public /* bridge */ /* synthetic */ Object invoke(Object obj3) {
                    return invoke(((Number) obj3).intValue());
                }

                @NotNull
                public final ScoredLabel invoke(int i3) {
                    Object element = this.getModel().getOutputLabels().getElement(i3);
                    if (element == null) {
                        Intrinsics.throwNpe();
                    }
                    return new ScoredLabel((Label) element, denseNDArray.get(i3).doubleValue());
                }
            })) {
                ScoredLabel scoredLabel = (ScoredLabel) obj2;
                if (LabelsDecoder.Companion.canFollow(scoredLabel, (ScoredLabel) objectRef.element) && (i2 != CollectionsKt.getLastIndex(list) || scoredLabel.getType() == IOBTag.Outside)) {
                    objectRef.element = (ScoredLabel) obj2;
                    ScoredLabel scoredLabel2 = (ScoredLabel) objectRef.element;
                    if (scoredLabel2 == null) {
                        Intrinsics.throwNpe();
                    }
                    arrayList.add(scoredLabel2);
                }
            }
            throw new NoSuchElementException("Sequence contains no element matching the predicate.");
        }
        return arrayList;
    }

    @NotNull
    public final TokensLabelerModel getModel() {
        return this.model;
    }

    public int getId() {
        return this.id;
    }

    public TokensLabeler(@NotNull TokensLabelerModel tokensLabelerModel, double d, double d2, int i) {
        Intrinsics.checkParameterIsNotNull(tokensLabelerModel, "model");
        this.model = tokensLabelerModel;
        this.id = i;
        this.tokensEncoder = TokensEncoderModel.DefaultImpls.buildEncoder$default(this.model.getTokensEncoderModel(), 0, 1, (Object) null);
        this.biRNNProcessor = new DeepBiRNNEncoder<>(this.model.getBiRNN(), d, d2, true, 0, 16, (DefaultConstructorMarker) null);
    }

    public /* synthetic */ TokensLabeler(TokensLabelerModel tokensLabelerModel, double d, double d2, int i, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(tokensLabelerModel, (i2 & 2) != 0 ? 0.0d : d, (i2 & 4) != 0 ? 0.0d : d2, (i2 & 8) != 0 ? 0 : i);
    }

    @NotNull
    public NeuralProcessor.NoInputErrors propagateErrors(@NotNull List<DenseNDArray> list, @NotNull ParamsOptimizer paramsOptimizer, boolean z) {
        Intrinsics.checkParameterIsNotNull(list, "errors");
        Intrinsics.checkParameterIsNotNull(paramsOptimizer, "optimizer");
        return (NeuralProcessor.NoInputErrors) NeuralProcessor.DefaultImpls.propagateErrors(this, list, paramsOptimizer, z);
    }
}
