package org.flinkextended.flink.ml.cluster.node.runner;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.apache.commons.io.IOUtils;
import org.flinkextended.flink.ml.cluster.node.MLContext;
import org.flinkextended.flink.ml.cluster.rpc.AMClient;
import org.flinkextended.flink.ml.cluster.rpc.AMRegistry;
import org.flinkextended.flink.ml.cluster.rpc.NodeServer;
import org.flinkextended.flink.ml.cluster.rpc.RpcCode;
import org.flinkextended.flink.ml.proto.AMStatus;
import org.flinkextended.flink.ml.proto.GetClusterInfoResponse;
import org.flinkextended.flink.ml.proto.MLClusterDef;
import org.flinkextended.flink.ml.proto.NodeSpec;
import org.flinkextended.flink.ml.proto.SimpleResponse;
import org.flinkextended.flink.ml.util.IpHostUtil;
import org.flinkextended.flink.ml.util.MLConstants;
import org.flinkextended.flink.ml.util.MLException;
import org.flinkextended.flink.ml.util.ProtoUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/flinkextended/flink/ml/cluster/node/runner/CommonMLRunner.class */
public class CommonMLRunner implements MLRunner {
    private static final Logger LOG = LoggerFactory.getLogger(CommonMLRunner.class);
    protected volatile AMClient amClient;
    protected NodeSpec nodeSpec;
    protected long version = 0;
    protected String localIp;
    protected NodeServer server;
    protected volatile MLContext mlContext;
    protected ScriptRunner scriptRunner;
    protected ExecutionStatus resultStatus;
    protected ExecutionStatus currentResultStatus;
    protected ExecutorService heartbeatService;
    protected MLClusterDef mlClusterDef;
    private Future<?> heartBeatRunnerFuture;
    private NodeHeartBeatRunner heartBeatRunner;

    public CommonMLRunner(MLContext mLContext, NodeServer nodeServer) {
        this.mlContext = mLContext;
        this.server = nodeServer;
    }

    protected boolean doRegisterAction() throws Exception {
        createNodeSpec(true);
        getCurrentJobVersion();
        SimpleResponse registerNode = this.amClient.registerNode(this.version, this.nodeSpec);
        if (RpcCode.OK.ordinal() == registerNode.getCode()) {
            return true;
        }
        if (RpcCode.VERSION_ERROR.ordinal() == registerNode.getCode()) {
            throw new MLException(this.mlContext.getIdentity() + " version mismatch with AM");
        }
        LOG.warn("register to master failed code :" + registerNode.getCode() + " message:" + registerNode.getMessage());
        LOG.error("Fail to register node. This node is " + this.localIp + ":" + this.server.getPort() + ", am server is " + this.amClient.getHost() + ":" + this.amClient.getPort());
        return false;
    }

