package org.flinkextended.flink.ml.tensorflow.client;

import com.google.common.base.Preconditions;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.StatementSet;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableDescriptor;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableEnvironmentInternal;
import org.apache.flink.types.Row;
import org.flinkextended.flink.ml.cluster.ExecutionMode;
import org.flinkextended.flink.ml.cluster.role.PsRole;
import org.flinkextended.flink.ml.cluster.role.WorkerRole;
import org.flinkextended.flink.ml.operator.client.RoleUtils;
import org.flinkextended.flink.ml.operator.util.PythonFileUtil;
import org.flinkextended.flink.ml.operator.util.TypeUtil;
import org.flinkextended.flink.ml.tensorflow.cluster.ChiefRole;
import org.flinkextended.flink.ml.tensorflow.cluster.TFAMStateMachineImpl;
import org.flinkextended.flink.ml.tensorflow.cluster.TensorBoardRole;
import org.flinkextended.flink.ml.tensorflow.cluster.node.runner.TFMLRunner;
import org.flinkextended.flink.ml.tensorflow.cluster.node.runner.TensorBoardPythonRunner;
import org.flinkextended.flink.ml.tensorflow.data.TFRecordReaderImpl;
import org.flinkextended.flink.ml.tensorflow.data.TFRecordWriterImpl;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/client/TFUtils.class */
public class TFUtils {
    private static Logger LOG = LoggerFactory.getLogger(TFUtils.class);
    static final Schema DUMMY_SCHEMA = Schema.newBuilder().column("a", DataTypes.STRING()).build();
    private static AtomicInteger count = new AtomicInteger(0);

    public static <OUT> DataStream<OUT> train(StreamExecutionEnvironment streamExecutionEnvironment, TFConfigBase tFConfigBase) throws IOException {
        return train(streamExecutionEnvironment, tFConfigBase, (TypeInformation) null);
    }

    public static <OUT> DataStream<OUT> train(StreamExecutionEnvironment streamExecutionEnvironment, TFConfigBase tFConfigBase, Class<OUT> cls) throws IOException {
        return train(streamExecutionEnvironment, tFConfigBase, getTypeInfo(cls));
    }

    public static <OUT> DataStream<OUT> train(StreamExecutionEnvironment streamExecutionEnvironment, TFConfigBase tFConfigBase, TypeInformation<OUT> typeInformation) throws IOException {
        return run(streamExecutionEnvironment, ExecutionMode.TRAIN, (DataStream) null, tFConfigBase, typeInformation);
    }

    public static <IN, OUT> DataStream<OUT> train(StreamExecutionEnvironment streamExecutionEnvironment, DataStream<IN> dataStream, TFConfigBase tFConfigBase) throws IOException {
        return train(streamExecutionEnvironment, dataStream, tFConfigBase, (TypeInformation) null);
    }

    public static <IN, OUT> DataStream<OUT> train(StreamExecutionEnvironment streamExecutionEnvironment, DataStream<IN> dataStream, TFConfigBase tFConfigBase, Class<OUT> cls) throws IOException {
        return train(streamExecutionEnvironment, dataStream, tFConfigBase, getTypeInfo(cls));
    }

    public static <IN, OUT> DataStream<OUT> train(StreamExecutionEnvironment streamExecutionEnvironment, DataStream<IN> dataStream, TFConfigBase tFConfigBase, TypeInformation<OUT> typeInformation) throws IOException {
        return run(streamExecutionEnvironment, ExecutionMode.TRAIN, dataStream, tFConfigBase, typeInformation);
    }

    public static <IN, OUT> DataStream<OUT> inference(StreamExecutionEnvironment streamExecutionEnvironment, DataStream<IN> dataStream, TFConfigBase tFConfigBase, Class<OUT> cls) throws IOException {
        return inference(streamExecutionEnvironment, dataStream, tFConfigBase, getTypeInfo(cls));
    }

    public static <IN, OUT> DataStream<OUT> inference(StreamExecutionEnvironment streamExecutionEnvironment, DataStream<IN> dataStream, TFConfigBase tFConfigBase, TypeInformation<OUT> typeInformation) throws IOException {
        return run(streamExecutionEnvironment, ExecutionMode.INFERENCE, dataStream, tFConfigBase, typeInformation);
    }

    public static <IN, OUT> DataStream<OUT> run(StreamExecutionEnvironment streamExecutionEnvironment, ExecutionMode executionMode, DataStream<IN> dataStream, TFConfigBase tFConfigBase, Class<OUT> cls) throws IOException {
        return run(streamExecutionEnvironment, executionMode, dataStream, tFConfigBase, getTypeInfo(cls));
    }

    public static <IN, OUT> DataStream<OUT> run(StreamExecutionEnvironment streamExecutionEnvironment, ExecutionMode executionMode, DataStream<IN> dataStream, TFConfigBase tFConfigBase) throws IOException {
        return run(streamExecutionEnvironment, executionMode, dataStream, tFConfigBase, (TypeInformation) null);
    }

