package com.kotlinnlp.simplednn.deeplearning.sequencelabeling;

import com.kotlinnlp.simplednn.core.embeddings.EmbeddingsOptimizer;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdateMethod;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.adagrad.AdaGradMethod;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.adam.ADAMMethod;
import com.kotlinnlp.simplednn.core.neuralnetwork.NetworkParameters;
import com.kotlinnlp.simplednn.core.optimizer.Optimizer;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.simplednn.core.optimizer.ScheduledUpdater;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import kotlin.Metadata;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: SWSLOptimizer.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��<\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0010\b\n��\n\u0002\u0010\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0006\u0018��2\u00020\u0001B)\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\f\b\u0002\u0010\u0004\u001a\u0006\u0012\u0002\b\u00030\u0005\u0012\f\b\u0002\u0010\u0006\u001a\u0006\u0012\u0002\b\u00030\u0005¢\u0006\u0002\u0010\u0007J\u0016\u0010\u000e\u001a\u00020\u000f2\u0006\u0010\u0010\u001a\u00020\r2\u0006\u0010\u0011\u001a\u00020\u0012J\u000e\u0010\u0013\u001a\u00020\u000f2\u0006\u0010\u0011\u001a\u00020\nJ\b\u0010\u0014\u001a\u00020\u000fH\u0016J\b\u0010\u0015\u001a\u00020\u000fH\u0016J\b\u0010\u0016\u001a\u00020\u000fH\u0016J\b\u0010\u0017\u001a\u00020\u000fH\u0016R\u0014\u0010\b\u001a\b\u0012\u0004\u0012\u00020\n0\tX\u0082\u0004¢\u0006\u0002\n��R\u0012\u0010\u0006\u001a\u0006\u0012\u0002\b\u00030\u0005X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u000b\u001a\b\u0012\u0004\u0012\u00020\r0\fX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u0012\u0010\u0004\u001a\u0006\u0012\u0002\b\u00030\u0005X\u0082\u0004¢\u0006\u0002\n��¨\u0006\u0018"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/sequencelabeling/SWSLOptimizer;", "Lcom/kotlinnlp/simplednn/core/optimizer/ScheduledUpdater;", "network", "Lcom/kotlinnlp/simplednn/deeplearning/sequencelabeling/SWSLNetwork;", "paramsUpdateMethod", "Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;", "embeddingsUpdateMethod", "(Lcom/kotlinnlp/simplednn/deeplearning/sequencelabeling/SWSLNetwork;Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;)V", "classifierOptimizer", "Lcom/kotlinnlp/simplednn/core/optimizer/ParamsOptimizer;", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NetworkParameters;", "labelEmbeddingsOptimizer", "Lcom/kotlinnlp/simplednn/core/embeddings/EmbeddingsOptimizer;", "", "accumulateLabelEmbeddingErrors", "", "embeddingId", "errors", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "accumulateParamsErrors", "newBatch", "newEpoch", "newExample", "update", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/sequencelabeling/SWSLOptimizer.class */
public final class SWSLOptimizer implements ScheduledUpdater {
    private final ParamsOptimizer<NetworkParameters> classifierOptimizer;
    private final EmbeddingsOptimizer<Integer> labelEmbeddingsOptimizer;
    private final SWSLNetwork network;
    private final UpdateMethod<?> paramsUpdateMethod;
    private final UpdateMethod<?> embeddingsUpdateMethod;

    @Override // com.kotlinnlp.simplednn.core.optimizer.ScheduledUpdater
    public void update() {
        this.classifierOptimizer.update();
        this.labelEmbeddingsOptimizer.update();
    }

    @Override // com.kotlinnlp.simplednn.core.optimizer.ScheduledUpdater, com.kotlinnlp.simplednn.utils.scheduling.EpochScheduling
    public void newEpoch() {
        this.classifierOptimizer.newEpoch();
        this.labelEmbeddingsOptimizer.newEpoch();
    }

    @Override // com.kotlinnlp.simplednn.core.optimizer.ScheduledUpdater, com.kotlinnlp.simplednn.utils.scheduling.BatchScheduling
    public void newBatch() {
        this.classifierOptimizer.newBatch();
        this.labelEmbeddingsOptimizer.newBatch();
    }

    @Override // com.kotlinnlp.simplednn.core.optimizer.ScheduledUpdater, com.kotlinnlp.simplednn.utils.scheduling.ExampleScheduling
    public void newExample() {
        this.classifierOptimizer.newExample();
        this.labelEmbeddingsOptimizer.newExample();
    }

    public final void accumulateParamsErrors(@NotNull NetworkParameters errors) {
        Intrinsics.checkParameterIsNotNull(errors, "errors");
        Optimizer.accumulate$default(this.classifierOptimizer, errors, false, 2, null);
    }

    public final void accumulateLabelEmbeddingErrors(int i, @NotNull DenseNDArray errors) {
        Intrinsics.checkParameterIsNotNull(errors, "errors");
        this.labelEmbeddingsOptimizer.accumulate((EmbeddingsOptimizer<Integer>) Integer.valueOf(i), errors);
    }

    public SWSLOptimizer(@NotNull SWSLNetwork network, @NotNull UpdateMethod<?> paramsUpdateMethod, @NotNull UpdateMethod<?> embeddingsUpdateMethod) {
        Intrinsics.checkParameterIsNotNull(network, "network");
        Intrinsics.checkParameterIsNotNull(paramsUpdateMethod, "paramsUpdateMethod");
        Intrinsics.checkParameterIsNotNull(embeddingsUpdateMethod, "embeddingsUpdateMethod");
        this.network = network;
        this.paramsUpdateMethod = paramsUpdateMethod;
        this.embeddingsUpdateMethod = embeddingsUpdateMethod;
        this.classifierOptimizer = new ParamsOptimizer<>(this.network.getClassifier().getModel(), this.paramsUpdateMethod);
        this.labelEmbeddingsOptimizer = new EmbeddingsOptimizer<>(this.network.getLabelsEmbeddings(), this.embeddingsUpdateMethod);
    }

    public /* synthetic */ SWSLOptimizer(SWSLNetwork sWSLNetwork, UpdateMethod updateMethod, UpdateMethod updateMethod2, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this(sWSLNetwork, (i & 2) != 0 ? new ADAMMethod(0.001d, 0.0d, 0.0d, 0.0d, null, 30, null) : updateMethod, (i & 4) != 0 ? new AdaGradMethod(0.1d, 0.0d, null, 6, null) : updateMethod2);
    }
}
