package mnist;

import com.kotlinnlp.simplednn.core.arrays.DistributionArray;
import com.kotlinnlp.simplednn.core.functionalities.activations.ELU;
import com.kotlinnlp.simplednn.core.functionalities.activations.Softmax;
import com.kotlinnlp.simplednn.core.functionalities.decaymethods.HyperbolicDecay;
import com.kotlinnlp.simplednn.core.functionalities.losses.SoftmaxCrossEntropyCalculator;
import com.kotlinnlp.simplednn.core.functionalities.outputevaluation.ClassificationEvaluation;
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.learningrate.LearningRateMethod;
import com.kotlinnlp.simplednn.core.layers.LayerType;
import com.kotlinnlp.simplednn.core.neuralnetwork.NeuralNetwork;
import com.kotlinnlp.simplednn.core.neuralnetwork.preset.FeedforwardNeuralNetwork;
import com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor;
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer;
import com.kotlinnlp.simplednn.dataset.Corpus;
import com.kotlinnlp.simplednn.dataset.Shuffler;
import com.kotlinnlp.simplednn.dataset.SimpleExample;
import com.kotlinnlp.simplednn.helpers.training.FeedforwardTrainingHelper;
import com.kotlinnlp.simplednn.helpers.validation.FeedforwardValidationHelper;
import com.kotlinnlp.simplednn.simplemath.ndarray.Shape;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray;
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory;
import com.kotlinnlp.simplednn.simplemath.ndarray.sparsebinary.SparseBinaryNDArray;
import java.util.ArrayList;
import java.util.Arrays;
import kotlin.Metadata;
import kotlin.Unit;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: MNISTSparseBinaryTest.kt */
@Metadata(mv = {1, 1, 8}, bv = {1, 0, 2}, k = 1, d1 = {"��@\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018��2\u00020\u0001B\u0019\u0012\u0012\u0010\u0002\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00050\u00040\u0003¢\u0006\u0002\u0010\u0006J\u0018\u0010\u000b\u001a\u00020\f2\u0006\u0010\r\u001a\u00020\u000e2\u0006\u0010\u000f\u001a\u00020\u0010H\u0002J,\u0010\u0011\u001a\u00020\f2\"\u0010\u0012\u001a\u001e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00050\u00040\u0013j\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00050\u0004`\u0014H\u0002J\u0006\u0010\u0015\u001a\u00020\fJ\b\u0010\u0016\u001a\u00020\fH\u0002R\u001d\u0010\u0002\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00050\u00040\u0003¢\u0006\b\n��\u001a\u0004\b\u0007\u0010\bR\u000e\u0010\t\u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��¨\u0006\u0017"}, d2 = {"Lmnist/MNISTSparseBinaryTest;", "", "dataset", "Lcom/kotlinnlp/simplednn/dataset/Corpus;", "Lcom/kotlinnlp/simplednn/dataset/SimpleExample;", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/sparsebinary/SparseBinaryNDArray;", "(Lcom/kotlinnlp/simplednn/dataset/Corpus;)V", "getDataset", "()Lcom/kotlinnlp/simplednn/dataset/Corpus;", "neuralNetwork", "Lcom/kotlinnlp/simplednn/core/neuralnetwork/NeuralNetwork;", "printImage", "", "image", "Lcom/kotlinnlp/simplednn/simplemath/ndarray/dense/DenseNDArray;", "value", "", "printImages", "examples", "Ljava/util/ArrayList;", "Lkotlin/collections/ArrayList;", "start", "train", "simplednn"})
/* loaded from: input_file:mnist/MNISTSparseBinaryTest.class */
public final class MNISTSparseBinaryTest {
    private final NeuralNetwork neuralNetwork;

    @NotNull
    private final Corpus<SimpleExample<SparseBinaryNDArray>> dataset;

    public final void start() {
        train();
        printImages(new ArrayList<>(this.dataset.getValidation().subList(0, 20)));
    }

    private final void train() {
        System.out.println((Object) "\n-- TRAINING\n");
        new FeedforwardTrainingHelper(new FeedforwardNeuralProcessor(this.neuralNetwork, 0, 2, null), new ParamsOptimizer(this.neuralNetwork.getModel(), new LearningRateMethod(0.01d, new HyperbolicDecay(0.5d, 0.01d, 0.0d, 4, null), null, 4, null)), new SoftmaxCrossEntropyCalculator(), null, true, 8, null).train(this.dataset.getTraining(), 3, 1, this.dataset.getValidation(), new FeedforwardValidationHelper(new FeedforwardNeuralProcessor(this.neuralNetwork, 0, 2, null), new ClassificationEvaluation()), new Shuffler(true, 1L));
    }

    private final void printImages(ArrayList<SimpleExample<SparseBinaryNDArray>> arrayList) {
        System.out.println((Object) "\n-- PRINT IMAGES RELEVANCE\n");
        final FeedforwardNeuralProcessor feedforwardNeuralProcessor = new FeedforwardNeuralProcessor(this.neuralNetwork, 0, 2, null);
        new FeedforwardValidationHelper(feedforwardNeuralProcessor, new ClassificationEvaluation()).validate(arrayList, new Function2<SimpleExample<SparseBinaryNDArray>, Boolean, Unit>() { // from class: mnist.MNISTSparseBinaryTest$printImages$1
            @Override // kotlin.jvm.functions.Function2
            public /* bridge */ /* synthetic */ Unit invoke(SimpleExample<SparseBinaryNDArray> simpleExample, Boolean bool) {
                invoke(simpleExample, bool.booleanValue());
                return Unit.INSTANCE;
            }

            public final void invoke(@NotNull SimpleExample<SparseBinaryNDArray> example, boolean z) {
                Intrinsics.checkParameterIsNotNull(example, "example");
                MNISTSparseBinaryTest.this.printImage(DenseNDArrayFactory.INSTANCE.zeros(new Shape(784, 0, 2, null)).assignValues(FeedforwardNeuralProcessor.calculateInputRelevance$default(feedforwardNeuralProcessor, DistributionArray.Companion.uniform(10), false, 2, null)), example.getOutputGold().argMaxIndex());
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(2);
            }
        }, true);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final void printImage(DenseNDArray denseNDArray, int i) {
        Object[] objArr = {Integer.valueOf(i)};
        String format = String.format("------------------ %d -----------------", Arrays.copyOf(objArr, objArr.length));
        Intrinsics.checkExpressionValueIsNotNull(format, "java.lang.String.format(this, *args)");
        System.out.println((Object) format);
        for (int i2 = 0; i2 < 28; i2++) {
            for (int i3 = 0; i3 < 28; i3++) {
                System.out.print((Object) (denseNDArray.get((i2 * 28) + i3).doubleValue() > 0.0d ? "# " : "  "));
            }
            System.out.println();
        }
    }

    @NotNull
    public final Corpus<SimpleExample<SparseBinaryNDArray>> getDataset() {
        return this.dataset;
    }

    public MNISTSparseBinaryTest(@NotNull Corpus<SimpleExample<SparseBinaryNDArray>> dataset) {
        Intrinsics.checkParameterIsNotNull(dataset, "dataset");
        this.dataset = dataset;
        this.neuralNetwork = FeedforwardNeuralNetwork.invoke$default(FeedforwardNeuralNetwork.INSTANCE, 784, LayerType.Input.SparseBinary, 0.0d, 100, new ELU(0.0d, 1, null), 0.0d, false, 10, new Softmax(), false, null, null, 3684, null);
    }
}
