package com.kotlinnlp.neuralparser.parsers.lhrparser.neuralmodules.labeler;

import com.kotlinnlp.dependencytree.DependencyTree;
import com.kotlinnlp.linguisticdescription.GrammaticalConfiguration;
import com.kotlinnlp.neuralparser.parsers.lhrparser.neuralmodules.labeler.utils.LossCriterion;
import com.kotlinnlp.neuralparser.parsers.lhrparser.neuralmodules.labeler.utils.LossCriterionType;
import com.kotlinnlp.simplednn.core.functionalities.activations.ActivationFunction;
import com.kotlinnlp.simplednn.core.functionalities.activations.Softmax;
import com.kotlinnlp.simplednn.core.functionalities.activations.Tanh;
import com.kotlinnlp.simplednn.core.functionalities.initializers.Initializer;
import com.kotlinnlp.simplednn.core.layers.LayerInterface;
import com.kotlinnlp.simplednn.core.layers.LayerType;
import com.kotlinnlp.simplednn.core.neuralnetwork.NeuralNetwork;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.utils.DictionarySet;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import kotlin.Metadata;
import kotlin.NoWhenBranchMatchedException;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: LabelerModel.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��>\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\b\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\u0018�� \u001a2\u00020\u0001:\u0001\u001aB#\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005\u0012\u0006\u0010\u0007\u001a\u00020\b¢\u0006\u0002\u0010\tJ\"\u0010\u0014\u001a\b\u0012\u0004\u0012\u00020\u00160\u00152\f\u0010\u0017\u001a\b\u0012\u0004\u0012\u00020\u00160\u00152\u0006\u0010\u0018\u001a\u00020\u0019R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\n\u0010\u000bR\u0017\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005¢\u0006\b\n��\u001a\u0004\b\f\u0010\rR\u0011\u0010\u0007\u001a\u00020\b¢\u0006\b\n��\u001a\u0004\b\u000e\u0010\u000fR\u0011\u0010\u0010\u001a\u00020\u0011¢\u0006\b\n��\u001a\u0004\b\u0012\u0010\u0013¨\u0006\u001b"}, d2 = {"Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/LabelerModel;", "Ljava/io/Serializable;", "contextEncodingSize", "", "grammaticalConfigurations", "Lcom/kotlinnlp/utils/DictionarySet;", "Lcom/kotlinnlp/linguisticdescription/GrammaticalConfiguration;", "lossCriterionType", "Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/utils/LossCriterionType;", "(ILcom/kotlinnlp/utils/DictionarySet;Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/utils/LossCriterionType;)V", "getContextEncodingSize", "()I", "getGrammaticalConfigurations", "()Lcom/kotlinnlp/utils/DictionarySet;", "getLossCriterionType", "()Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/utils/LossCriterionType;", "networkModel", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;", "getNetworkModel", "()Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;", "calculateLoss", "", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "predictions", "goldTree", "Lcom/kotlinnlp/dependencytree/DependencyTree;", "Companion", "neuralparser"})
/* loaded from: input_file:com/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/LabelerModel.class */
public final class LabelerModel implements Serializable {

    @NotNull
    private final NeuralNetwork networkModel;
    private final int contextEncodingSize;

    @NotNull
    private final DictionarySet<GrammaticalConfiguration> grammaticalConfigurations;

    @NotNull
    private final LossCriterionType lossCriterionType;
    private static final long serialVersionUID = 1;
    public static final Companion Companion = new Companion(null);

    /* compiled from: LabelerModel.kt */
    @Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��\u0014\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0010\t\n\u0002\b\u0002\b\u0086\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002R\u0016\u0010\u0003\u001a\u00020\u00048\u0002X\u0083T¢\u0006\b\n��\u0012\u0004\b\u0005\u0010\u0002¨\u0006\u0006"}, d2 = {"Lcom/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/LabelerModel$Companion;", "", "()V", "serialVersionUID", "", "serialVersionUID$annotations", "neuralparser"})
    /* loaded from: input_file:com/kotlinnlp/neuralparser/parsers/lhrparser/neuralmodules/labeler/LabelerModel$Companion.class */
    public static final class Companion {
        private static /* synthetic */ void serialVersionUID$annotations() {
        }

