package com.kotlinnlp.simplednn.deeplearning.newrecirculation;

import com.kotlinnlp.simplednn.core.arrays.AugmentedArray;
import com.kotlinnlp.simplednn.core.layers.models.feedforward.simple.FeedforwardLayerParameters;
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.Shape;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import com.kotlinnlp.utils.ItemsPool;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: NewRecirculationNetwork.kt */
@Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��D\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0006\n\u0002\b\u0002\n\u0002\u0010\b\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0002\n\u0002\b\u0006\n\u0002\u0010\u000b\n\u0002\b\u0005\u0018�� %2\u00020\u0001:\u0001%B+\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\b\b\u0002\u0010\u0004\u001a\u00020\u0005\u0012\b\b\u0002\u0010\u0006\u001a\u00020\u0005\u0012\b\b\u0002\u0010\u0007\u001a\u00020\b¢\u0006\u0002\u0010\tJ\b\u0010\u0019\u001a\u00020\u001aH\u0002J\b\u0010\u001b\u001a\u00020\u001aH\u0002J\b\u0010\u001c\u001a\u00020\u001aH\u0002J\u000e\u0010\u001d\u001a\u00020\u00052\u0006\u0010\u001e\u001a\u00020\u000eJ\u0012\u0010\u001f\u001a\u00020\u001a2\b\b\u0002\u0010 \u001a\u00020!H\u0002J\u0018\u0010\"\u001a\u00020\u000e2\u0006\u0010\u001e\u001a\u00020\u000e2\b\b\u0002\u0010#\u001a\u00020!J\b\u0010$\u001a\u00020\u001aH\u0002R\u0014\u0010\u0007\u001a\u00020\bX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\n\u0010\u000bR\u0014\u0010\f\u001a\b\u0012\u0004\u0012\u00020\u000e0\rX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u000f\u001a\b\u0012\u0004\u0012\u00020\u000e0\rX\u0082\u0004¢\u0006\u0002\n��R\u0011\u0010\u0010\u001a\u00020\u00058F¢\u0006\u0006\u001a\u0004\b\u0011\u0010\u0012R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n��\u001a\u0004\b\u0013\u0010\u0014R\u000e\u0010\u0015\u001a\u00020\u0016X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0017\u001a\b\u0012\u0004\u0012\u00020\u000e0\rX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0018\u001a\b\u0012\u0004\u0012\u00020\u000e0\rX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0004\u001a\u00020\u0005X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0006\u001a\u00020\u0005X\u0082\u0004¢\u0006\u0002\n��¨\u0006&"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/newrecirculation/NewRecirculationNetwork;", "Lcom/kotlinnlp/utils/ItemsPool$IDItem;", "model", "Lcom/kotlinnlp/simplednn/deeplearning/newrecirculation/NewRecirculationModel;", "recallThreshold", "", "trainingLearningRate", "id", "", "(Lcom/kotlinnlp/simplednn/deeplearning/newrecirculation/NewRecirculationModel;DDI)V", "getId", "()I", "imaginaryInput", "Lcom/kotlinnlp/simplednn/core/arrays/AugmentedArray;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "imaginaryOutput", "meanAbsError", "getMeanAbsError", "()D", "getModel", "()Lcom/kotlinnlp/simplednn/deeplearning/newrecirculation/NewRecirculationModel;", "paramsErrors", "Lcom/kotlinnlp/simplednn/core/layers/models/feedforward/simple/FeedforwardLayerParameters;", "realInput", "realOutput", "backward", "", "calcImaginaryInput", "calcParamsErrors", "learn", "inputArray", "reEntry", "trainingMode", "", "reconstruct", "useReEntry", "update", "Companion", "simplednn"})
/* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/newrecirculation/NewRecirculationNetwork.class */
public final class NewRecirculationNetwork implements ItemsPool.IDItem {
    private final AugmentedArray<DenseNDArray> realInput;
    private final AugmentedArray<DenseNDArray> realOutput;
    private final AugmentedArray<DenseNDArray> imaginaryInput;
    private final AugmentedArray<DenseNDArray> imaginaryOutput;
    private final FeedforwardLayerParameters paramsErrors;

