package org.flinkextended.flink.ml.examples.tensorflow.ut;

import com.google.common.base.Joiner;
import java.io.File;
import java.util.HashMap;
import java.util.Map;
import java.util.StringJoiner;
import java.util.concurrent.ThreadLocalRandom;
import org.apache.curator.test.TestingServer;
import org.apache.flink.core.execution.JobClient;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.StatementSet;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableDescriptor;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.api.Types;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.flinkextended.flink.ml.examples.tensorflow.mnist.MnistDataUtil;
import org.flinkextended.flink.ml.examples.tensorflow.mnist.MnistJavaInference;
import org.flinkextended.flink.ml.examples.tensorflow.mnist.ops.MnistTFRExtractRowForJavaFunction;
import org.flinkextended.flink.ml.examples.tensorflow.mnist.ops.MnistTFRPojo;
import org.flinkextended.flink.ml.examples.tensorflow.ops.MnistTFRExtractPojoMapOp;
import org.flinkextended.flink.ml.operator.ops.sink.LogSink;
import org.flinkextended.flink.ml.operator.util.DataTypes;
import org.flinkextended.flink.ml.operator.util.FlinkUtil;
import org.flinkextended.flink.ml.operator.util.TypeUtil;
import org.flinkextended.flink.ml.tensorflow.client.TFConfig;
import org.flinkextended.flink.ml.tensorflow.client.TFUtils;
import org.flinkextended.flink.ml.tensorflow.coding.ExampleCoding;
import org.flinkextended.flink.ml.tensorflow.coding.ExampleCodingConfig;
import org.flinkextended.flink.ml.tensorflow.io.TFRToRowSourceFunc;
import org.flinkextended.flink.ml.tensorflow.io.TFRToRowTableSourceFactory;
import org.flinkextended.flink.ml.tensorflow.io.TFRecordSource;
import org.flinkextended.flink.ml.util.IpHostUtil;
import org.flinkextended.flink.ml.util.SysUtil;
import org.flinkextended.flink.ml.util.TestUtil;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/flinkextended/flink/ml/examples/tensorflow/ut/TFMnistInferenceTest.class */
public class TFMnistInferenceTest {
    private static TestingServer server;
    private static final String mnist_dist_with_input = "mnist_dist_with_input.py";
    private static final String mnist_inference_with_input = "mnist_table_inference.py";
    private static String version = "0";
    private static String rootPath = TestUtil.getProjectRootPath() + "/dl-on-flink-examples";
    private static final String checkpointPath = rootPath + "/target/ckpt/" + version;
    public static final String exportPath = rootPath + "/target/export/" + version;
    public static final String testDataPath = rootPath + "/target/data/test";

    @Before
    public void setUp() throws Exception {
        MnistDataUtil.prepareData();
        server = new TestingServer(IpHostUtil.getFreePort(), true);
        generateModelIfNeeded();
    }

    @After
    public void tearDown() throws Exception {
        if (server != null) {
            server.close();
            server = null;
        }
    }

    private static TFConfig buildTFConfig(String str) {
        String str2 = rootPath + "/src/test/python/" + str;
        System.out.println("Current version:" + version);
        HashMap hashMap = new HashMap();
        hashMap.put("batch_size", "32");
        hashMap.put("input", rootPath + "/target/data/train/");
        hashMap.put("epochs", "1");
        hashMap.put("checkpoint_dir", checkpointPath);
        hashMap.put("export_dir", exportPath);
        hashMap.put("zookeeper_connect_str", server.getConnectString());
        return new TFConfig(4, 1, hashMap, str2, "map_fun", (String) null);
    }

    public static void generateModelIfNeeded() throws Exception {
        if (new File(exportPath).exists()) {
            return;
        }
        boolean z = server == null;
        if (z) {
            server = new TestingServer(IpHostUtil.getFreePort(), true);
        }
        dataStreamHaveInput();
        if (z) {
            server.close();
            server = null;
        }
    }

    private static void dataStreamHaveInput() throws Exception {
        TFConfig buildTFConfig = buildTFConfig(mnist_dist_with_input);
        TFMnistTest.setExampleCodingType(buildTFConfig);
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        String absolutePath = new File("").getAbsolutePath();
        String[] strArr = {absolutePath + "/target/data/train/0.tfrecords", absolutePath + "/target/data/train/1.tfrecords"};
        DataStreamSource parallelism = executionEnvironment.addSource(TFRecordSource.createSource(strArr, 1)).setParallelism(strArr.length);
        TFUtils.train(executionEnvironment, parallelism.flatMap(new MnistTFRExtractPojoMapOp()).setParallelism(parallelism.getParallelism()), buildTFConfig);
        executionEnvironment.execute();
    }

