package org.flinkextended.flink.ml.examples.tensorflow.ut;

import java.io.File;
import java.util.HashMap;
import org.apache.curator.test.TestingServer;
import org.apache.flink.core.execution.JobClient;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.StatementSet;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableDescriptor;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.flinkextended.flink.ml.examples.tensorflow.mnist.MnistDataUtil;
import org.flinkextended.flink.ml.examples.tensorflow.mnist.MnistJavaInference;
import org.flinkextended.flink.ml.examples.tensorflow.mnist.ops.MnistTFRPojo;
import org.flinkextended.flink.ml.examples.tensorflow.ops.MnistTFRExtractPojoMapOp;
import org.flinkextended.flink.ml.operator.util.DataTypes;
import org.flinkextended.flink.ml.operator.util.TypeUtil;
import org.flinkextended.flink.ml.tensorflow.client.TFConfig;
import org.flinkextended.flink.ml.tensorflow.client.TFUtils;
import org.flinkextended.flink.ml.tensorflow.coding.ExampleCoding;
import org.flinkextended.flink.ml.tensorflow.coding.ExampleCodingConfig;
import org.flinkextended.flink.ml.tensorflow.io.TFRToRowTableSourceFactory;
import org.flinkextended.flink.ml.tensorflow.io.TFRecordSource;
import org.flinkextended.flink.ml.util.SysUtil;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/flinkextended/flink/ml/examples/tensorflow/ut/TFMnistTest.class */
public class TFMnistTest {
    private static TestingServer server;
    private static final String mnist_dist = "mnist_dist.py";
    private static final String mnist_dist_with_input = "mnist_dist_with_input.py";

    @Before
    public void setUp() throws Exception {
        MnistDataUtil.prepareData();
        server = new TestingServer(2181, true);
    }

    @After
    public void tearDown() throws Exception {
        if (server != null) {
            server.stop();
        }
    }

    public TFConfig buildTFConfig(String str) {
        return buildTFConfig(str, String.valueOf(System.currentTimeMillis()));
    }

    private TFConfig buildTFConfig(String str, String str2) {
        System.out.println("Run Test: " + SysUtil._FUNC_());
        String absolutePath = new File("").getAbsolutePath();
        String str3 = absolutePath + "/src/test/python/" + str;
        System.out.println("Current version:" + str2);
        HashMap hashMap = new HashMap();
        hashMap.put("batch_size", "32");
        hashMap.put("input", absolutePath + "/target/data/train/");
        hashMap.put("epochs", "1");
        hashMap.put("checkpoint_dir", absolutePath + "/target/ckpt/" + str2);
        hashMap.put("export_dir", absolutePath + "/target/export/" + str2);
        return new TFConfig(2, 1, hashMap, str3, "map_fun", (String) null);
    }

    @Test
    public void testDataStreamApi() throws Exception {
        System.out.println("Run Test: " + SysUtil._FUNC_());
        TFConfig buildTFConfig = buildTFConfig(mnist_dist);
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        TFUtils.train(executionEnvironment, buildTFConfig);
        System.out.println(executionEnvironment.execute().getNetRuntime());
    }

    @Test
    public void testTableStreamApi() throws Exception {
        System.out.println("Run Test: " + SysUtil._FUNC_());
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.setParallelism(2);
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        StatementSet createStatementSet = create.createStatementSet();
        TFUtils.train(executionEnvironment, create, createStatementSet, (Table) null, buildTFConfig(mnist_dist), (Schema) null);
        ((JobClient) createStatementSet.execute().getJobClient().get()).getJobExecutionResult().get();
    }

    @Test
    public void testDataStreamHaveInput() throws Exception {
        System.out.println("Run Test: " + SysUtil._FUNC_());
        String valueOf = String.valueOf(System.currentTimeMillis());
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        String absolutePath = new File("").getAbsolutePath();
        String[] strArr = {absolutePath + "/target/data/train/0.tfrecords", absolutePath + "/target/data/train/1.tfrecords"};
        TFConfig buildTFConfig = buildTFConfig(mnist_dist_with_input, valueOf);
        buildTFConfig.setWorkerNum(strArr.length);
        DataStreamSource parallelism = executionEnvironment.addSource(TFRecordSource.createSource(strArr, 1)).setParallelism(strArr.length);
        SingleOutputStreamOperator parallelism2 = parallelism.flatMap(new MnistTFRExtractPojoMapOp()).setParallelism(parallelism.getParallelism());
        setExampleCodingType(buildTFConfig);
        TFUtils.train(executionEnvironment, parallelism2, buildTFConfig);
        System.out.println("Run Finish:" + executionEnvironment.execute().getNetRuntime());
    }

    public static void setExampleCodingType(TFConfig tFConfig) {
        String createExampleConfigStr = ExampleCodingConfig.createExampleConfigStr(new String[]{"image_raw", "label"}, new DataTypes[]{DataTypes.STRING, DataTypes.INT_32}, ExampleCodingConfig.ObjectType.POJO, MnistTFRPojo.class);
        tFConfig.getProperties().put("sys:input_tf_example_config", createExampleConfigStr);
        tFConfig.getProperties().put("sys:output_tf_example_config", createExampleConfigStr);
        tFConfig.getProperties().put("sys:encoding_class", ExampleCoding.class.getCanonicalName());
        tFConfig.getProperties().put("sys:decoding_class", ExampleCoding.class.getCanonicalName());
    }

    public static void setExampleCodingRowType(TFConfig tFConfig) {
        String createExampleConfigStr = ExampleCodingConfig.createExampleConfigStr(new String[]{"image_raw", "label"}, new DataTypes[]{DataTypes.STRING, DataTypes.INT_32}, ExampleCodingConfig.ObjectType.ROW, MnistTFRPojo.class);
        tFConfig.getProperties().put("sys:input_tf_example_config", createExampleConfigStr);
        tFConfig.getProperties().put("sys:output_tf_example_config", createExampleConfigStr);
        tFConfig.getProperties().put("sys:encoding_class", ExampleCoding.class.getCanonicalName());
        tFConfig.getProperties().put("sys:decoding_class", ExampleCoding.class.getCanonicalName());
    }

    @Test
    public void testTableStreamHaveInput() throws Exception {
        System.out.println("Run Test: " + SysUtil._FUNC_());
        TFConfig buildTFConfig = buildTFConfig(mnist_dist_with_input, String.valueOf(System.currentTimeMillis()));
        setExampleCodingRowType(buildTFConfig);
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.setParallelism(2);
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        StatementSet createStatementSet = create.createStatementSet();
        String absolutePath = new File("").getAbsolutePath();
        create.createTemporaryTable("input", TableDescriptor.forConnector("TFRToRow").schema(TypeUtil.rowTypeInfoToSchema(MnistJavaInference.OUT_ROW_TYPE)).option(TFRToRowTableSourceFactory.CONNECTOR_PATH_OPTION, absolutePath + "/target/data/train/0.tfrecords," + absolutePath + "/target/data/train/1.tfrecords").option(TFRToRowTableSourceFactory.CONNECTOR_EPOCHS_OPTION, "1").option(TFRToRowTableSourceFactory.CONNECTOR_CONVERTERS_OPTION, MnistJavaInference.CONVERTERS_STRING).build());
        TFUtils.train(executionEnvironment, create, createStatementSet, create.from("input"), buildTFConfig, (Schema) null);
        ((JobClient) createStatementSet.execute().getJobClient().get()).getJobExecutionResult().get();
    }
}
