package org.deeplearning4j.plot;

import java.awt.Color;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.awt.Image;
import java.awt.image.BufferStrategy;
import java.awt.image.BufferedImage;
import java.awt.image.ImageObserver;
import java.awt.image.WritableRaster;
import java.io.File;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;
import javax.imageio.ImageIO;
import javax.swing.ImageIcon;
import javax.swing.JFrame;
import javax.swing.JLabel;
import org.deeplearning4j.nn.NeuralNetwork;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/plot/FilterRenderer.class */
public class FilterRenderer {
    public JFrame frame;
    BufferedImage img;
    private int width = 28;
    private int height = 28;
    public String title = "TEST";
    private int heightOffset = 0;
    private int widthOffset = 0;
    private static Logger log = LoggerFactory.getLogger(FilterRenderer.class);

    public void renderHiddenBiases(int i, int i2, DoubleMatrix doubleMatrix, String str) {
        this.width = doubleMatrix.columns;
        this.height = doubleMatrix.rows;
        this.img = new BufferedImage(this.width, this.height, 1);
        this.heightOffset = i;
        this.widthOffset = i2;
        WritableRaster raster = this.img.getRaster();
        int[] iArr = new int[doubleMatrix.length];
        for (int i3 = 0; i3 < iArr.length; i3++) {
            iArr[i3] = (int) Math.round(doubleMatrix.get(i3) * 256.0d);
            log.debug("> " + iArr[i3]);
        }
        log.debug("hbias size: Cols: " + doubleMatrix.columns + ", Rows: " + doubleMatrix.rows);
        raster.setDataElements(0, 0, this.width, this.height, iArr);
        saveToDisk(str);
    }

    public void renderAllHistograms(NeuralNetwork neuralNetwork) {
    }

    public int computeHistogramBucketIndex(double d, double d2, double d3, int i) {
        for (int i2 = 0; i2 < i; i2++) {
            double d4 = (i2 * d2) + d;
            if (d3 >= d4 && d3 <= d4 + d2) {
                return i2;
            }
        }
        return -10;
    }

    public int computeHistogramBucketIndexAlt(double d, double d2, double d3, int i) {
        return (int) ((d3 - d) / d2);
    }

    public static double round(double d, int i, int i2) {
        return new BigDecimal(d).setScale(i, i2).doubleValue();
    }

    private String buildBucketLabel(int i, double d, double d2) {
        return "" + round(d2 + (i * d), 2, 4);
    }

    public Map<Integer, Integer> generateHistogramBuckets(DoubleMatrix doubleMatrix, int i) {
        TreeMap treeMap = new TreeMap();
        double min = MatrixUtil.min(doubleMatrix);
        double max = (MatrixUtil.max(doubleMatrix) - min) / i;
        for (int i2 = 0; i2 < doubleMatrix.rows; i2++) {
            for (int i3 = 0; i3 < doubleMatrix.columns; i3++) {
                int computeHistogramBucketIndex = computeHistogramBucketIndex(min, max, doubleMatrix.get(i2, i3), i);
                if (treeMap.containsKey(Integer.valueOf(computeHistogramBucketIndex))) {
                    treeMap.put(Integer.valueOf(computeHistogramBucketIndex), Integer.valueOf(((Integer) treeMap.get(Integer.valueOf(computeHistogramBucketIndex))).intValue() + 1));
                } else {
                    buildBucketLabel(computeHistogramBucketIndex, max, min);
                    treeMap.put(Integer.valueOf(computeHistogramBucketIndex), 1);
                }
            }
        }
        return treeMap;
    }

    public void renderHistogram(DoubleMatrix doubleMatrix, String str, int i) {
        Map<Integer, Integer> generateHistogramBuckets = generateHistogramBuckets(doubleMatrix, i);
        double min = MatrixUtil.min(doubleMatrix);
        double max = (MatrixUtil.max(doubleMatrix) - min) / i;
        BufferedImage bufferedImage = new BufferedImage(600, 400, 1);
        Graphics2D createGraphics = bufferedImage.createGraphics();
        createGraphics.setColor(Color.LIGHT_GRAY);
        createGraphics.fillRect(0, 0, 600, 400);
        int i2 = 0;
        Iterator<Integer> it = generateHistogramBuckets.keySet().iterator();
        while (it.hasNext()) {
            i2 = Math.max(i2, generateHistogramBuckets.get(it.next()).intValue());
        }
        double d = (400 - 50) / 4.0d;
        double d2 = i2 / 4.0d;
        for (int i3 = 0; i3 < 5; i3++) {
            double d3 = i3 * d2;
            int round = (400 - 50) - Math.round((((float) d3) / i2) * ((400 - 50) - 20));
            createGraphics.setColor(Color.BLACK);
            createGraphics.drawString("" + d3, 10, round);
        }
        int i4 = 50;
        for (Integer num : generateHistogramBuckets.keySet()) {
            int intValue = generateHistogramBuckets.get(num).intValue();
            String buildBucketLabel = buildBucketLabel(num.intValue(), max, min);
            int round2 = Math.round((intValue / i2) * ((400 - 50) - 20));
            createGraphics.setColor(Color.BLUE);
            int i5 = (400 - 50) - round2;
            createGraphics.fillRect(i4, i5, 40, round2);
            createGraphics.setColor(Color.DARK_GRAY);
            createGraphics.drawRect(i4, i5, 40, round2);
            createGraphics.setColor(Color.BLACK);
            createGraphics.drawString(buildBucketLabel, i4 + ((40 / 2) - 10), round2 + 20 + i5);
            i4 += 40 + 10;
        }
        try {
            saveImageToDisk(bufferedImage, str);
        } catch (IOException e) {
            e.printStackTrace();
        }
        createGraphics.dispose();
    }