    protected static void setExampleCodingTypeWithPojoOut(TFConfig tFConfig) {
        tFConfig.getProperties().put("sys:input_tf_example_config", ExampleCodingConfig.createExampleConfigStr(new String[]{"image_raw", "label"}, new DataTypes[]{DataTypes.STRING, DataTypes.INT_32}, ExampleCodingConfig.ObjectType.ROW, MnistTFRPojo.class));
        tFConfig.getProperties().put("sys:output_tf_example_config", ExampleCodingConfig.createExampleConfigStr(new String[]{"predict_label", "label_org"}, new DataTypes[]{DataTypes.INT_32, DataTypes.INT_32}, ExampleCodingConfig.ObjectType.POJO, InferenceOutPojo.class));
        tFConfig.getProperties().put("sys:encoding_class", ExampleCoding.class.getCanonicalName());
        tFConfig.getProperties().put("sys:decoding_class", ExampleCoding.class.getCanonicalName());
    }

    protected static void setExampleCodingTypeWithRowOut(TFConfig tFConfig) {
        tFConfig.getProperties().put("sys:input_tf_example_config", ExampleCodingConfig.createExampleConfigStr(new String[]{"image_raw", "label"}, new DataTypes[]{DataTypes.STRING, DataTypes.INT_32}, ExampleCodingConfig.ObjectType.ROW, MnistTFRPojo.class));
        tFConfig.getProperties().put("sys:output_tf_example_config", ExampleCodingConfig.createExampleConfigStr(new String[]{"predict_label", "label_org"}, new DataTypes[]{DataTypes.INT_32, DataTypes.INT_32}, ExampleCodingConfig.ObjectType.ROW, InferenceOutPojo.class));
        tFConfig.getProperties().put("sys:encoding_class", ExampleCoding.class.getCanonicalName());
        tFConfig.getProperties().put("sys:decoding_class", ExampleCoding.class.getCanonicalName());
    }

    @Test
    public void inferenceDataStreamWithInput() throws Exception {
        System.out.println("Run Test: " + SysUtil._FUNC_());
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        String[] strArr = {testDataPath + "/0.tfrecords", testDataPath + "/1.tfrecords"};
        DataStreamSource parallelism = executionEnvironment.addSource(new TFRToRowSourceFunc(strArr, 1, MnistJavaInference.OUT_ROW_TYPE, MnistJavaInference.CONVERTERS)).setParallelism(strArr.length);
        TFConfig buildTFConfig = buildTFConfig(mnist_inference_with_input);
        buildTFConfig.setWorkerNum(strArr.length);
        buildTFConfig.setPsNum(0);
        setExampleCodingTypeWithPojoOut(buildTFConfig);
        TFUtils.inference(executionEnvironment, parallelism, buildTFConfig, InferenceOutPojo.class).addSink(new LogSink()).setParallelism(buildTFConfig.getWorkerNum());
        executionEnvironment.execute();
    }

    @Test
    public void tableStreamWithInput() throws Exception {
        System.out.println("Run Test: " + SysUtil._FUNC_());
        inferenceWithTable();
    }

    private void inferenceWithTable() throws Exception {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.setParallelism(2);
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        StatementSet createStatementSet = create.createStatementSet();
        TFConfig buildTFConfig = buildTFConfig(mnist_inference_with_input);
        buildTFConfig.setPsNum(0);
        buildTFConfig.setWorkerNum(3);
        setExampleCodingTypeWithRowOut(buildTFConfig);
        create.createTemporaryTable("tfr_input_table", TableDescriptor.forConnector("TFRToRow").schema(TypeUtil.rowTypeInfoToSchema(MnistJavaInference.OUT_ROW_TYPE)).option(TFRToRowTableSourceFactory.CONNECTOR_PATH_OPTION, testDataPath + "/0.tfrecords").option(TFRToRowTableSourceFactory.CONNECTOR_CONVERTERS_OPTION, MnistJavaInference.CONVERTERS_STRING).option(TFRToRowTableSourceFactory.CONNECTOR_EPOCHS_OPTION, "1").build());
        Table from = create.from("tfr_input_table");
        Schema build = Schema.newBuilder().column("label_org", org.apache.flink.table.api.DataTypes.INT()).column("predict_label", org.apache.flink.table.api.DataTypes.INT()).build();
        Table inference = TFUtils.inference(executionEnvironment, create, createStatementSet, from, buildTFConfig, build);
        create.createTemporaryView("predict_tbl", inference);
        create.createTemporaryTable("predict_sink", TableDescriptor.forConnector("LogTable").schema(build).build());
        createStatementSet.addInsert("predict_sink", inference);
        ((JobClient) createStatementSet.execute().getJobClient().get()).getJobExecutionResult().get();
    }

