package org.flinkextended.flink.ml.cluster.rpc;

import com.google.common.util.concurrent.ListenableFuture;
import com.google.protobuf.MessageOrBuilder;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.stream.Collectors;
import org.flinkextended.flink.ml.cluster.master.AMEvent;
import org.flinkextended.flink.ml.cluster.master.AMEventType;
import org.flinkextended.flink.ml.cluster.master.AMService;
import org.flinkextended.flink.ml.cluster.master.HeartbeatMonitor;
import org.flinkextended.flink.ml.cluster.role.WorkerRole;
import org.flinkextended.flink.ml.proto.AMStatusMessage;
import org.flinkextended.flink.ml.proto.AppMasterServiceGrpc;
import org.flinkextended.flink.ml.proto.FinishNodeRequest;
import org.flinkextended.flink.ml.proto.GetAMStatusRequest;
import org.flinkextended.flink.ml.proto.GetClusterInfoRequest;
import org.flinkextended.flink.ml.proto.GetClusterInfoResponse;
import org.flinkextended.flink.ml.proto.GetFinishNodeResponse;
import org.flinkextended.flink.ml.proto.GetFinishedNodeRequest;
import org.flinkextended.flink.ml.proto.GetTaskIndexRequest;
import org.flinkextended.flink.ml.proto.GetTaskIndexResponse;
import org.flinkextended.flink.ml.proto.GetVersionRequest;
import org.flinkextended.flink.ml.proto.GetVersionResponse;
import org.flinkextended.flink.ml.proto.HeartBeatRequest;
import org.flinkextended.flink.ml.proto.MLClusterDef;
import org.flinkextended.flink.ml.proto.MLJobDef;
import org.flinkextended.flink.ml.proto.NodeRestartResponse;
import org.flinkextended.flink.ml.proto.NodeSpec;
import org.flinkextended.flink.ml.proto.NodeStopResponse;
import org.flinkextended.flink.ml.proto.RegisterFailedNodeRequest;
import org.flinkextended.flink.ml.proto.RegisterNodeRequest;
import org.flinkextended.flink.ml.proto.SimpleResponse;
import org.flinkextended.flink.ml.proto.StopAllWorkerRequest;
import org.flinkextended.flink.ml.util.ProtoUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/flinkextended/flink/ml/cluster/rpc/AppMasterServiceImpl.class */
public class AppMasterServiceImpl extends AppMasterServiceGrpc.AppMasterServiceImplBase implements AMService {
    private static final Logger LOG = LoggerFactory.getLogger(AppMasterServiceImpl.class);
    private final AppMasterServer appMasterServer;
    private final ScheduledExecutorService scheduledExecutor;
    private final Duration heartbeatTimeout;
    private volatile long version = 0;
    private volatile Map<String, NodeClient> nodeClientCache = new ConcurrentHashMap();
    private volatile Map<String, Map<String, Integer>> nodeIndexMap = new ConcurrentHashMap();
    private final Map<String, HeartbeatMonitor> heartbeatMonitors = new ConcurrentHashMap();
    private volatile boolean isRestart = false;

    @Override // org.flinkextended.flink.ml.cluster.master.AMService
    public long version() {
        return this.version;
    }

    @Override // org.flinkextended.flink.ml.cluster.master.AMService
    public void setVersion(long j) {
        this.version = j;
    }

    public AppMasterServiceImpl(AppMasterServer appMasterServer, int i, Duration duration) {
        this.appMasterServer = appMasterServer;
        this.scheduledExecutor = Executors.newScheduledThreadPool(i);
        this.heartbeatTimeout = duration;
    }

