package com.kotlinnlp.simplednn.deeplearning.sequencelabeling;

import com.kotlinnlp.simplednn.core.functionalities.losses.MulticlassMSECalculator;
import com.kotlinnlp.simplednn.core.neuralnetwork.NetworkParameters;
import com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor;
import com.kotlinnlp.simplednn.deeplearning.sequencelabeling.SWSLabeler;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.Shape;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.Unit;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;

/* compiled from: SWSLabeler.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��t\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0011\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0010\b\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u0015\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u000b\u0018��2\u00020\u0001:\u000278B\r\u0012\u0006\u0010\u0002\u001a\u00020\u0003¢\u0006\u0002\u0010\u0004J\u0010\u0010\u0012\u001a\u00020\u00132\u0006\u0010\u0014\u001a\u00020\u0015H\u0002J\u0010\u0010\u0016\u001a\u00020\u00132\u0006\u0010\u0017\u001a\u00020\u0007H\u0002J\u0018\u0010\u0018\u001a\u00020\u00132\u0006\u0010\u0017\u001a\u00020\u00072\u0006\u0010\u0014\u001a\u00020\u0015H\u0002J\u0010\u0010\u0019\u001a\u00020\u001a2\u0006\u0010\u001b\u001a\u00020\u000bH\u0002J\"\u0010\u001c\u001a\u0014\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00020\u001e\u0012\u0004\u0012\u00020\u00070\u001d0\n2\u0006\u0010\u0017\u001a\u00020\u0007H\u0002J\u001f\u0010\u001f\u001a\b\u0012\u0004\u0012\u00020\u000b0\n2\f\u0010 \u001a\b\u0012\u0004\u0012\u00020\u00070\u0006¢\u0006\u0002\u0010!J\u001e\u0010\"\u001a\u00020\u00132\f\u0010#\u001a\b\u0012\u0004\u0012\u00020\u00130$2\u0006\u0010%\u001a\u00020\u001aH\u0002J\b\u0010&\u001a\u00020\u000bH\u0002J\u0010\u0010'\u001a\u00020\u000b2\u0006\u0010(\u001a\u00020)H\u0002J\u001b\u0010*\u001a\b\u0012\u0004\u0012\u00020\u00070\u00062\b\b\u0002\u0010+\u001a\u00020\u001a¢\u0006\u0002\u0010,J\b\u0010-\u001a\u00020.H\u0002J\u0010\u0010/\u001a\u00020\u00072\u0006\u00100\u001a\u00020\u000bH\u0002J\u0010\u00101\u001a\u00020\u00132\u0006\u00102\u001a\u00020\u001eH\u0002J3\u00103\u001a\u00020\u00132\f\u0010 \u001a\b\u0012\u0004\u0012\u00020\u00070\u00062\u0006\u0010(\u001a\u00020)2\u0006\u0010\u0014\u001a\u00020\u00152\b\b\u0002\u0010%\u001a\u00020\u001a¢\u0006\u0002\u00104J\u001b\u00105\u001a\u00020\u00132\f\u0010 \u001a\b\u0012\u0004\u0012\u00020\u00070\u0006H\u0002¢\u0006\u0002\u00106R\u0016\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00070\u0006X\u0082.¢\u0006\u0004\n\u0002\u0010\bR\u0014\u0010\t\u001a\b\u0012\u0004\u0012\u00020\u000b0\nX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\f\u001a\u00020\rX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u000e\u001a\b\u0012\u0004\u0012\u00020\u00070\u000fX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0010\u001a\u00020\u0011X\u0082.¢\u0006\u0002\n��¨\u00069"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/sequencelabeling/SWSLabeler;", "", "network", "Lcom/kotlinnlp/simplednn/deeplearning/sequencelabeling/SWSLNetwork;", "(Lcom/kotlinnlp/simplednn/deeplearning/sequencelabeling/SWSLNetwork;)V", "inputSequenceErrors", "", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "[Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "labels", "Ljava/util/ArrayList;", "Lcom/kotlinnlp/simplednn/deeplearning/sequencelabeling/SWSLabeler$Label;", "lossCalculator", "Lcom/kotlinnlp/simplednn/core/functionalities/losses/MulticlassMSECalculator;", "processor", "Lcom/kotlinnlp/simplednn/core/neuralprocessor/feedforward/FeedforwardNeuralProcessor;", "sequence", "Lcom/kotlinnlp/simplednn/deeplearning/sequencelabeling/SlidingWindowSequence;", "accumulateErrors", "", "optimizer", "Lcom/kotlinnlp/simplednn/deeplearning/sequencelabeling/SWSLOptimizer;", "accumulateInputErrors", "errors", "accumulateLabelsEmbeddingsErrors", "addLabel", "", "label", "alignLabelsEmbeddingsErrors", "Lkotlin/Pair;", "", "annotate", "inputSequence", "([Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;)Ljava/util/ArrayList;", "forwardSequence", "forEachPrediction", "Lkotlin/Function0;", "useDropout", "getBestLabel", "getGoldLabel", "goldLabels", "", "getInputSequenceErrors", "copy", "(Z)[Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "getNetworkErrors", "Lcom/kotlinnlp/simplednn/deeplearning/sequencelabeling/SWSLabeler$NetworkErrors;", "getOutputErrors", "goldLabel", "initInputErrors", "size", "learn", "([Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;[ILcom/kotlinnlp/simplednn/deeplearning/sequencelabeling/SWSLOptimizer;Z)V", "setNewSequence", "([Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;)V", "Label", "NetworkErrors", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/sequencelabeling/SWSLabeler.class */
public final class SWSLabeler {
    private SlidingWindowSequence sequence;
    private final ArrayList<Label> labels;
    private final FeedforwardNeuralProcessor<DenseNDArray> processor;
    private final MulticlassMSECalculator lossCalculator;
    private DenseNDArray[] inputSequenceErrors;
    private final SWSLNetwork network;

