package org.flinkextended.flink.ml.cluster.rpc;

import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.FutureTask;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.flinkextended.flink.ml.cluster.ExecutionMode;
import org.flinkextended.flink.ml.cluster.MLConfig;
import org.flinkextended.flink.ml.cluster.node.MLContext;
import org.flinkextended.flink.ml.cluster.role.AMRole;
import org.flinkextended.flink.ml.cluster.role.PsRole;
import org.flinkextended.flink.ml.cluster.role.WorkerRole;
import org.flinkextended.flink.ml.cluster.storage.StorageFactory;
import org.flinkextended.flink.ml.proto.AMStatus;
import org.flinkextended.flink.ml.proto.MLClusterDef;
import org.flinkextended.flink.ml.proto.MLJobDef;
import org.flinkextended.flink.ml.proto.NodeRestartRequest;
import org.flinkextended.flink.ml.proto.NodeRestartResponse;
import org.flinkextended.flink.ml.proto.NodeServiceGrpc;
import org.flinkextended.flink.ml.proto.NodeSpec;
import org.flinkextended.flink.ml.proto.NodeStopRequest;
import org.flinkextended.flink.ml.proto.NodeStopResponse;
import org.flinkextended.flink.ml.proto.SimpleResponse;
import org.flinkextended.flink.ml.util.DummyContext;
import org.flinkextended.flink.ml.util.IpHostUtil;
import org.flinkextended.flink.ml.util.MLException;
import org.flinkextended.flink.ml.util.SysUtil;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.mockito.Matchers;
import org.mockito.Mockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/flinkextended/flink/ml/cluster/rpc/AppMasterServerTest.class */
public class AppMasterServerTest {
    private static String ip;
    private static final Logger LOG = LoggerFactory.getLogger(AppMasterServerTest.class);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/flinkextended/flink/ml/cluster/rpc/AppMasterServerTest$DummyNodeServer.class */
    public static class DummyNodeServer {
        private static final long timeout = Duration.ofSeconds(60).toMillis();
        private final MLContext mlContext;
        private final Server server;
        private AMClient amClient;
        private long version;

        DummyNodeServer(MLContext mLContext, NodeServiceGrpc.NodeServiceImplBase nodeServiceImplBase) throws IOException {
            this.mlContext = mLContext;
            this.server = ServerBuilder.forPort(0).addService(nodeServiceImplBase).build();
            this.server.start();
            this.amClient = AMRegistry.getAMClient(mLContext);
        }

        SimpleResponse registerNode() {
            waitForAMStatus(AMStatus.AM_INIT);
            this.version = this.amClient.getVersion().getVersion();
            return this.amClient.registerNode(this.version, AppMasterServerTest.newNodeSpec(this.mlContext.getRoleName(), AppMasterServerTest.ip, this.mlContext.getIndex(), this.server.getPort()));
        }

        void ensureRegisterSucceed() {
            Assert.assertEquals(this.mlContext.getIdentity() + " register node failed", RpcCode.OK.ordinal(), registerNode().getCode());
        }

        SimpleResponse twiceRegisterNode() {
            waitForAMStatus(AMStatus.AM_RUNNING);
            this.version = this.amClient.getVersion().getVersion();
            return this.amClient.registerNode(this.version, AppMasterServerTest.newNodeSpec(this.mlContext.getRoleName(), AppMasterServerTest.ip, this.mlContext.getIndex(), this.server.getPort()));
        }

        AMStatus getAmStatus() {
            while (true) {
                try {
                    return this.amClient.getAMStatus();
                } catch (Exception e) {
                    try {
                        this.amClient = AMRegistry.getAMClient(this.mlContext);
                        Thread.sleep(3000L);
                    } catch (Exception e2) {
                        AppMasterServerTest.LOG.warn("{} failed update AM address", this.mlContext.getIdentity(), e);
                    }
                }
            }
        }

        SimpleResponse finishNode() {
            waitForAMStatus(AMStatus.AM_RUNNING);
            return this.amClient.nodeFinish(this.version, AppMasterServerTest.newNodeSpec(this.mlContext.getRoleName(), AppMasterServerTest.ip, this.mlContext.getIndex(), this.server.getPort()));
        }

