package org.datavec.image.recordreader;

import com.google.common.base.Preconditions;
import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.datavec.api.conf.Configuration;
import org.datavec.api.io.labels.PathLabelGenerator;
import org.datavec.api.io.labels.PathMultiLabelGenerator;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataURI;
import org.datavec.api.records.reader.BaseRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.util.files.FileFromPathIterator;
import org.datavec.api.util.files.URIUtil;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.loader.ImageLoader;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.transform.ImageTransform;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/datavec/image/recordreader/BaseImageRecordReader.class */
public abstract class BaseImageRecordReader extends BaseRecordReader {
    protected boolean finishedInputStreamSplit;
    protected Iterator<File> iter;
    protected Configuration conf;
    protected File currentFile;
    protected PathLabelGenerator labelGenerator;
    protected PathMultiLabelGenerator labelMultiGenerator;
    protected List<String> labels;
    protected boolean appendLabel;
    protected boolean writeLabel;
    protected List<Writable> record;
    protected boolean hitImage;
    protected long height;
    protected long width;
    protected long channels;
    protected boolean cropImage;
    protected ImageTransform imageTransform;
    protected BaseImageLoader imageLoader;
    protected InputSplit inputSplit;
    protected Map<String, String> fileNameMap;
    protected String pattern;
    protected int patternPosition;
    protected boolean logLabelCountOnInit;
    private static final Logger log = LoggerFactory.getLogger((Class<?>) BaseImageRecordReader.class);
    public static final String HEIGHT = NAME_SPACE + ".height";
    public static final String WIDTH = NAME_SPACE + ".width";
    public static final String CHANNELS = NAME_SPACE + ".channels";
    public static final String CROP_IMAGE = NAME_SPACE + ".cropimage";
    public static final String IMAGE_LOADER = NAME_SPACE + ".imageloader";

    public BaseImageRecordReader() {
        this.labelGenerator = null;
        this.labelMultiGenerator = null;
        this.labels = new ArrayList();
        this.appendLabel = false;
        this.writeLabel = false;
        this.hitImage = false;
        this.height = 28L;
        this.width = 28L;
        this.channels = 1L;
        this.cropImage = false;
        this.fileNameMap = new LinkedHashMap();
        this.patternPosition = 0;
        this.logLabelCountOnInit = true;
    }

    public BaseImageRecordReader(long j, long j2, long j3, PathLabelGenerator pathLabelGenerator) {
        this(j, j2, j3, pathLabelGenerator, null);
    }

    public BaseImageRecordReader(long j, long j2, long j3, PathMultiLabelGenerator pathMultiLabelGenerator) {
        this(j, j2, j3, null, pathMultiLabelGenerator, null);
    }

    public BaseImageRecordReader(long j, long j2, long j3, PathLabelGenerator pathLabelGenerator, ImageTransform imageTransform) {
        this(j, j2, j3, pathLabelGenerator, null, imageTransform);
    }

    protected BaseImageRecordReader(long j, long j2, long j3, PathLabelGenerator pathLabelGenerator, PathMultiLabelGenerator pathMultiLabelGenerator, ImageTransform imageTransform) {
        this.labelGenerator = null;
        this.labelMultiGenerator = null;
        this.labels = new ArrayList();
        this.appendLabel = false;
        this.writeLabel = false;
        this.hitImage = false;
        this.height = 28L;
        this.width = 28L;
        this.channels = 1L;
        this.cropImage = false;
        this.fileNameMap = new LinkedHashMap();
        this.patternPosition = 0;
        this.logLabelCountOnInit = true;
        this.height = j;
        this.width = j2;
        this.channels = j3;
        this.labelGenerator = pathLabelGenerator;
        this.labelMultiGenerator = pathMultiLabelGenerator;
        this.imageTransform = imageTransform;
        this.appendLabel = (pathLabelGenerator == null && pathMultiLabelGenerator == null) ? false : true;
    }

    protected boolean containsFormat(String str) {
        for (String str2 : this.imageLoader.getAllowedFormats()) {
            if (str.endsWith("." + str2)) {
                return true;
            }
        }
        return false;
    }