        private Companion() {
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    @NotNull
    public final NeuralNetwork getNetworkModel() {
        return this.networkModel;
    }

    @NotNull
    public final List<DenseNDArray> calculateLoss(@NotNull List<DenseNDArray> list, @NotNull DependencyTree dependencyTree) {
        Intrinsics.checkParameterIsNotNull(list, "predictions");
        Intrinsics.checkParameterIsNotNull(dependencyTree, "goldTree");
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (Object obj : list) {
            int i2 = i;
            i++;
            DenseNDArray denseNDArray = (DenseNDArray) obj;
            int intValue = ((Number) dependencyTree.getElements().get(i2)).intValue();
            LossCriterion invoke = LossCriterion.Companion.invoke(this.lossCriterionType);
            DictionarySet<GrammaticalConfiguration> dictionarySet = this.grammaticalConfigurations;
            GrammaticalConfiguration configuration = dependencyTree.getConfiguration(intValue);
            if (configuration == null) {
                Intrinsics.throwNpe();
            }
            Integer id = dictionarySet.getId(configuration);
            if (id == null) {
                Intrinsics.throwNpe();
            }
            arrayList.add(invoke.getPredictionErrors(denseNDArray, id.intValue()));
        }
        return arrayList;
    }

    public final int getContextEncodingSize() {
        return this.contextEncodingSize;
    }

    @NotNull
    public final DictionarySet<GrammaticalConfiguration> getGrammaticalConfigurations() {
        return this.grammaticalConfigurations;
    }

    @NotNull
    public final LossCriterionType getLossCriterionType() {
        return this.lossCriterionType;
    }

    public LabelerModel(int i, @NotNull DictionarySet<GrammaticalConfiguration> dictionarySet, @NotNull LossCriterionType lossCriterionType) {
        Softmax softmax;
        Intrinsics.checkParameterIsNotNull(dictionarySet, "grammaticalConfigurations");
        Intrinsics.checkParameterIsNotNull(lossCriterionType, "lossCriterionType");
        this.contextEncodingSize = i;
        this.grammaticalConfigurations = dictionarySet;
        this.lossCriterionType = lossCriterionType;
        LayerInterface[] layerInterfaceArr = new LayerInterface[3];
        layerInterfaceArr[0] = new LayerInterface(CollectionsKt.listOf(new Integer[]{Integer.valueOf(this.contextEncodingSize), Integer.valueOf(this.contextEncodingSize)}), (LayerType.Input) null, (LayerType.Connection) null, (ActivationFunction) null, false, 0.0d, 62, (DefaultConstructorMarker) null);
        layerInterfaceArr[1] = new LayerInterface(this.contextEncodingSize, (LayerType.Input) null, LayerType.Connection.Affine, new Tanh(), false, 0.0d, 50, (DefaultConstructorMarker) null);
        LayerType.Input input = LayerType.Input.Dense;
        int size = this.grammaticalConfigurations.getSize();
        LayerType.Connection connection = LayerType.Connection.Feedforward;
        switch (this.lossCriterionType) {
            case Softmax:
                softmax = new Softmax();
                break;
            case HingeLoss:
                softmax = null;
                break;
            default:
                throw new NoWhenBranchMatchedException();
        }
        layerInterfaceArr[2] = new LayerInterface(size, input, connection, (ActivationFunction) softmax, false, 0.0d, 16, (DefaultConstructorMarker) null);
        this.networkModel = new NeuralNetwork(layerInterfaceArr, (Initializer) null, (Initializer) null, 6, (DefaultConstructorMarker) null);
    }
}
