package org.flinkextended.flink.ml.tensorflow.client;

import java.util.Map;
import org.apache.curator.test.TestingServer;
import org.apache.flink.core.execution.JobClient;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.StatementSet;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableDescriptor;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.flinkextended.flink.ml.cluster.MLConfig;
import org.flinkextended.flink.ml.operator.coding.RowCSVCoding;
import org.flinkextended.flink.ml.operator.source.DebugRowSource;
import org.flinkextended.flink.ml.operator.util.DataTypes;
import org.flinkextended.flink.ml.operator.util.TypeUtil;
import org.flinkextended.flink.ml.tensorflow.hooks.DebugHook;
import org.flinkextended.flink.ml.tensorflow.util.TFConstants;
import org.flinkextended.flink.ml.util.SysUtil;
import org.flinkextended.flink.ml.util.TestUtil;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/client/TFUtilsTest.class */
public class TFUtilsTest {
    private static TestingServer server;
    private static final String pythonPath = TestUtil.getProjectRootPath() + "/dl-on-flink-tensorflow-2.x/src/test/python/";
    private static final String add = pythonPath + "add.py";
    private static final String workerZeroFinishScript = pythonPath + "worker_0_finish.py";
    private static final String addTBScript = pythonPath + "add_withtb.py";
    private static final String inputOutputScript = pythonPath + "input_output.py";
    private static final String tensorboardScript = pythonPath + "tensorboard.py";
    private static final String ckptDir = TestUtil.getProjectRootPath() + "/dl-on-flink-tensorflow/target/tmp/add_withtb/";

    @Before
    public void setUp() throws Exception {
        server = new TestingServer(2181, true);
    }

    @After
    public void tearDown() throws Exception {
        server.stop();
    }

    @Test
    public void addTrainStream() throws Exception {
        System.out.println(SysUtil._FUNC_());
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        TFUtils.train(executionEnvironment, (DataStream) null, new TFConfig(2, 1, (Map) null, add, "map_func", (String) null));
        System.out.println(executionEnvironment.execute().getNetRuntime());
    }

    @Test
    public void addTrainTable() throws Exception {
        System.out.println(SysUtil._FUNC_());
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        StatementSet createStatementSet = create.createStatementSet();
        TFConfig tFConfig = new TFConfig(2, 1, (Map) null, add, "map_func", (String) null);
        TFUtils.train(executionEnvironment, create, createStatementSet, (Table) null, tFConfig, (Schema) null);
        execTableJobCustom(tFConfig.getMlConfig(), executionEnvironment, create, createStatementSet);
    }

    @Test
    public void addTrainChiefAloneStream() throws Exception {
        System.out.println(SysUtil._FUNC_());
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        TFConfig tFConfig = new TFConfig(2, 1, (Map) null, add, "map_func", (String) null);
        tFConfig.addProperty(TFConstants.TF_IS_CHIEF_ALONE, "true");
        TFUtils.train(executionEnvironment, (DataStream) null, tFConfig);
        System.out.println(executionEnvironment.execute().getNetRuntime());
    }

    @Test
    public void addTrainChiefAloneTable() throws Exception {
        System.out.println(SysUtil._FUNC_());
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        StatementSet createStatementSet = create.createStatementSet();
        TFConfig tFConfig = new TFConfig(2, 1, (Map) null, add, "map_func", (String) null);
        tFConfig.addProperty(TFConstants.TF_IS_CHIEF_ALONE, "true");
        TFUtils.train(executionEnvironment, create, createStatementSet, (Table) null, tFConfig, (Schema) null);
        execTableJobCustom(tFConfig.getMlConfig(), executionEnvironment, create, createStatementSet);
    }