        void updateAmAddress() {
            try {
                this.amClient = AMRegistry.getAMClient(this.mlContext);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }

        void ensureFinishSucceed() {
            Assert.assertEquals(this.mlContext.getIdentity() + " finish node failed", RpcCode.OK.ordinal(), finishNode().getCode());
        }

        SimpleResponse failNode() {
            waitForAMStatus(AMStatus.AM_RUNNING);
            return this.amClient.reportFailedNode(this.version, getNodeSpec());
        }

        NodeSpec getNodeSpec() {
            return NodeSpec.newBuilder().setRoleName(this.mlContext.getRoleName()).setIndex(this.mlContext.getIndex()).setIp(AppMasterServerTest.ip).setClientPort(this.server.getPort()).build();
        }

        MLClusterDef getCluster() {
            return this.amClient.getClusterInfo(this.version).getClusterDef();
        }

        public void waitForAMStatus(AMStatus aMStatus) {
            long currentTimeMillis = System.currentTimeMillis() + timeout;
            AMStatus aMStatus2 = this.amClient.getAMStatus();
            while (true) {
                AMStatus aMStatus3 = aMStatus2;
                if (aMStatus3 == aMStatus) {
                    return;
                }
                if (System.currentTimeMillis() > currentTimeMillis) {
                    throw new RuntimeException(String.format("Timed out waiting for status: %s current status: %s", aMStatus, aMStatus3));
                }
                Thread.yield();
                aMStatus2 = this.amClient.getAMStatus();
            }
        }

        void close() {
            this.amClient.close();
        }
    }

    /* loaded from: input_file:org/flinkextended/flink/ml/cluster/rpc/AppMasterServerTest$NodeMessage.class */
    public static class NodeMessage {
        int nodeStopNum = 0;
        int nodeRestartNum = 0;

        public int getNodeStopNum() {
            return this.nodeStopNum;
        }

        public void setNodeStopNum(int i) {
            this.nodeStopNum = i;
        }

        public void addNodeStopNum() {
            this.nodeStopNum++;
        }

        public int getNodeRestartNum() {
            return this.nodeRestartNum;
        }

        public void setNodeRestartNum(int i) {
            this.nodeRestartNum = i;
        }

        public void addNodeRestartNum() {
            this.nodeRestartNum++;
        }

