package com.kotlinnlp.neuralparser.templates.actionsscorer;

import com.kotlinnlp.dependencytree.Deprel;
import com.kotlinnlp.dependencytree.POSTag;
import com.kotlinnlp.neuralparser.templates.inputcontexts.TokensEncodingContext;
import com.kotlinnlp.neuralparser.templates.supportstructure.compositeprediction.TPDSupportStructure;
import com.kotlinnlp.neuralparser.utils.features.DenseFeatures;
import com.kotlinnlp.neuralparser.utils.features.DenseFeaturesErrors;
import com.kotlinnlp.neuralparser.utils.items.DenseItem;
import com.kotlinnlp.simplednn.core.functionalities.activations.ActivationFunction;
import com.kotlinnlp.simplednn.core.neuralnetwork.NetworkParameters;
import com.kotlinnlp.simplednn.core.neuralnetwork.NeuralNetwork;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import com.kotlinnlp.simplednn.utils.DictionarySet;
import com.kotlinnlp.syntaxdecoder.modules.actionsscorer.ActionsScorerTrainable;
import com.kotlinnlp.syntaxdecoder.transitionsystem.Transition;
import com.kotlinnlp.syntaxdecoder.transitionsystem.state.State;
import com.kotlinnlp.syntaxdecoder.utils.DecodingContext;
import java.util.List;
import kotlin.Metadata;
import kotlin.Triple;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: TPDEmbeddingsActionsScorer.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��\u008e\u0001\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u000f\n\u0002\u0010\b\n\u0002\u0018\u0002\n\u0002\b\b\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000b\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\t\b&\u0018��*\u000e\b��\u0010\u0001*\b\u0012\u0004\u0012\u0002H\u00010\u0002*\u0014\b\u0001\u0010\u0003*\u000e\u0012\u0004\u0012\u0002H\u0003\u0012\u0004\u0012\u0002H\u00010\u0004*\u000e\b\u0002\u0010\u0005*\b\u0012\u0004\u0012\u0002H\u00050\u00062,\u0012\u0004\u0012\u0002H\u0001\u0012\u0004\u0012\u0002H\u0003\u0012\u0004\u0012\u0002H\u0005\u0012\u0004\u0012\u00020\b\u0012\u0004\u0012\u00020\t\u0012\u0004\u0012\u00020\n\u0012\u0004\u0012\u00020\u000b0\u0007Bm\u0012\b\u0010\f\u001a\u0004\u0018\u00010\r\u0012\u0006\u0010\u000e\u001a\u00020\u000f\u0012\u0006\u0010\u0010\u001a\u00020\u000f\u0012\u0006\u0010\u0011\u001a\u00020\u000f\u0012\f\u0010\u0012\u001a\b\u0012\u0004\u0012\u00020\u00140\u0013\u0012\f\u0010\u0015\u001a\b\u0012\u0004\u0012\u00020\u00140\u0013\u0012\f\u0010\u0016\u001a\b\u0012\u0004\u0012\u00020\u00140\u0013\u0012\f\u0010\u0017\u001a\b\u0012\u0004\u0012\u00020\u00190\u0018\u0012\f\u0010\u001a\u001a\b\u0012\u0004\u0012\u00020\u001b0\u0018¢\u0006\u0002\u0010\u001cJ>\u00104\u001a\u0002052$\u00106\u001a \u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00020\b\u0012\u0004\u0012\u00020\n072\u0006\u00108\u001a\u00020\u000b2\u0006\u00109\u001a\u00020:H\u0016J>\u0010;\u001a\u00020<2\u001c\u0010=\u001a\u0018\u0012\u0014\u0012\u00120,R\u000e\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028��0\u00040>2\u0006\u0010?\u001a\u00020<2\u0006\u0010@\u001a\u00020<2\u0006\u0010A\u001a\u00020<H\u0002J6\u0010B\u001a\u00020\t2$\u00106\u001a \u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00020\b\u0012\u0004\u0012\u00020\n072\u0006\u00108\u001a\u00020\u000bH\u0016JP\u0010C\u001a\u0014\u0012\u0004\u0012\u00020<\u0012\u0004\u0012\u00020<\u0012\u0004\u0012\u00020<0D2\u001c\u0010=\u001a\u0018\u0012\u0014\u0012\u00120,R\u000e\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028��0\u00040>2\u0006\u0010E\u001a\u00020<2\u0006\u0010F\u001a\u00020<2\u0006\u0010G\u001a\u00020<H\u0002J\b\u0010H\u001a\u000205H\u0016J\b\u0010I\u001a\u000205H\u0016J\b\u0010J\u001a\u000205H\u0016J6\u0010K\u001a\u0002052$\u00106\u001a \u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00020\b\u0012\u0004\u0012\u00020\n072\u0006\u00108\u001a\u00020\u000bH\u0016J\b\u0010L\u001a\u000205H\u0016R\u0013\u0010\f\u001a\u0004\u0018\u00010\r¢\u0006\b\n��\u001a\u0004\b\u001d\u0010\u001eR\u0011\u0010\u0011\u001a\u00020\u000f¢\u0006\b\n��\u001a\u0004\b\u001f\u0010 R\u0017\u0010\u0016\u001a\b\u0012\u0004\u0012\u00020\u00140\u0013¢\u0006\b\n��\u001a\u0004\b!\u0010\"R\u0017\u0010\u001a\u001a\b\u0012\u0004\u0012\u00020\u001b0\u0018¢\u0006\b\n��\u001a\u0004\b#\u0010$R\u0011\u0010\u0010\u001a\u00020\u000f¢\u0006\b\n��\u001a\u0004\b%\u0010 R\u0017\u0010\u0015\u001a\b\u0012\u0004\u0012\u00020\u00140\u0013¢\u0006\b\n��\u001a\u0004\b&\u0010\"R\u0017\u0010\u0017\u001a\b\u0012\u0004\u0012\u00020\u00190\u0018¢\u0006\b\n��\u001a\u0004\b'\u0010$R\u0011\u0010\u000e\u001a\u00020\u000f¢\u0006\b\n��\u001a\u0004\b(\u0010 R\u0017\u0010\u0012\u001a\b\u0012\u0004\u0012\u00020\u00140\u0013¢\u0006\b\n��\u001a\u0004\b)\u0010\"R&\u0010*\u001a\u00020+*\u00120,R\u000e\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028��0\u0004X¤\u0004¢\u0006\u0006\u001a\u0004\b-\u0010.R\"\u0010/\u001a\u00020+*\u000e\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028��0\u0004X¤\u0004¢\u0006\u0006\u001a\u0004\b0\u00101R&\u00102\u001a\u00020+*\u00120,R\u000e\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028��0\u0004X¤\u0004¢\u0006\u0006\u001a\u0004\b3\u0010.¨\u0006M"}, d2 = {"Lcom/kotlinnlp/neuralparser/templates/actionsscorer/TPDEmbeddingsActionsScorer;", "StateType", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/state/State;", "TransitionType", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition;", "InputContextType", "Lcom/kotlinnlp/neuralparser/templates/inputcontexts/TokensEncodingContext;", "Lcom/kotlinnlp/syntaxdecoder/modules/actionsscorer/ActionsScorerTrainable;", "Lcom/kotlinnlp/neuralparser/utils/items/DenseItem;", "Lcom/kotlinnlp/neuralparser/utils/features/DenseFeaturesErrors;", "Lcom/kotlinnlp/neuralparser/utils/features/DenseFeatures;", "Lcom/kotlinnlp/neuralparser/templates/supportstructure/compositeprediction/TPDSupportStructure;", "activationFunction", "Lcom/kotlinnlp/simplednn/core/functionalities/activations/ActivationFunction;", "transitionNetwork", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;", "posNetwork", "deprelNetwork", "transitionOptimizer", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;", "posOptimizer", "deprelOptimizer", "posTags", "Lcom/kotlinnlp/simplednn/utils/DictionarySet;", "Lcom/kotlinnlp/dependencytree/POSTag;", "deprelTags", "Lcom/kotlinnlp/dependencytree/Deprel;", "(Lcom/kotlinnlp/simplednn/core/functionalities/activations/ActivationFunction;Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;Lcom/kotlinnlp/simplednn/utils/DictionarySet;Lcom/kotlinnlp/simplednn/utils/DictionarySet;)V", "getActivationFunction", "()Lcom/kotlinnlp/simplednn/core/functionalities/activations/ActivationFunction;", "getDeprelNetwork", "()Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;", "getDeprelOptimizer", "()Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;", "getDeprelTags", "()Lcom/kotlinnlp/simplednn/utils/DictionarySet;", "getPosNetwork", "getPosOptimizer", "getPosTags", "getTransitionNetwork", "getTransitionOptimizer", "deprelOutcomeIndex", "", "Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition$Action;", "getDeprelOutcomeIndex", "(Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition$Action;)I", "outcomeIndex", "getOutcomeIndex", "(Lcom/kotlinnlp/syntaxdecoder/transitionsystem/Transition;)I", "posTagOutcomeIndex", "getPosTagOutcomeIndex", "backward", "", "decodingContext", "Lcom/kotlinnlp/syntaxdecoder/utils/DecodingContext;", "supportStructure", "propagateToInput", "", "combineScores", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "actions", "", "transitionPrediction", "posPrediction", "deprelPrediction", "getFeaturesErrors", "getOutputErrors", "Lkotlin/Triple;", "transitionOutcomes", "posOutcomes", "deprelOutcomes", "newBatch", "newEpoch", "newExample", "score", "update", "neuralparser"})
/* loaded from: input_file:com/kotlinnlp/neuralparser/templates/actionsscorer/TPDEmbeddingsActionsScorer.class */
public abstract class TPDEmbeddingsActionsScorer<StateType extends State<StateType>, TransitionType extends Transition<TransitionType, StateType>, InputContextType extends TokensEncodingContext<InputContextType>> extends ActionsScorerTrainable<StateType, TransitionType, InputContextType, DenseItem, DenseFeaturesErrors, DenseFeatures, TPDSupportStructure> {

