package org.flinkextended.flink.ml.pytorch;

import org.apache.flink.configuration.Configuration;
import org.apache.flink.iteration.DataStreamList;
import org.apache.flink.iteration.IterationConfig;
import org.apache.flink.iteration.Iterations;
import org.apache.flink.iteration.ReplayableDataStreamList;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.StatementSet;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableDescriptor;
import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
import org.apache.flink.table.api.internal.StatementSetImpl;
import org.flinkextended.flink.ml.operator.client.NodeUtils;
import org.flinkextended.flink.ml.operator.util.ReflectionUtils;

/* loaded from: input_file:org/flinkextended/flink/ml/pytorch/PyTorchUtils.class */
public class PyTorchUtils {
    public static void train(StatementSet statementSet, PyTorchClusterConfig pyTorchClusterConfig) {
        NodeUtils.scheduleAMNode(statementSet, pyTorchClusterConfig);
        NodeUtils.scheduleNodes(statementSet, pyTorchClusterConfig, PyTorchClusterConfig.WORKER_NODE_TYPE);
    }

    public static void train(StatementSet statementSet, Table table, PyTorchClusterConfig pyTorchClusterConfig) {
        NodeUtils.scheduleAMNode(statementSet, pyTorchClusterConfig);
        NodeUtils.scheduleNodes(statementSet, table, pyTorchClusterConfig, PyTorchClusterConfig.WORKER_NODE_TYPE);
    }

    public static void train(StatementSet statementSet, Table table, PyTorchClusterConfig pyTorchClusterConfig, Integer num) {
        StreamTableEnvironmentImpl streamTableEnvironmentImpl = (StreamTableEnvironmentImpl) ReflectionUtils.getFieldValue(statementSet, StatementSetImpl.class, "tableEnvironment");
        StreamExecutionEnvironment execEnv = streamTableEnvironmentImpl.execEnv();
        DataStream dataStream = streamTableEnvironmentImpl.toDataStream(table);
        Configuration mergeConfiguration = NodeUtils.mergeConfiguration(execEnv, streamTableEnvironmentImpl.getConfig());
        NodeUtils.scheduleAMNode(statementSet, pyTorchClusterConfig);
        statementSet.addInsert(TableDescriptor.forConnector("blackhole").build(), streamTableEnvironmentImpl.fromDataStream(Iterations.iterateBoundedStreamsUntilTermination(DataStreamList.of(new DataStream[]{execEnv.fromElements(new Integer[]{0})}), ReplayableDataStreamList.replay(new DataStream[]{dataStream}), IterationConfig.newBuilder().build(), new PyTorchNodeIterationBody(execEnv, pyTorchClusterConfig, num, mergeConfiguration)).get(0)));
    }

    public static Table inference(StatementSet statementSet, Table table, PyTorchClusterConfig pyTorchClusterConfig, Schema schema) {
        NodeUtils.scheduleAMNode(statementSet, pyTorchClusterConfig);
        return NodeUtils.scheduleNodes(table, pyTorchClusterConfig, schema, PyTorchClusterConfig.WORKER_NODE_TYPE);
    }
}
