package org.flinkextended.flink.ml.pytorch;

import java.io.IOException;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.StatementSet;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableEnvironment;
import org.flinkextended.flink.ml.cluster.ExecutionMode;
import org.flinkextended.flink.ml.cluster.role.WorkerRole;
import org.flinkextended.flink.ml.operator.client.RoleUtils;
import org.flinkextended.flink.ml.operator.util.PythonFileUtil;

/* loaded from: input_file:org/flinkextended/flink/ml/pytorch/PyTorchUtil.class */
public class PyTorchUtil {
    public static <IN, OUT> DataStream<OUT> run(StreamExecutionEnvironment streamExecutionEnvironment, ExecutionMode executionMode, DataStream<IN> dataStream, PyTorchConfig pyTorchConfig, TypeInformation<OUT> typeInformation) throws IOException {
        pyTorchConfig.getMlConfig().getProperties().put("sys:ml_runner_class", PyTorchRunner.class.getCanonicalName());
        PythonFileUtil.registerPythonFiles(streamExecutionEnvironment, pyTorchConfig.getMlConfig());
        RoleUtils.addAMRole(streamExecutionEnvironment, pyTorchConfig.getMlConfig());
        return RoleUtils.addRole(streamExecutionEnvironment, executionMode, dataStream, pyTorchConfig.getMlConfig(), typeInformation, new WorkerRole());
    }

    public static Table run(StreamExecutionEnvironment streamExecutionEnvironment, TableEnvironment tableEnvironment, StatementSet statementSet, ExecutionMode executionMode, Table table, PyTorchConfig pyTorchConfig, Schema schema) throws IOException {
        pyTorchConfig.getMlConfig().getProperties().put("sys:ml_runner_class", PyTorchRunner.class.getCanonicalName());
        PythonFileUtil.registerPythonFiles(streamExecutionEnvironment, pyTorchConfig.getMlConfig());
        RoleUtils.addAMRole(tableEnvironment, statementSet, pyTorchConfig.getMlConfig());
        return RoleUtils.addRole(tableEnvironment, statementSet, executionMode, table, pyTorchConfig.getMlConfig(), schema, new WorkerRole());
    }

    public static <IN, OUT> DataStream<OUT> train(StreamExecutionEnvironment streamExecutionEnvironment, DataStream<IN> dataStream, PyTorchConfig pyTorchConfig, TypeInformation<OUT> typeInformation) throws IOException {
        return run(streamExecutionEnvironment, ExecutionMode.TRAIN, dataStream, pyTorchConfig, typeInformation);
    }

    public static <IN, OUT> DataStream<OUT> inference(StreamExecutionEnvironment streamExecutionEnvironment, DataStream<IN> dataStream, PyTorchConfig pyTorchConfig, TypeInformation<OUT> typeInformation) throws IOException {
        return run(streamExecutionEnvironment, ExecutionMode.INFERENCE, dataStream, pyTorchConfig, typeInformation);
    }

    public static Table train(StreamExecutionEnvironment streamExecutionEnvironment, TableEnvironment tableEnvironment, StatementSet statementSet, Table table, PyTorchConfig pyTorchConfig, Schema schema) throws IOException {
        return run(streamExecutionEnvironment, tableEnvironment, statementSet, ExecutionMode.TRAIN, table, pyTorchConfig, schema);
    }

    public static Table inference(StreamExecutionEnvironment streamExecutionEnvironment, TableEnvironment tableEnvironment, StatementSet statementSet, Table table, PyTorchConfig pyTorchConfig, Schema schema) throws IOException {
        return run(streamExecutionEnvironment, tableEnvironment, statementSet, ExecutionMode.INFERENCE, table, pyTorchConfig, schema);
    }
}
