package org.deeplearning4j.datasets;

import java.awt.Color;
import java.awt.Dimension;
import java.awt.Graphics;
import java.awt.GridLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.IOException;
import javax.swing.BorderFactory;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;

/* loaded from: input_file:org/deeplearning4j/datasets/MNISTViewer.class */
public class MNISTViewer {
    MnistManager manager;
    NN network;
    InputPanel input;
    OutputPanel output;
    OptionsPanel options;
    int xnodes = 50;
    int ynodes = 24;
    int connections = 10;
    int outputRows = 10;
    int outputCols = 10;

    /* loaded from: input_file:org/deeplearning4j/datasets/MNISTViewer$InputPanel.class */
    class InputPanel extends JPanel {
        private int width = 200;
        private int height = 200;
        private int mx = 20;
        private int my = 30;
        private int imageIndex = 1;
        private int maxIndex;

        public InputPanel() {
            setPreferredSize(new Dimension(this.width, this.height));
            setBorder(BorderFactory.createLineBorder(Color.black));
            add(new JLabel("Input"));
            MNISTViewer.this.manager.setCurrent(this.imageIndex);
            this.maxIndex = MNISTViewer.this.manager.getImages().getCount();
        }

        public void nextImage() {
            this.imageIndex = this.imageIndex + 1 > this.maxIndex ? this.maxIndex : this.imageIndex + 1;
            MNISTViewer.this.manager.setCurrent(this.imageIndex);
        }

        public void previousImage() {
            this.imageIndex = this.imageIndex - 1 < 1 ? 1 : this.imageIndex - 1;
            MNISTViewer.this.manager.setCurrent(this.imageIndex);
        }

        public void drawCurrentImage(Graphics graphics) {
            MNISTViewer.this.manager.setCurrent(this.imageIndex);
            int[][] iArr = (int[][]) null;
            int rows = MNISTViewer.this.manager.getImages().getRows();
            int rows2 = MNISTViewer.this.manager.getImages().getRows();
            try {
                iArr = MNISTViewer.this.manager.readImage();
            } catch (IOException e) {
                e.printStackTrace();
            }
            for (int i = 0; i < rows; i++) {
                for (int i2 = 0; i2 < rows2; i2++) {
                    int i3 = iArr[i][i2];
                    graphics.setColor(new Color(i3, i3, i3));
                    graphics.fillRect(this.mx + i2, this.my + i, 1, 1);
                }
            }
        }