    private static void setTFDefaultConfig(TFConfigBase tFConfigBase) {
        tFConfigBase.getProperties().put("sys:ml_runner_class", TFMLRunner.class.getCanonicalName());
        tFConfigBase.getProperties().put("am_state_machine_class", TFAMStateMachineImpl.class.getCanonicalName());
        tFConfigBase.getProperties().put("sys:record_reader_class", TFRecordReaderImpl.class.getCanonicalName());
        tFConfigBase.getProperties().put("sys:record_writer_class", TFRecordWriterImpl.class.getCanonicalName());
    }

    public static void startTensorBoard(StreamExecutionEnvironment streamExecutionEnvironment, TFConfigBase tFConfigBase) throws IOException {
        RoleUtils.addRole(streamExecutionEnvironment, ExecutionMode.OTHER, (DataStream) null, buildTensorBoardConfig(streamExecutionEnvironment, tFConfigBase).getMlConfig(), (TypeInformation) null, new TensorBoardRole());
    }

    private static TFConfigBase buildTensorBoardConfig(StreamExecutionEnvironment streamExecutionEnvironment, TFConfigBase tFConfigBase) throws IOException {
        TFConfigBase deepCopy = tFConfigBase.deepCopy();
        deepCopy.getProperties().put("script_runner_class", TensorBoardPythonRunner.class.getCanonicalName());
        deepCopy.getMlConfig().getRoleParallelismMap().put(new TensorBoardRole().name(), 1);
        PythonFileUtil.registerPythonFiles(streamExecutionEnvironment, deepCopy.getMlConfig());
        return deepCopy;
    }

    public static void startTensorBoard(StreamExecutionEnvironment streamExecutionEnvironment, TableEnvironment tableEnvironment, StatementSet statementSet, TFConfigBase tFConfigBase) throws IOException {
        RoleUtils.addRole(tableEnvironment, statementSet, ExecutionMode.OTHER, (Table) null, buildTensorBoardConfig(streamExecutionEnvironment, tFConfigBase).getMlConfig(), (Schema) null, new TensorBoardRole());
    }

    public static <IN, OUT> DataStream<OUT> run(StreamExecutionEnvironment streamExecutionEnvironment, ExecutionMode executionMode, DataStream<IN> dataStream, TFConfigBase tFConfigBase, TypeInformation<OUT> typeInformation) throws IOException {
        if (null != dataStream) {
            tFConfigBase.addProperty("job_has_input", "true");
        }
        setTFDefaultConfig(tFConfigBase);
        PythonFileUtil.registerPythonFiles(streamExecutionEnvironment, tFConfigBase.getMlConfig());
        TFConfigBase chiefTypeConfig = toChiefTypeConfig(tFConfigBase);
        RoleUtils.addAMRole(streamExecutionEnvironment, tFConfigBase.getMlConfig());
        if (tFConfigBase.getPsNum() > 0) {
            RoleUtils.addRole(streamExecutionEnvironment, executionMode, (DataStream) null, chiefTypeConfig.getMlConfig(), (TypeInformation) null, new PsRole());
        }
        return (DataStream) getWorkerDataStream(streamExecutionEnvironment, executionMode, dataStream, chiefTypeConfig, typeInformation).getLeft();
    }

    private static <IN, OUT> Pair<DataStream<OUT>, DataStream<OUT>> getWorkerDataStream(StreamExecutionEnvironment streamExecutionEnvironment, ExecutionMode executionMode, DataStream<IN> dataStream, TFConfigBase tFConfigBase, TypeInformation<OUT> typeInformation) throws IOException {
        DataStream dataStream2 = null;
        DataStream dataStream3 = null;
        boolean isWorkerZeroAlone = tFConfigBase.isWorkerZeroAlone();
        if (dataStream != null) {
            dataStream2 = hasScript(tFConfigBase) ? RoleUtils.addRole(streamExecutionEnvironment, executionMode, dataStream, tFConfigBase.getMlConfig(), typeInformation, new WorkerRole()) : dataStream.flatMap(tFConfigBase.getInferenceFlatMapFunction(new WorkerRole(), tFConfigBase.getMlConfig(), dataStream.getType(), typeInformation)).setParallelism(tFConfigBase.getWorkerNum()).name(new WorkerRole().name());
        } else if (isWorkerZeroAlone) {
            dataStream3 = RoleUtils.addRole(streamExecutionEnvironment, executionMode, (DataStream) null, tFConfigBase.getMlConfig(), typeInformation, new ChiefRole());
            if (tFConfigBase.getWorkerNum() > 0) {
                dataStream2 = RoleUtils.addRole(streamExecutionEnvironment, executionMode, (DataStream) null, tFConfigBase.getMlConfig(), typeInformation, new WorkerRole());
            }
        } else {
            dataStream2 = RoleUtils.addRole(streamExecutionEnvironment, executionMode, (DataStream) null, tFConfigBase.getMlConfig(), typeInformation, new WorkerRole());
        }
        return Pair.of(dataStream2, dataStream3);
    }