    @Override // org.flinkextended.flink.ml.proto.AppMasterServiceGrpc.AppMasterServiceImplBase
    public void registerNode(RegisterNodeRequest registerNodeRequest, StreamObserver<SimpleResponse> streamObserver) {
        SimpleResponse.Builder newBuilder = SimpleResponse.newBuilder();
        if (checkVersionError(registerNodeRequest.getVersion(), streamObserver)) {
            return;
        }
        this.appMasterServer.updateRpcLastContact();
        try {
            String nodeClientKey = AppMasterServer.getNodeClientKey(registerNodeRequest.getNodeSpec());
            NodeClient orDefault = this.nodeClientCache.getOrDefault(nodeClientKey, null);
            if (null != orDefault) {
                orDefault.close();
            }
            this.nodeClientCache.put(nodeClientKey, new NodeClient(registerNodeRequest.getNodeSpec().getIp(), registerNodeRequest.getNodeSpec().getClientPort()));
            LOG.info("register node:" + nodeClientKey);
            startHeartBeatMonitor(registerNodeRequest.getNodeSpec(), registerNodeRequest.getVersion());
            this.appMasterServer.getAmStateMachine().handle(new AMEvent(AMEventType.REGISTER_NODE, registerNodeRequest, registerNodeRequest.getVersion()));
            newBuilder.setCode(RpcCode.OK.ordinal());
            newBuilder.setMessage("");
        } catch (Exception e) {
            newBuilder.setCode(RpcCode.ERROR.ordinal());
            newBuilder.setMessage(e.getMessage());
            handleStateTransitionError(registerNodeRequest, e);
        }
        streamObserver.onNext(newBuilder.m1973build());
        streamObserver.onCompleted();
    }

    @Override // org.flinkextended.flink.ml.cluster.master.AMService
    public void startHeartBeatMonitor(NodeSpec nodeSpec, long j) {
        String nodeClientKey = AppMasterServer.getNodeClientKey(nodeSpec);
        HeartbeatMonitor heartbeatMonitor = new HeartbeatMonitor(new HeartbeatListenerImpl(this.appMasterServer, nodeSpec, j));
        heartbeatMonitor.updateTimeout(this.heartbeatTimeout, this.scheduledExecutor);
        this.heartbeatMonitors.put(nodeClientKey, heartbeatMonitor);
        LOG.info("Started monitoring heartbeat for {}", nodeClientKey);
    }

    @Override // org.flinkextended.flink.ml.cluster.master.AMService
    public void stopService() {
        this.appMasterServer.setEnd(true);
        this.scheduledExecutor.shutdownNow();
        LOG.info("stop heartbeat thread pool!");
    }

    @Override // org.flinkextended.flink.ml.proto.AppMasterServiceGrpc.AppMasterServiceImplBase
    public void heartBeatNode(HeartBeatRequest heartBeatRequest, StreamObserver<SimpleResponse> streamObserver) {
        if (this.isRestart || !checkVersionError(heartBeatRequest.getVersion(), streamObserver)) {
            this.appMasterServer.updateRpcLastContact();
            HeartbeatMonitor heartbeatMonitor = this.heartbeatMonitors.get(AppMasterServer.getNodeClientKey(heartBeatRequest.getNodeSpec()));
            if (heartbeatMonitor != null) {
                heartbeatMonitor.updateTimeout(this.heartbeatTimeout, this.scheduledExecutor);
            }
            SimpleResponse.Builder newBuilder = SimpleResponse.newBuilder();
            newBuilder.setCode(RpcCode.OK.ordinal()).setMessage("");
            streamObserver.onNext(newBuilder.m1973build());
            streamObserver.onCompleted();
        }
    }

    @Override // org.flinkextended.flink.ml.proto.AppMasterServiceGrpc.AppMasterServiceImplBase
    public void nodeFinish(FinishNodeRequest finishNodeRequest, StreamObserver<SimpleResponse> streamObserver) {
        if (checkVersionError(finishNodeRequest.getVersion(), streamObserver)) {
            return;
        }
        this.appMasterServer.updateRpcLastContact();
        SimpleResponse.Builder newBuilder = SimpleResponse.newBuilder();
        try {
            NodeClient remove = this.nodeClientCache.remove(AppMasterServer.getNodeClientKey(finishNodeRequest.getNodeSpec()));
            if (remove != null) {
                remove.close();
            }
            stopHeartBeatMonitorNode(AppMasterServer.getNodeClientKey(finishNodeRequest.getNodeSpec()));
            this.appMasterServer.getAmStateMachine().handle(new AMEvent(AMEventType.FINISH_NODE, finishNodeRequest, finishNodeRequest.getVersion()));
            newBuilder.setCode(RpcCode.OK.ordinal());
            newBuilder.setMessage("");
        } catch (Exception e) {
            newBuilder.setCode(RpcCode.ERROR.ordinal());
            newBuilder.setMessage(e.getMessage());
            handleStateTransitionError(finishNodeRequest, e);
        }
        streamObserver.onNext(newBuilder.m1973build());
        streamObserver.onCompleted();
    }

