package org.datavec.image.loader;

import com.clearspring.analytics.stream.frequency.CountMinSketch;
import com.github.jaiimageio.impl.plugins.tiff.TIFFImageReaderSpi;
import com.github.jaiimageio.impl.plugins.tiff.TIFFImageWriterSpi;
import com.twelvemonkeys.imageio.plugins.bmp.BMPImageReaderSpi;
import com.twelvemonkeys.imageio.plugins.bmp.CURImageReaderSpi;
import com.twelvemonkeys.imageio.plugins.bmp.ICOImageReaderSpi;
import com.twelvemonkeys.imageio.plugins.jpeg.JPEGImageReaderSpi;
import com.twelvemonkeys.imageio.plugins.jpeg.JPEGImageWriterSpi;
import com.twelvemonkeys.imageio.plugins.psd.PSDImageReaderSpi;
import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.awt.image.ImageObserver;
import java.awt.image.WritableRaster;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import javax.imageio.ImageIO;
import javax.imageio.spi.IIORegistry;
import org.datavec.image.data.Image;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.NDArrayUtil;

/* loaded from: input_file:org/datavec/image/loader/ImageLoader.class */
public class ImageLoader extends BaseImageLoader {
    public ImageLoader() {
    }

    public ImageLoader(long j, long j2) {
        this.height = j;
        this.width = j2;
    }

    public ImageLoader(long j, long j2, long j3) {
        this.height = j;
        this.width = j2;
        this.channels = j3;
    }

    public ImageLoader(long j, long j2, long j3, boolean z) {
        this(j, j2, j3);
        this.centerCropIfNeeded = z;
    }

    @Override // org.datavec.image.loader.BaseImageLoader
    public INDArray asRowVector(File file) throws IOException {
        return asRowVector(ImageIO.read(file));
    }

    @Override // org.datavec.image.loader.BaseImageLoader
    public INDArray asRowVector(InputStream inputStream) throws IOException {
        return asRowVector(ImageIO.read(inputStream));
    }

    public INDArray asRowVector(BufferedImage bufferedImage) {
        if (this.centerCropIfNeeded) {
            bufferedImage = centerCropIfNeeded(bufferedImage);
        }
        BufferedImage scalingIfNeed = scalingIfNeed(bufferedImage, true);
        return this.channels == 3 ? toINDArrayBGR(scalingIfNeed).ravel() : NDArrayUtil.toNDArray(ArrayUtil.flatten(toIntArrayArray(scalingIfNeed)));
    }

