package org.flinkextended.flink.ml.lib.tensorflow;

import java.io.File;
import java.util.Properties;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.core.execution.JobClient;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.TableDescriptor;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.flinkextended.flink.ml.lib.tensorflow.table.descriptor.TableDebugRowOptions;
import org.flinkextended.flink.ml.lib.tensorflow.util.ShellExec;
import org.flinkextended.flink.ml.operator.util.TypeUtil;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/flinkextended/flink/ml/lib/tensorflow/TFInferenceUDTFTest.class */
public class TFInferenceUDTFTest {
    @Before
    public void setUp() throws Exception {
        String str = getClass().getClassLoader().getResource("").getPath() + "../../src/test/python/";
        String str2 = str + "add_saved_model.py";
        if (!new File(getClass().getClassLoader().getResource("").getPath() + "export").exists()) {
            Assert.assertTrue(ShellExec.run("python " + str2));
        }
        String str3 = str + "build_model.py";
        if (new File(getClass().getClassLoader().getResource("").getPath() + "export2").exists()) {
            return;
        }
        Assert.assertTrue(ShellExec.run("python " + str3));
    }

    @After
    public void tearDown() throws Exception {
    }

    @Test
    public void eval() throws Exception {
        TFInferenceUDTF tFInferenceUDTF = new TFInferenceUDTF("file://" + getClass().getClassLoader().getResource("").getPath() + "export", "a,b", "DT_FLOAT, DT_FLOAT", "0, 0", "d", "DT_FLOAT", "0", new Properties(), 5);
        RowTypeInfo rowTypeInfo = new RowTypeInfo(new TypeInformation[]{BasicTypeInfo.FLOAT_TYPE_INFO}, "d".split(","));
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        executionEnvironment.setParallelism(1);
        create.createTemporaryTable("source", TableDescriptor.forConnector("TableDebugRow").schema(Schema.newBuilder().column("a", DataTypes.FLOAT()).column("b", DataTypes.FLOAT()).build()).build());
        create.registerFunction("inference", tFInferenceUDTF);
        create.createTemporaryTable("sink", TableDescriptor.forConnector("print").schema(TypeUtil.rowTypeInfoToSchema(rowTypeInfo)).build());
        ((JobClient) create.executeSql("INSERT INTO sink SELECT d FROM source, LATERAL TABLE(inference(a, b)) as T(d)").getJobClient().get()).getJobExecutionResult().get();
    }

    @Test
    public void eval2() throws Exception {
        TFInferenceUDTF tFInferenceUDTF = new TFInferenceUDTF("file://" + getClass().getClassLoader().getResource("").getPath() + "export", "a,b", "DT_FLOAT, DT_FLOAT", "0, 0", "d,a", "DT_FLOAT, DT_FLOAT", "0, 0", new Properties(), 5);
        RowTypeInfo rowTypeInfo = new RowTypeInfo(new TypeInformation[]{BasicTypeInfo.FLOAT_TYPE_INFO, BasicTypeInfo.FLOAT_TYPE_INFO}, "d,a".split(","));
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        executionEnvironment.setParallelism(1);
        create.createTemporaryTable("source", TableDescriptor.forConnector("TableDebugRow").schema(Schema.newBuilder().column("a", DataTypes.FLOAT()).column("b", DataTypes.FLOAT()).build()).build());
        create.registerFunction("inference", tFInferenceUDTF);
        create.createTemporaryTable("sink", TableDescriptor.forConnector("print").schema(TypeUtil.rowTypeInfoToSchema(rowTypeInfo)).build());
        ((JobClient) create.executeSql("INSERT INTO sink SELECT d, e FROM source, LATERAL TABLE(inference(a, b)) as T(d, e)").getJobClient().get()).getJobExecutionResult().get();
    }