    @Nullable
    private final ActivationFunction activationFunction;

    @NotNull
    private final NeuralNetwork transitionNetwork;

    @NotNull
    private final NeuralNetwork posNetwork;

    @NotNull
    private final NeuralNetwork deprelNetwork;

    @NotNull
    private final ParamsOptimizer<NetworkParameters> transitionOptimizer;

    @NotNull
    private final ParamsOptimizer<NetworkParameters> posOptimizer;

    @NotNull
    private final ParamsOptimizer<NetworkParameters> deprelOptimizer;

    @NotNull
    private final DictionarySet<POSTag> posTags;

    @NotNull
    private final DictionarySet<Deprel> deprelTags;

    protected abstract int getDeprelOutcomeIndex(@NotNull Transition<TransitionType, StateType>.Action action);

    protected abstract int getPosTagOutcomeIndex(@NotNull Transition<TransitionType, StateType>.Action action);

    protected abstract int getOutcomeIndex(@NotNull Transition<TransitionType, StateType> transition);

    /* JADX WARN: Code restructure failed: missing block: B:4:0x0069, code lost:
    
        if (r0 != null) goto L8;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public void score(@org.jetbrains.annotations.NotNull com.kotlinnlp.syntaxdecoder.utils.DecodingContext<StateType, TransitionType, InputContextType, com.kotlinnlp.neuralparser.utils.items.DenseItem, com.kotlinnlp.neuralparser.utils.features.DenseFeatures> r11, @org.jetbrains.annotations.NotNull com.kotlinnlp.neuralparser.templates.supportstructure.compositeprediction.TPDSupportStructure r12) {
        /*
            r10 = this;
            r0 = r11
            java.lang.String r1 = "decodingContext"
            kotlin.jvm.internal.Intrinsics.checkParameterIsNotNull(r0, r1)
            r0 = r12
            java.lang.String r1 = "supportStructure"
            kotlin.jvm.internal.Intrinsics.checkParameterIsNotNull(r0, r1)
            r0 = r10
            r1 = r11
            java.util.List r1 = r1.getActions()
            r2 = r12
            com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor r2 = r2.getTransitionProcessor()
            r3 = r11
            com.kotlinnlp.syntaxdecoder.modules.featuresextractor.features.Features r3 = r3.getFeatures()
            com.kotlinnlp.neuralparser.utils.features.DenseFeatures r3 = (com.kotlinnlp.neuralparser.utils.features.DenseFeatures) r3
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r3 = r3.getArray()
            com.kotlinnlp.simplednn.simplemath.ndarray.NDArray r3 = (com.kotlinnlp.simplednn.simplemath.ndarray.NDArray) r3
            r4 = 0
            r5 = 2
            r6 = 0
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r2 = com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor.forward$default(r2, r3, r4, r5, r6)
            r3 = r12
            com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor r3 = r3.getPosProcessor()
            r4 = r11
            com.kotlinnlp.syntaxdecoder.modules.featuresextractor.features.Features r4 = r4.getFeatures()
            com.kotlinnlp.neuralparser.utils.features.DenseFeatures r4 = (com.kotlinnlp.neuralparser.utils.features.DenseFeatures) r4
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r4 = r4.getArray()
            com.kotlinnlp.simplednn.simplemath.ndarray.NDArray r4 = (com.kotlinnlp.simplednn.simplemath.ndarray.NDArray) r4
            r5 = 0
            r6 = 2
            r7 = 0
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r3 = com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor.forward$default(r3, r4, r5, r6, r7)
            r4 = r12
            com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor r4 = r4.getDeprelProcessor()
            r5 = r11
            com.kotlinnlp.syntaxdecoder.modules.featuresextractor.features.Features r5 = r5.getFeatures()
            com.kotlinnlp.neuralparser.utils.features.DenseFeatures r5 = (com.kotlinnlp.neuralparser.utils.features.DenseFeatures) r5
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r5 = r5.getArray()
            com.kotlinnlp.simplednn.simplemath.ndarray.NDArray r5 = (com.kotlinnlp.simplednn.simplemath.ndarray.NDArray) r5
            r6 = 0
            r7 = 2
            r8 = 0
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r4 = com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor.forward$default(r4, r5, r6, r7, r8)
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r0 = r0.combineScores(r1, r2, r3, r4)
            r13 = r0
            r0 = r10
            com.kotlinnlp.simplednn.core.functionalities.activations.ActivationFunction r0 = r0.activationFunction
            r1 = r0
            if (r1 == 0) goto L6f
            r1 = r13
            com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray r0 = r0.f(r1)
            r1 = r0
            if (r1 == 0) goto L6f
            goto L71
        L6f:
            r0 = r13
        L71:
            r14 = r0
            r0 = r11
            java.util.List r0 = r0.getActions()
            java.lang.Iterable r0 = (java.lang.Iterable) r0
            r15 = r0
            r0 = 0
            r16 = r0
            r0 = r15
            java.util.Iterator r0 = r0.iterator()
            r17 = r0
        L88:
            r0 = r17
            boolean r0 = r0.hasNext()
            if (r0 == 0) goto Lbc
            r0 = r17
            java.lang.Object r0 = r0.next()
            r18 = r0
            r0 = r16
            int r16 = r16 + 1
            r1 = r18
            com.kotlinnlp.syntaxdecoder.transitionsystem.Transition$Action r1 = (com.kotlinnlp.syntaxdecoder.transitionsystem.Transition.Action) r1
            r19 = r1
            r20 = r0
            r0 = r19
            r1 = r14
            r2 = r20
            java.lang.Double r1 = r1.get(r2)
            double r1 = r1.doubleValue()
            r0.setScore(r1)
            goto L88
        Lbc:
            return
        */
        throw new UnsupportedOperationException("Method not decompiled: com.kotlinnlp.neuralparser.templates.actionsscorer.TPDEmbeddingsActionsScorer.score(com.kotlinnlp.syntaxdecoder.utils.DecodingContext, com.kotlinnlp.neuralparser.templates.supportstructure.compositeprediction.TPDSupportStructure):void");
    }

    public void backward(@NotNull DecodingContext<StateType, TransitionType, InputContextType, DenseItem, DenseFeatures> decodingContext, @NotNull TPDSupportStructure tPDSupportStructure, boolean z) {
        Intrinsics.checkParameterIsNotNull(decodingContext, "decodingContext");
        Intrinsics.checkParameterIsNotNull(tPDSupportStructure, "supportStructure");
        Triple<DenseNDArray, DenseNDArray, DenseNDArray> outputErrors = getOutputErrors(decodingContext.getActions(), tPDSupportStructure.getTransitionProcessor().getStructure().getOutputLayer().getOutputArray().getValues(), tPDSupportStructure.getPosProcessor().getStructure().getOutputLayer().getOutputArray().getValues(), tPDSupportStructure.getDeprelProcessor().getStructure().getOutputLayer().getOutputArray().getValues());
        DenseNDArray denseNDArray = (DenseNDArray) outputErrors.component1();
        DenseNDArray denseNDArray2 = (DenseNDArray) outputErrors.component2();
        DenseNDArray denseNDArray3 = (DenseNDArray) outputErrors.component3();
        tPDSupportStructure.getTransitionProcessor().backward(denseNDArray, z, (List) null);
        tPDSupportStructure.getPosProcessor().backward(denseNDArray2, z, (List) null);
        tPDSupportStructure.getDeprelProcessor().backward(denseNDArray3, z, (List) null);
        ParamsOptimizer.accumulate$default(this.deprelOptimizer, tPDSupportStructure.getDeprelProcessor().getParamsErrors(false), false, 2, (Object) null);
        ParamsOptimizer.accumulate$default(this.transitionOptimizer, tPDSupportStructure.getTransitionProcessor().getParamsErrors(false), false, 2, (Object) null);
        ParamsOptimizer.accumulate$default(this.posOptimizer, tPDSupportStructure.getPosProcessor().getParamsErrors(false), false, 2, (Object) null);
    }

    @NotNull
    public DenseFeaturesErrors getFeaturesErrors(@NotNull DecodingContext<StateType, TransitionType, InputContextType, DenseItem, DenseFeatures> decodingContext, @NotNull TPDSupportStructure tPDSupportStructure) {
        Intrinsics.checkParameterIsNotNull(decodingContext, "decodingContext");
        Intrinsics.checkParameterIsNotNull(tPDSupportStructure, "supportStructure");
        return new DenseFeaturesErrors(tPDSupportStructure.getTransitionProcessor().getInputErrors(true).assignSum(tPDSupportStructure.getPosProcessor().getInputErrors(false)).assignSum(tPDSupportStructure.getDeprelProcessor().getInputErrors(false)));
    }

    public void newBatch() {
        this.transitionOptimizer.newBatch();
        this.deprelOptimizer.newBatch();
        this.posOptimizer.newBatch();
    }

    public void newExample() {
        this.transitionOptimizer.newExample();
        this.deprelOptimizer.newExample();
        this.posOptimizer.newExample();
    }

    public void newEpoch() {
        this.transitionOptimizer.newEpoch();
        this.deprelOptimizer.newEpoch();
        this.posOptimizer.newEpoch();
    }

    public void update() {
        this.transitionOptimizer.update();
        this.deprelOptimizer.update();
        this.posOptimizer.update();
    }

    private final DenseNDArray combineScores(List<? extends Transition<TransitionType, StateType>.Action> list, DenseNDArray denseNDArray, DenseNDArray denseNDArray2, DenseNDArray denseNDArray3) {
        DenseNDArrayFactory denseNDArrayFactory = DenseNDArrayFactory.INSTANCE;
        double[] dArr = new double[list.size()];
        int length = dArr.length;
        for (int i = 0; i < length; i++) {
            Transition<TransitionType, StateType>.Action action = list.get(i);
            dArr[i] = denseNDArray.get(getOutcomeIndex(action.getTransition())).doubleValue() + denseNDArray2.get(getPosTagOutcomeIndex(action)).doubleValue() + denseNDArray3.get(getDeprelOutcomeIndex(action)).doubleValue();
        }
        return denseNDArrayFactory.arrayOf(dArr);
    }

    private final Triple<DenseNDArray, DenseNDArray, DenseNDArray> getOutputErrors(List<? extends Transition<TransitionType, StateType>.Action> list, DenseNDArray denseNDArray, DenseNDArray denseNDArray2, DenseNDArray denseNDArray3) {
        DenseNDArray zeros = DenseNDArrayFactory.INSTANCE.zeros(denseNDArray.getShape());
        DenseNDArray zeros2 = DenseNDArrayFactory.INSTANCE.zeros(denseNDArray2.getShape());
        DenseNDArray zeros3 = DenseNDArrayFactory.INSTANCE.zeros(denseNDArray3.getShape());
        for (Transition<TransitionType, StateType>.Action action : list) {
            int outcomeIndex = getOutcomeIndex(action.getTransition());
            zeros.set(outcomeIndex, Double.valueOf(zeros.get(outcomeIndex).doubleValue() + action.getError()));
            int posTagOutcomeIndex = getPosTagOutcomeIndex(action);
            zeros2.set(posTagOutcomeIndex, Double.valueOf(zeros2.get(posTagOutcomeIndex).doubleValue() + action.getError()));
            int deprelOutcomeIndex = getDeprelOutcomeIndex(action);
            zeros3.set(deprelOutcomeIndex, Double.valueOf(zeros3.get(deprelOutcomeIndex).doubleValue() + action.getError()));
        }
        return new Triple<>(zeros, zeros2, zeros3);
    }

    @Nullable
    public final ActivationFunction getActivationFunction() {
        return this.activationFunction;
    }

    @NotNull
    public final NeuralNetwork getTransitionNetwork() {
        return this.transitionNetwork;
    }

    @NotNull
    public final NeuralNetwork getPosNetwork() {
        return this.posNetwork;
    }

    @NotNull
    public final NeuralNetwork getDeprelNetwork() {
        return this.deprelNetwork;
    }

    @NotNull
    public final ParamsOptimizer<NetworkParameters> getTransitionOptimizer() {
        return this.transitionOptimizer;
    }

    @NotNull
    public final ParamsOptimizer<NetworkParameters> getPosOptimizer() {
        return this.posOptimizer;
    }

    @NotNull
    public final ParamsOptimizer<NetworkParameters> getDeprelOptimizer() {
        return this.deprelOptimizer;
    }

    @NotNull
    public final DictionarySet<POSTag> getPosTags() {
        return this.posTags;
    }

    @NotNull
    public final DictionarySet<Deprel> getDeprelTags() {
        return this.deprelTags;
    }

    public TPDEmbeddingsActionsScorer(@Nullable ActivationFunction activationFunction, @NotNull NeuralNetwork neuralNetwork, @NotNull NeuralNetwork neuralNetwork2, @NotNull NeuralNetwork neuralNetwork3, @NotNull ParamsOptimizer<NetworkParameters> paramsOptimizer, @NotNull ParamsOptimizer<NetworkParameters> paramsOptimizer2, @NotNull ParamsOptimizer<NetworkParameters> paramsOptimizer3, @NotNull DictionarySet<POSTag> dictionarySet, @NotNull DictionarySet<Deprel> dictionarySet2) {
        Intrinsics.checkParameterIsNotNull(neuralNetwork, "transitionNetwork");
        Intrinsics.checkParameterIsNotNull(neuralNetwork2, "posNetwork");
        Intrinsics.checkParameterIsNotNull(neuralNetwork3, "deprelNetwork");
        Intrinsics.checkParameterIsNotNull(paramsOptimizer, "transitionOptimizer");
        Intrinsics.checkParameterIsNotNull(paramsOptimizer2, "posOptimizer");
        Intrinsics.checkParameterIsNotNull(paramsOptimizer3, "deprelOptimizer");
        Intrinsics.checkParameterIsNotNull(dictionarySet, "posTags");
        Intrinsics.checkParameterIsNotNull(dictionarySet2, "deprelTags");
        this.activationFunction = activationFunction;
        this.transitionNetwork = neuralNetwork;
        this.posNetwork = neuralNetwork2;
        this.deprelNetwork = neuralNetwork3;
        this.transitionOptimizer = paramsOptimizer;
        this.posOptimizer = paramsOptimizer2;
        this.deprelOptimizer = paramsOptimizer3;
        this.posTags = dictionarySet;
        this.deprelTags = dictionarySet2;
    }
}
