package org.flinkextended.flink.ml.operator.ops;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.util.HashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.FutureTask;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.iteration.IterationListener;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.Output;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.OutputTag;
import org.flinkextended.flink.ml.cluster.ClusterConfig;
import org.flinkextended.flink.ml.cluster.ExecutionMode;
import org.flinkextended.flink.ml.cluster.node.MLContext;
import org.flinkextended.flink.ml.cluster.rpc.NodeServer;
import org.flinkextended.flink.ml.data.DataExchange;
import org.flinkextended.flink.ml.operator.util.ColumnInfos;
import org.flinkextended.flink.ml.operator.util.PythonFileUtil;

/* loaded from: input_file:org/flinkextended/flink/ml/operator/ops/NodeOperator.class */
public class NodeOperator<OUT> extends AbstractStreamOperator<OUT> implements OneInputStreamOperator<Row, OUT>, IterationListener<OUT> {
    protected long closeTimeoutMs;
    private final String nodeType;
    private final ClusterConfig clusterConfig;
    private final Configuration flinkConfig;
    private MLContext mlContext;
    private DataExchange<Row, OUT> dataExchange;
    private FutureTask<Void> serverFuture;
    private FutureTask<Void> dataExchangeConsumerFuture;
    private Thread serverThread;
    private Thread dataExchangeConsumerThread;

    /* loaded from: input_file:org/flinkextended/flink/ml/operator/ops/NodeOperator$DataExchangeConsumer.class */
    private static class DataExchangeConsumer<IN, OUT> implements Runnable {
        private final DataExchange<IN, OUT> dataExchange;
        private final Output<StreamRecord<OUT>> output;

        DataExchangeConsumer(DataExchange<IN, OUT> dataExchange, Output<StreamRecord<OUT>> output) {
            this.dataExchange = dataExchange;
            this.output = output;
        }

        @Override // java.lang.Runnable
        public void run() {
            Object read;
            while (true) {
                try {
                    read = this.dataExchange.read(true);
                } catch (InterruptedIOException e) {
                    NodeOperator.LOG.warn("Reading from data exchange is interrupted.", e);
                } catch (IOException e2) {
                    NodeOperator.LOG.error("Fail to read data from python.", e2);
                }
                if (read == null) {
                    break;
                } else {
                    this.output.collect(new StreamRecord(read));
                }
            }
            NodeOperator.LOG.info("DataExchange consumer consume all data, exiting...");
        }
    }

    public NodeOperator(String str, ClusterConfig clusterConfig) {
        this(str, clusterConfig, new Configuration());
    }

    public NodeOperator(String str, ClusterConfig clusterConfig, Configuration configuration) {
        this.closeTimeoutMs = 30000L;
        this.nodeType = str;
        this.clusterConfig = clusterConfig;
        this.flinkConfig = configuration;
    }

    public void open() throws Exception {
        PythonEnvironmentManager pythonEnvironmentManager = new PythonEnvironmentManager(this.clusterConfig, this.flinkConfig);
        pythonEnvironmentManager.open(getRuntimeContext());
        HashMap hashMap = new HashMap(this.clusterConfig.getProperties());
        hashMap.put("gpu_info", ResourcesUtils.parseGpuInfo(getRuntimeContext()));
        hashMap.putAll(pythonEnvironmentManager.getPythonEnvProperties());
        this.mlContext = new MLContext(ExecutionMode.OTHER, this.nodeType, getRuntimeContext().getIndexOfThisSubtask(), this.clusterConfig.getNodeTypeCntMap(), this.clusterConfig.getEntryFuncName(), hashMap, this.clusterConfig.getPythonVirtualEnvZipPath(), ColumnInfos.dummy().getNameToTypeMap());
        preparePythonFiles();
        this.serverFuture = new FutureTask<>(createNodeServerRunnable(), null);
        this.serverThread = runRunnable(this.serverFuture, "NodeServer_" + maybeGetIdentity());
        this.dataExchange = new DataExchange<>(this.mlContext);
        this.dataExchangeConsumerFuture = new FutureTask<>(new DataExchangeConsumer(this.dataExchange, this.output), null);
        this.dataExchangeConsumerThread = runRunnable(this.dataExchangeConsumerFuture, "NodeServerDataExchangeConsumer_" + maybeGetIdentity());
    }