    public void renderFilters(DoubleMatrix doubleMatrix, String str, int i, int i2) throws Exception {
        int[] iArr = new int[doubleMatrix.getColumn(0).length];
        int i3 = doubleMatrix.columns;
        MatrixUtil.min(doubleMatrix);
        MatrixUtil.max(doubleMatrix);
        int i4 = i3 / 10;
        int i5 = (i + 2) * 10;
        int i6 = i4 * (i2 + 2);
        log.debug("Filter Width: " + i5);
        log.debug("Filter Height: " + i6);
        log.debug("Patch array size: " + iArr.length);
        this.img = new BufferedImage(i5, i6, 10);
        WritableRaster raster = this.img.getRaster();
        if (i * i2 != doubleMatrix.rows) {
            throw new Exception("patch size does not match filter patch size");
        }
        for (int i7 = 0; i7 < doubleMatrix.columns; i7++) {
            int i8 = (i7 % 10) * (i + 2);
            int i9 = (i7 / 10) * (i2 + 2);
            DoubleMatrix column = doubleMatrix.getColumn(i7);
            double min = MatrixUtil.min(column);
            double max = MatrixUtil.max(column);
            log.debug("rendering " + column.length + " pixels in column " + i7 + " for filter patch " + i + " x " + i2 + ", total size: " + (i * i2) + " at " + i8);
            for (int i10 = 0; i10 < column.length; i10++) {
                iArr[i10] = (int) (255.0d * ((column.get(i10) - max) / ((min - max) + 1.0E-6d)));
            }
            raster.setPixels(i8, i9, i, i2, iArr);
        }
        try {
            saveImageToDisk(this.img, str);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void renderActivations(int i, int i2, DoubleMatrix doubleMatrix, String str, int i3) {
        this.width = doubleMatrix.columns;
        this.height = doubleMatrix.rows;
        log.debug("----- renderActivations ------");
        this.img = new BufferedImage(this.width, this.height, 10);
        this.heightOffset = i;
        this.widthOffset = i2;
        WritableRaster raster = this.img.getRaster();
        int[] iArr = new int[doubleMatrix.length];
        double d = (0.1d * i3) - ((-0.1d) * i3);
        for (int i4 = 0; i4 < iArr.length; i4++) {
            iArr[i4] = (int) Math.round(doubleMatrix.get(i4) * 255.0d);
        }
        log.debug("activations size: Cols: " + doubleMatrix.columns + ", Rows: " + doubleMatrix.rows);
        raster.setPixels(0, 0, this.width, this.height, iArr);
        saveToDisk(str);
    }

    public static void saveImageToDisk(BufferedImage bufferedImage, String str) throws IOException {
        File file = new File(str);
        if (!file.exists()) {
            file.createNewFile();
        }
        ImageIO.write(bufferedImage, "png", file);
    }

    public void saveToDisk(String str) {
        try {
            saveImageToDisk(this.img, str);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void draw() {
        this.frame = new JFrame(this.title);
        this.frame.setVisible(true);
        start();
        this.frame.add(new JLabel(new ImageIcon(getImage())));
        this.frame.pack();
        this.frame.setDefaultCloseOperation(2);
    }

    public void close() {
        this.frame.dispose();
    }

    public Image getImage() {
        return this.img;
    }

    public void start() {
        int[] data = this.img.getRaster().getDataBuffer().getData();
        while (1 != 0) {
            BufferStrategy bufferStrategy = this.frame.getBufferStrategy();
            if (bufferStrategy == null) {
                this.frame.createBufferStrategy(4);
                return;
            }
            for (int i = 0; i < this.width * this.height; i++) {
                data[i] = 0;
            }
            Graphics drawGraphics = bufferStrategy.getDrawGraphics();
            drawGraphics.drawImage(this.img, this.heightOffset, this.widthOffset, this.width, this.height, (ImageObserver) null);
            drawGraphics.dispose();
            bufferStrategy.show();
        }
    }
}
