package com.kotlinnlp.neuralparser.parsers.transitionbased.templates.actionsscorer;

import com.kotlinnlp.dependencytree.Deprel;
import com.kotlinnlp.dependencytree.POSTag;
import com.kotlinnlp.neuralparser.parsers.transitionbased.templates.inputcontexts.TokensEncodingContext;
import com.kotlinnlp.neuralparser.parsers.transitionbased.templates.supportstructure.compositeprediction.TPDJointSupportStructure;
import com.kotlinnlp.neuralparser.parsers.transitionbased.utils.features.DenseFeatures;
import com.kotlinnlp.neuralparser.parsers.transitionbased.utils.features.DenseFeaturesErrors;
import com.kotlinnlp.neuralparser.parsers.transitionbased.utils.items.DenseItem;
import com.kotlinnlp.simplednn.core.functionalities.activations.ActivationFunction;
import com.kotlinnlp.simplednn.core.neuralnetwork.NetworkParameters;
import com.kotlinnlp.simplednn.core.neuralnetwork.NeuralNetwork;
import com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor;
import com.kotlinnlp.simplednn.core.optimizer.Optimizer;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.simplednn.deeplearning.multitasknetwork.MultiTaskNetworkModel;
import com.kotlinnlp.simplednn.deeplearning.multitasknetwork.MultiTaskNetworkParameters;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import com.kotlinnlp.syntaxdecoder.modules.actionsscorer.ActionsScorerTrainable;
import com.kotlinnlp.syntaxdecoder.transitionsystem.Transition;
import com.kotlinnlp.syntaxdecoder.transitionsystem.state.State;
import com.kotlinnlp.syntaxdecoder.utils.DecodingContext;
import com.kotlinnlp.utils.DictionarySet;
import java.util.List;
import kotlin.Metadata;
import kotlin.Triple;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: TPDJointEmbeddingsActionsScorer.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��\u0090\u0001\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u000e\n\u0002\u0010\b\n\u0002\u0018\u0002\n\u0002\b\b\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\t\b&\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u0002*\u0014\b\u0001\u0010\u0003*\u000e\u0012\u0004\u0012\u0002H\u0003\u0012\u0004\u0012\u0002H\u00010\u0004*\u000e\b\u0002\u0010\u0005*\b\u0012\u0004\u0012\u0002H\u00050\u00062,\u0012\u0004\u0012\u0002H\u0001\u0012\u0004\u0012\u0002H\u0003\u0012\u0004\u0012\u0002H\u0005\u0012\u0004\u0012\u00020\b\u0012\u0004\u0012\u00020\t\u0012\u0004\u0012\u00020\n\u0012\u0004\u0012\u00020\u000b0\u0007BW\u0012\b\u0010\f\u001a\u0004\u0018\u00010\r\u0012\u0006\u0010\u000e\u001a\u00020\u000f\u0012\u0006\u0010\u0010\u001a\u00020\u0011\u0012\f\u0010\u0012\u001a\b\u0012\u0004\u0012\u00020\u00140\u0013\u0012\f\u0010\u0015\u001a\b\u0012\u0004\u0012\u00020\u00160\u0013\u0012\f\u0010\u0017\u001a\b\u0012\u0004\u0012\u00020\u00190\u0018\u0012\f\u0010\u001a\u001a\b\u0012\u0004\u0012\u00020\u001b0\u0018¢\u0006\u0002\u0010\u001cJ6\u00103\u001a\u0002042$\u00105\u001a \u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00020\b\u0012\u0004\u0012\u00020\n062\u0006\u00107\u001a\u00020\u000bH\u0016J>\u00108\u001a\u0002092\u001c\u0010:\u001a\u0018\u0012\u0014\u0012\u00120+R\u000e\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028��0\u00040;2\u0006\u0010<\u001a\u0002092\u0006\u0010=\u001a\u0002092\u0006\u0010>\u001a\u000209H\u0014J6\u0010?\u001a\u00020\t2$\u00105\u001a \u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00020\b\u0012\u0004\u0012\u00020\n062\u0006\u00107\u001a\u00020\u000bH\u0016JP\u0010@\u001a\u0014\u0012\u0004\u0012\u000209\u0012\u0004\u0012\u000209\u0012\u0004\u0012\u0002090A2\u001c\u0010:\u001a\u0018\u0012\u0014\u0012\u00120+R\u000e\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028��0\u00040;2\u0006\u0010B\u001a\u0002092\u0006\u0010C\u001a\u0002092\u0006\u0010D\u001a\u000209H\u0014J\b\u0010E\u001a\u000204H\u0016J\b\u0010F\u001a\u000204H\u0016J\b\u0010G\u001a\u000204H\u0016J6\u0010H\u001a\u0002042$\u00105\u001a \u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00020\b\u0012\u0004\u0012\u00020\n062\u0006\u00107\u001a\u00020\u000bH\u0016J\b\u0010I\u001a\u000204H\u0016R\u0013\u0010\f\u001a\u0004\u0018\u00010\r¢\u0006\b\n��\u001a\u0004\b\u001d\u0010\u001eR\u0017\u0010\u0017\u001a\b\u0012\u0004\u0012\u00020\u00190\u0018¢\u0006\b\n��\u001a\u0004\b\u001f\u0010 R\u0011\u0010\u0010\u001a\u00020\u0011¢\u0006\b\n��\u001a\u0004\b!\u0010\"R\u0017\u0010\u0015\u001a\b\u0012\u0004\u0012\u00020\u00160\u0013¢\u0006\b\n��\u001a\u0004\b#\u0010$R\u0017\u0010\u001a\u001a\b\u0012\u0004\u0012\u00020\u001b0\u0018¢\u0006\b\n��\u001a\u0004\b%\u0010 R\u0011\u0010\u000e\u001a\u00020\u000f¢\u0006\b\n��\u001a\u0004\b&\u0010'R\u0017\u0010\u0012\u001a\b\u0012\u0004\u0012\u00020\u00140\u0013¢\u0006\b\n��\u001a\u0004\b(\u0010$R&\u0010)\u001a\u00020**\u00120+R\u000e\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028��0\u0004X¤\u0004¢\u0006\u0006\u001a\u0004\b,\u0010-R\"\u0010.\u001a\u00020**\u000e\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028��0\u0004X¤\u0004¢\u0006\u0006\u001a\u0004\b/\u00100R&\u00101\u001a\u00020**\u00120+R\u000e\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028��0\u0004X¤\u0004¢\u0006\u0006\u001a\u0004\b2\u0010-¨\u0006J"}, d2 = {"Lcom/kotlinnlp/neuralparser/parsers/transitionbased/templates/actionsscorer/TPDJointEmbeddingsActionsScorer;", "StateType", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/state/State;", "TransitionType", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition;", "InputContextType", "Lcom/kotlinnlp/neuralparser/parsers/transitionbased/templates/inputcontexts/TokensEncodingContext;", "Lcom/kotlinnlp/syntaxdecoder/modules/actionsscorer/ActionsScorerTrainable;", "Lcom/kotlinnlp/neuralparser/parsers/transitionbased/utils/items/DenseItem;", "Lcom/kotlinnlp/neuralparser/parsers/transitionbased/utils/features/DenseFeaturesErrors;", "Lcom/kotlinnlp/neuralparser/parsers/transitionbased/utils/features/DenseFeatures;", "Lcom/kotlinnlp/neuralparser/parsers/transitionbased/templates/supportstructure/compositeprediction/TPDJointSupportStructure;", "activationFunction", "Lcom/kotlinnlp/simplednn/core/functionalities/activations/ActivationFunction;", "transitionNetwork", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;", "posDeprelNetworkModel", "Lcom/kotlinnlp/simplednn/deeplearning/multitasknetwork/MultiTaskNetworkModel;", "transitionOptimizer", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;", "posDeprelOptimizer", "Lcom/kotlinnlp/simplednn/deeplearning/multitasknetwork/MultiTaskNetworkParameters;", "deprelTags", "Lcom/kotlinnlp/utils/DictionarySet;", "Lcom/kotlinnlp/dependencytree/Deprel;", "posTags", "Lcom/kotlinnlp/dependencytree/POSTag;", "(Lcom/kotlinnlp/simplednn/core/functionalities/activations/ActivationFunction;Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;Lcom/kotlinnlp/simplednn/deeplearning/multitasknetwork/MultiTaskNetworkModel;Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;Lcom/kotlinnlp/utils/DictionarySet;Lcom/kotlinnlp/utils/DictionarySet;)V", "getActivationFunction", "()Lcom/kotlinnlp/simplednn/core/functionalities/activations/ActivationFunction;", "getDeprelTags", "()Lcom/kotlinnlp/utils/DictionarySet;", "getPosDeprelNetworkModel", "()Lcom/kotlinnlp/simplednn/deeplearning/multitasknetwork/MultiTaskNetworkModel;", "getPosDeprelOptimizer", "()Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;", "getPosTags", "getTransitionNetwork", "()Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;", "getTransitionOptimizer", "deprelOutcomeIndex", "", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition$Action;", "getDeprelOutcomeIndex", "(Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition$Action;)I", "outcomeIndex", "getOutcomeIndex", "(Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition;)I", "posTagOutcomeIndex", "getPosTagOutcomeIndex", "backward", "", "decodingContext", "Lcom/kotlinnlp/syntaxdecoder/utils/DecodingContext;", "supportStructure", "combineScores", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "actions", "", "transitionPrediction", "posPrediction", "deprelPrediction", "getFeaturesErrors", "getOutputErrors", "Lkotlin/Triple;", "transitionOutcomes", "posOutcomes", "deprelOutcomes", "newBatch", "newEpoch", "newExample", "score", "update", "neuralparser"})
/* loaded from: input_file:com/kotlinnlp/neuralparser/parsers/transitionbased/templates/actionsscorer/TPDJointEmbeddingsActionsScorer.class */
public abstract class TPDJointEmbeddingsActionsScorer<StateType extends State<StateType>, TransitionType extends Transition<TransitionType, StateType>, InputContextType extends TokensEncodingContext<InputContextType>> extends ActionsScorerTrainable<StateType, TransitionType, InputContextType, DenseItem, DenseFeaturesErrors, DenseFeatures, TPDJointSupportStructure> {

    @Nullable
    private final ActivationFunction activationFunction;

    @NotNull
    private final NeuralNetwork transitionNetwork;

    @NotNull
    private final MultiTaskNetworkModel posDeprelNetworkModel;

    @NotNull
    private final ParamsOptimizer<NetworkParameters> transitionOptimizer;

    @NotNull
    private final ParamsOptimizer<MultiTaskNetworkParameters> posDeprelOptimizer;

    @NotNull
    private final DictionarySet<Deprel> deprelTags;

    @NotNull
    private final DictionarySet<POSTag> posTags;

    protected abstract int getDeprelOutcomeIndex(@NotNull Transition<TransitionType, StateType>.Action action);

    protected abstract int getPosTagOutcomeIndex(@NotNull Transition<TransitionType, StateType>.Action action);

    protected abstract int getOutcomeIndex(@NotNull Transition<TransitionType, StateType> transition);

    /* JADX WARN: Code restructure failed: missing block: B:4:0x0063, code lost:
    
        if (r0 != null) goto L8;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public void score(@org.jetbrains.annotations.NotNull com.kotlinnlp.syntaxdecoder.utils.DecodingContext<StateType, TransitionType, InputContextType, com.kotlinnlp.neuralparser.parsers.transitionbased.utils.items.DenseItem, com.kotlinnlp.neuralparser.parsers.transitionbased.utils.features.DenseFeatures> r8, @org.jetbrains.annotations.NotNull com.kotlinnlp.neuralparser.parsers.transitionbased.templates.supportstructure.compositeprediction.TPDJointSupportStructure r9) {
        /*
            r7 = this;
            r0 = r8
            java.lang.String r1 = "decodingContext"
            kotlin.jvm.internal.Intrinsics.checkParameterIsNotNull(r0, r1)
            r0 = r9
            java.lang.String r1 = "supportStructure"
            kotlin.jvm.internal.Intrinsics.checkParameterIsNotNull(r0, r1)
            r0 = r8
            com.kotlinnlp.syntaxdecoder.modules.featuresextractor.features.Features r0 = r0.getFeatures()
            com.kotlinnlp.neuralparser.parsers.transitionbased.utils.features.DenseFeatures r0 = (com.kotlinnlp.neuralparser.parsers.transitionbased.utils.features.DenseFeatures) r0
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r0 = r0.getArray()
            r10 = r0
            r0 = r9
            com.kotlinnlp.simplednn.deeplearning.multitasknetwork.MultiTaskNetwork r0 = r0.getPosDeprelNetwork()
            r1 = r10
            com.kotlinnlp.simplednn.simplemath.ndarray.NDArray r1 = (com.kotlinnlp.simplednn.simplemath.ndarray.NDArray) r1
            java.util.List r0 = r0.forward(r1)
            r11 = r0
            r0 = r7
            r1 = r8
            java.util.List r1 = r1.getActions()
            r2 = r9
            com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor r2 = r2.getTransitionProcessor()
            r3 = r10
            com.kotlinnlp.simplednn.simplemath.ndarray.NDArray r3 = (com.kotlinnlp.simplednn.simplemath.ndarray.NDArray) r3
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r2 = r2.forward(r3)
            r3 = r11
            r4 = 0
            java.lang.Object r3 = r3.get(r4)
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r3 = (com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray) r3
            r4 = r11
            r5 = 1
            java.lang.Object r4 = r4.get(r5)
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r4 = (com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray) r4
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r0 = r0.combineScores(r1, r2, r3, r4)
            r12 = r0
            r0 = r7
            com.kotlinnlp.simplednn.core.functionalities.activations.ActivationFunction r0 = r0.activationFunction
            r1 = r0
            if (r1 == 0) goto L69
            r1 = r12
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r0 = r0.f(r1)
            r1 = r0
            if (r1 == 0) goto L69
            goto L6c
        L69:
            r0 = r12
        L6c:
            r13 = r0
            r0 = r8
            java.util.List r0 = r0.getActions()
            java.lang.Iterable r0 = (java.lang.Iterable) r0
            r14 = r0
            r0 = 0
            r15 = r0
            r0 = r14
            java.util.Iterator r0 = r0.iterator()
            r16 = r0
        L83:
            r0 = r16
            boolean r0 = r0.hasNext()
            if (r0 == 0) goto Lb7
            r0 = r16
            java.lang.Object r0 = r0.next()
            r17 = r0
            r0 = r15
            int r15 = r15 + 1
            r1 = r17
            com.kotlinnlp.syntaxdecoder.transitionsystem.Transition$Action r1 = (com.kotlinnlp.syntaxdecoder.transitionsystem.Transition.Action) r1
            r18 = r1
            r19 = r0
            r0 = r18
            r1 = r13
            r2 = r19
            java.lang.Double r1 = r1.get(r2)
            double r1 = r1.doubleValue()
            r0.setScore(r1)
            goto L83
        Lb7:
            return
        */
        throw new UnsupportedOperationException("Method not decompiled: com.kotlinnlp.neuralparser.parsers.transitionbased.templates.actionsscorer.TPDJointEmbeddingsActionsScorer.score(com.kotlinnlp.syntaxdecoder.utils.DecodingContext, com.kotlinnlp.neuralparser.parsers.transitionbased.templates.supportstructure.compositeprediction.TPDJointSupportStructure):void");
    }

    public void backward(@NotNull DecodingContext<StateType, TransitionType, InputContextType, DenseItem, DenseFeatures> decodingContext, @NotNull TPDJointSupportStructure tPDJointSupportStructure) {
        Intrinsics.checkParameterIsNotNull(decodingContext, "decodingContext");
        Intrinsics.checkParameterIsNotNull(tPDJointSupportStructure, "supportStructure");
        Triple<DenseNDArray, DenseNDArray, DenseNDArray> outputErrors = getOutputErrors(decodingContext.getActions(), tPDJointSupportStructure.getTransitionProcessor().getStructure().getOutputLayer().getOutputArray().getValues(), ((FeedforwardNeuralProcessor) tPDJointSupportStructure.getPosDeprelNetwork().getOutputProcessors().get(0)).getStructure().getOutputLayer().getOutputArray().getValues(), ((FeedforwardNeuralProcessor) tPDJointSupportStructure.getPosDeprelNetwork().getOutputProcessors().get(1)).getStructure().getOutputLayer().getOutputArray().getValues());
        DenseNDArray denseNDArray = (DenseNDArray) outputErrors.component1();
        DenseNDArray denseNDArray2 = (DenseNDArray) outputErrors.component2();
        DenseNDArray denseNDArray3 = (DenseNDArray) outputErrors.component3();
        tPDJointSupportStructure.getTransitionProcessor().backward(denseNDArray);
        tPDJointSupportStructure.getPosDeprelNetwork().backward(CollectionsKt.listOf(new DenseNDArray[]{denseNDArray2, denseNDArray3}));
        Optimizer.accumulate$default(this.transitionOptimizer, tPDJointSupportStructure.getTransitionProcessor().getParamsErrors(false), false, 2, (Object) null);
        Optimizer.accumulate$default(this.posDeprelOptimizer, tPDJointSupportStructure.getPosDeprelNetwork().getParamsErrors(false), false, 2, (Object) null);
    }

    @NotNull
    public DenseFeaturesErrors getFeaturesErrors(@NotNull DecodingContext<StateType, TransitionType, InputContextType, DenseItem, DenseFeatures> decodingContext, @NotNull TPDJointSupportStructure tPDJointSupportStructure) {
        Intrinsics.checkParameterIsNotNull(decodingContext, "decodingContext");
        Intrinsics.checkParameterIsNotNull(tPDJointSupportStructure, "supportStructure");
        return new DenseFeaturesErrors(tPDJointSupportStructure.getPosDeprelNetwork().getInputErrors(true).assignSum(tPDJointSupportStructure.getTransitionProcessor().getInputErrors(false)));
    }

    public void newBatch() {
        this.transitionOptimizer.newBatch();
        this.posDeprelOptimizer.newBatch();
    }

    public void newExample() {
        this.transitionOptimizer.newExample();
        this.posDeprelOptimizer.newExample();
    }

    public void newEpoch() {
        this.transitionOptimizer.newEpoch();
        this.posDeprelOptimizer.newEpoch();
    }

    public void update() {
        this.transitionOptimizer.update();
        this.posDeprelOptimizer.update();
    }

    @NotNull
    protected DenseNDArray combineScores(@NotNull List<? extends Transition<TransitionType, StateType>.Action> list, @NotNull DenseNDArray denseNDArray, @NotNull DenseNDArray denseNDArray2, @NotNull DenseNDArray denseNDArray3) {
        Intrinsics.checkParameterIsNotNull(list, "actions");
        Intrinsics.checkParameterIsNotNull(denseNDArray, "transitionPrediction");
        Intrinsics.checkParameterIsNotNull(denseNDArray2, "posPrediction");
        Intrinsics.checkParameterIsNotNull(denseNDArray3, "deprelPrediction");
        DenseNDArrayFactory denseNDArrayFactory = DenseNDArrayFactory.INSTANCE;
        double[] dArr = new double[list.size()];
        int length = dArr.length;
        for (int i = 0; i < length; i++) {
            Transition<TransitionType, StateType>.Action action = list.get(i);
            dArr[i] = denseNDArray.get(getOutcomeIndex(action.getTransition())).doubleValue() + denseNDArray2.get(getPosTagOutcomeIndex(action)).doubleValue() + denseNDArray3.get(getDeprelOutcomeIndex(action)).doubleValue();
        }
        return denseNDArrayFactory.arrayOf(dArr);
    }

    @NotNull
    protected Triple<DenseNDArray, DenseNDArray, DenseNDArray> getOutputErrors(@NotNull List<? extends Transition<TransitionType, StateType>.Action> list, @NotNull DenseNDArray denseNDArray, @NotNull DenseNDArray denseNDArray2, @NotNull DenseNDArray denseNDArray3) {
        Intrinsics.checkParameterIsNotNull(list, "actions");
        Intrinsics.checkParameterIsNotNull(denseNDArray, "transitionOutcomes");
        Intrinsics.checkParameterIsNotNull(denseNDArray2, "posOutcomes");
        Intrinsics.checkParameterIsNotNull(denseNDArray3, "deprelOutcomes");
        DenseNDArray zeros = DenseNDArrayFactory.INSTANCE.zeros(denseNDArray.getShape());
        DenseNDArray zeros2 = DenseNDArrayFactory.INSTANCE.zeros(denseNDArray2.getShape());
        DenseNDArray zeros3 = DenseNDArrayFactory.INSTANCE.zeros(denseNDArray3.getShape());
        for (Transition<TransitionType, StateType>.Action action : list) {
            int outcomeIndex = getOutcomeIndex(action.getTransition());
            zeros.set(outcomeIndex, Double.valueOf(zeros.get(outcomeIndex).doubleValue() + action.getError()));
            int posTagOutcomeIndex = getPosTagOutcomeIndex(action);
            zeros2.set(posTagOutcomeIndex, Double.valueOf(zeros2.get(posTagOutcomeIndex).doubleValue() + action.getError()));
            int deprelOutcomeIndex = getDeprelOutcomeIndex(action);
            zeros3.set(deprelOutcomeIndex, Double.valueOf(zeros3.get(deprelOutcomeIndex).doubleValue() + action.getError()));
        }
        return new Triple<>(zeros, zeros2, zeros3);
    }

    @Nullable
    public final ActivationFunction getActivationFunction() {
        return this.activationFunction;
    }

    @NotNull
    public final NeuralNetwork getTransitionNetwork() {
        return this.transitionNetwork;
    }

    @NotNull
    public final MultiTaskNetworkModel getPosDeprelNetworkModel() {
        return this.posDeprelNetworkModel;
    }

    @NotNull
    public final ParamsOptimizer<NetworkParameters> getTransitionOptimizer() {
        return this.transitionOptimizer;
    }

    @NotNull
    public final ParamsOptimizer<MultiTaskNetworkParameters> getPosDeprelOptimizer() {
        return this.posDeprelOptimizer;
    }

    @NotNull
    public final DictionarySet<Deprel> getDeprelTags() {
        return this.deprelTags;
    }

    @NotNull
    public final DictionarySet<POSTag> getPosTags() {
        return this.posTags;
    }

    public TPDJointEmbeddingsActionsScorer(@Nullable ActivationFunction activationFunction, @NotNull NeuralNetwork neuralNetwork, @NotNull MultiTaskNetworkModel multiTaskNetworkModel, @NotNull ParamsOptimizer<NetworkParameters> paramsOptimizer, @NotNull ParamsOptimizer<MultiTaskNetworkParameters> paramsOptimizer2, @NotNull DictionarySet<Deprel> dictionarySet, @NotNull DictionarySet<POSTag> dictionarySet2) {
        Intrinsics.checkParameterIsNotNull(neuralNetwork, "transitionNetwork");
        Intrinsics.checkParameterIsNotNull(multiTaskNetworkModel, "posDeprelNetworkModel");
        Intrinsics.checkParameterIsNotNull(paramsOptimizer, "transitionOptimizer");
        Intrinsics.checkParameterIsNotNull(paramsOptimizer2, "posDeprelOptimizer");
        Intrinsics.checkParameterIsNotNull(dictionarySet, "deprelTags");
        Intrinsics.checkParameterIsNotNull(dictionarySet2, "posTags");
        this.activationFunction = activationFunction;
        this.transitionNetwork = neuralNetwork;
        this.posDeprelNetworkModel = multiTaskNetworkModel;
        this.transitionOptimizer = paramsOptimizer;
        this.posDeprelOptimizer = paramsOptimizer2;
        this.deprelTags = dictionarySet;
        this.posTags = dictionarySet2;
    }
}
