package org.flinkextended.flink.ml.lib.tensorflow;

import java.io.File;
import java.util.ArrayList;
import java.util.Properties;
import org.apache.flink.types.Row;
import org.flinkextended.flink.ml.lib.tensorflow.util.ShellExec;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.tensorflow.proto.framework.DataType;

/* loaded from: input_file:org/flinkextended/flink/ml/lib/tensorflow/TFInferenceTest.class */
public class TFInferenceTest {
    @Before
    public void setUp() throws Exception {
        String str = (getClass().getClassLoader().getResource("").getPath() + "../../src/test/python/") + "add_saved_model.py";
        if (new File(getClass().getClassLoader().getResource("").getPath() + "export").exists()) {
            return;
        }
        Assert.assertTrue(ShellExec.run("python " + str));
    }

    @After
    public void tearDown() throws Exception {
    }

    @Test
    public void inferenceTest() throws Exception {
        TFInference tFInference = new TFInference("file://" + getClass().getClassLoader().getResource("").getPath() + "export", new String[]{"a", "b"}, new DataType[]{DataType.DT_FLOAT, DataType.DT_FLOAT}, new int[]{0, 0}, new String[]{"d"}, new DataType[]{DataType.DT_FLOAT}, new int[]{0}, new Properties());
        ArrayList arrayList = new ArrayList();
        for (int i = 1; i < 4; i++) {
            arrayList.add(new Object[]{Float.valueOf(1.0f * i), Float.valueOf(2.0f * i)});
        }
        for (Row row : tFInference.inference(arrayList)) {
            System.out.println(row);
        }
        tFInference.close();
    }
}