    public void processElement(StreamRecord<Row> streamRecord) throws Exception {
        boolean write;
        do {
            try {
                write = this.dataExchange.write(streamRecord.getValue());
                if (!write) {
                    Thread.yield();
                }
            } catch (IOException e) {
                if (!this.serverFuture.isDone()) {
                    throw e;
                }
                return;
            }
        } while (!write);
    }

    public void finish() throws Exception {
        LOG.info("Start finishing NodeOperator {}", maybeGetIdentity());
        cleanup(false);
    }

    public void close() throws Exception {
        LOG.info("Start closing NodeOperator {}", maybeGetIdentity());
        cleanup(true);
        if (this.mlContext != null) {
            try {
                this.mlContext.close();
            } catch (IOException e) {
                LOG.error("Fail to close mlContext.", e);
            }
            this.mlContext = null;
        }
    }

    private void cleanup(boolean z) {
        if (this.mlContext != null && this.mlContext.getOutputQueue() != null) {
            this.mlContext.getOutputQueue().markFinished();
        }
        try {
            try {
                try {
                    try {
                        if (this.serverFuture != null && !this.serverFuture.isCancelled()) {
                            if (z) {
                                this.serverFuture.get(this.closeTimeoutMs, TimeUnit.MILLISECONDS);
                            } else {
                                this.serverFuture.get();
                            }
                        }
                        if (this.dataExchangeConsumerFuture != null && !this.dataExchangeConsumerFuture.isCancelled()) {
                            this.dataExchangeConsumerFuture.get();
                        }
                        if (this.serverFuture != null) {
                            this.serverFuture.cancel(true);
                            while (true) {
                                try {
                                    this.serverThread.join(30000L);
                                } catch (InterruptedException e) {
                                    LOG.error("Fail to wait for NodeServer to exit", e);
                                }
                                if (!this.serverThread.isAlive()) {
                                    break;
                                }
                                LOG.warn("NodeServer fail to exit in 30 second, interrupting...");
                                this.serverThread.interrupt();
                            }
                        }
                        if (this.dataExchangeConsumerFuture != null) {
                            while (true) {
                                this.dataExchangeConsumerFuture.cancel(true);
                                try {
                                    this.dataExchangeConsumerThread.join();
                                    break;
                                } catch (InterruptedException e2) {
                                    LOG.error("Fail to wait for DataExchangeConsumer to exit", e2);
                                }
                            }
                        }
                        this.serverFuture = null;
                        this.dataExchangeConsumerFuture = null;
                        LOG.info("Records output: " + this.dataExchange.getReadRecords());
                    } catch (Throwable th) {
                        if (this.serverFuture != null) {
                            this.serverFuture.cancel(true);
                            while (true) {
                                try {
                                    this.serverThread.join(30000L);
                                } catch (InterruptedException e3) {
                                    LOG.error("Fail to wait for NodeServer to exit", e3);
                                }
                                if (!this.serverThread.isAlive()) {
                                    break;
                                }
                                LOG.warn("NodeServer fail to exit in 30 second, interrupting...");
                                this.serverThread.interrupt();
                            }
                        }
                        if (this.dataExchangeConsumerFuture != null) {
                            while (true) {
                                this.dataExchangeConsumerFuture.cancel(true);
                                try {
                                    this.dataExchangeConsumerThread.join();
                                    break;
                                } catch (InterruptedException e4) {
                                    LOG.error("Fail to wait for DataExchangeConsumer to exit", e4);
                                }
                            }
                        }
                        this.serverFuture = null;
                        this.dataExchangeConsumerFuture = null;
                        LOG.info("Records output: " + this.dataExchange.getReadRecords());
                        throw th;
                    }
                } catch (ExecutionException e5) {
                    LOG.error(maybeGetIdentity() + " node server failed");
                    throw new RuntimeException(e5);
                }
            } catch (TimeoutException e6) {
                LOG.warn("Timeout waiting for node {} to finish", maybeGetIdentity(), e6);
                if (this.serverFuture != null) {
                    this.serverFuture.cancel(true);
                    while (true) {
                        try {
                            this.serverThread.join(30000L);
                        } catch (InterruptedException e7) {
                            LOG.error("Fail to wait for NodeServer to exit", e7);
                        }
                        if (!this.serverThread.isAlive()) {
                            break;
                        }
                        LOG.warn("NodeServer fail to exit in 30 second, interrupting...");
                        this.serverThread.interrupt();
                    }
                }
                if (this.dataExchangeConsumerFuture != null) {
                    while (true) {
                        this.dataExchangeConsumerFuture.cancel(true);
                        try {
                            this.dataExchangeConsumerThread.join();
                            break;
                        } catch (InterruptedException e8) {
                            LOG.error("Fail to wait for DataExchangeConsumer to exit", e8);
                        }
                    }
                }
                this.serverFuture = null;
                this.dataExchangeConsumerFuture = null;
                LOG.info("Records output: " + this.dataExchange.getReadRecords());
            }
        } catch (InterruptedException e9) {
            LOG.warn("Fail to join node {}", maybeGetIdentity(), e9);
            if (this.serverFuture != null) {
                this.serverFuture.cancel(true);
                while (true) {
                    try {
                        this.serverThread.join(30000L);
                    } catch (InterruptedException e10) {
                        LOG.error("Fail to wait for NodeServer to exit", e10);
                    }
                    if (!this.serverThread.isAlive()) {
                        break;
                    }
                    LOG.warn("NodeServer fail to exit in 30 second, interrupting...");
                    this.serverThread.interrupt();
                }
            }
            if (this.dataExchangeConsumerFuture != null) {
                while (true) {
                    this.dataExchangeConsumerFuture.cancel(true);
                    try {
                        this.dataExchangeConsumerThread.join();
                        break;
                    } catch (InterruptedException e11) {
                        LOG.error("Fail to wait for DataExchangeConsumer to exit", e11);
                    }
                }
            }
            this.serverFuture = null;
            this.dataExchangeConsumerFuture = null;
            LOG.info("Records output: " + this.dataExchange.getReadRecords());
        }
    }