    private boolean checkVersionError(long j, StreamObserver<SimpleResponse> streamObserver) {
        if (this.version == j) {
            return false;
        }
        String format = String.format("version change current:%d request:%d", Long.valueOf(this.version), Long.valueOf(j));
        SimpleResponse.Builder newBuilder = SimpleResponse.newBuilder();
        newBuilder.setCode(RpcCode.VERSION_ERROR.ordinal()).setMessage(format);
        streamObserver.onNext(newBuilder.m1973build());
        streamObserver.onCompleted();
        return true;
    }

    @Override // org.flinkextended.flink.ml.cluster.master.AMService
    public void restartNode(NodeSpec nodeSpec) throws Exception {
        String nodeClientKey = AppMasterServer.getNodeClientKey(nodeSpec);
        NodeClient orDefault = this.nodeClientCache.getOrDefault(nodeClientKey, null);
        if (null == orDefault) {
            orDefault = new NodeClient(nodeSpec.getIp(), nodeSpec.getClientPort());
            this.nodeClientCache.put(nodeClientKey, orDefault);
        }
        try {
            NodeRestartResponse nodeRestartResponse = (NodeRestartResponse) orDefault.restartNode().get();
            if (nodeRestartResponse.getCode() != RpcCode.OK.ordinal()) {
                LOG.info(nodeRestartResponse.getMessage());
                throw new Exception(nodeRestartResponse.getMessage());
            }
            LOG.info("restart response:" + nodeRestartResponse.getMessage());
            stopHeartBeatMonitorNode(nodeClientKey);
        } catch (ExecutionException e) {
            e.printStackTrace();
            throw e;
        }
    }

    @Override // org.flinkextended.flink.ml.cluster.master.AMService
    public void restartAllNodes() throws Exception {
        this.isRestart = true;
        this.version = System.currentTimeMillis();
        LOG.info("current version:" + this.version);
        ArrayList arrayList = new ArrayList();
        HashSet<String> hashSet = new HashSet();
        for (Map.Entry<String, NodeClient> entry : this.nodeClientCache.entrySet()) {
            arrayList.add(entry.getValue().restartNode());
            LOG.info("send restart to node:" + entry.getKey());
            hashSet.add(entry.getKey());
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            try {
                NodeRestartResponse nodeRestartResponse = (NodeRestartResponse) ((ListenableFuture) it.next()).get();
                if (nodeRestartResponse.getCode() == RpcCode.OK.ordinal()) {
                    LOG.info("restart response:" + nodeRestartResponse.getMessage());
                } else {
                    LOG.info(nodeRestartResponse.getMessage());
                }
            } catch (ExecutionException e) {
                e.printStackTrace();
                LOG.info("restart err:" + e.getMessage());
            }
        }
        Iterator<NodeClient> it2 = this.nodeClientCache.values().iterator();
        while (it2.hasNext()) {
            it2.next().close();
        }
        for (String str : hashSet) {
            this.nodeClientCache.remove(str);
            stopHeartBeatMonitorNode(str);
        }
        this.isRestart = false;
    }

    @Override // org.flinkextended.flink.ml.cluster.master.AMService
    public void stopHeartBeatMonitorNode(String str) {
        HeartbeatMonitor remove = this.heartbeatMonitors.remove(str);
        if (remove != null) {
            remove.cancel();
            LOG.info("Stopped monitoring heartbeat for {}", str);
        }
    }

