package org.flinkextended.flink.ml.tensorflow.util;

import com.google.common.base.Preconditions;
import com.google.common.io.Files;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.flinkextended.flink.ml.cluster.MLConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;

/* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/util/JavaInference.class */
public class JavaInference implements Closeable {
    private static final Logger LOG = LoggerFactory.getLogger(JavaInference.class);
    private static final String SPLITTER = ",";
    private static final String TAG = "serve";
    private final String[] inRowFieldNames;
    private final SavedModelBundle model;
    private final SignatureDef modelSig;
    private final Set<String> inputTensorNameSet;
    private final String[] outputTensorNames;
    private final String[] outputFieldNames;
    private File downloadModelPath;

    public JavaInference(MLConfig mLConfig, TableSchema tableSchema, TableSchema tableSchema2) throws Exception {
        this((Map<String, String>) mLConfig.getProperties(), tableSchema.getFieldNames(), tableSchema2.getFieldNames());
    }

    public JavaInference(Map<String, String> map, String[] strArr, String[] strArr2) throws Exception {
        this.inRowFieldNames = strArr;
        String requireConfig = requireConfig(map, "tf.inference.export.path");
        Path path = new Path(requireConfig);
        String scheme = path.toUri().getScheme();
        ConfigProto build = ConfigProto.newBuilder().setAllowSoftPlacement(true).build();
        if (StringUtils.isEmpty(scheme) || scheme.equals("file")) {
            this.model = SavedModelBundle.loader(requireConfig).withConfigProto(build).withTags(new String[]{TAG}).load();
        } else {
            if (!scheme.equals("hdfs")) {
                throw new IllegalArgumentException("Model URI not supported: " + requireConfig);
            }
            FileSystem fileSystem = path.getFileSystem(new Configuration());
            this.downloadModelPath = Files.createTempDir();
            Path path2 = new Path(this.downloadModelPath.getPath(), path.getName());
            LOG.info("Downloading model from {} to {}", path, path2);
            fileSystem.copyToLocalFile(path, path2);
            this.model = SavedModelBundle.loader(path2.toString()).withConfigProto(build).withTags(new String[]{TAG}).load();
        }
        this.modelSig = MetaGraphDef.parseFrom(this.model.metaGraphDef().toByteArray()).getSignatureDefOrThrow("serving_default");
        logSignature();
        String[] split = requireConfig(map, "tf.inference.input.tensor.names").split(SPLITTER);
        Preconditions.checkArgument(this.modelSig.getInputsMap().keySet().containsAll(Arrays.asList(split)) && Arrays.asList(strArr).containsAll(Arrays.asList(split)), "Invalid input tensor names: " + Arrays.toString(split));
        this.inputTensorNameSet = new HashSet(Arrays.asList(split));
        this.outputTensorNames = requireConfig(map, "tf.inference.output.tensor.names").split(SPLITTER);
        Preconditions.checkArgument(this.modelSig.getOutputsMap().keySet().containsAll(Arrays.asList(this.outputTensorNames)), "Invalid output tensor names: " + Arrays.toString(this.outputTensorNames));
        this.outputFieldNames = requireConfig(map, "tf.inference.output.row.fields").split(SPLITTER);
        Preconditions.checkArgument(this.outputFieldNames.length == strArr2.length, "Output fields length mismatch");
        for (String str : this.outputFieldNames) {
            Preconditions.checkArgument(Arrays.asList(this.outputTensorNames).contains(str) || Arrays.asList(strArr).contains(str), "Unknown output field name: " + str);
        }
    }

    public Row[] generateRowsOneBatch(List<Object[]> list, int i) {
        int min = Math.min(i, list.size());
        return min <= 0 ? new Row[0] : generateRows(list, min);
    }

