package com.kotlinnlp.simplednn.core.functionalities.updatemethods;

import com.kotlinnlp.simplednn.core.arrays.UpdatableDenseArray;
import com.kotlinnlp.simplednn.core.functionalities.regularization.WeightsRegularization;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.UpdaterSupportStructure;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.sparse.SparseNDArray;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.UninitializedPropertyAccessException;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.Intrinsics;
import kotlin.reflect.KClass;
import kotlin.reflect.KFunction;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: UpdateMethod.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��H\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0002\n\u0002\b\u0002\b&\u0018��*\b\b��\u0010\u0001*\u00020\u00022\u00020\u0003B\u000f\u0012\b\u0010\u0004\u001a\u0004\u0018\u00010\u0005¢\u0006\u0002\u0010\u0006J\u0013\u0010\r\u001a\u00028��2\u0006\u0010\u000e\u001a\u00020\u000f¢\u0006\u0002\u0010\u0010J\u0018\u0010\u0011\u001a\u00020\u00122\u0006\u0010\u0013\u001a\u00020\u00122\u0006\u0010\u000e\u001a\u00020\u000fH$J-\u0010\u0014\u001a\u0002H\u0015\"\u000e\b\u0001\u0010\u0015*\b\u0012\u0004\u0012\u0002H\u00150\u00162\u0006\u0010\u0013\u001a\u0002H\u00152\u0006\u0010\u000e\u001a\u00020\u000fH\u0002¢\u0006\u0002\u0010\u0017J\u0018\u0010\u0018\u001a\u00020\u00192\u0006\u0010\u0013\u001a\u00020\u00192\u0006\u0010\u000e\u001a\u00020\u000fH$J\u0015\u0010\u001a\u001a\u00028��2\u0006\u0010\u000e\u001a\u00020\u000fH\u0002¢\u0006\u0002\u0010\u0010J+\u0010\u001b\u001a\u00020\u001c\"\u000e\b\u0001\u0010\u0015*\b\u0012\u0004\u0012\u0002H\u00150\u00162\u0006\u0010\u000e\u001a\u00020\u000f2\u0006\u0010\u0013\u001a\u0002H\u0015¢\u0006\u0002\u0010\u001dR\u0013\u0010\u0004\u001a\u0004\u0018\u00010\u0005¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\bR\u0018\u0010\t\u001a\b\u0012\u0004\u0012\u00028��0\nX¤\u0004¢\u0006\u0006\u001a\u0004\b\u000b\u0010\f¨\u0006\u001e"}, d2 = {"Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod;", "SupportStructureType", "Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdaterSupportStructure;", "", "regularization", "Lcom/kotlinnlp/simplednn/core/functionalities/regularization/WeightsRegularization;", "(Lcom/kotlinnlp/simplednn/core/functionalities/regularization/WeightsRegularization;)V", "getRegularization", "()Lcom/kotlinnlp/simplednn/core/functionalities/regularization/WeightsRegularization;", "structureClass", "Lkotlin/reflect/KClass;", "getStructureClass", "()Lkotlin/reflect/KClass;", "getSupportStructure", "array", "Lcom/kotlinnlp/simplednn/core/arrays/UpdatableDenseArray;", "(Lcom/kotlinnlp/simplednn/core/arrays/UpdatableDenseArray;)Lcom/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdaterSupportStructure;", "optimizeDenseErrors", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "errors", "optimizeErrors", "NDArrayType", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "(Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;Lcom/kotlinnlp/simplednn/core/arrays/UpdatableDenseArray;)Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;", "optimizeSparseErrors", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/sparse/SparseNDArray;", "structureFactory", "update", "", "(Lcom/kotlinnlp/simplednn/core/arrays/UpdatableDenseArray;Lcom/kotlinnlp/simplednn/simplemath/ndarray/NDArray;)V", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/core/functionalities/updatemethods/UpdateMethod.class */
public abstract class UpdateMethod<SupportStructureType extends UpdaterSupportStructure> {

    @Nullable
    private final WeightsRegularization regularization;

    @NotNull
    protected abstract KClass<SupportStructureType> getStructureClass();

    private final SupportStructureType structureFactory(UpdatableDenseArray updatableDenseArray) {
        return (SupportStructureType) ((KFunction) CollectionsKt.first(getStructureClass().getConstructors())).call(new Object[]{updatableDenseArray.getValues().getShape()});
    }

    @NotNull
    public final SupportStructureType getSupportStructure(@NotNull UpdatableDenseArray updatableDenseArray) {
        Intrinsics.checkParameterIsNotNull(updatableDenseArray, "array");
        try {
            updatableDenseArray.getUpdaterSupportStructure();
        } catch (UninitializedPropertyAccessException e) {
            updatableDenseArray.setUpdaterSupportStructure(structureFactory(updatableDenseArray));
        }
        if (!getStructureClass().isInstance(updatableDenseArray.getUpdaterSupportStructure())) {
            throw new IllegalArgumentException("Incompatible updaterSupportStructure".toString());
        }
        SupportStructureType supportstructuretype = (SupportStructureType) updatableDenseArray.getUpdaterSupportStructure();
        if (supportstructuretype == null) {
            throw new TypeCastException("null cannot be cast to non-null type SupportStructureType");
        }
        return supportstructuretype;
    }

    public final <NDArrayType extends NDArray<NDArrayType>> void update(@NotNull UpdatableDenseArray updatableDenseArray, @NotNull NDArrayType ndarraytype) {
        Intrinsics.checkParameterIsNotNull(updatableDenseArray, "array");
        Intrinsics.checkParameterIsNotNull(ndarraytype, "errors");
        NDArray<?> optimizeErrors = optimizeErrors(ndarraytype, updatableDenseArray);
        WeightsRegularization weightsRegularization = this.regularization;
        if (weightsRegularization != null) {
            weightsRegularization.apply(updatableDenseArray);
        }
        updatableDenseArray.getValues().assignSub(optimizeErrors);
    }

    private final <NDArrayType extends NDArray<NDArrayType>> NDArrayType optimizeErrors(NDArrayType ndarraytype, UpdatableDenseArray updatableDenseArray) {
        if (ndarraytype instanceof SparseNDArray) {
            SparseNDArray optimizeSparseErrors = optimizeSparseErrors((SparseNDArray) ndarraytype, updatableDenseArray);
            if (optimizeSparseErrors == null) {
                throw new TypeCastException("null cannot be cast to non-null type NDArrayType");
            }
            return optimizeSparseErrors;
        }
        if (!(ndarraytype instanceof DenseNDArray)) {
            throw new RuntimeException("Invalid errors type");
        }
        DenseNDArray optimizeDenseErrors = optimizeDenseErrors((DenseNDArray) ndarraytype, updatableDenseArray);
        if (optimizeDenseErrors == null) {
            throw new TypeCastException("null cannot be cast to non-null type NDArrayType");
        }
        return optimizeDenseErrors;
    }

    @NotNull
    protected abstract SparseNDArray optimizeSparseErrors(@NotNull SparseNDArray sparseNDArray, @NotNull UpdatableDenseArray updatableDenseArray);

    @NotNull
    protected abstract DenseNDArray optimizeDenseErrors(@NotNull DenseNDArray denseNDArray, @NotNull UpdatableDenseArray updatableDenseArray);

    @Nullable
    public final WeightsRegularization getRegularization() {
        return this.regularization;
    }

    public UpdateMethod(@Nullable WeightsRegularization weightsRegularization) {
        this.regularization = weightsRegularization;
    }
}