    @NotNull
    private final NewRecirculationModel model;
    private final double recallThreshold;
    private final double trainingLearningRate;
    private final int id;
    private static final double MAX_RECALL_ITERATIONS = 10000.0d;
    public static final Companion Companion = new Companion(null);

    /* compiled from: NewRecirculationNetwork.kt */
    @Metadata(mv = {1, 1, 13}, bv = {1, 0, 3}, k = 1, d1 = {"��\u0012\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0010\u0006\n��\b\u0086\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002R\u000e\u0010\u0003\u001a\u00020\u0004X\u0082T¢\u0006\u0002\n��¨\u0006\u0005"}, d2 = {"Lcom/kotlinnlp/simplednn/deeplearning/newrecirculation/NewRecirculationNetwork$Companion;", "", "()V", "MAX_RECALL_ITERATIONS", "", "simplednn"})
    /* loaded from: input_file:com/kotlinnlp/simplednn/deeplearning/newrecirculation/NewRecirculationNetwork$Companion.class */
    public static final class Companion {
        private Companion() {
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    public final double getMeanAbsError() {
        return this.realInput.getValues().sub(this.imaginaryInput.getValues()).abs().sum();
    }

    @NotNull
    public final DenseNDArray reconstruct(@NotNull DenseNDArray denseNDArray, boolean z) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "inputArray");
        if (!(denseNDArray.getLength() == this.model.getInputSize())) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        this.realInput.assignValues(denseNDArray);
        calcImaginaryInput();
        if (z) {
            reEntry(false);
        }
        return this.imaginaryInput.getValues();
    }

    @NotNull
    public static /* synthetic */ DenseNDArray reconstruct$default(NewRecirculationNetwork newRecirculationNetwork, DenseNDArray denseNDArray, boolean z, int i, Object obj) {
        if ((i & 2) != 0) {
            z = true;
        }
        return newRecirculationNetwork.reconstruct(denseNDArray, z);
    }

    public final double learn(@NotNull DenseNDArray denseNDArray) {
        Intrinsics.checkParameterIsNotNull(denseNDArray, "inputArray");
        if (!(denseNDArray.getLength() == this.model.getInputSize())) {
            throw new IllegalArgumentException("Failed requirement.".toString());
        }
        this.realInput.assignValues(denseNDArray);
        calcImaginaryInput();
        double meanAbsError = getMeanAbsError();
        reEntry(true);
        return meanAbsError;
    }

