package com.kotlinnlp.simplednn.core.functionalities.updatemethods.adam;

import com.kotlinnlp.simplednn.core.arrays.UpdatableDenseArray;
import com.kotlinnlp.simplednn.core.functionalities.regularization.WeightsRegularization;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdateMethod;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdaterSupportStructure;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArrayMask;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.sparse.SparseNDArray;
import com.kotlinnlp.simplednn.utils.scheduling.ExampleScheduling;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.UninitializedPropertyAccessException;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.Reflection;
import kotlin.reflect.KFunction;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: ADAMMethod.kt */
@Metadata(mv = {1, 1, 10}, bv = {1, 0, 2}, k = 1, d1 = {"��>\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0006\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u000e\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0002\u0018��2\u00020\u00012\b\u0012\u0004\u0012\u00020\u00030\u0002B9\u0012\b\b\u0002\u0010\u0004\u001a\u00020\u0005\u0012\b\b\u0002\u0010\u0006\u001a\u00020\u0005\u0012\b\b\u0002\u0010\u0007\u001a\u00020\u0005\u0012\b\b\u0002\u0010\b\u001a\u00020\u0005\u0012\n\b\u0002\u0010\t\u001a\u0004\u0018\u00010\n¢\u0006\u0002\u0010\u000bJ\u0010\u0010\u0017\u001a\u00020\u00032\u0006\u0010\u0018\u001a\u00020\u0019H\u0016J\b\u0010\u001a\u001a\u00020\u001bH\u0016J\u0018\u0010\u001c\u001a\u00020\u001d2\u0006\u0010\u001e\u001a\u00020\u001d2\u0006\u0010\u001f\u001a\u00020\u0003H\u0014J\u0018\u0010 \u001a\u00020!2\u0006\u0010\u001e\u001a\u00020!2\u0006\u0010\u001f\u001a\u00020\u0003H\u0014J\b\u0010\"\u001a\u00020\u001bH\u0002R$\u0010\r\u001a\u00020\u00052\u0006\u0010\f\u001a\u00020\u0005@BX\u0086\u000e¢\u0006\u000e\n��\u001a\u0004\b\u000e\u0010\u000f\"\u0004\b\u0010\u0010\u0011R\u0011\u0010\u0006\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\u0012\u0010\u000fR\u0011\u0010\u0007\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\u0013\u0010\u000fR\u0011\u0010\b\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\u0014\u0010\u000fR\u000e\u0010\u0015\u001a\u00020\u0005X\u0082\u000e¢\u0006\u0002\n��R\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n��\u001a\u0004\b\u0016\u0010\u000f¨\u0006#"}, d2 = {"Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/adam/ADAMMethod;", "Lcom/kotlinnlp/simplednn/utils/scheduling/ExampleScheduling;", "Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;", "Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/adam/ADAMStructure;", "stepSize", "", "beta1", "beta2", "epsilon", "regularization", "Lcom/kotlinnlp/simplednn/core/functionalities/regularization/WeightsRegularization;", "(DDDDLcom/kotlinnlp/simplednn/core/functionalities/regularization/WeightsRegularization;)V", "<set-?>", "alpha", "getAlpha", "()D", "setAlpha", "(D)V", "getBeta1", "getBeta2", "getEpsilon", "exampleCount", "getStepSize", "getSupportStructure", "array", "Lcom/kotlinnlp/simplednn/core/arrays/UpdatableDenseArray;", "newExample", "", "optimizeDenseErrors", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "errors", "supportStructure", "optimizeSparseErrors", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/sparse/SparseNDArray;", "updateAlpha", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/functionalities/updatemethods/adam/ADAMMethod.class */
public final class ADAMMethod extends UpdateMethod<ADAMStructure> implements ExampleScheduling {
    private double alpha;
    private double exampleCount;
    private final double stepSize;
    private final double beta1;
    private final double beta2;
    private final double epsilon;

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdateMethod
    @NotNull
    public ADAMStructure getSupportStructure(@NotNull UpdatableDenseArray updatableDenseArray) {
        Intrinsics.checkParameterIsNotNull(updatableDenseArray, "array");
        try {
            updatableDenseArray.getUpdaterSupportStructure();
        } catch (UninitializedPropertyAccessException e) {
            updatableDenseArray.setUpdaterSupportStructure((UpdaterSupportStructure) ((KFunction) CollectionsKt.first(Reflection.getOrCreateKotlinClass(ADAMStructure.class).getConstructors())).call(new Object[]{updatableDenseArray.getValues().getShape()}));
        }
        if (!(updatableDenseArray.getUpdaterSupportStructure() instanceof ADAMStructure)) {
            throw new IllegalArgumentException("Incompatible support structure".toString());
        }
        UpdaterSupportStructure updaterSupportStructure = updatableDenseArray.getUpdaterSupportStructure();
        if (updaterSupportStructure == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.core.functionalities.updatemethods.adam.ADAMStructure");
        }
        return (ADAMStructure) updaterSupportStructure;
    }

    public final double getAlpha() {
        return this.alpha;
    }

    private final void setAlpha(double d) {
        this.alpha = d;
    }

    @Override // com.kotlinnlp.simplednn.utils.scheduling.ExampleScheduling
    public void newExample() {
        this.exampleCount += 1.0d;
        updateAlpha();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdateMethod
    @NotNull
    public SparseNDArray optimizeSparseErrors(@NotNull SparseNDArray sparseNDArray, @NotNull ADAMStructure aDAMStructure) {
        Intrinsics.checkParameterIsNotNull(sparseNDArray, "errors");
        Intrinsics.checkParameterIsNotNull(aDAMStructure, "supportStructure");
        DenseNDArray firstOrderMoments = aDAMStructure.getFirstOrderMoments();
        DenseNDArray secondOrderMoments = aDAMStructure.getSecondOrderMoments();
        NDArrayMask mask = sparseNDArray.getMask();
        firstOrderMoments.assignProd(this.beta1, mask).assignSum((NDArray<?>) sparseNDArray.prod(1.0d - this.beta1));
        secondOrderMoments.assignProd(this.beta2, mask).assignSum((NDArray<?>) sparseNDArray.prod((NDArray<?>) sparseNDArray).assignProd(1.0d - this.beta2));
        return firstOrderMoments.div(secondOrderMoments.sqrt(mask).assignSum(this.epsilon)).assignProd(this.alpha);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdateMethod
    @NotNull
    public DenseNDArray optimizeDenseErrors(@NotNull DenseNDArray denseNDArray, @NotNull ADAMStructure aDAMStructure) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "errors");
        Intrinsics.checkParameterIsNotNull(aDAMStructure, "supportStructure");
        DenseNDArray firstOrderMoments = aDAMStructure.getFirstOrderMoments();
        DenseNDArray secondOrderMoments = aDAMStructure.getSecondOrderMoments();
        firstOrderMoments.assignProd(this.beta1).assignSum((NDArray<?>) denseNDArray.prod(1.0d - this.beta1));
        secondOrderMoments.assignProd(this.beta2).assignSum((NDArray<?>) denseNDArray.prod((NDArray<?>) denseNDArray).assignProd(1.0d - this.beta2));
        return firstOrderMoments.div((NDArray<?>) secondOrderMoments.sqrt().assignSum(this.epsilon)).assignProd(this.alpha);
    }

    private final void updateAlpha() {
        this.alpha = (this.stepSize * Math.sqrt(1.0d - Math.pow(this.beta2, this.exampleCount))) / (1.0d - Math.pow(this.beta1, this.exampleCount));
    }

    public final double getStepSize() {
        return this.stepSize;
    }

    public final double getBeta1() {
        return this.beta1;
    }

    public final double getBeta2() {
        return this.beta2;
    }

    public final double getEpsilon() {
        return this.epsilon;
    }

    public ADAMMethod(double d, double d2, double d3, double d4, @Nullable WeightsRegularization weightsRegularization) {
        super(weightsRegularization);
        this.stepSize = d;
        this.beta1 = d2;
        this.beta2 = d3;
        this.epsilon = d4;
        this.alpha = this.stepSize;
    }

    public /* synthetic */ ADAMMethod(double d, double d2, double d3, double d4, WeightsRegularization weightsRegularization, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this((i & 1) != 0 ? 0.001d : d, (i & 2) != 0 ? 0.9d : d2, (i & 4) != 0 ? 0.999d : d3, (i & 8) != 0 ? 1.0E-8d : d4, (i & 16) != 0 ? (WeightsRegularization) null : weightsRegularization);
    }

    public ADAMMethod() {
        this(0.0d, 0.0d, 0.0d, 0.0d, null, 31, null);
    }
}