    @Test
    public void inputOutputTable() throws Exception {
        System.out.println(SysUtil._FUNC_());
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        TFConfig tFConfig = new TFConfig(2, 1, (Map) null, inputOutputScript, "map_func", (String) null);
        tFConfig.getProperties().put("sys:encoding_class", RowCSVCoding.class.getCanonicalName());
        tFConfig.getProperties().put("sys:decoding_class", RowCSVCoding.class.getCanonicalName());
        StringBuilder sb = new StringBuilder();
        sb.append(DataTypes.INT_32.name()).append(",");
        sb.append(DataTypes.INT_64.name()).append(",");
        sb.append(DataTypes.FLOAT_32.name()).append(",");
        sb.append(DataTypes.FLOAT_64.name()).append(",");
        sb.append(DataTypes.STRING.name());
        tFConfig.getProperties().put(RowCSVCoding.ENCODE_TYPES, sb.toString());
        tFConfig.getProperties().put(RowCSVCoding.DECODE_TYPES, sb.toString());
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        StatementSet createStatementSet = create.createStatementSet();
        create.createTemporaryTable("debug_source", TableDescriptor.forConnector("TableDebug").schema(TypeUtil.rowTypeInfoToSchema(DebugRowSource.typeInfo)).build());
        Table scan = create.scan(new String[]{"debug_source"});
        create.createTemporaryTable("table_row_sink", TableDescriptor.forConnector("TableDebug").schema(TypeUtil.rowTypeInfoToSchema(DebugRowSource.typeInfo)).build());
        createStatementSet.addInsert("table_row_sink", TFUtils.train(executionEnvironment, create, createStatementSet, scan, tFConfig, TypeUtil.rowTypeInfoToSchema(DebugRowSource.typeInfo)));
        execTableJobCustom(tFConfig.getMlConfig(), executionEnvironment, create, createStatementSet);
    }

    @Test
    public void testTensorBoard() throws Exception {
        System.out.println(SysUtil._FUNC_());
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        TFConfig tFConfig = new TFConfig(2, 1, (Map) null, addTBScript, "map_func", (String) null);
        tFConfig.getProperties().put("flink_hook_class_names", DebugHook.class.getCanonicalName());
        tFConfig.addProperty("checkpoint_dir", ckptDir + String.valueOf(System.currentTimeMillis()));
        TFUtils.train(executionEnvironment, (DataStream) null, tFConfig);
        TFConfig deepCopy = tFConfig.deepCopy();
        deepCopy.setPythonFiles(new String[]{tensorboardScript});
        TFUtils.startTensorBoard(executionEnvironment, deepCopy);
        executionEnvironment.execute();
    }

    @Test
    public void testTensorBoardTable() throws Exception {
        System.out.println(SysUtil._FUNC_());
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        StatementSet createStatementSet = create.createStatementSet();
        TFConfig tFConfig = new TFConfig(2, 1, (Map) null, addTBScript, "map_func", (String) null);
        tFConfig.getProperties().put("flink_hook_class_names", DebugHook.class.getCanonicalName());
        tFConfig.addProperty("checkpoint_dir", ckptDir + String.valueOf(System.currentTimeMillis()));
        TFUtils.train(executionEnvironment, create, createStatementSet, (Table) null, tFConfig, (Schema) null);
        TFConfig deepCopy = tFConfig.deepCopy();
        deepCopy.setPythonFiles(new String[]{tensorboardScript});
        TFUtils.startTensorBoard(executionEnvironment, create, createStatementSet, deepCopy);
        ((JobClient) createStatementSet.execute().getJobClient().get()).getJobExecutionResult().get();
    }

    @Test
    public void testWorkerZeroFinish() throws Exception {
        System.out.println(SysUtil._FUNC_());
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        StatementSet createStatementSet = create.createStatementSet();
        TFConfig tFConfig = new TFConfig(3, 2, (Map) null, workerZeroFinishScript, "map_func", (String) null);
        TFUtils.train(executionEnvironment, create, createStatementSet, (Table) null, tFConfig, (Schema) null);
        execTableJobCustom(tFConfig.getMlConfig(), executionEnvironment, create, createStatementSet);
    }

    public static void execTableJobCustom(MLConfig mLConfig, StreamExecutionEnvironment streamExecutionEnvironment, TableEnvironment tableEnvironment, StatementSet statementSet) throws Exception {
        ((JobClient) statementSet.execute().getJobClient().get()).getJobExecutionResult().get();
    }
}