    private final void calcImaginaryInput() {
        double lambda = this.model.getLambda();
        Object values = this.model.getParams().getUnit().getWeights().getValues();
        if (values == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray = (DenseNDArray) values;
        Object values2 = this.model.getParams().getUnit().getBiases().getValues();
        if (values2 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray2 = (DenseNDArray) values2;
        DenseNDArray values3 = this.realInput.getValues();
        DenseNDArray values4 = this.realOutput.getValues();
        DenseNDArray values5 = this.imaginaryInput.getValues();
        values4.assignDot(denseNDArray, values3).assignSum((NDArray<?>) denseNDArray2);
        this.realOutput.activate();
        values5.assignSum(values3.prod(lambda), values4.getT().dot((NDArray<?>) denseNDArray).getT().assignProd(1 - lambda));
    }

    private final void reEntry(boolean z) {
        int i = 0;
        while (getMeanAbsError() >= this.recallThreshold) {
            double d = i;
            i++;
            if (d >= MAX_RECALL_ITERATIONS) {
                return;
            }
            if (z) {
                backward();
            }
            this.realInput.assignValues(this.imaginaryInput.getValues());
            calcImaginaryInput();
        }
    }

    static /* synthetic */ void reEntry$default(NewRecirculationNetwork newRecirculationNetwork, boolean z, int i, Object obj) {
        if ((i & 1) != 0) {
            z = false;
        }
        newRecirculationNetwork.reEntry(z);
    }

    private final void backward() {
        double lambda = this.model.getLambda();
        Object values = this.model.getParams().getUnit().getWeights().getValues();
        if (values == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray = (DenseNDArray) values;
        Object values2 = this.model.getParams().getUnit().getBiases().getValues();
        if (values2 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray2 = (DenseNDArray) values2;
        DenseNDArray values3 = this.realOutput.getValues();
        DenseNDArray values4 = this.imaginaryInput.getValues();
        DenseNDArray values5 = this.imaginaryOutput.getValues();
        values5.assignDot(denseNDArray, values4).assignSum((NDArray<?>) denseNDArray2);
        this.imaginaryOutput.activate();
        values5.assignProd(1 - lambda).assignSum((NDArray<?>) values3.prod(lambda));
        calcParamsErrors();
        update();
    }

    private final void calcParamsErrors() {
        DenseNDArray values = this.realInput.getValues();
        DenseNDArray values2 = this.realOutput.getValues();
        DenseNDArray values3 = this.imaginaryInput.getValues();
        DenseNDArray values4 = this.imaginaryOutput.getValues();
        Object values5 = this.paramsErrors.getUnit().getWeights().getValues();
        if (values5 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray = (DenseNDArray) values5;
        Object values6 = this.paramsErrors.getUnit().getBiases().getValues();
        if (values6 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray2 = (DenseNDArray) values6;
        DenseNDArray sub = values3.sub(values);
        DenseNDArray sub2 = values4.sub(values2);
        denseNDArray.assignDot(sub2, values3.getT()).assignSum((NDArray<?>) values2.dot((NDArray<?>) sub.getT()));
        denseNDArray2.assignValues((NDArray<?>) sub2);
    }

    private final void update() {
        double d = this.trainingLearningRate;
        Object values = this.model.getParams().getUnit().getWeights().getValues();
        if (values == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray = (DenseNDArray) values;
        Object values2 = this.model.getParams().getUnit().getBiases().getValues();
        if (values2 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray2 = (DenseNDArray) values2;
        Object values3 = this.paramsErrors.getUnit().getWeights().getValues();
        if (values3 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        DenseNDArray denseNDArray3 = (DenseNDArray) values3;
        Object values4 = this.paramsErrors.getUnit().getBiases().getValues();
        if (values4 == null) {
            throw new TypeCastException("null cannot be cast to non-null type com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray");
        }
        denseNDArray.assignSub((NDArray<?>) denseNDArray3.assignProd(d));
        denseNDArray2.assignSub((NDArray<?>) ((DenseNDArray) values4).assignProd(d));
    }

    @NotNull
    public final NewRecirculationModel getModel() {
        return this.model;
    }

    public int getId() {
        return this.id;
    }

    public NewRecirculationNetwork(@NotNull NewRecirculationModel newRecirculationModel, double d, double d2, int i) {
        Intrinsics.checkParameterIsNotNull(newRecirculationModel, "model");
        this.model = newRecirculationModel;
        this.recallThreshold = d;
        this.trainingLearningRate = d2;
        this.id = i;
        this.realInput = AugmentedArray.Companion.invoke(DenseNDArrayFactory.INSTANCE.zeros(new Shape(this.model.getInputSize(), 0, 2, null)));
        this.realOutput = AugmentedArray.Companion.invoke(DenseNDArrayFactory.INSTANCE.zeros(new Shape(this.model.getHiddenSize(), 0, 2, null)));
        this.imaginaryInput = AugmentedArray.Companion.invoke(DenseNDArrayFactory.INSTANCE.zeros(new Shape(this.model.getInputSize(), 0, 2, null)));
        this.imaginaryOutput = AugmentedArray.Companion.invoke(DenseNDArrayFactory.INSTANCE.zeros(new Shape(this.model.getHiddenSize(), 0, 2, null)));
        this.paramsErrors = this.model.getParams().copy();
        if (this.model.getActivationFunction() != null) {
            this.realOutput.setActivation(this.model.getActivationFunction());
            this.imaginaryOutput.setActivation(this.model.getActivationFunction());
        }
    }

    public /* synthetic */ NewRecirculationNetwork(NewRecirculationModel newRecirculationModel, double d, double d2, int i, int i2, DefaultConstructorMarker defaultConstructorMarker) {
        this(newRecirculationModel, (i2 & 2) != 0 ? 0.001d : d, (i2 & 4) != 0 ? 0.01d : d2, (i2 & 8) != 0 ? 0 : i);
    }
}
