package org.flinkextended.flink.ml.pytorch;

import java.net.URL;
import java.util.concurrent.ExecutionException;
import org.apache.flink.api.common.JobStatus;
import org.apache.flink.core.execution.JobClient;
import org.apache.flink.runtime.client.JobCancellationException;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.bridge.java.StreamStatementSet;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/flinkextended/flink/ml/pytorch/FailOverTest.class */
public class FailOverTest {
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private StreamStatementSet statementSet;
    private final String alwaysFail = getScriptPathFromResources("always_fail.py");

    @Before
    public void setUp() {
        this.env = StreamExecutionEnvironment.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.statementSet = this.tEnv.createStatementSet();
    }

    @Test
    public void testCancelWhileFailover() throws InterruptedException, ExecutionException {
        PyTorchUtils.train(this.statementSet, buildTFConfig(this.alwaysFail));
        JobClient jobClient = (JobClient) this.statementSet.execute().getJobClient().get();
        while (jobClient.getJobStatus().get() != JobStatus.RUNNING) {
            Thread.sleep(1000L);
        }
        Thread.sleep(10000L);
        jobClient.cancel().get();
        while (jobClient.getJobStatus().get() == JobStatus.RUNNING) {
            Thread.sleep(1000L);
        }
        try {
            jobClient.getJobExecutionResult().get();
        } catch (ExecutionException e) {
            Assert.assertTrue(e.getCause() instanceof JobCancellationException);
        }
    }

    @Test
    public void testTrainWithInputCancelWhileFailover() throws InterruptedException, ExecutionException {
        PyTorchUtils.train(this.statementSet, this.tEnv.fromDataStream(this.env.fromElements(new Integer[]{1, 2, 3, 4})), buildTFConfig(this.alwaysFail));
        JobClient jobClient = (JobClient) this.statementSet.execute().getJobClient().get();
        while (jobClient.getJobStatus().get() != JobStatus.RUNNING) {
            Thread.sleep(1000L);
        }
        Thread.sleep(10000L);
        jobClient.cancel().get();
        while (jobClient.getJobStatus().get() == JobStatus.RUNNING) {
            Thread.sleep(1000L);
        }
        try {
            jobClient.getJobExecutionResult().get();
        } catch (ExecutionException e) {
            Assert.assertTrue(e.getCause() instanceof JobCancellationException);
        }
    }

    private PyTorchClusterConfig buildTFConfig(String str) {
        return PyTorchClusterConfig.newBuilder().setWorldSize(3).setNodeEntry(str, "map_func").build();
    }

    private static String getScriptPathFromResources(String str) {
        URL resource = Thread.currentThread().getContextClassLoader().getResource(str);
        Assert.assertNotNull(resource);
        return resource.getPath();
    }
}