    private Row[] generateRows(List<Object[]> list, int i) {
        Row[] rowArr = new Row[i];
        HashMap hashMap = new HashMap(this.inRowFieldNames.length);
        for (int i2 = 0; i2 < this.inRowFieldNames.length; i2++) {
            hashMap.put(this.inRowFieldNames[i2], extractCols(list, i2, i));
        }
        ArrayList arrayList = new ArrayList(this.inputTensorNameSet.size() + this.outputTensorNames.length);
        try {
            Session.Runner runner = this.model.session().runner();
            for (int i3 = 0; i3 < this.inRowFieldNames.length; i3++) {
                if (this.inputTensorNameSet.contains(this.inRowFieldNames[i3])) {
                    TensorInfo tensorInfo = (TensorInfo) this.modelSig.getInputsMap().get(this.inRowFieldNames[i3]);
                    Tensor tensor = TFTensorConversion.toTensor((Object[]) hashMap.get(this.inRowFieldNames[i3]), tensorInfo);
                    arrayList.add(tensor);
                    runner.feed(tensorInfo.getName(), tensor);
                }
            }
            for (String str : this.outputTensorNames) {
                runner.fetch(((TensorInfo) this.modelSig.getOutputsMap().get(str)).getName());
            }
            List run = runner.run();
            arrayList.addAll(run);
            HashMap hashMap2 = new HashMap();
            for (int i4 = 0; i4 < this.outputTensorNames.length; i4++) {
                hashMap2.put(this.outputTensorNames[i4], run.get(i4));
            }
            for (int i5 = 0; i5 < this.outputFieldNames.length; i5++) {
                Object[] fromTensor = hashMap2.containsKey(this.outputFieldNames[i5]) ? TFTensorConversion.fromTensor((Tensor) hashMap2.get(this.outputFieldNames[i5])) : (Object[]) hashMap.get(this.outputFieldNames[i5]);
                for (int i6 = 0; i6 < rowArr.length; i6++) {
                    if (rowArr[i6] == null) {
                        rowArr[i6] = new Row(this.outputFieldNames.length);
                    }
                    rowArr[i6].setField(i5, fromTensor[i6]);
                }
            }
            return rowArr;
        } finally {
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                ((Tensor) it.next()).close();
            }
        }
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() throws IOException {
        if (this.model != null) {
            this.model.close();
            LOG.info("Model closed");
        }
        if (this.downloadModelPath != null) {
            FileUtils.deleteQuietly(this.downloadModelPath);
        }
    }

    private String requireConfig(Map<String, String> map, String str) {
        String str2 = map.get(str);
        Preconditions.checkArgument(!StringUtils.isEmpty(str2), "Need to specify proper " + str);
        return str2;
    }

    private void logSignature() {
        int inputsCount = this.modelSig.getInputsCount();
        StringBuilder sb = new StringBuilder();
        int i = 1;
        sb.append("\nMODEL SIGNATURE\n");
        sb.append("Inputs:\n");
        for (Map.Entry entry : this.modelSig.getInputsMap().entrySet()) {
            TensorInfo tensorInfo = (TensorInfo) entry.getValue();
            int i2 = i;
            i++;
            sb.append(String.format("%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n", Integer.valueOf(i2), Integer.valueOf(inputsCount), entry.getKey(), tensorInfo.getName(), tensorInfo.getDtype()));
        }
        int outputsCount = this.modelSig.getOutputsCount();
        int i3 = 1;
        sb.append("Outputs:\n");
        for (Map.Entry entry2 : this.modelSig.getOutputsMap().entrySet()) {
            TensorInfo tensorInfo2 = (TensorInfo) entry2.getValue();
            int i4 = i3;
            i3++;
            sb.append(String.format("%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n", Integer.valueOf(i4), Integer.valueOf(outputsCount), entry2.getKey(), tensorInfo2.getName(), tensorInfo2.getDtype()));
        }
        sb.append("-----------------------------------------------");
        LOG.info(sb.toString());
    }

    private Object[] extractCols(List<Object[]> list, int i, int i2) {
        Object[] objArr = new Object[i2];
        for (int i3 = 0; i3 < objArr.length; i3++) {
            objArr[i3] = list.get(i3)[i];
        }
        return objArr;
    }
}