    @Override // org.flinkextended.flink.ml.cluster.node.runner.MLRunner
    public void registerNode() throws Exception {
        doRegisterAction(System.currentTimeMillis(), this.mlContext.getProperties().getOrDefault(MLConstants.FAILOVER_STRATEGY, "all").equalsIgnoreCase(MLConstants.FAILOVER_RESTART_INDIVIDUAL_STRATEGY));
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:2:0x0043. Please report as an issue. */
    protected void doRegisterAction(long j, boolean z) throws Exception {
        do {
            checkEnd();
            AMStatus aMStatus = this.amClient.getAMStatus();
            LOG.info(this.mlContext.getIdentity() + " registerNode status:" + aMStatus.toString());
            checkEnd();
            switch (aMStatus) {
                case AM_INIT:
                    if (doRegisterAction()) {
                        return;
                    }
                    Thread.sleep(3000L);
                    break;
                case AM_RUNNING:
                    if (z) {
                        if (doRegisterAction()) {
                            return;
                        }
                        Thread.sleep(3000L);
                        break;
                    }
                case AM_UNKNOW:
                case AM_FAILOVER:
                    LOG.warn("master status is {} wait for INIT!", aMStatus.toString());
                    Thread.sleep(3000L);
                    break;
                default:
                    throw new RuntimeException("AM status is " + aMStatus.toString() + " can not register node!");
            }
        } while (System.currentTimeMillis() - j <= MLConstants.TIMEOUT);
        throw new MLException(this.mlContext.getIdentity() + " timed out registering to AM");
    }

    protected NodeSpec createNodeSpec(boolean z) throws Exception {
        if (z || null == this.nodeSpec) {
            this.nodeSpec = NodeSpec.newBuilder().setIp(this.localIp).setIndex(this.mlContext.getIndex()).setClientPort(this.server.getPort().intValue()).setRoleName(this.mlContext.getRoleName()).m1500build();
        }
        return this.nodeSpec;
    }

    @Override // org.flinkextended.flink.ml.cluster.node.runner.MLRunner
    public void getClusterInfo() throws MLException, InterruptedException {
        long currentTimeMillis = System.currentTimeMillis();
        while (System.currentTimeMillis() - currentTimeMillis <= MLConstants.TIMEOUT) {
            checkEnd();
            GetClusterInfoResponse clusterInfo = this.amClient.getClusterInfo(this.version);
            if (RpcCode.OK.ordinal() == clusterInfo.getCode()) {
                this.mlClusterDef = clusterInfo.getClusterDef();
                return;
            } else {
                LOG.info("wait for cluster info:" + clusterInfo.getCode() + " message:" + clusterInfo.getMessage());
                Thread.sleep(3000L);
            }
        }
        this.mlClusterDef = null;
    }

    protected void checkEnd() throws MLException {
        if (this.resultStatus == ExecutionStatus.KILLED_BY_FLINK) {
            throw new FlinkKillException("Exit per request.");
        }
    }

    @Override // java.lang.Runnable
    public void run() {
        this.resultStatus = ExecutionStatus.RUNNING;
        this.currentResultStatus = ExecutionStatus.RUNNING;
        try {
            try {
                initAMClient();
                checkEnd();
                LOG.info("init amClient.");
                getCurrentJobVersion();
                checkEnd();
                LOG.info("get current job version.");
                getTaskIndex();
                LOG.info("get task index.");
                registerNode();
                checkEnd();
                LOG.info("register node to application master.");
                startHeartBeat();
                LOG.info("start heart beat thread.");
                Thread.sleep(10000L);
                waitClusterRunning();
                LOG.info("wait for cluster to running status.");
                getClusterInfo();
                Preconditions.checkNotNull(this.mlClusterDef, "Cannot get cluster def from AM");
                checkEnd();
                LOG.info("get cluster info.");
                resetMLContext();
                checkEnd();
                LOG.info("reset machine learning context.");
                runScript();
                checkEnd();
                LOG.info("run script.");
                this.currentResultStatus = ExecutionStatus.SUCCEED;
                stopExecution(this.currentResultStatus == ExecutionStatus.SUCCEED);
                this.resultStatus = this.currentResultStatus;
            } catch (Exception e) {
                if ((e instanceof FlinkKillException) || (e instanceof InterruptedException)) {
                    LOG.info("{} killed by flink.", this.mlContext.getIdentity());
                    this.currentResultStatus = ExecutionStatus.KILLED_BY_FLINK;
                } else {
                    LOG.error("Got exception during python running", e);
                    this.mlContext.addFailNum();
                    this.currentResultStatus = ExecutionStatus.FAILED;
                }
                stopExecution(this.currentResultStatus == ExecutionStatus.SUCCEED);
                this.resultStatus = this.currentResultStatus;
            }
        } catch (Throwable th) {
            stopExecution(this.currentResultStatus == ExecutionStatus.SUCCEED);
            this.resultStatus = this.currentResultStatus;
            throw th;
        }
    }

    @Override // org.flinkextended.flink.ml.cluster.node.runner.MLRunner
    public void runScript() throws Exception {
        this.scriptRunner = ScriptRunnerFactory.getScriptRunner(this.mlContext);
        this.scriptRunner.runScript();
    }

    @Override // org.flinkextended.flink.ml.cluster.node.runner.MLRunner
    public void resetMLContext() {
        String protoToJson = ProtoUtil.protoToJson(this.mlClusterDef);
        LOG.info("java cluster:" + protoToJson);
        this.mlContext.getProperties().put(MLConstants.CONFIG_CLUSTER_PATH, protoToJson);
        this.mlContext.setNodeServerIP(this.localIp);
        this.mlContext.setNodeServerPort(this.server.getPort().intValue());
    }

    @Override // org.flinkextended.flink.ml.cluster.node.runner.MLRunner
    public void startHeartBeat() throws Exception {
        this.heartbeatService = Executors.newFixedThreadPool(1, runnable -> {
            Thread thread = new Thread(runnable);
            thread.setName(this.mlContext.getIdentity() + "-HeartBeat");
            thread.setDaemon(true);
            return thread;
        });
        this.heartBeatRunner = new NodeHeartBeatRunner(this.mlContext, this.server, this.nodeSpec, this.version);
        this.heartBeatRunnerFuture = this.heartbeatService.submit(this.heartBeatRunner);
    }

    protected void stopHeartBeat() {
        if (null == this.heartbeatService || this.heartbeatService.isShutdown()) {
            return;
        }
        this.heartBeatRunner.setStopFlag(true);
        this.heartbeatService.shutdownNow();
        while (!this.heartbeatService.awaitTermination(10L, TimeUnit.SECONDS)) {
            try {
                LOG.info("CommonMLRunner {} timed out waiting for Heartbeat service to terminate", this.mlContext.getIdentity());
                this.heartbeatService.shutdownNow();
            } catch (InterruptedException e) {
                LOG.warn("stop heart beat exception", e);
                this.heartbeatService.shutdownNow();
            }
        }
    }

    @Override // org.flinkextended.flink.ml.cluster.node.runner.MLRunner
    public void getCurrentJobVersion() {
        this.version = this.amClient.getVersion().getVersion();
    }

    @Override // org.flinkextended.flink.ml.cluster.node.runner.MLRunner
    public void initAMClient() throws Exception {
        this.localIp = IpHostUtil.getIpAddress();
        checkEnd();
        this.amClient = AMRegistry.getAMClient(this.mlContext);
        LOG.info("{} at {}:{}, am server at {}:{}", new Object[]{this.mlContext.getIdentity(), this.localIp, this.server.getPort(), this.amClient.getHost(), Integer.valueOf(this.amClient.getPort())});
    }

    @Override // org.flinkextended.flink.ml.cluster.node.runner.MLRunner
    public void waitClusterRunning() throws InterruptedException, MLException {
        long currentTimeMillis = System.currentTimeMillis();
        while (AMStatus.AM_RUNNING != this.amClient.getAMStatus()) {
            Thread.sleep(5000L);
            checkEnd();
            if (System.currentTimeMillis() - currentTimeMillis > MLConstants.TIMEOUT) {
                throw new MLException("Timed out waiting for job to start running");
            }
        }
    }

    public void getTaskIndex() throws MLException, InterruptedException {
        if (this.mlContext.getIndex() < 0) {
            try {
                this.mlContext.setIndex(this.amClient.getTaskIndex(this.version, this.mlContext.getRoleName(), IpHostUtil.getIpAddress() + ":" + this.server.getPort()));
                checkEnd();
            } catch (Exception e) {
                e.printStackTrace();
                throw new RuntimeException(e.getMessage());
            }
        }
    }

    protected void stopExecution(boolean z) {
        if (this.scriptRunner != null) {
            IOUtils.closeQuietly(this.scriptRunner);
            this.scriptRunner = null;
        }
        stopHeartBeat();
        notifyAmWorkerFinish(z);
        if (this.amClient != null) {
            LOG.info("{} closing AM connection", this.mlContext.getIdentity());
            this.amClient.close();
            this.amClient = null;
        }
        if (z) {
            return;
        }
        this.mlContext.reset();
    }

    protected void notifyAmWorkerFinish(boolean z) {
        if (this.amClient != null) {
            try {
                if (this.nodeSpec == null) {
                    return;
                }
                try {
                    if (z) {
                        LOG.info("report node finish:" + this.mlContext.getIdentity());
                        SimpleResponse nodeFinish = this.amClient.nodeFinish(this.version, this.nodeSpec);
                        if (RpcCode.OK.ordinal() != nodeFinish.getCode() && RpcCode.VERSION_ERROR.ordinal() != nodeFinish.getCode()) {
                            LOG.error("Fail to report node finish status to AM.");
                        }
                    } else if (this.currentResultStatus == ExecutionStatus.FAILED) {
                        SimpleResponse reportFailedNode = this.amClient.reportFailedNode(this.version, this.nodeSpec);
                        LOG.info("report failed node:" + this.mlContext.getIdentity());
                        if (RpcCode.OK.ordinal() != reportFailedNode.getCode()) {
                            LOG.error("Fail to report node failed status to AM.");
                        }
                    } else if (this.currentResultStatus == ExecutionStatus.KILLED_BY_FLINK) {
                    }
                    this.amClient.close();
                } catch (Exception e) {
                    LOG.error(this.mlContext.getIdentity() + " failed to notify AM of finished node", e);
                    throw new RuntimeException(e);
                }
            } catch (Throwable th) {
                this.amClient.close();
                throw th;
            }
        }
    }

    @Override // org.flinkextended.flink.ml.cluster.node.runner.MLRunner
    public ExecutionStatus getResultStatus() {
        return this.resultStatus;
    }

    @Override // org.flinkextended.flink.ml.cluster.node.runner.MLRunner
    public void notifyStop() {
        if (this.scriptRunner != null) {
            this.scriptRunner.notifyKillSignal();
        }
        this.resultStatus = ExecutionStatus.KILLED_BY_FLINK;
    }

    @VisibleForTesting
    Future<?> getHeartBeatRunnerFuture() {
        return this.heartBeatRunnerFuture;
    }
}
