package org.flinkextended.flink.ml.tensorflow.client;

import java.net.URL;
import java.time.Duration;
import java.util.concurrent.ExecutionException;
import org.apache.flink.api.common.JobStatus;
import org.apache.flink.core.execution.JobClient;
import org.apache.flink.runtime.client.JobCancellationException;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.bridge.java.StreamStatementSet;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.flinkextended.flink.ml.tensorflow.storage.DummyStorage;
import org.flinkextended.flink.ml.util.SysUtil;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/client/RunWithFailTest.class */
public class RunWithFailTest {
    private static final Logger LOG = LoggerFactory.getLogger(RunWithFailTest.class);
    private static final String simple_print = getScriptPathFromResources("simple_print.py");
    private static final String failover = getScriptPathFromResources("failover.py");
    private static final String failover2 = getScriptPathFromResources("failover2.py");
    private static final String alwaysFail = getScriptPathFromResources("always_fail.py");
    private StreamStatementSet statementSet;
    private StreamTableEnvironment tEnv;
    private StreamExecutionEnvironment env;

    @Rule
    public ExpectedException expectedException = ExpectedException.none();

    @Before
    public void setUp() throws Exception {
        this.env = StreamExecutionEnvironment.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.statementSet = this.tEnv.createStatementSet();
    }

    private TFClusterConfig buildTFConfig(String str) {
        return buildTFConfig(str, String.valueOf(System.currentTimeMillis()), 2, 1);
    }

    private TFClusterConfig buildTFConfig(String str, String str2, int i, int i2) {
        System.out.println("buildTFConfig: " + SysUtil._FUNC_());
        System.out.println("Current version:" + str2);
        return TFClusterConfig.newBuilder().setWorkerCount(Integer.valueOf(i)).setPsCount(Integer.valueOf(i2)).setNodeEntry(str, "map_func").build();
    }

    @Test
    public void simpleStartupTest() throws Exception {
        TFUtils.train(this.statementSet, buildTFConfig(simple_print, "1", 1, 1));
        this.statementSet.execute().await();
    }

    @Test
    public void workerFailoverTest() throws Exception {
        LOG.info("############ Start failover test.");
        TFUtils.train(this.statementSet, buildTFConfig(failover, String.valueOf(System.currentTimeMillis()), 2, 1));
        this.statementSet.execute().await();
    }

    @Test
    public void testFailoverWithFinishedNode() throws Exception {
        TFUtils.train(this.statementSet, buildTFConfig(failover2, String.valueOf(System.currentTimeMillis()), 2, 1));
        this.statementSet.execute().await();
    }

    @Test
    public void testJobTimeout() throws Exception {
        TFUtils.train(this.statementSet, buildTFConfig(simple_print).toBuilder().setWorkerCount(1).setProperty("storage_type", "storage_custom").setProperty("storage_impl_class", DummyStorage.class.getName()).setProperty("am.registry.timeout", String.valueOf(Duration.ofSeconds(10L).toMillis())).setProperty("node.idle.timeout", String.valueOf(Duration.ofSeconds(10L).toMillis())).build());
        this.expectedException.expect(ExecutionException.class);
        this.expectedException.expectMessage("Failed to wait job finish");
        this.statementSet.execute().await();
    }

    @Test
    public void testCancelWhileFailover() throws InterruptedException, ExecutionException {
        this.tEnv.fromDataStream(this.env.fromElements(new Integer[]{1, 2, 3, 4}));
        TFUtils.train(this.statementSet, buildTFConfig(alwaysFail));
        JobClient jobClient = (JobClient) this.statementSet.execute().getJobClient().get();
        while (jobClient.getJobStatus().get() != JobStatus.RUNNING) {
            Thread.sleep(1000L);
        }
        Thread.sleep(10000L);
        jobClient.cancel().get();
        while (jobClient.getJobStatus().get() == JobStatus.RUNNING) {
            Thread.sleep(1000L);
        }
        try {
            jobClient.getJobExecutionResult().get();
        } catch (ExecutionException e) {
            Assert.assertTrue(e.getCause() instanceof JobCancellationException);
        }
    }

    @Test
    public void testTrainWithInputCancelWhileFailover() throws InterruptedException, ExecutionException {
        TFUtils.train(this.statementSet, this.tEnv.fromDataStream(this.env.fromElements(new Integer[]{1, 2, 3, 4})), buildTFConfig(alwaysFail));
        JobClient jobClient = (JobClient) this.statementSet.execute().getJobClient().get();
        while (jobClient.getJobStatus().get() != JobStatus.RUNNING) {
            Thread.sleep(1000L);
        }
        Thread.sleep(10000L);
        jobClient.cancel().get();
        while (jobClient.getJobStatus().get() == JobStatus.RUNNING) {
            Thread.sleep(1000L);
        }
        try {
            jobClient.getJobExecutionResult().get();
        } catch (ExecutionException e) {
            Assert.assertTrue(e.getCause() instanceof JobCancellationException);
        }
    }

    private static String getScriptPathFromResources(String str) {
        URL resource = Thread.currentThread().getContextClassLoader().getResource(str);
        Assert.assertNotNull(resource);
        return resource.getPath();
    }
}