    @Test
    public void eval3() throws Exception {
        TFInferenceUDTF tFInferenceUDTF = new TFInferenceUDTF("file://" + getClass().getClassLoader().getResource("").getPath() + "export", "a,b", "DT_FLOAT, DT_FLOAT", "1, 1", "d,a", "DT_FLOAT, DT_FLOAT", "1, 1", new Properties(), 5);
        RowTypeInfo rowTypeInfo = new RowTypeInfo(new TypeInformation[]{TypeInformation.of(float[].class), TypeInformation.of(float[].class)}, "d,a".split(","));
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        executionEnvironment.setParallelism(1);
        create.createTemporaryTable("source", TableDescriptor.forConnector("TableDebugRow").option(TableDebugRowOptions.CONNECTOR_RANK_OPTION, 1).schema(Schema.newBuilder().column("a", DataTypes.ARRAY(DataTypes.FLOAT())).column("b", DataTypes.ARRAY(DataTypes.FLOAT())).build()).build());
        create.registerFunction("inference", tFInferenceUDTF);
        create.createTemporaryTable("sink", TableDescriptor.forConnector("print").schema(TypeUtil.rowTypeInfoToSchema(rowTypeInfo)).build());
        ((JobClient) create.executeSql("INSERT INTO sink SELECT d, e FROM source, LATERAL TABLE(inference(a, b)) as T(d, e)").getJobClient().get()).getJobExecutionResult().get();
    }

    @Test
    public void eval4() throws Exception {
        TFInferenceUDTF tFInferenceUDTF = new TFInferenceUDTF("file://" + getClass().getClassLoader().getResource("").getPath() + "export2", "a,b,e", "DT_FLOAT, DT_FLOAT, DT_STRING", "1, 1, 1", "d,a,e4", "DT_FLOAT, DT_FLOAT, DT_STRING", "1,1,1", new Properties(), 5);
        TypeInformation[] typeInformationArr = {TypeInformation.of(float[].class), TypeInformation.of(float[].class), TypeInformation.of(String[].class)};
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        executionEnvironment.setParallelism(1);
        create.createTemporaryTable("source", TableDescriptor.forConnector("TableDebugRow").option(TableDebugRowOptions.CONNECTOR_RANK_OPTION, 1).option(TableDebugRowOptions.CONNECTOR_HAS_STRING_OPTION, true).schema(Schema.newBuilder().column("a", DataTypes.ARRAY(DataTypes.FLOAT())).column("b", DataTypes.ARRAY(DataTypes.FLOAT())).column("c", DataTypes.ARRAY(DataTypes.STRING())).build()).build());
        create.registerFunction("inference", tFInferenceUDTF);
        create.createTemporaryTable("sink", TableDescriptor.forConnector("print").schema(Schema.newBuilder().column("d", DataTypes.ARRAY(DataTypes.FLOAT())).column("f", DataTypes.ARRAY(DataTypes.FLOAT())).column("h", DataTypes.ARRAY(DataTypes.STRING())).build()).build());
        ((JobClient) create.executeSql("INSERT INTO sink SELECT d, f, h FROM source, LATERAL TABLE(inference(a, b, c)) as T(d, f, h)").getJobClient().get()).getJobExecutionResult().get();
    }

    @Test
    public void eval5() throws Exception {
        TFInferenceUDTF tFInferenceUDTF = new TFInferenceUDTF("file://" + getClass().getClassLoader().getResource("").getPath() + "export2", "a,b,e", "DT_FLOAT, DT_FLOAT, DT_STRING", "2, 2, 2", "d,a,e4", "DT_FLOAT, DT_FLOAT, DT_STRING", "2,2,2", new Properties(), 5);
        TypeInformation[] typeInformationArr = {TypeInformation.of(float[][].class), TypeInformation.of(float[][].class), TypeInformation.of(String[][].class)};
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        executionEnvironment.setParallelism(1);
        create.createTemporaryTable("source", TableDescriptor.forConnector("TableDebugRow").option(TableDebugRowOptions.CONNECTOR_RANK_OPTION, 2).option(TableDebugRowOptions.CONNECTOR_HAS_STRING_OPTION, true).schema(Schema.newBuilder().column("a", DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.FLOAT()))).column("b", DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.FLOAT()))).column("c", DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.STRING()))).build()).build());
        create.registerFunction("inference", tFInferenceUDTF);
        create.createTemporaryTable("sink", TableDescriptor.forConnector("print").schema(Schema.newBuilder().column("d", DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.FLOAT()))).column("f", DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.FLOAT()))).column("h", DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.STRING()))).build()).build());
        ((JobClient) create.executeSql("INSERT INTO sink SELECT d, f, h FROM source, LATERAL TABLE(inference(a, b, c)) as T(d, f, h)").getJobClient().get()).getJobExecutionResult().get();
    }
}