    @Override // org.flinkextended.flink.ml.cluster.master.AMService
    public void stopHeartBeatMonitorAllNode() {
        Iterator<String> it = this.heartbeatMonitors.keySet().iterator();
        while (it.hasNext()) {
            stopHeartBeatMonitorNode(it.next());
        }
    }

    @Override // org.flinkextended.flink.ml.cluster.master.AMService
    public void stopNode(NodeSpec nodeSpec) throws Exception {
        String nodeClientKey = AppMasterServer.getNodeClientKey(nodeSpec);
        NodeClient orDefault = this.nodeClientCache.getOrDefault(nodeClientKey, null);
        if (null == orDefault) {
            orDefault = new NodeClient(nodeSpec.getIp(), nodeSpec.getClientPort());
            this.nodeClientCache.put(nodeClientKey, orDefault);
        }
        try {
            NodeStopResponse nodeStopResponse = (NodeStopResponse) orDefault.stopNode().get();
            if (nodeStopResponse.getCode() == RpcCode.OK.ordinal()) {
                LOG.info("stop response:" + nodeStopResponse.getMessage());
            } else {
                LOG.info(nodeStopResponse.getMessage());
                throw new Exception(nodeStopResponse.getMessage());
            }
        } catch (ExecutionException e) {
            e.printStackTrace();
            throw e;
        }
    }

    @Override // org.flinkextended.flink.ml.cluster.master.AMService
    public void stopAllNodes() {
        if (this.nodeClientCache.isEmpty()) {
            return;
        }
        ArrayList arrayList = new ArrayList();
        LOG.info("client size:" + this.nodeClientCache.size());
        for (Map.Entry<String, NodeClient> entry : this.nodeClientCache.entrySet()) {
            arrayList.add(entry.getValue().stopNode());
            LOG.info("send stop to node:" + entry.getKey());
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            try {
                NodeStopResponse nodeStopResponse = (NodeStopResponse) ((ListenableFuture) it.next()).get();
                if (nodeStopResponse.getCode() != RpcCode.OK.ordinal()) {
                    LOG.info(nodeStopResponse.getMessage());
                }
            } catch (InterruptedException | ExecutionException e) {
                LOG.debug("Stop node server.", e);
            }
        }
        Iterator<NodeClient> it2 = this.nodeClientCache.values().iterator();
        while (it2.hasNext()) {
            it2.next().close();
        }
        this.nodeClientCache.clear();
        stopHeartBeatMonitorAllNode();
    }

    private boolean checkVersionError(long j, StreamObserver streamObserver, GetClusterInfoResponse.Builder builder) {
        if (this.version == j) {
            return false;
        }
        builder.setCode(RpcCode.VERSION_ERROR.ordinal()).setMessage(String.format("version change current:%d request:%d", Long.valueOf(this.version), Long.valueOf(j))).setClusterDef(MLClusterDef.newBuilder());
        streamObserver.onNext(builder.m788build());
        streamObserver.onCompleted();
        return true;
    }

    @Override // org.flinkextended.flink.ml.proto.AppMasterServiceGrpc.AppMasterServiceImplBase
    public void getClusterInfo(GetClusterInfoRequest getClusterInfoRequest, StreamObserver<GetClusterInfoResponse> streamObserver) {
        GetClusterInfoResponse.Builder newBuilder = GetClusterInfoResponse.newBuilder();
        if (checkVersionError(getClusterInfoRequest.getVersion(), streamObserver, newBuilder)) {
            return;
        }
        this.appMasterServer.updateRpcLastContact();
        try {
            MLClusterDef restoreClusterDef = this.appMasterServer.getAmMeta().restoreClusterDef();
            if (null != restoreClusterDef) {
                newBuilder.setCode(RpcCode.OK.ordinal()).setClusterDef(mergeFinishedClusterDef(restoreClusterDef, this.appMasterServer.getAmMeta().restoreFinishClusterDef()));
            } else {
                newBuilder.setCode(RpcCode.NOT_READY.ordinal()).setMessage("cluster is null!");
            }
        } catch (IOException e) {
            e.printStackTrace();
            newBuilder.setCode(RpcCode.ERROR.ordinal());
        }
        streamObserver.onNext(newBuilder.m788build());
        streamObserver.onCompleted();
    }