    @Test
    public void testInferenceJavaFunction() throws Exception {
        System.out.println("Run Test: " + SysUtil._FUNC_());
        inferenceWithJava(1, false);
    }

    @Test
    public void testInferenceJavaFunctionBatching() throws Exception {
        System.out.println("Run Test: " + SysUtil._FUNC_());
        int nextInt = ThreadLocalRandom.current().nextInt(500) + 2;
        System.out.println("Batch size set to " + nextInt);
        inferenceWithJava(nextInt, false);
    }

    @Test
    public void testJavaInferenceTableToStream() throws Exception {
        System.out.println("Run Test: " + SysUtil._FUNC_());
        int nextInt = ThreadLocalRandom.current().nextInt(500) + 2;
        System.out.println("Batch size set to " + nextInt);
        inferenceWithJava(nextInt, true);
    }

    protected static void setExampleCodingTypeRow(TFConfig tFConfig) {
        tFConfig.getProperties().put("sys:input_tf_example_config", ExampleCodingConfig.createExampleConfigStr(new String[]{"image_raw", "org_label"}, new DataTypes[]{DataTypes.FLOAT_32_ARRAY, DataTypes.INT_32}, ExampleCodingConfig.ObjectType.ROW, MnistTFRPojo.class));
        tFConfig.getProperties().put("sys:output_tf_example_config", ExampleCodingConfig.createExampleConfigStr(new String[]{"real_label", "predicted_label"}, new DataTypes[]{DataTypes.INT_32, DataTypes.INT_32}, ExampleCodingConfig.ObjectType.ROW, InferenceOutPojo.class));
        tFConfig.getProperties().put("sys:encoding_class", ExampleCoding.class.getCanonicalName());
        tFConfig.getProperties().put("sys:decoding_class", ExampleCoding.class.getCanonicalName());
    }

    private void inferenceWithJava(int i, boolean z) throws Exception {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.setParallelism(2);
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        StatementSet createStatementSet = create.createStatementSet();
        TFConfig tFConfig = new TFConfig(2, 0, (Map) null, (String) null, (String) null, (String) null);
        tFConfig.setPsNum(0);
        tFConfig.setWorkerNum(2);
        tFConfig.addProperty("tf.inference.batch.size", String.valueOf(i));
        File[] listFiles = new File(testDataPath).listFiles();
        StringJoiner stringJoiner = new StringJoiner(",");
        for (File file : listFiles) {
            stringJoiner.add(file.getAbsolutePath());
        }
        create.createTemporaryTable("tfr_input_table", TableDescriptor.forConnector("TFRToRow").schema(TypeUtil.rowTypeInfoToSchema(MnistJavaInference.OUT_ROW_TYPE)).option(TFRToRowTableSourceFactory.CONNECTOR_PATH_OPTION, stringJoiner.toString()).option(TFRToRowTableSourceFactory.CONNECTOR_EPOCHS_OPTION, "1").option(TFRToRowTableSourceFactory.CONNECTOR_CONVERTERS_OPTION, MnistJavaInference.CONVERTERS_STRING).build());
        Table from = create.from("tfr_input_table");
        FlinkUtil.registerTableFunction(create, "tfr_extract", new MnistTFRExtractRowForJavaFunction());
        String join = Joiner.on(",").join(TableSchema.builder().field("image", Types.PRIMITIVE_ARRAY(Types.FLOAT())).field("org_label", Types.LONG()).build().getFieldNames());
        Table sqlQuery = create.sqlQuery(String.format("select %s from %s, LATERAL TABLE(%s(%s)) as T(%s)", join, "tfr_input_table", "tfr_extract", Joiner.on(",").join(from.getSchema().getFieldNames()), join));
        Schema build = Schema.newBuilder().column("label_org", org.apache.flink.table.api.DataTypes.INT()).column("predict_label", org.apache.flink.table.api.DataTypes.INT()).build();
        tFConfig.addProperty("tf.inference.export.path", rootPath + "/target/export/0");
        tFConfig.addProperty("tf.inference.input.tensor.names", "image");
        tFConfig.addProperty("tf.inference.output.tensor.names", "prediction");
        tFConfig.addProperty("tf.inference.output.row.fields", Joiner.on(",").join(new String[]{"org_label", "prediction"}));
        setExampleCodingTypeRow(tFConfig);
        Table inference = TFUtils.inference(executionEnvironment, create, createStatementSet, sqlQuery, tFConfig, build);
        create.createTemporaryTable("inference_sink", TableDescriptor.forConnector("LogTable").schema(build).build());
        createStatementSet.addInsert("inference_sink", inference);
        ((JobClient) createStatementSet.execute().getJobClient().get()).getJobExecutionResult().get();
    }
}