    public INDArray toRaveledTensor(File file) {
        try {
            BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(file));
            INDArray raveledTensor = toRaveledTensor(bufferedInputStream);
            bufferedInputStream.close();
            return raveledTensor.ravel();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public INDArray toRaveledTensor(InputStream inputStream) {
        return toBgr(inputStream).ravel();
    }

    public INDArray toRaveledTensor(BufferedImage bufferedImage) {
        try {
            return toINDArrayBGR(scalingIfNeed(bufferedImage, false)).ravel();
        } catch (Exception e) {
            throw new RuntimeException("Unable to load image", e);
        }
    }

    public INDArray toBgr(File file) {
        try {
            BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(file));
            INDArray bgr = toBgr(bufferedInputStream);
            bufferedInputStream.close();
            return bgr;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public INDArray toBgr(InputStream inputStream) {
        try {
            return toBgr(ImageIO.read(inputStream));
        } catch (IOException e) {
            throw new RuntimeException("Unable to load image", e);
        }
    }

    private Image toBgrImage(InputStream inputStream) {
        try {
            BufferedImage read = ImageIO.read(inputStream);
            return new Image(toBgr(read), read.getData().getNumBands(), read.getHeight(), read.getWidth());
        } catch (IOException e) {
            throw new RuntimeException("Unable to load image", e);
        }
    }

    public INDArray toBgr(BufferedImage bufferedImage) {
        if (bufferedImage == null) {
            throw new IllegalStateException("Unable to load image");
        }
        return toINDArrayBGR(scalingIfNeed(bufferedImage, false));
    }

    @Override // org.datavec.image.loader.BaseImageLoader
    public INDArray asMatrix(File file) throws IOException {
        return NDArrayUtil.toNDArray(fromFile(file));
    }

    @Override // org.datavec.image.loader.BaseImageLoader
    public INDArray asMatrix(InputStream inputStream) throws IOException {
        if (this.channels == 3) {
            return toBgr(inputStream);
        }
        try {
            return asMatrix(ImageIO.read(inputStream));
        } catch (IOException e) {
            throw new IOException("Unable to load image", e);
        }
    }

    @Override // org.datavec.image.loader.BaseImageLoader
    public Image asImageMatrix(File file) throws IOException {
        BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(file));
        Throwable th = null;
        try {
            try {
                Image asImageMatrix = asImageMatrix(bufferedInputStream);
                if (bufferedInputStream != null) {
                    if (0 != 0) {
                        try {
                            bufferedInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        bufferedInputStream.close();
                    }
                }
                return asImageMatrix;
            } finally {
            }
        } catch (Throwable th3) {
            if (bufferedInputStream != null) {
                if (th != null) {
                    try {
                        bufferedInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedInputStream.close();
                }
            }
            throw th3;
        }
    }

    @Override // org.datavec.image.loader.BaseImageLoader
    public Image asImageMatrix(InputStream inputStream) throws IOException {
        if (this.channels == 3) {
            return toBgrImage(inputStream);
        }
        try {
            BufferedImage read = ImageIO.read(inputStream);
            return new Image(asMatrix(read), read.getData().getNumBands(), read.getHeight(), read.getWidth());
        } catch (IOException e) {
            throw new IOException("Unable to load image", e);
        }
    }

    public INDArray asMatrix(BufferedImage bufferedImage) {
        if (this.channels == 3) {
            return toBgr(bufferedImage);
        }
        BufferedImage scalingIfNeed = scalingIfNeed(bufferedImage, true);
        int width = scalingIfNeed.getWidth();
        int height = scalingIfNeed.getHeight();
        INDArray create = Nd4j.create(height, width);
        for (int i = 0; i < height; i++) {
            for (int i2 = 0; i2 < width; i2++) {
                create.putScalar(new int[]{i, i2}, scalingIfNeed.getRGB(i2, i));
            }
        }
        return create;
    }

    public INDArray asImageMiniBatches(File file, int i, int i2) {
        try {
            return Nd4j.create(i, i2, asMatrix(file).columns());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public int[] flattenedImageFromFile(File file) throws IOException {
        return ArrayUtil.flatten(fromFile(file));
    }

    public int[][] fromFile(File file) throws IOException {
        return toIntArrayArray(scalingIfNeed(ImageIO.read(file), true));
    }

    public int[][][] fromFileMultipleChannels(File file) throws IOException {
        BufferedImage scalingIfNeed = scalingIfNeed(ImageIO.read(file), this.channels > 3);
        int width = scalingIfNeed.getWidth();
        int height = scalingIfNeed.getHeight();
        int numBands = scalingIfNeed.getSampleModel().getNumBands();
        int[][][] iArr = new int[(int) Math.min(this.channels, CountMinSketch.PRIME_MODULUS)][Math.min(height, Integer.MAX_VALUE)][Math.min(width, Integer.MAX_VALUE)];
        byte[] data = scalingIfNeed.getRaster().getDataBuffer().getData();
        for (int i = 0; i < height; i++) {
            for (int i2 = 0; i2 < width; i2++) {
                for (int i3 = 0; i3 < this.channels && i3 < numBands; i3++) {
                    iArr[i3][i][i2] = data[(int) Math.min((this.channels * width * i) + (this.channels * i2) + i3, CountMinSketch.PRIME_MODULUS)];
                }
            }
        }
        return iArr;
    }

    public static BufferedImage toImage(INDArray iNDArray) {
        BufferedImage bufferedImage = new BufferedImage(iNDArray.rows(), iNDArray.columns(), 2);
        WritableRaster raster = bufferedImage.getRaster();
        int[] iArr = new int[(int) iNDArray.length()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = (int) iNDArray.getDouble(i);
        }
        raster.setDataElements(0, 0, iNDArray.rows(), iNDArray.columns(), iArr);
        return bufferedImage;
    }

    private static int[] rasterData(INDArray iNDArray) {
        int[] iArr = new int[(int) iNDArray.length()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = (int) Math.round(((Double) iNDArray.getScalar(i).element()).doubleValue());
        }
        return iArr;
    }

    public void toBufferedImageRGB(INDArray iNDArray, BufferedImage bufferedImage) {
        if (iNDArray.rank() < 3) {
            throw new IllegalArgumentException("Arr must be 3d");
        }
        BufferedImage scalingIfNeed = scalingIfNeed(bufferedImage, iNDArray.size(-2), iNDArray.size(-1), bufferedImage.getType(), true);
        for (int i = 0; i < scalingIfNeed.getHeight(); i++) {
            for (int i2 = 0; i2 < scalingIfNeed.getWidth(); i2++) {
                scalingIfNeed.setRGB(i2, i, (1 << 24) | (iNDArray.slice(2L).getInt(i, i2) << 16) | (iNDArray.slice(1L).getInt(i, i2) << 8) | iNDArray.slice(0L).getInt(i, i2));
            }
        }
    }

    public static BufferedImage toBufferedImage(java.awt.Image image, int i) {
        if ((image instanceof BufferedImage) && ((BufferedImage) image).getType() == i) {
            return (BufferedImage) image;
        }
        BufferedImage bufferedImage = new BufferedImage(image.getWidth((ImageObserver) null), image.getHeight((ImageObserver) null), i);
        Graphics2D createGraphics = bufferedImage.createGraphics();
        createGraphics.drawImage(image, 0, 0, (ImageObserver) null);
        createGraphics.dispose();
        return bufferedImage;
    }

    protected int[][] toIntArrayArray(BufferedImage bufferedImage) {
        int width = bufferedImage.getWidth();
        int height = bufferedImage.getHeight();
        int[][] iArr = new int[height][width];
        if (bufferedImage.getRaster().getNumDataElements() == 1) {
            WritableRaster raster = bufferedImage.getRaster();
            for (int i = 0; i < height; i++) {
                for (int i2 = 0; i2 < width; i2++) {
                    iArr[i][i2] = raster.getSample(i2, i, 0);
                }
            }
        } else {
            for (int i3 = 0; i3 < height; i3++) {
                for (int i4 = 0; i4 < width; i4++) {
                    iArr[i3][i4] = bufferedImage.getRGB(i4, i3);
                }
            }
        }
        return iArr;
    }

    protected INDArray toINDArrayBGR(BufferedImage bufferedImage) {
        int height = bufferedImage.getHeight();
        int width = bufferedImage.getWidth();
        int numBands = bufferedImage.getSampleModel().getNumBands();
        byte[] data = bufferedImage.getRaster().getDataBuffer().getData();
        int[] iArr = {height, width, numBands};
        INDArray create = Nd4j.create(1, data.length);
        for (int i = 0; i < create.length(); i++) {
            create.putScalar(i, data[i] & 255);
        }
        return create.reshape(iArr).permute(2, 0, 1);
    }

    public BufferedImage centerCropIfNeeded(BufferedImage bufferedImage) {
        int i = 0;
        int i2 = 0;
        int height = bufferedImage.getHeight();
        int width = bufferedImage.getWidth();
        int abs = Math.abs(width - height) / 2;
        if (width > height) {
            i = abs;
            width -= abs;
        } else if (height > width) {
            i2 = abs;
            height -= abs;
        }
        return bufferedImage.getSubimage(i, i2, width, height);
    }

    protected BufferedImage scalingIfNeed(BufferedImage bufferedImage, boolean z) {
        return scalingIfNeed(bufferedImage, this.height, this.width, this.channels, z);
    }

    protected BufferedImage scalingIfNeed(BufferedImage bufferedImage, long j, long j2, long j3, boolean z) {
        BufferedImage scaledInstance = (j <= 0 || j2 <= 0 || (((long) bufferedImage.getHeight()) == j && ((long) bufferedImage.getWidth()) == j2)) ? bufferedImage : bufferedImage.getScaledInstance((int) j2, (int) j, 4);
        return ((scaledInstance instanceof BufferedImage) && ((long) scaledInstance.getType()) == j3) ? scaledInstance : (z && bufferedImage.getColorModel().hasAlpha() && j3 == 6) ? toBufferedImage(scaledInstance, 6) : j3 == 10 ? toBufferedImage(scaledInstance, 10) : toBufferedImage(scaledInstance, 5);
    }

    static {
        ImageIO.scanForPlugins();
        IIORegistry defaultInstance = IIORegistry.getDefaultInstance();
        defaultInstance.registerServiceProvider(new TIFFImageWriterSpi());
        defaultInstance.registerServiceProvider(new TIFFImageReaderSpi());
        defaultInstance.registerServiceProvider(new JPEGImageReaderSpi());
        defaultInstance.registerServiceProvider(new JPEGImageWriterSpi());
        defaultInstance.registerServiceProvider(new PSDImageReaderSpi());
        defaultInstance.registerServiceProvider(Arrays.asList(new BMPImageReaderSpi(), new CURImageReaderSpi(), new ICOImageReaderSpi()));
    }
}