    @Override // org.flinkextended.flink.ml.proto.AppMasterServiceGrpc.AppMasterServiceImplBase
    public void getVersion(GetVersionRequest getVersionRequest, StreamObserver<GetVersionResponse> streamObserver) {
        this.appMasterServer.updateRpcLastContact();
        streamObserver.onNext(GetVersionResponse.newBuilder().setVersion(this.version).m1070build());
        streamObserver.onCompleted();
    }

    @Override // org.flinkextended.flink.ml.proto.AppMasterServiceGrpc.AppMasterServiceImplBase
    public void stopAllWorker(StopAllWorkerRequest stopAllWorkerRequest, StreamObserver<SimpleResponse> streamObserver) {
        this.appMasterServer.updateRpcLastContact();
        streamObserver.onNext(SimpleResponse.newBuilder().setMessage("").setCode(RpcCode.OK.ordinal()).m1973build());
        streamObserver.onCompleted();
        try {
            this.appMasterServer.getAmStateMachine().handle(new AMEvent(AMEventType.STOP_JOB, stopAllWorkerRequest, stopAllWorkerRequest.getVersion()));
        } catch (Exception e) {
            handleStateTransitionError(stopAllWorkerRequest, e);
        }
    }

    @Override // org.flinkextended.flink.ml.proto.AppMasterServiceGrpc.AppMasterServiceImplBase
    public void getAMStatus(GetAMStatusRequest getAMStatusRequest, StreamObserver<AMStatusMessage> streamObserver) {
        this.appMasterServer.updateRpcLastContact();
        streamObserver.onNext(AMStatusMessage.newBuilder().setStatus(this.appMasterServer.getAmStateMachine().getInternalState()).m73build());
        streamObserver.onCompleted();
    }

    @Override // org.flinkextended.flink.ml.proto.AppMasterServiceGrpc.AppMasterServiceImplBase
    public void registerFailNode(RegisterFailedNodeRequest registerFailedNodeRequest, StreamObserver<SimpleResponse> streamObserver) {
        if (checkVersionError(registerFailedNodeRequest.getVersion(), streamObserver)) {
            return;
        }
        this.appMasterServer.updateRpcLastContact();
        SimpleResponse.Builder newBuilder = SimpleResponse.newBuilder();
        try {
            this.appMasterServer.getAmStateMachine().handle(new AMEvent(AMEventType.FAIL_NODE, registerFailedNodeRequest, registerFailedNodeRequest.getVersion()));
            newBuilder.setCode(RpcCode.OK.ordinal()).setMessage("");
            streamObserver.onNext(newBuilder.m1973build());
            streamObserver.onCompleted();
        } catch (Exception e) {
            newBuilder.setCode(RpcCode.ERROR.ordinal()).setMessage(e.getMessage());
            streamObserver.onNext(newBuilder.m1973build());
            streamObserver.onCompleted();
            handleStateTransitionError(registerFailedNodeRequest, e);
        }
    }

    private boolean checkVersionError(long j, StreamObserver streamObserver, GetTaskIndexResponse.Builder builder) {
        if (this.version == j) {
            return false;
        }
        builder.setCode(RpcCode.VERSION_ERROR.ordinal()).setMessage(String.format("version change current:%d request:%d", Long.valueOf(this.version), Long.valueOf(j))).setIndex(0);
        streamObserver.onNext(builder.m976build());
        streamObserver.onCompleted();
        return true;
    }