    @Override // org.datavec.api.records.reader.BaseRecordReader, org.datavec.api.records.reader.RecordReader
    public void initialize(InputSplit inputSplit) throws IOException {
        if (this.imageLoader == null) {
            this.imageLoader = new NativeImageLoader(this.height, this.width, this.channels, this.imageTransform);
        }
        if (inputSplit instanceof InputStreamInputSplit) {
            this.inputSplit = inputSplit;
            this.finishedInputStreamSplit = false;
            return;
        }
        this.inputSplit = inputSplit;
        URI[] locations = inputSplit.locations();
        if (locations == null || locations.length < 1) {
            throw new IllegalArgumentException("No path locations found in the split.");
        }
        if (this.appendLabel && this.labelGenerator != null && this.labelGenerator.inferLabelClasses()) {
            HashSet hashSet = new HashSet();
            for (URI uri : locations) {
                File file = new File(uri);
                String obj = this.labelGenerator.getLabelForPath(uri).toString();
                hashSet.add(obj);
                if (this.pattern != null) {
                    this.fileNameMap.put(file.toString(), obj.split(this.pattern)[this.patternPosition]);
                }
            }
            this.labels.clear();
            this.labels.addAll(hashSet);
            if (this.logLabelCountOnInit) {
                log.info("ImageRecordReader: {} label classes inferred using label generator {}", Integer.valueOf(hashSet.size()), this.labelGenerator.getClass().getSimpleName());
            }
        }
        this.iter = new FileFromPathIterator(this.inputSplit.locationsPathIterator());
        if (inputSplit instanceof FileSplit) {
            this.labels.remove(((FileSplit) inputSplit).getRootDir());
        }
        Collections.sort(this.labels);
    }

    @Override // org.datavec.api.records.reader.RecordReader
    public void initialize(Configuration configuration, InputSplit inputSplit) throws IOException, InterruptedException {
        this.appendLabel = configuration.getBoolean(APPEND_LABEL, this.appendLabel);
        this.labels = new ArrayList(configuration.getStringCollection(LABELS));
        this.height = configuration.getLong(HEIGHT, this.height);
        this.width = configuration.getLong(WIDTH, this.width);
        this.channels = configuration.getLong(CHANNELS, this.channels);
        this.cropImage = configuration.getBoolean(CROP_IMAGE, this.cropImage);
        if ("imageio".equals(configuration.get(IMAGE_LOADER))) {
            this.imageLoader = new ImageLoader(this.height, this.width, this.channels, this.cropImage);
        } else {
            this.imageLoader = new NativeImageLoader(this.height, this.width, this.channels, this.imageTransform);
        }
        this.conf = configuration;
        initialize(inputSplit);
    }

    public void initialize(InputSplit inputSplit, ImageTransform imageTransform) throws IOException {
        this.imageLoader = null;
        this.imageTransform = imageTransform;
        initialize(inputSplit);
    }

    public void initialize(Configuration configuration, InputSplit inputSplit, ImageTransform imageTransform) throws IOException, InterruptedException {
        this.imageLoader = null;
        this.imageTransform = imageTransform;
        initialize(configuration, inputSplit);
    }