        public void waitForStop(int i) {
            while (i != this.nodeStopNum) {
                try {
                    Thread.sleep(1000L);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    @BeforeClass
    public static void init() throws Exception {
        ip = IpHostUtil.getIpAddress();
    }

    @Before
    public void setUp() throws Exception {
    }

    @After
    public void tearDown() throws Exception {
        StorageFactory.memoryStorage.clear();
    }

    @Test
    public void testAMFailOver() throws Exception {
        System.out.println(SysUtil._FUNC_());
        Duration ofSeconds = Duration.ofSeconds(100L);
        ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(3, 3, 60L, TimeUnit.SECONDS, new LinkedBlockingQueue(3));
        MLConfig createDummyMLConfig = DummyContext.createDummyMLConfig();
        createDummyMLConfig.setRoleNum(new WorkerRole().name(), 3);
        createDummyMLConfig.getProperties().put("heartbeat.timeout", String.valueOf(ofSeconds.toMillis()));
        createDummyMLConfig.getProperties().put("failover.strategy", "individual");
        FutureTask<Void> startAMServer = startAMServer(createDummyMLConfig);
        AtomicBoolean atomicBoolean = new AtomicBoolean(false);
        ArrayList arrayList = new ArrayList(3);
        for (int i = 0; i < 3; i++) {
            DummyNodeServer dummyNodeServer = new DummyNodeServer(new MLContext(ExecutionMode.TRAIN, createDummyMLConfig, new WorkerRole().name(), i, (String) null, (Map) null), mockNodeService());
            FutureTask futureTask = new FutureTask(() -> {
                Assert.assertEquals(dummyNodeServer.mlContext.getIdentity() + " register node failed", RpcCode.OK.ordinal(), dummyNodeServer.registerNode().getCode());
                while (!atomicBoolean.get()) {
                    try {
                        Thread.sleep(1000L);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
                dummyNodeServer.updateAmAddress();
                Assert.assertEquals("am not running", AMStatus.AM_RUNNING, dummyNodeServer.getAmStatus());
                Assert.assertEquals(dummyNodeServer.mlContext.getIdentity() + " finish node failed " + dummyNodeServer.finishNode().getMessage(), RpcCode.OK.ordinal(), r0.getCode());
                dummyNodeServer.close();
            }, null);
            threadPoolExecutor.submit(futureTask);
            arrayList.add(futureTask);
        }
        AMClient aMClient = AMRegistry.getAMClient(new MLContext(ExecutionMode.TRAIN, createDummyMLConfig, new AMRole().name(), 0, (String) null, (Map) null));
        LOG.info("am client: {}:{}", aMClient.getHost(), Integer.valueOf(aMClient.getPort()));
        while (!aMClient.getAMStatus().equals(AMStatus.AM_RUNNING)) {
            try {
                try {
                    Thread.sleep(1000L);
                } catch (Exception e) {
                    LOG.error("error on getting am status", e);
                    Thread.sleep(1000L);
                }
            } catch (Throwable th) {
                Thread.sleep(1000L);
                throw th;
            }
        }
        Thread.sleep(1000L);
        startAMServer.cancel(true);
        FutureTask<Void> startAMServer2 = startAMServer(createDummyMLConfig);
        atomicBoolean.set(true);
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            ((FutureTask) it.next()).get();
        }
        threadPoolExecutor.shutdown();
        startAMServer2.get();
    }

    @Test
    public void testHeartbeatTimeout() throws Exception {
        System.out.println(SysUtil._FUNC_());
        Duration ofSeconds = Duration.ofSeconds(3L);
        ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(2, 2, 60L, TimeUnit.SECONDS, new LinkedBlockingQueue(2));
        MLConfig createDummyMLConfig = DummyContext.createDummyMLConfig();
        createDummyMLConfig.setRoleNum(new WorkerRole().name(), 2);
        createDummyMLConfig.getProperties().put("heartbeat.timeout", String.valueOf(ofSeconds.toMillis()));
        FutureTask<Void> startAMServer = startAMServer(createDummyMLConfig);
        ArrayList arrayList = new ArrayList(2);
        for (int i = 0; i < 2; i++) {
            MLContext mLContext = new MLContext(ExecutionMode.TRAIN, createDummyMLConfig, new WorkerRole().name(), i, (String) null, (Map) null);
            NodeServiceGrpc.NodeServiceImplBase mockNodeService = mockNodeService();
            DummyNodeServer dummyNodeServer = new DummyNodeServer(mLContext, mockNodeService);
            FutureTask futureTask = new FutureTask(() -> {
                dummyNodeServer.ensureRegisterSucceed();
                try {
                    Thread.sleep(ofSeconds.toMillis());
                    dummyNodeServer.ensureRegisterSucceed();
                    dummyNodeServer.ensureFinishSucceed();
                    dummyNodeServer.close();
                    ((NodeServiceGrpc.NodeServiceImplBase) Mockito.verify(mockNodeService)).nodeRestart((NodeRestartRequest) Matchers.any(), (StreamObserver) Matchers.any());
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
            }, null);
            threadPoolExecutor.submit(futureTask);
            arrayList.add(futureTask);
        }
        threadPoolExecutor.shutdown();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            ((FutureTask) it.next()).get();
        }
        startAMServer.get();
    }

    @Test
    public void multiRegisterNode() throws Exception {
        System.out.println(SysUtil._FUNC_());
        ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(3, 3, 60L, TimeUnit.SECONDS, new LinkedBlockingQueue(3));
        MLConfig createDummyMLConfig = DummyContext.createDummyMLConfig();
        createDummyMLConfig.setRoleNum(new WorkerRole().name(), 3);
        FutureTask<Void> startAMServer = startAMServer(createDummyMLConfig);
        ArrayList arrayList = new ArrayList(3);
        AtomicBoolean atomicBoolean = new AtomicBoolean(false);
        for (int i = 0; i < 3; i++) {
            MLContext mLContext = new MLContext(ExecutionMode.TRAIN, createDummyMLConfig, new WorkerRole().name(), i, (String) null, (Map) null);
            NodeServiceGrpc.NodeServiceImplBase mockNodeService = mockNodeService();
            DummyNodeServer dummyNodeServer = new DummyNodeServer(mLContext, mockNodeService);
            int i2 = i;
            FutureTask futureTask = new FutureTask(() -> {
                Assert.assertEquals(dummyNodeServer.mlContext.getIdentity() + " register node failed", RpcCode.OK.ordinal(), dummyNodeServer.registerNode().getCode());
                if (0 == i2) {
                    dummyNodeServer.waitForAMStatus(AMStatus.AM_RUNNING);
                    atomicBoolean.set(true);
                    Assert.assertEquals(dummyNodeServer.mlContext.getIdentity() + " re-register node failed", RpcCode.OK.ordinal(), dummyNodeServer.twiceRegisterNode().getCode());
                } else {
                    while (!atomicBoolean.get()) {
                        try {
                            Thread.sleep(100L);
                        } catch (InterruptedException e) {
                            e.printStackTrace();
                        }
                    }
                }
                Assert.assertEquals(dummyNodeServer.mlContext.getIdentity() + " register node failed", RpcCode.OK.ordinal(), dummyNodeServer.registerNode().getCode());
                Assert.assertEquals(dummyNodeServer.mlContext.getIdentity() + " finish node failed", RpcCode.OK.ordinal(), dummyNodeServer.finishNode().getCode());
                dummyNodeServer.close();
                ((NodeServiceGrpc.NodeServiceImplBase) Mockito.verify(mockNodeService)).nodeRestart((NodeRestartRequest) Matchers.any(), (StreamObserver) Matchers.any());
            }, null);
            threadPoolExecutor.submit(futureTask);
            arrayList.add(futureTask);
        }
        threadPoolExecutor.shutdown();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            ((FutureTask) it.next()).get();
        }
        startAMServer.get();
    }

    @Test
    public void testMergeFinishedNodes() throws Exception {
        System.out.println(SysUtil._FUNC_());
        MLConfig createDummyMLConfig = DummyContext.createDummyMLConfig();
        createDummyMLConfig.setRoleNum(new WorkerRole().name(), 3);
        createDummyMLConfig.setRoleNum(new PsRole().name(), 1);
        createDummyMLConfig.addProperty("job_has_input", "true");
        FutureTask<Void> startAMServer = startAMServer(createDummyMLConfig);
        NodeServiceGrpc.NodeServiceImplBase mockNodeService = mockNodeService();
        MLContext mLContext = new MLContext(ExecutionMode.TRAIN, createDummyMLConfig, new PsRole().name(), 0, (String) null, (Map) null);
        MLContext mLContext2 = new MLContext(ExecutionMode.TRAIN, createDummyMLConfig, new WorkerRole().name(), 0, (String) null, (Map) null);
        MLContext mLContext3 = new MLContext(ExecutionMode.TRAIN, createDummyMLConfig, new WorkerRole().name(), 1, (String) null, (Map) null);
        MLContext mLContext4 = new MLContext(ExecutionMode.TRAIN, createDummyMLConfig, new WorkerRole().name(), 2, (String) null, (Map) null);
        DummyNodeServer dummyNodeServer = new DummyNodeServer(mLContext, mockNodeService);
        DummyNodeServer dummyNodeServer2 = new DummyNodeServer(mLContext2, mockNodeService);
        DummyNodeServer dummyNodeServer3 = new DummyNodeServer(mLContext3, mockNodeService);
        DummyNodeServer dummyNodeServer4 = new DummyNodeServer(mLContext4, mockNodeService);
        dummyNodeServer.registerNode();
        dummyNodeServer2.registerNode();
        dummyNodeServer3.registerNode();
        dummyNodeServer4.registerNode();
        long j = dummyNodeServer.version;
        dummyNodeServer3.finishNode();
        dummyNodeServer3.close();
        dummyNodeServer4.failNode();
        dummyNodeServer2.registerNode();
        dummyNodeServer.registerNode();
        dummyNodeServer4.registerNode();
        ((NodeServiceGrpc.NodeServiceImplBase) Mockito.verify(mockNodeService, Mockito.times(3))).nodeRestart((NodeRestartRequest) Matchers.any(), (StreamObserver) Matchers.any());
        Assert.assertNotEquals("Version is not updated", j, dummyNodeServer4.version);
        MLClusterDef cluster = dummyNodeServer2.getCluster();
        Assert.assertEquals("There should be 2 jobs", 2L, cluster.getJobList().size());
        MLJobDef mLJobDef = null;
        Iterator it = cluster.getJobList().iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            MLJobDef mLJobDef2 = (MLJobDef) it.next();
            if (mLJobDef2.getName().equals(new WorkerRole().name())) {
                mLJobDef = mLJobDef2;
                break;
            }
        }
        Assert.assertNotNull("Worker job not found", mLJobDef);
        Assert.assertEquals("There should be 2 tasks in worker job", 3L, mLJobDef.getTasksCount());
        dummyNodeServer2.finishNode();
        dummyNodeServer2.close();
        dummyNodeServer4.finishNode();
        dummyNodeServer4.close();
        dummyNodeServer.finishNode();
        dummyNodeServer.close();
        startAMServer.get();
    }

    private FutureTask<Void> startAMServer(MLConfig mLConfig) throws MLException {
        FutureTask<Void> futureTask = new FutureTask<>(new AppMasterServer(new MLContext(ExecutionMode.TRAIN, mLConfig, new AMRole().name(), 0, (String) null, (Map) null)), null);
        Thread thread = new Thread(futureTask);
        thread.setDaemon(true);
        thread.start();
        return futureTask;
    }

    private NodeServiceGrpc.NodeServiceImplBase mockNodeService() {
        NodeServiceGrpc.NodeServiceImplBase nodeServiceImplBase = (NodeServiceGrpc.NodeServiceImplBase) Mockito.mock(NodeServiceGrpc.NodeServiceImplBase.class);
        ((NodeServiceGrpc.NodeServiceImplBase) Mockito.doAnswer(invocationOnMock -> {
            NodeStopResponse build = NodeStopResponse.newBuilder().setCode(RpcCode.OK.ordinal()).setMessage("").build();
            StreamObserver streamObserver = (StreamObserver) invocationOnMock.getArguments()[1];
            streamObserver.onNext(build);
            streamObserver.onCompleted();
            return null;
        }).when(nodeServiceImplBase)).nodeStop((NodeStopRequest) Matchers.any(NodeStopRequest.class), (StreamObserver) Matchers.any(StreamObserver.class));
        ((NodeServiceGrpc.NodeServiceImplBase) Mockito.doAnswer(invocationOnMock2 -> {
            NodeRestartResponse build = NodeRestartResponse.newBuilder().setCode(RpcCode.OK.ordinal()).setMessage("").build();
            StreamObserver streamObserver = (StreamObserver) invocationOnMock2.getArguments()[1];
            streamObserver.onNext(build);
            streamObserver.onCompleted();
            return null;
        }).when(nodeServiceImplBase)).nodeRestart((NodeRestartRequest) Matchers.any(NodeRestartRequest.class), (StreamObserver) Matchers.any(StreamObserver.class));
        return nodeServiceImplBase;
    }

    private NodeServiceGrpc.NodeServiceImplBase mockNodeServiceSetFlag(NodeMessage nodeMessage) {
        NodeServiceGrpc.NodeServiceImplBase nodeServiceImplBase = (NodeServiceGrpc.NodeServiceImplBase) Mockito.mock(NodeServiceGrpc.NodeServiceImplBase.class);
        ((NodeServiceGrpc.NodeServiceImplBase) Mockito.doAnswer(invocationOnMock -> {
            NodeStopResponse build = NodeStopResponse.newBuilder().setCode(RpcCode.OK.ordinal()).setMessage("").build();
            StreamObserver streamObserver = (StreamObserver) invocationOnMock.getArguments()[1];
            streamObserver.onNext(build);
            streamObserver.onCompleted();
            nodeMessage.addNodeStopNum();
            return null;
        }).when(nodeServiceImplBase)).nodeStop((NodeStopRequest) Matchers.any(NodeStopRequest.class), (StreamObserver) Matchers.any(StreamObserver.class));
        ((NodeServiceGrpc.NodeServiceImplBase) Mockito.doAnswer(invocationOnMock2 -> {
            NodeRestartResponse build = NodeRestartResponse.newBuilder().setCode(RpcCode.OK.ordinal()).setMessage("").build();
            StreamObserver streamObserver = (StreamObserver) invocationOnMock2.getArguments()[1];
            streamObserver.onNext(build);
            streamObserver.onCompleted();
            nodeMessage.addNodeRestartNum();
            return null;
        }).when(nodeServiceImplBase)).nodeRestart((NodeRestartRequest) Matchers.any(NodeRestartRequest.class), (StreamObserver) Matchers.any(StreamObserver.class));
        return nodeServiceImplBase;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static NodeSpec newNodeSpec(String str, String str2, int i, int i2) {
        return NodeSpec.newBuilder().setRoleName(str).setClientPort(i2).setIndex(i).setIp(str2).build();
    }
}