    /* compiled from: SWSLabeler.kt */
    @Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��&\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0010\b\n��\n\u0002\u0010\u0006\n\u0002\b\t\n\u0002\u0010\u000b\n\u0002\b\u0003\n\u0002\u0010\u000e\n��\b\u0086\b\u0018��2\u00020\u0001B\u0017\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0004\u001a\u00020\u0005¢\u0006\u0002\u0010\u0006J\t\u0010\u000b\u001a\u00020\u0003HÆ\u0003J\t\u0010\f\u001a\u00020\u0005HÆ\u0003J\u001d\u0010\r\u001a\u00020��2\b\b\u0002\u0010\u0002\u001a\u00020\u00032\b\b\u0002\u0010\u0004\u001a\u00020\u0005HÆ\u0001J\u0013\u0010\u000e\u001a\u00020\u000f2\b\u0010\u0010\u001a\u0004\u0018\u00010\u0001HÖ\u0003J\t\u0010\u0011\u001a\u00020\u0003HÖ\u0001J\t\u0010\u0012\u001a\u00020\u0013HÖ\u0001R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\bR\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\t\u0010\n¨\u0006\u0014"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/sequencelabeling/SWSLabeler$Label;", "", "index", "", "score", "", "(ID)V", "getIndex", "()I", "getScore", "()D", "component1", "component2", "copy", "equals", "", "other", "hashCode", "toString", "", "simplednn"})
    /* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/sequencelabeling/SWSLabeler$Label.class */
    public static final class Label {
        private final int index;
        private final double score;

        public final int getIndex() {
            return this.index;
        }

        public final double getScore() {
            return this.score;
        }

        public Label(int i, double d) {
            this.index = i;
            this.score = d;
        }

        public /* synthetic */ Label(int i, double d, int i2, DefaultConstructorMarker defaultConstructorMarker) {
            this(i, (i2 & 2) != 0 ? 1.0d : d);
        }

        public final int component1() {
            return this.index;
        }

        public final double component2() {
            return this.score;
        }

        @NotNull
        public final Label copy(int i, double d) {
            return new Label(i, d);
        }

        @NotNull
        public static /* bridge */ /* synthetic */ Label copy$default(Label label, int i, double d, int i2, Object obj) {
            if ((i2 & 1) != 0) {
                i = label.index;
            }
            if ((i2 & 2) != 0) {
                d = label.score;
            }
            return label.copy(i, d);
        }

        public String toString() {
            return "Label(index=" + this.index + ", score=" + this.score + ")";
        }

        public int hashCode() {
            int i = this.index * 31;
            return i + ((int) (i ^ (Double.doubleToLongBits(this.score) >>> 32)));
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof Label)) {
                return false;
            }
            Label label = (Label) obj;
            return (this.index == label.index) && Double.compare(this.score, label.score) == 0;
        }
    }

    /* compiled from: SWSLabeler.kt */
    @Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��,\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\f\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0010\b\n��\n\u0002\u0010\u000e\n��\b\u0086\b\u0018��2\u00020\u0001B\u001d\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0005¢\u0006\u0002\u0010\u0007J\t\u0010\r\u001a\u00020\u0003HÆ\u0003J\t\u0010\u000e\u001a\u00020\u0005HÆ\u0003J\t\u0010\u000f\u001a\u00020\u0005HÆ\u0003J'\u0010\u0010\u001a\u00020��2\b\b\u0002\u0010\u0002\u001a\u00020\u00032\b\b\u0002\u0010\u0004\u001a\u00020\u00052\b\b\u0002\u0010\u0006\u001a\u00020\u0005HÆ\u0001J\u0013\u0010\u0011\u001a\u00020\u00122\b\u0010\u0013\u001a\u0004\u0018\u00010\u0001HÖ\u0003J\t\u0010\u0014\u001a\u00020\u0015HÖ\u0001J\t\u0010\u0016\u001a\u00020\u0017HÖ\u0001R\u0011\u0010\u0006\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\b\u0010\tR\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\n\u0010\tR\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u000b\u0010\f¨\u0006\u0018"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/sequencelabeling/SWSLabeler$NetworkErrors;", "", "paramsErrors", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;", "labelsEmbeddingsErrors", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "inputErrors", "(Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;)V", "getInputErrors", "()Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "getLabelsEmbeddingsErrors", "getParamsErrors", "()Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;", "component1", "component2", "component3", "copy", "equals", "", "other", "hashCode", "", "toString", "", "simplednn"})
    /* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/sequencelabeling/SWSLabeler$NetworkErrors.class */
    public static final class NetworkErrors {

        @NotNull
        private final NetworkParameters paramsErrors;

        @NotNull
        private final DenseNDArray labelsEmbeddingsErrors;

        @NotNull
        private final DenseNDArray inputErrors;

        @NotNull
        public final NetworkParameters getParamsErrors() {
            return this.paramsErrors;
        }

        @NotNull
        public final DenseNDArray getLabelsEmbeddingsErrors() {
            return this.labelsEmbeddingsErrors;
        }

        @NotNull
        public final DenseNDArray getInputErrors() {
            return this.inputErrors;
        }

        public NetworkErrors(@NotNull NetworkParameters networkParameters, @NotNull DenseNDArray denseNDArray, @NotNull DenseNDArray denseNDArray2) {
            Intrinsics.checkParameterIsNotNull(networkParameters, "paramsErrors");
            Intrinsics.checkParameterIsNotNull(denseNDArray, "labelsEmbeddingsErrors");
            Intrinsics.checkParameterIsNotNull(denseNDArray2, "inputErrors");
            this.paramsErrors = networkParameters;
            this.labelsEmbeddingsErrors = denseNDArray;
            this.inputErrors = denseNDArray2;
        }

        @NotNull
        public final NetworkParameters component1() {
            return this.paramsErrors;
        }

        @NotNull
        public final DenseNDArray component2() {
            return this.labelsEmbeddingsErrors;
        }

        @NotNull
        public final DenseNDArray component3() {
            return this.inputErrors;
        }

        @NotNull
        public final NetworkErrors copy(@NotNull NetworkParameters networkParameters, @NotNull DenseNDArray denseNDArray, @NotNull DenseNDArray denseNDArray2) {
            Intrinsics.checkParameterIsNotNull(networkParameters, "paramsErrors");
            Intrinsics.checkParameterIsNotNull(denseNDArray, "labelsEmbeddingsErrors");
            Intrinsics.checkParameterIsNotNull(denseNDArray2, "inputErrors");
            return new NetworkErrors(networkParameters, denseNDArray, denseNDArray2);
        }

        @NotNull
        public static /* bridge */ /* synthetic */ NetworkErrors copy$default(NetworkErrors networkErrors, NetworkParameters networkParameters, DenseNDArray denseNDArray, DenseNDArray denseNDArray2, int i, Object obj) {
            if ((i & 1) != 0) {
                networkParameters = networkErrors.paramsErrors;
            }
            if ((i & 2) != 0) {
                denseNDArray = networkErrors.labelsEmbeddingsErrors;
            }
            if ((i & 4) != 0) {
                denseNDArray2 = networkErrors.inputErrors;
            }
            return networkErrors.copy(networkParameters, denseNDArray, denseNDArray2);
        }

        public String toString() {
            return "NetworkErrors(paramsErrors=" + this.paramsErrors + ", labelsEmbeddingsErrors=" + this.labelsEmbeddingsErrors + ", inputErrors=" + this.inputErrors + ")";
        }

        public int hashCode() {
            NetworkParameters networkParameters = this.paramsErrors;
            int hashCode = (networkParameters != null ? networkParameters.hashCode() : 0) * 31;
            DenseNDArray denseNDArray = this.labelsEmbeddingsErrors;
            int hashCode2 = (hashCode + (denseNDArray != null ? denseNDArray.hashCode() : 0)) * 31;
            DenseNDArray denseNDArray2 = this.inputErrors;
            return hashCode2 + (denseNDArray2 != null ? denseNDArray2.hashCode() : 0);
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof NetworkErrors)) {
                return false;
            }
            NetworkErrors networkErrors = (NetworkErrors) obj;
            return Intrinsics.areEqual(this.paramsErrors, networkErrors.paramsErrors) && Intrinsics.areEqual(this.labelsEmbeddingsErrors, networkErrors.labelsEmbeddingsErrors) && Intrinsics.areEqual(this.inputErrors, networkErrors.inputErrors);
        }
    }

    @NotNull
    public final ArrayList<Label> annotate(@NotNull DenseNDArray[] denseNDArrayArr) {
        Intrinsics.checkParameterIsNotNull(denseNDArrayArr, "inputSequence");
        if (!(!(denseNDArrayArr.length == 0))) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        setNewSequence(denseNDArrayArr);
        forwardSequence(new Function0<Unit>() { // from class: com.kotlinnlp.simplednn.deeplearning.sequencelabeling.SWSLabeler$annotate$1
            public /* bridge */ /* synthetic */ Object invoke() {
                m74invoke();
                return Unit.INSTANCE;
            }

            /* renamed from: invoke, reason: collision with other method in class */
            public final void m74invoke() {
                SWSLabeler.Label bestLabel;
                SWSLabeler sWSLabeler = SWSLabeler.this;
                bestLabel = SWSLabeler.this.getBestLabel();
                sWSLabeler.addLabel(bestLabel);
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            {
                super(0);
            }
        }, false);
        return this.labels;
    }

    public final void learn(@NotNull DenseNDArray[] denseNDArrayArr, @NotNull final int[] iArr, @NotNull final SWSLOptimizer sWSLOptimizer, boolean z) {
        Intrinsics.checkParameterIsNotNull(denseNDArrayArr, "inputSequence");
        Intrinsics.checkParameterIsNotNull(iArr, "goldLabels");
        Intrinsics.checkParameterIsNotNull(sWSLOptimizer, "optimizer");
        if (!(!(denseNDArrayArr.length == 0))) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        setNewSequence(denseNDArrayArr);
        initInputErrors(denseNDArrayArr.length);
        forwardSequence(new Function0<Unit>() { // from class: com.kotlinnlp.simplednn.deeplearning.sequencelabeling.SWSLabeler$learn$1
            public /* bridge */ /* synthetic */ Object invoke() {
                m75invoke();
                return Unit.INSTANCE;
            }

            /* renamed from: invoke, reason: collision with other method in class */
            public final void m75invoke() {
                SWSLabeler.Label goldLabel;
                FeedforwardNeuralProcessor feedforwardNeuralProcessor;
                DenseNDArray outputErrors;
                goldLabel = SWSLabeler.this.getGoldLabel(iArr);
                feedforwardNeuralProcessor = SWSLabeler.this.processor;
                outputErrors = SWSLabeler.this.getOutputErrors(goldLabel);
                FeedforwardNeuralProcessor.backward$default(feedforwardNeuralProcessor, outputErrors, true, null, 4, null);
                SWSLabeler.this.accumulateErrors(sWSLOptimizer);
                SWSLabeler.this.addLabel(goldLabel);
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(0);
            }
        }, z);
    }

    public static /* bridge */ /* synthetic */ void learn$default(SWSLabeler sWSLabeler, DenseNDArray[] denseNDArrayArr, int[] iArr, SWSLOptimizer sWSLOptimizer, boolean z, int i, Object obj) {
        if ((i & 8) != 0) {
            z = false;
        }
        sWSLabeler.learn(denseNDArrayArr, iArr, sWSLOptimizer, z);
    }

    @NotNull
    public final DenseNDArray[] getInputSequenceErrors(boolean z) {
        if (!z) {
            DenseNDArray[] denseNDArrayArr = this.inputSequenceErrors;
            if (denseNDArrayArr != null) {
                return denseNDArrayArr;
            }
            Intrinsics.throwUninitializedPropertyAccessException("inputSequenceErrors");
            return denseNDArrayArr;
        }
        DenseNDArray[] denseNDArrayArr2 = this.inputSequenceErrors;
        if (denseNDArrayArr2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("inputSequenceErrors");
        }
        DenseNDArray[] denseNDArrayArr3 = new DenseNDArray[denseNDArrayArr2.length];
        int length = denseNDArrayArr3.length;
        for (int i = 0; i < length; i++) {
            int i2 = i;
            int i3 = i;
            DenseNDArray[] denseNDArrayArr4 = this.inputSequenceErrors;
            if (denseNDArrayArr4 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("inputSequenceErrors");
            }
            denseNDArrayArr3[i2] = denseNDArrayArr4[i3].copy();
        }
        return denseNDArrayArr3;
    }

    @NotNull
    public static /* bridge */ /* synthetic */ DenseNDArray[] getInputSequenceErrors$default(SWSLabeler sWSLabeler, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = true;
        }
        return sWSLabeler.getInputSequenceErrors(z);
    }

    private final void setNewSequence(DenseNDArray[] denseNDArrayArr) {
        this.sequence = new SlidingWindowSequence(denseNDArrayArr, this.network.getLeftContextSize(), this.network.getRightContextSize());
        this.labels.clear();
    }

    private final void initInputErrors(int i) {
        DenseNDArray[] denseNDArrayArr = new DenseNDArray[i];
        int length = denseNDArrayArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            denseNDArrayArr[i2] = DenseNDArrayFactory.INSTANCE.zeros(new Shape(this.network.getElementSize(), 0, 2, null));
        }
        this.inputSequenceErrors = denseNDArrayArr;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final boolean addLabel(Label label) {
        return this.labels.add(label);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final Label getBestLabel() {
        DenseNDArray output = this.processor.getOutput(false);
        int argMaxIndex = output.argMaxIndex();
        return new Label(argMaxIndex, output.get(argMaxIndex).doubleValue());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final Label getGoldLabel(int[] iArr) {
        SlidingWindowSequence slidingWindowSequence = this.sequence;
        if (slidingWindowSequence == null) {
            Intrinsics.throwUninitializedPropertyAccessException("sequence");
        }
        return new Label(iArr[slidingWindowSequence.getFocusIndex()], 1.0d);
    }

    private final void forwardSequence(Function0<Unit> function0, boolean z) {
        SlidingWindowSequence slidingWindowSequence = this.sequence;
        if (slidingWindowSequence == null) {
            Intrinsics.throwUninitializedPropertyAccessException("sequence");
        }
        SWSLFeaturesExtractor sWSLFeaturesExtractor = new SWSLFeaturesExtractor(slidingWindowSequence, this.labels, this.network);
        while (true) {
            SlidingWindowSequence slidingWindowSequence2 = this.sequence;
            if (slidingWindowSequence2 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("sequence");
            }
            if (!slidingWindowSequence2.focusInRange()) {
                return;
            }
            this.processor.forward(sWSLFeaturesExtractor.getFeatures(), z);
            function0.invoke();
            SlidingWindowSequence slidingWindowSequence3 = this.sequence;
            if (slidingWindowSequence3 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("sequence");
            }
            slidingWindowSequence3.shift();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final DenseNDArray getOutputErrors(Label label) {
        return this.lossCalculator.calculateErrors(this.processor.getOutput(false), DenseNDArrayFactory.INSTANCE.oneHotEncoder(this.network.getNumberOfLabels(), label.getIndex()));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final void accumulateErrors(SWSLOptimizer sWSLOptimizer) {
        NetworkErrors networkErrors = getNetworkErrors();
        sWSLOptimizer.accumulateParamsErrors(networkErrors.getParamsErrors());
        accumulateLabelsEmbeddingsErrors(networkErrors.getLabelsEmbeddingsErrors(), sWSLOptimizer);
        accumulateInputErrors(networkErrors.getInputErrors());
    }

    private final NetworkErrors getNetworkErrors() {
        DenseNDArray inputErrors = this.processor.getInputErrors(false);
        return new NetworkErrors(this.processor.getParamsErrors(false), inputErrors.getRange(0, this.network.getLabelsEmbeddingsSize()), inputErrors.getRange(this.network.getLabelsEmbeddingsSize(), this.network.getFeaturesSize()));
    }

    private final void accumulateLabelsEmbeddingsErrors(DenseNDArray denseNDArray, SWSLOptimizer sWSLOptimizer) {
        Iterator<T> it = alignLabelsEmbeddingsErrors(denseNDArray).iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            sWSLOptimizer.accumulateLabelEmbeddingErrors(((Number) pair.component1()).intValue(), (DenseNDArray) pair.component2());
        }
    }

    private final void accumulateInputErrors(DenseNDArray denseNDArray) {
        SlidingWindowSequence slidingWindowSequence = this.sequence;
        if (slidingWindowSequence == null) {
            Intrinsics.throwUninitializedPropertyAccessException("sequence");
        }
        for (Pair pair : ArraysKt.zip(slidingWindowSequence.getContext(), denseNDArray.splitV(this.network.getElementSize()))) {
            Integer num = (Integer) pair.component1();
            DenseNDArray denseNDArray2 = (DenseNDArray) pair.component2();
            if (num != null) {
                DenseNDArray[] denseNDArrayArr = this.inputSequenceErrors;
                if (denseNDArrayArr == null) {
                    Intrinsics.throwUninitializedPropertyAccessException("inputSequenceErrors");
                }
                denseNDArrayArr[num.intValue()].assignSum((NDArray<?>) denseNDArray2);
            }
        }
    }

    private final ArrayList<Pair<Integer, DenseNDArray>> alignLabelsEmbeddingsErrors(DenseNDArray denseNDArray) {
        ArrayList<Pair<Integer, DenseNDArray>> arrayList = new ArrayList<>();
        if (!this.labels.isEmpty()) {
            List reversed = ArraysKt.reversed(denseNDArray.splitV(this.network.getLabelEmbeddingSize()));
            int i = 0;
            IntIterator it = RangesKt.downTo(CollectionsKt.getLastIndex(this.labels), Math.max(0, this.labels.size() - reversed.size())).iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                arrayList.add(new Pair<>(Integer.valueOf(this.labels.get(it.nextInt()).getIndex()), reversed.get(i2)));
            }
        }
        return arrayList;
    }

    public SWSLabeler(@NotNull SWSLNetwork sWSLNetwork) {
        Intrinsics.checkParameterIsNotNull(sWSLNetwork, "network");
        this.network = sWSLNetwork;
        this.labels = new ArrayList<>();
        this.processor = new FeedforwardNeuralProcessor<>(this.network.getClassifier(), 0, 2, null);
        this.lossCalculator = new MulticlassMSECalculator();
    }
}