    @Override // org.datavec.api.records.reader.RecordReader
    public List<Writable> next() {
        if (this.inputSplit instanceof InputStreamInputSplit) {
            try {
                NDArrayWritable nDArrayWritable = new NDArrayWritable(this.imageLoader.asMatrix(((InputStreamInputSplit) this.inputSplit).getIs()));
                this.finishedInputStreamSplit = true;
                return Arrays.asList(nDArrayWritable);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        if (this.iter == null) {
            if (this.record == null) {
                throw new IllegalStateException("No more elements");
            }
            this.hitImage = true;
            invokeListeners(this.record);
            return this.record;
        }
        File next = this.iter.next();
        this.currentFile = next;
        if (next.isDirectory()) {
            return next();
        }
        try {
            invokeListeners(next);
            INDArray asMatrix = this.imageLoader.asMatrix(next);
            Nd4j.getAffinityManager().ensureLocation(asMatrix, AffinityManager.Location.DEVICE);
            List<Writable> record = RecordConverter.toRecord(asMatrix);
            if (this.appendLabel || this.writeLabel) {
                if (this.labelMultiGenerator != null) {
                    record.addAll(this.labelMultiGenerator.getLabels(next.getPath()));
                } else if (this.labelGenerator.inferLabelClasses()) {
                    record.add(new IntWritable(this.labels.indexOf(getLabel(next.getPath()))));
                } else {
                    record.add(this.labelGenerator.getLabelForPath(next.getPath()));
                }
            }
            return record;
        } catch (Exception e2) {
            throw new RuntimeException(e2);
        }
    }

    @Override // org.datavec.api.records.reader.RecordReader
    public boolean hasNext() {
        if (this.inputSplit instanceof InputStreamInputSplit) {
            return this.finishedInputStreamSplit;
        }
        if (this.iter != null) {
            return this.iter.hasNext();
        }
        if (this.record != null) {
            return !this.hitImage;
        }
        throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
    }

    @Override // org.datavec.api.records.reader.BaseRecordReader, org.datavec.api.records.reader.RecordReader
    public boolean batchesSupported() {
        return this.imageLoader instanceof NativeImageLoader;
    }

    @Override // org.datavec.api.records.reader.BaseRecordReader, org.datavec.api.records.reader.RecordReader
    public List<List<Writable>> next(int i) {
        INDArray minibatchArray;
        Preconditions.checkArgument(i > 0, "Number of examples must be > 0: got " + i);
        if (this.imageLoader == null) {
            this.imageLoader = new NativeImageLoader(this.height, this.width, this.channels, this.imageTransform);
        }
        ArrayList arrayList = new ArrayList();
        int i2 = 0;
        int size = (this.appendLabel || this.writeLabel) ? this.labels.size() : 0;
        ArrayList arrayList2 = null;
        ArrayList arrayList3 = null;
        ArrayList arrayList4 = null;
        while (i2 < i && this.iter.hasNext()) {
            this.currentFile = this.iter.next();
            arrayList.add(this.currentFile);
            invokeListeners(this.currentFile);
            if (this.appendLabel || this.writeLabel) {
                if (this.labelMultiGenerator != null) {
                    if (arrayList4 == null) {
                        arrayList4 = new ArrayList();
                    }
                    arrayList4.add(this.labelMultiGenerator.getLabels(this.currentFile.getPath()));
                } else if (this.labelGenerator.inferLabelClasses()) {
                    if (arrayList2 == null) {
                        arrayList2 = new ArrayList();
                    }
                    arrayList2.add(Integer.valueOf(this.labels.indexOf(getLabel(this.currentFile.getPath()))));
                } else {
                    if (arrayList3 == null) {
                        arrayList3 = new ArrayList();
                    }
                    arrayList3.add(this.labelGenerator.getLabelForPath(this.currentFile.getPath()));
                }
            }
            i2++;
        }
        INDArray createUninitialized = Nd4j.createUninitialized(new long[]{i2, this.channels, this.height, this.width}, 'c');
        Nd4j.getAffinityManager().tagLocation(createUninitialized, AffinityManager.Location.HOST);
        for (int i3 = 0; i3 < i2; i3++) {
            try {
                ((NativeImageLoader) this.imageLoader).asMatrixView((File) arrayList.get(i3), createUninitialized.tensorAlongDimension(i3, 1, 2, 3));
            } catch (Exception e) {
                System.out.println("Image file failed during load: " + ((File) arrayList.get(i3)).getAbsolutePath());
                throw new RuntimeException(e);
            }
        }
        Nd4j.getAffinityManager().ensureLocation(createUninitialized, AffinityManager.Location.DEVICE);
        ArrayList arrayList5 = new ArrayList();
        arrayList5.add(createUninitialized);
        if (this.appendLabel || this.writeLabel) {
            if (this.labelMultiGenerator != null) {
                ArrayList arrayList6 = new ArrayList();
                List list = (List) arrayList4.get(0);
                for (int i4 = 0; i4 < list.size(); i4++) {
                    arrayList6.clear();
                    Iterator it2 = arrayList4.iterator();
                    while (it2.hasNext()) {
                        arrayList6.add(((List) it2.next()).get(i4));
                    }
                    arrayList5.add(RecordConverter.toMinibatchArray(arrayList6));
                }
            } else {
                if (this.labelGenerator.inferLabelClasses()) {
                    minibatchArray = Nd4j.create(i2, size, 'c');
                    Nd4j.getAffinityManager().tagLocation(minibatchArray, AffinityManager.Location.HOST);
                    for (int i5 = 0; i5 < arrayList2.size(); i5++) {
                        minibatchArray.putScalar(i5, ((Integer) arrayList2.get(i5)).intValue(), 1.0d);
                    }
                } else if (arrayList3.get(0) instanceof NDArrayWritable) {
                    ArrayList arrayList7 = new ArrayList();
                    Iterator it3 = arrayList3.iterator();
                    while (it3.hasNext()) {
                        arrayList7.add(((NDArrayWritable) ((Writable) it3.next())).get());
                    }
                    minibatchArray = Nd4j.concat(0, (INDArray[]) arrayList7.toArray(new INDArray[arrayList7.size()]));
                } else {
                    minibatchArray = RecordConverter.toMinibatchArray(arrayList3);
                }
                arrayList5.add(minibatchArray);
            }
        }
        return new NDArrayRecordBatch(arrayList5);
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() throws IOException {
    }

    @Override // org.datavec.api.conf.Configurable
    public void setConf(Configuration configuration) {
        this.conf = configuration;
    }

    @Override // org.datavec.api.conf.Configurable
    public Configuration getConf() {
        return this.conf;
    }

    public String getLabel(String str) {
        return this.labelGenerator != null ? this.labelGenerator.getLabelForPath(str).toString() : (this.fileNameMap == null || !this.fileNameMap.containsKey(str)) ? new File(str).getParentFile().getName() : this.fileNameMap.get(str);
    }

    protected void accumulateLabel(String str) {
        String label = getLabel(str);
        if (this.labels.contains(label)) {
            return;
        }
        this.labels.add(label);
    }

    public File getCurrentFile() {
        return this.currentFile;
    }

    public void setCurrentFile(File file) {
        this.currentFile = file;
    }

    @Override // org.datavec.api.records.reader.RecordReader
    public List<String> getLabels() {
        return this.labels;
    }

    public void setLabels(List<String> list) {
        this.labels = list;
        this.writeLabel = true;
    }

    @Override // org.datavec.api.records.reader.RecordReader
    public void reset() {
        if (this.inputSplit == null) {
            throw new UnsupportedOperationException("Cannot reset without first initializing");
        }
        this.inputSplit.reset();
        if (this.iter != null) {
            this.iter = new FileFromPathIterator(this.inputSplit.locationsPathIterator());
        } else if (this.record != null) {
            this.hitImage = false;
        }
    }

    @Override // org.datavec.api.records.reader.RecordReader
    public boolean resetSupported() {
        if (this.inputSplit == null) {
            return false;
        }
        return this.inputSplit.resetSupported();
    }

    public int numLabels() {
        return this.labels.size();
    }

    @Override // org.datavec.api.records.reader.RecordReader
    public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException {
        invokeListeners(uri);
        if (this.imageLoader == null) {
            this.imageLoader = new NativeImageLoader(this.height, this.width, this.channels, this.imageTransform);
        }
        List<Writable> record = RecordConverter.toRecord(this.imageLoader.asMatrix(dataInputStream));
        if (this.appendLabel) {
            record.add(new IntWritable(this.labels.indexOf(getLabel(uri.getPath()))));
        }
        return record;
    }

    @Override // org.datavec.api.records.reader.RecordReader
    public Record nextRecord() {
        return new org.datavec.api.records.impl.Record(next(), new RecordMetaDataURI(URIUtil.fileToURI(this.currentFile), BaseImageRecordReader.class));
    }

    @Override // org.datavec.api.records.reader.RecordReader
    public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
        return loadFromMetaData(Collections.singletonList(recordMetaData)).get(0);
    }

    @Override // org.datavec.api.records.reader.RecordReader
    public List<Record> loadFromMetaData(List<RecordMetaData> list) throws IOException {
        ArrayList arrayList = new ArrayList();
        for (RecordMetaData recordMetaData : list) {
            URI uri = recordMetaData.getURI();
            DataInputStream dataInputStream = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(uri))));
            Throwable th = null;
            try {
                try {
                    List<Writable> record = record(uri, dataInputStream);
                    if (dataInputStream != null) {
                        if (0 != 0) {
                            try {
                                dataInputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            dataInputStream.close();
                        }
                    }
                    arrayList.add(new org.datavec.api.records.impl.Record(record, recordMetaData));
                } finally {
                }
            } catch (Throwable th3) {
                if (dataInputStream != null) {
                    if (th != null) {
                        try {
                            dataInputStream.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        dataInputStream.close();
                    }
                }
                throw th3;
            }
        }
        return arrayList;
    }

    public boolean isLogLabelCountOnInit() {
        return this.logLabelCountOnInit;
    }

    public void setLogLabelCountOnInit(boolean z) {
        this.logLabelCountOnInit = z;
    }
}
