package org.flinkextended.flink.ml.pytorch;

import java.util.Map;
import org.apache.curator.test.TestingServer;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.StatementSet;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.flinkextended.flink.ml.operator.client.TableTestUtil;
import org.flinkextended.flink.ml.util.TestUtil;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/flinkextended/flink/ml/pytorch/PyTorchUtilTest.class */
public class PyTorchUtilTest {
    private static TestingServer testingServer;
    private static String rootPath = TestUtil.getProjectRootPath() + "/dl-on-flink-pytorch/src/test/python/";

    @Before
    public void setUp() throws Exception {
        testingServer = new TestingServer(2181, true);
    }

    @After
    public void tearDown() throws Exception {
        testingServer.stop();
    }

    @Test
    public void trainStream() throws Exception {
        PyTorchConfig pyTorchConfig = new PyTorchConfig(3, (Map) null, rootPath + "greeter.py", "map_func", (String) null);
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        PyTorchUtil.train(executionEnvironment, (DataStream) null, pyTorchConfig, (TypeInformation) null);
        executionEnvironment.execute();
    }

    @Test
    public void trainTable() throws Exception {
        PyTorchConfig pyTorchConfig = new PyTorchConfig(3, (Map) null, rootPath + "greeter.py", "map_func", (String) null);
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        StatementSet createStatementSet = create.createStatementSet();
        PyTorchUtil.train(executionEnvironment, create, createStatementSet, (Table) null, pyTorchConfig, (Schema) null);
        TableTestUtil.execTableJobCustom(pyTorchConfig.getMlConfig(), executionEnvironment, create, createStatementSet);
    }
}
