package org.flinkextended.flink.ml.tensorflow.client;

import java.lang.invoke.SerializedLambda;
import java.net.URL;
import java.util.concurrent.ExecutionException;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.TableDescriptor;
import org.apache.flink.table.api.bridge.java.StreamStatementSet;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.flinkextended.flink.ml.operator.coding.RowCSVCoding;
import org.flinkextended.flink.ml.operator.source.DebugRowSource;
import org.flinkextended.flink.ml.operator.util.DataTypes;
import org.flinkextended.flink.ml.operator.util.TypeUtil;
import org.flinkextended.flink.ml.tensorflow.hooks.DebugHook;
import org.flinkextended.flink.ml.util.SysUtil;
import org.flinkextended.flink.ml.util.TestUtil;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/client/TFUtilsTest.class */
public class TFUtilsTest {
    private static final String ckptDir = TestUtil.getProjectRootPath() + "/dl-on-flink-tensorflow/target/tmp/add_withtb/";

    @Test
    public void testTrainAdd() throws Exception {
        System.out.println(SysUtil._FUNC_());
        StreamStatementSet createStatementSet = StreamTableEnvironment.create(StreamExecutionEnvironment.getExecutionEnvironment()).createStatementSet();
        TFUtils.train(createStatementSet, TFClusterConfig.newBuilder().setWorkerCount(2).setPsCount(1).setNodeEntry(getScriptPathFromResources("add.py"), "map_func").build());
        createStatementSet.execute().await();
    }

    @Test
    public void testIterationTrain() throws ExecutionException, InterruptedException {
        System.out.println(SysUtil._FUNC_());
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        StreamStatementSet createStatementSet = create.createStatementSet();
        TFUtils.train(createStatementSet, create.fromDataStream(executionEnvironment.fromElements(new Integer[]{1, 2, 3, 4}).broadcast().map(num -> {
            return num;
        }).setParallelism(2)), TFClusterConfig.newBuilder().setNodeEntry(getScriptPathFromResources("print_input_iter.py"), "map_func").setWorkerCount(2).setProperty("sys:encoding_class", RowCSVCoding.class.getName()).setProperty("input_types", "INT_32").build(), 4);
        createStatementSet.execute().await();
    }

    @Test
    public void testIterationTrainWithEarlyTermination() throws ExecutionException, InterruptedException {
        System.out.println(SysUtil._FUNC_());
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        StreamStatementSet createStatementSet = create.createStatementSet();
        TFUtils.train(createStatementSet, create.fromDataStream(executionEnvironment.fromElements(new Integer[]{1, 2, 3, 4}).broadcast().map(num -> {
            return num;
        }).setParallelism(2)), TFClusterConfig.newBuilder().setNodeEntry(getScriptPathFromResources("print_input_iter.py"), "map_func").setWorkerCount(2).setProperty("sys:encoding_class", RowCSVCoding.class.getName()).setProperty("input_types", "INT_32").build(), Integer.MAX_VALUE);
        createStatementSet.execute().await();
    }

    @Test
    public void inferenceTable() throws Exception {
        System.out.println(SysUtil._FUNC_());
        StreamTableEnvironment create = StreamTableEnvironment.create(StreamExecutionEnvironment.getExecutionEnvironment());
        StreamStatementSet createStatementSet = create.createStatementSet();
        StringBuilder sb = new StringBuilder();
        sb.append(DataTypes.INT_32.name()).append(",");
        sb.append(DataTypes.INT_64.name()).append(",");
        sb.append(DataTypes.FLOAT_32.name()).append(",");
        sb.append(DataTypes.FLOAT_64.name()).append(",");
        sb.append(DataTypes.STRING.name());
        TFClusterConfig build = TFClusterConfig.newBuilder().setWorkerCount(2).setPsCount(1).setNodeEntry(getScriptPathFromResources("input_output.py"), "map_func").setProperty("sys:encoding_class", RowCSVCoding.class.getName()).setProperty("sys:decoding_class", RowCSVCoding.class.getName()).setProperty("input_types", sb.toString()).setProperty("output_types", sb.toString()).build();
        create.createTemporaryTable("debug_source", TableDescriptor.forConnector("TableDebug").schema(TypeUtil.rowTypeInfoToSchema(DebugRowSource.typeInfo)).build());
        createStatementSet.addInsert(TableDescriptor.forConnector("TableDebug").schema(TypeUtil.rowTypeInfoToSchema(DebugRowSource.typeInfo)).build(), TFUtils.inference(createStatementSet, create.from("debug_source"), build, TypeUtil.rowTypeInfoToSchema(DebugRowSource.typeInfo)));
        createStatementSet.execute().await();
    }

    @Test
    public void testTensorBoardTable() throws Exception {
        System.out.println(SysUtil._FUNC_());
        StreamStatementSet createStatementSet = StreamTableEnvironment.create(StreamExecutionEnvironment.getExecutionEnvironment()).createStatementSet();
        TFClusterConfig build = TFClusterConfig.newBuilder().setWorkerCount(2).setPsCount(1).setNodeEntry(getScriptPathFromResources("add_withtb.py"), "map_func").setProperty("flink_hook_class_names", DebugHook.class.getName()).setProperty("checkpoint_dir", ckptDir + System.currentTimeMillis()).build();
        TFUtils.train(createStatementSet, build);
        TFUtils.tensorBoard(createStatementSet, build);
        createStatementSet.execute().await();
    }

    @Test
    public void testWorkerZeroFinish() throws Exception {
        System.out.println(SysUtil._FUNC_());
        StreamStatementSet createStatementSet = StreamTableEnvironment.create(StreamExecutionEnvironment.getExecutionEnvironment()).createStatementSet();
        TFUtils.train(createStatementSet, TFClusterConfig.newBuilder().setWorkerCount(3).setPsCount(2).setNodeEntry(getScriptPathFromResources("worker_0_finish.py"), "map_func").build());
        createStatementSet.execute().await();
    }

    private String getScriptPathFromResources(String str) {
        URL resource = Thread.currentThread().getContextClassLoader().getResource(str);
        Assert.assertNotNull(resource);
        return resource.getPath();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1723351535:
                if (implMethodName.equals("lambda$testIterationTrainWithEarlyTermination$920cfdd9$1")) {
                    z = false;
                    break;
                }
                break;
            case 1935661046:
                if (implMethodName.equals("lambda$testIterationTrain$920cfdd9$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/flinkextended/flink/ml/tensorflow/client/TFUtilsTest") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Integer;)Ljava/lang/Integer;")) {
                    return num -> {
                        return num;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/flinkextended/flink/ml/tensorflow/client/TFUtilsTest") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Integer;)Ljava/lang/Integer;")) {
                    return num2 -> {
                        return num2;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