    public void onEpochWatermarkIncremented(int i, IterationListener.Context context, Collector<OUT> collector) throws Exception {
        this.mlContext.getOutputQueue().markBarrier();
        while (!this.serverFuture.isDone() && this.mlContext.getOutputQueue().canRead()) {
            Thread.sleep(100L);
        }
        if (this.serverFuture.isDone()) {
            LOG.info("{} finished at epoch {}", maybeGetIdentity(), Integer.valueOf(i));
        } else {
            context.output(new OutputTag<Integer>("termination") { // from class: org.flinkextended.flink.ml.operator.ops.NodeOperator.1
            }, 0);
        }
    }

    public void onIterationTerminated(IterationListener.Context context, Collector<OUT> collector) throws Exception {
    }

    private Thread runRunnable(Runnable runnable, String str) throws IOException {
        try {
            Thread thread = new Thread(runnable);
            thread.setDaemon(true);
            thread.setName(str);
            thread.start();
            LOG.info("start: {}", str);
            return thread;
        } catch (Exception e) {
            LOG.error("Fail to start node service.", e);
            throw new IOException(e.getMessage());
        }
    }

    public String getNodeType() {
        return this.nodeType;
    }

    @VisibleForTesting
    void preparePythonFiles() throws IOException {
        PythonFileUtil.preparePythonFilesForExec(getRuntimeContext(), this.mlContext);
    }

    @VisibleForTesting
    Runnable createNodeServerRunnable() {
        return new NodeServer(this.mlContext, this.nodeType);
    }

    @VisibleForTesting
    MLContext getMlContext() {
        return this.mlContext;
    }

    @VisibleForTesting
    FutureTask<Void> getServerFuture() {
        return this.serverFuture;
    }

    @VisibleForTesting
    DataExchange<Row, OUT> getDataExchange() {
        return this.dataExchange;
    }

    @VisibleForTesting
    FutureTask<Void> getDataExchangeConsumerFuture() {
        return this.dataExchangeConsumerFuture;
    }

    private String maybeGetIdentity() {
        return this.mlContext == null ? "" : this.mlContext.getIdentity();
    }
}