    @Override // org.flinkextended.flink.ml.proto.AppMasterServiceGrpc.AppMasterServiceImplBase
    public synchronized void getTaskIndex(GetTaskIndexRequest getTaskIndexRequest, StreamObserver<GetTaskIndexResponse> streamObserver) {
        GetTaskIndexResponse.Builder newBuilder = GetTaskIndexResponse.newBuilder();
        if (checkVersionError(getTaskIndexRequest.getVersion(), streamObserver, newBuilder)) {
            return;
        }
        this.appMasterServer.updateRpcLastContact();
        Map<String, Integer> computeIfAbsent = this.nodeIndexMap.computeIfAbsent(getTaskIndexRequest.getScope(), str -> {
            return new ConcurrentHashMap();
        });
        newBuilder.setIndex(computeIfAbsent.computeIfAbsent(getTaskIndexRequest.getKey(), str2 -> {
            return Integer.valueOf(computeIfAbsent.size());
        }).intValue());
        newBuilder.setCode(RpcCode.OK.ordinal());
        streamObserver.onNext(newBuilder.m976build());
        streamObserver.onCompleted();
    }

    @Override // org.flinkextended.flink.ml.proto.AppMasterServiceGrpc.AppMasterServiceImplBase
    public void getFinishedNode(GetFinishedNodeRequest getFinishedNodeRequest, StreamObserver<GetFinishNodeResponse> streamObserver) {
        this.appMasterServer.updateRpcLastContact();
        GetFinishNodeResponse.Builder newBuilder = GetFinishNodeResponse.newBuilder();
        try {
            newBuilder.setCode(0).setMessage("");
            MLClusterDef restoreFinishClusterDef = this.appMasterServer.getAmMeta().restoreFinishClusterDef();
            if (null != restoreFinishClusterDef) {
                for (MLJobDef mLJobDef : restoreFinishClusterDef.getJobList()) {
                    if (mLJobDef.getName().equals(new WorkerRole().name())) {
                        Iterator<Integer> it = mLJobDef.getTasksMap().keySet().iterator();
                        while (it.hasNext()) {
                            newBuilder.addWorkers(it.next().intValue());
                        }
                    }
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
            newBuilder.setCode(1).setMessage(e.getMessage());
        }
        streamObserver.onNext(newBuilder.m835build());
        streamObserver.onCompleted();
    }

    private MLClusterDef mergeFinishedClusterDef(MLClusterDef mLClusterDef, MLClusterDef mLClusterDef2) {
        if (mLClusterDef2 == null) {
            return mLClusterDef;
        }
        MLClusterDef.Builder newBuilder = MLClusterDef.newBuilder();
        Map map = (Map) mLClusterDef.getJobList().stream().collect(Collectors.toMap((v0) -> {
            return v0.getName();
        }, mLJobDef -> {
            return mLJobDef;
        }));
        for (MLJobDef mLJobDef2 : mLClusterDef2.getJobList()) {
            if (map.containsKey(mLJobDef2.getName())) {
                MLJobDef.Builder newBuilder2 = MLJobDef.newBuilder();
                MLJobDef mLJobDef3 = (MLJobDef) map.get(mLJobDef2.getName());
                newBuilder2.mergeFrom(mLJobDef3);
                for (Integer num : mLJobDef2.getTasksMap().keySet()) {
                    if (!mLJobDef3.getTasksMap().containsKey(num)) {
                        newBuilder2.putTasks(num.intValue(), mLJobDef2.getTasksMap().get(num));
                    }
                }
                newBuilder.addJob(newBuilder2.m1259build());
                map.remove(mLJobDef2.getName());
            } else {
                newBuilder.addJob(mLJobDef2);
            }
        }
        newBuilder.addAllJob(map.values());
        return newBuilder.m1164build();
    }

    @Override // org.flinkextended.flink.ml.cluster.master.AMService
    public void handleStateTransitionError(MessageOrBuilder messageOrBuilder, Throwable th) {
        LOG.error(messageOrBuilder != null ? String.format("Failed to handle request %s:\n%s", messageOrBuilder.getClass().getName(), ProtoUtil.protoToJson(messageOrBuilder)) : "State transition failed", th);
        this.appMasterServer.onError(th);
    }

    @Override // org.flinkextended.flink.ml.cluster.master.AMService
    public void updateNodeClient(String str, NodeClient nodeClient) {
        this.nodeClientCache.put(str, nodeClient);
    }
}
