package org.marketcetera.tensorflow;

import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import org.marketcetera.module.DataRequest;
import org.marketcetera.tensorflow.model.TensorFlowRunner;
import org.marketcetera.util.log.SLF4JLoggerProxy;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

/* loaded from: input_file:org/marketcetera/tensorflow/ImageLabelTensorFlowRunner.class */
public class ImageLabelTensorFlowRunner implements TensorFlowRunner {
    private final byte[] graphDef;

    @Override // org.marketcetera.tensorflow.model.TensorFlowRunner
    public Object fetch(DataRequest dataRequest, Tensor tensor) {
        try {
            List<String> readAllLines = Files.readAllLines(Paths.get("src/test/sample_data", "imagenet_comp_graph_label_strings.txt"), Charset.forName("UTF-8"));
            float[] executeInceptionGraph = executeInceptionGraph(tensor);
            int maxIndex = maxIndex(executeInceptionGraph);
            String str = readAllLines.get(maxIndex);
            SLF4JLoggerProxy.trace(this, "BEST MATCH: {} ({} likely)", new Object[]{str, Float.valueOf(executeInceptionGraph[maxIndex] * 100.0f)});
            return str;
        } catch (IOException e) {
            SLF4JLoggerProxy.warn(this, e);
            throw new RuntimeException(e);
        }
    }

    public ImageLabelTensorFlowRunner(byte[] bArr) throws IOException {
        this.graphDef = bArr;
    }

    private int maxIndex(float[] fArr) {
        int i = 0;
        for (int i2 = 1; i2 < fArr.length; i2++) {
            if (fArr[i2] > fArr[i]) {
                i = i2;
            }
        }
        return i;
    }

    private float[] executeInceptionGraph(Tensor tensor) {
        Graph graph = new Graph();
        try {
            graph.importGraphDef(this.graphDef);
            Session session = new Session(graph);
            try {
                Tensor tensor2 = (Tensor) session.runner().feed("input", tensor).fetch("output").run().get(0);
                try {
                    long[] shape = tensor2.shape();
                    if (tensor2.numDimensions() != 2 || shape[0] != 1) {
                        throw new RuntimeException(String.format("Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", Arrays.toString(shape)));
                    }
                    float[] fArr = ((float[][]) tensor2.copyTo(new float[1][(int) shape[1]]))[0];
                    if (tensor2 != null) {
                        tensor2.close();
                    }
                    session.close();
                    graph.close();
                    return fArr;
                } catch (Throwable th) {
                    if (tensor2 != null) {
                        try {
                            tensor2.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (Throwable th3) {
            try {
                graph.close();
            } catch (Throwable th4) {
                th3.addSuppressed(th4);
            }
            throw th3;
        }
    }
}