    public static Table train(StreamExecutionEnvironment streamExecutionEnvironment, TableEnvironment tableEnvironment, StatementSet statementSet, Table table, TFConfigBase tFConfigBase, Schema schema) throws IOException {
        return run(streamExecutionEnvironment, tableEnvironment, statementSet, ExecutionMode.TRAIN, table, tFConfigBase, schema);
    }

    public static Table inference(StreamExecutionEnvironment streamExecutionEnvironment, TableEnvironment tableEnvironment, StatementSet statementSet, Table table, TFConfigBase tFConfigBase, Schema schema) throws IOException {
        return run(streamExecutionEnvironment, tableEnvironment, statementSet, ExecutionMode.INFERENCE, table, tFConfigBase, schema);
    }

    public static Table run(StreamExecutionEnvironment streamExecutionEnvironment, TableEnvironment tableEnvironment, StatementSet statementSet, ExecutionMode executionMode, Table table, TFConfigBase tFConfigBase, Schema schema) throws IOException {
        boolean hasScript = hasScript(tFConfigBase);
        Preconditions.checkArgument(hasScript || executionMode == ExecutionMode.INFERENCE, "Python script can be omitted only for inference");
        Preconditions.checkArgument(hasScript || table != null, "Input table and python script can't both be null");
        if (null != table) {
            tFConfigBase.addProperty("job_has_input", "true");
        }
        setTFDefaultConfig(tFConfigBase);
        Table table2 = null;
        Table table3 = null;
        TFConfigBase chiefTypeConfig = toChiefTypeConfig(tFConfigBase);
        DataStream<Row> tableToDS = tableToDS(table, tableEnvironment);
        if (hasScript) {
            PythonFileUtil.registerPythonFiles(streamExecutionEnvironment, chiefTypeConfig.getMlConfig());
            RoleUtils.addAMRole(tableEnvironment, statementSet, tFConfigBase.getMlConfig());
            if (chiefTypeConfig.getPsNum() > 0) {
                RoleUtils.addRole(tableEnvironment, statementSet, executionMode, (Table) null, chiefTypeConfig.getMlConfig(), (Schema) null, new PsRole());
            }
        }
        Pair workerDataStream = getWorkerDataStream(streamExecutionEnvironment, executionMode, tableToDS, chiefTypeConfig, TypeUtil.schemaToRowTypeInfo((schema != null ? schema : DUMMY_SCHEMA).resolve(((TableEnvironmentInternal) tableEnvironment).getCatalogManager().getSchemaResolver())));
        if (workerDataStream.getLeft() != null) {
            table2 = dsToTable((DataStream) workerDataStream.getLeft(), tableEnvironment);
        }
        if (workerDataStream.getRight() != null) {
            table3 = dsToTable((DataStream) workerDataStream.getRight(), tableEnvironment);
        }
        if (schema == null) {
            if (table2 != null) {
                writeToDummySink(table2, tableEnvironment, statementSet);
            }
            if (table3 != null) {
                writeToDummySink(table3, tableEnvironment, statementSet);
            }
        }
        return table2;
    }

    private static TFConfigBase toChiefTypeConfig(TFConfigBase tFConfigBase) {
        TFConfigBase tFConfigBase2;
        if (tFConfigBase.isWorkerZeroAlone()) {
            TFConfigBase deepCopy = tFConfigBase.deepCopy();
            deepCopy.getMlConfig().getRoleParallelismMap().put(new ChiefRole().name(), 1);
            if (tFConfigBase.getWorkerNum() > 1) {
                deepCopy.getMlConfig().getRoleParallelismMap().put(new WorkerRole().name(), Integer.valueOf(tFConfigBase.getWorkerNum() - 1));
            } else {
                deepCopy.getMlConfig().getRoleParallelismMap().remove(new WorkerRole().name());
            }
            tFConfigBase2 = deepCopy;
        } else {
            tFConfigBase2 = tFConfigBase;
        }
        return tFConfigBase2;
    }

    private static Table dsToTable(DataStream<Row> dataStream, TableEnvironment tableEnvironment) {
        return ((StreamTableEnvironment) tableEnvironment).fromDataStream(dataStream);
    }

    private static DataStream<Row> tableToDS(Table table, TableEnvironment tableEnvironment) {
        if (table == null) {
            return null;
        }
        return ((StreamTableEnvironment) tableEnvironment).toAppendStream(table, TypeUtil.schemaToRowTypeInfo(table.getSchema()));
    }

    private static <OUT> TypeInformation<OUT> getTypeInfo(Class<OUT> cls) {
        if (cls == null) {
            return null;
        }
        return TypeInformation.of(cls);
    }

    private static void writeToDummySink(Table table, TableEnvironment tableEnvironment, StatementSet statementSet) {
        String format = String.format("dummy_sink_%s", Integer.valueOf(count.getAndIncrement()));
        tableEnvironment.createTemporaryTable(format, TableDescriptor.forConnector("DummyTable").schema(DUMMY_SCHEMA).build());
        statementSet.addInsert(format, table);
    }

    private static boolean hasScript(TFConfigBase tFConfigBase) {
        return tFConfigBase.getPythonFiles() != null && tFConfigBase.getPythonFiles().length > 0;
    }
}