        public void drawCurrentOutput(Graphics graphics) {
            int i = this.mx;
            int i2 = 3 * this.my;
            try {
                int readLabel = MNISTViewer.this.manager.readLabel();
                int[][] iArr = new int[MNISTViewer.this.outputRows][MNISTViewer.this.outputCols];
                for (int i3 = 0; i3 < MNISTViewer.this.outputCols; i3++) {
                    iArr[readLabel][i3] = 255;
                }
                for (int i4 = 0; i4 < MNISTViewer.this.outputRows; i4++) {
                    for (int i5 = 0; i5 < MNISTViewer.this.outputCols; i5++) {
                        int i6 = iArr[i4][i5];
                        graphics.setColor(new Color(i6, i6, i6));
                        graphics.fillRect(i + i5, i2 + i4, 1, 1);
                    }
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }

        public void paintComponent(Graphics graphics) {
            super.paintComponent(graphics);
            graphics.drawRect(this.mx, this.my, getWidth() - (2 * this.mx), getHeight() - (2 * this.my));
            drawCurrentImage(graphics);
            drawCurrentOutput(graphics);
        }
    }

    /* loaded from: input_file:org/deeplearning4j/datasets/MNISTViewer$MyFrame.class */
    class MyFrame extends JFrame {
        public MyFrame(String str) {
            super(str);
            setDefaultCloseOperation(3);
            JPanel jPanel = new JPanel();
            jPanel.setBorder(BorderFactory.createLineBorder(Color.black));
            jPanel.setLayout(new GridLayout(1, 3));
            MNISTViewer.this.input = new InputPanel();
            MNISTViewer.this.output = new OutputPanel();
            MNISTViewer.this.options = new OptionsPanel();
            add(jPanel);
            jPanel.add(MNISTViewer.this.input);
            jPanel.add(MNISTViewer.this.output);
            jPanel.add(MNISTViewer.this.options);
        }
    }

    /* loaded from: input_file:org/deeplearning4j/datasets/MNISTViewer$OptionsPanel.class */
    class OptionsPanel extends JPanel {
        public OptionsPanel() {
            setPreferredSize(new Dimension(200, 200));
            setBorder(BorderFactory.createLineBorder(Color.black));
            add(new JLabel("Options"));
            JButton jButton = new JButton("Next");
            JButton jButton2 = new JButton("Previous");
            JButton jButton3 = new JButton("Set Input");
            JButton jButton4 = new JButton("Set Output");
            JButton jButton5 = new JButton("Update");
            JButton jButton6 = new JButton("Reset");
            jButton.addActionListener(new ActionListener() { // from class: org.deeplearning4j.datasets.MNISTViewer.OptionsPanel.1
                public void actionPerformed(ActionEvent actionEvent) {
                    MNISTViewer.this.input.nextImage();
                    MNISTViewer.this.input.repaint();
                }
            });
            jButton2.addActionListener(new ActionListener() { // from class: org.deeplearning4j.datasets.MNISTViewer.OptionsPanel.2
                public void actionPerformed(ActionEvent actionEvent) {
                    MNISTViewer.this.input.previousImage();
                    MNISTViewer.this.input.repaint();
                }
            });
            jButton3.addActionListener(new ActionListener() { // from class: org.deeplearning4j.datasets.MNISTViewer.OptionsPanel.3
                public void actionPerformed(ActionEvent actionEvent) {
                    try {
                        MNISTViewer.this.manager.setCurrent(MNISTViewer.this.input.imageIndex);
                        int[][] readImage = MNISTViewer.this.manager.readImage();
                        int length = readImage.length * readImage[0].length;
                        int[] iArr = new int[length];
                        for (int i = 0; i < length; i++) {
                            iArr[i] = i;
                        }
                        float[] fArr = new float[length];
                        for (int i2 = 0; i2 < readImage.length; i2++) {
                            for (int i3 = 0; i3 < readImage[0].length; i3++) {
                                fArr[(i2 * readImage.length) + i3] = readImage[i2][i3] / 255.0f;
                            }
                        }
                        MNISTViewer.this.network.setInput(iArr, fArr);
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                    MNISTViewer.this.output.repaint();
                }
            });
            jButton4.addActionListener(new ActionListener() { // from class: org.deeplearning4j.datasets.MNISTViewer.OptionsPanel.4
                public void actionPerformed(ActionEvent actionEvent) {
                    int i = MNISTViewer.this.outputRows * MNISTViewer.this.outputCols;
                    int[] iArr = new int[i];
                    int i2 = ((MNISTViewer.this.xnodes * MNISTViewer.this.ynodes) - 1) - i;
                    for (int i3 = 0; i3 < i; i3++) {
                        iArr[i3] = i2;
                        i2++;
                    }
                    try {
                        int readLabel = MNISTViewer.this.manager.readLabel();
                        float[][] fArr = new float[MNISTViewer.this.outputRows][MNISTViewer.this.outputCols];
                        for (int i4 = 0; i4 < MNISTViewer.this.outputCols; i4++) {
                            fArr[readLabel][i4] = 1.0f;
                        }
                        float[] fArr2 = new float[i];
                        for (int i5 = 0; i5 < MNISTViewer.this.outputRows; i5++) {
                            for (int i6 = 0; i6 < MNISTViewer.this.outputCols; i6++) {
                                fArr2[(i5 * MNISTViewer.this.outputCols) + i6] = fArr[i5][i6];
                            }
                        }
                        MNISTViewer.this.network.setOutput(iArr, fArr2);
                        MNISTViewer.this.input.repaint();
                        MNISTViewer.this.output.repaint();
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            });
            jButton5.addActionListener(new ActionListener() { // from class: org.deeplearning4j.datasets.MNISTViewer.OptionsPanel.5
                public void actionPerformed(ActionEvent actionEvent) {
                    MNISTViewer.this.network.update();
                    MNISTViewer.this.output.repaint();
                }
            });
            jButton6.addActionListener(new ActionListener() { // from class: org.deeplearning4j.datasets.MNISTViewer.OptionsPanel.6
                public void actionPerformed(ActionEvent actionEvent) {
                    MNISTViewer.this.network.reset();
                    MNISTViewer.this.output.repaint();
                }
            });
            add(jButton);
            add(jButton2);
            add(jButton3);
            add(jButton4);
            add(jButton5);
            add(jButton6);
        }
    }

    /* loaded from: input_file:org/deeplearning4j/datasets/MNISTViewer$OutputPanel.class */
    class OutputPanel extends JPanel {
        private int width = 200;
        private int height = 200;
        private int mx = 20;
        private int my = 30;

        public OutputPanel() {
            setPreferredSize(new Dimension(this.width, this.height));
            setBorder(BorderFactory.createLineBorder(Color.black));
            add(new JLabel("Output"));
        }

        private void drawInputNodes(Graphics graphics) {
            int rows = MNISTViewer.this.manager.getImages().getRows();
            int rows2 = MNISTViewer.this.manager.getImages().getRows();
            float[] readInput = MNISTViewer.this.network.readInput();
            if (readInput == null) {
                readInput = new float[rows * rows2];
            }
            int i = this.mx;
            int i2 = this.my;
            for (int i3 = 0; i3 < rows; i3++) {
                for (int i4 = 0; i4 < rows2; i4++) {
                    int i5 = (int) (255.0d * readInput[(i3 * rows2) + i4]);
                    int i6 = i5 > 255 ? 255 : i5;
                    try {
                        graphics.setColor(new Color(i6, i6, i6));
                        graphics.fillRect(i + i4, i2 + i3, 1, 1);
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            }
        }

        private void drawState(Graphics graphics) {
            float[] state = MNISTViewer.this.network.getState();
            int i = this.mx;
            int i2 = 3 * this.my;
            int i3 = MNISTViewer.this.ynodes;
            int i4 = MNISTViewer.this.xnodes;
            for (int i5 = 0; i5 < i3; i5++) {
                for (int i6 = 0; i6 < i4; i6++) {
                    int i7 = (int) (255.0d * state[(i5 * i4) + i6]);
                    int i8 = i7 > 255 ? 255 : i7;
                    graphics.setColor(new Color(i8, i8, i8));
                    graphics.fillRect(i + i6, i2 + i5, 1, 1);
                }
            }
        }

        private void drawOutputNodes(Graphics graphics) {
            int i = MNISTViewer.this.outputRows;
            int i2 = MNISTViewer.this.outputCols;
            float[] readOutput = MNISTViewer.this.network.readOutput();
            if (readOutput == null) {
                readOutput = new float[i * i2];
            }
            int i3 = this.mx;
            int i4 = 4 * this.my;
            for (int i5 = 0; i5 < i; i5++) {
                for (int i6 = 0; i6 < i2; i6++) {
                    int i7 = (int) (255.0d * readOutput[(i5 * i2) + i6]);
                    int i8 = i7 > 255 ? 255 : i7;
                    try {
                        graphics.setColor(new Color(i8, i8, i8));
                        graphics.fillRect(i3 + i6, i4 + i5, 1, 1);
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            }
        }

        public void paintComponent(Graphics graphics) {
            super.paintComponent(graphics);
            graphics.drawRect(this.mx, this.my, getWidth() - (2 * this.mx), getHeight() - (2 * this.my));
            drawInputNodes(graphics);
            drawState(graphics);
            drawOutputNodes(graphics);
        }
    }

    public MNISTViewer() {
        try {
            this.manager = new MnistManager("MNIST/train-images-idx3-ubyte", "MNIST/train-labels-idx1-ubyte");
        } catch (IOException e) {
            e.printStackTrace();
        }
        this.network = new NN(this.xnodes * this.ynodes, this.connections);
        this.network.init();
        MyFrame myFrame = new MyFrame("MNSIT Viewer");
        myFrame.pack();
        myFrame.setVisible(true);
    }

    public static void main(String[] strArr) {
        new MNISTViewer();
    }
}
