package org.flinkextended.flink.ml.examples.pytorch.it;

import com.google.common.io.Files;
import java.io.File;
import java.io.IOException;
import org.flinkextended.flink.ml.examples.pytorch.PyTorchRunDist;
import org.flinkextended.flink.ml.examples.util.CodeUtil;
import org.flinkextended.flink.ml.util.MiniCluster;
import org.flinkextended.flink.ml.util.SysUtil;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/flinkextended/flink/ml/examples/pytorch/it/PyTorchUtilIT.class */
public class PyTorchUtilIT {
    private static MiniCluster miniCluster;
    private static final int numTMs = 3;

    @Before
    public void setUp() throws Exception {
        miniCluster = MiniCluster.start(numTMs);
        miniCluster.setExecJar("/dl-on-flink-examples/target/dl-on-flink-examples-" + SysUtil.getProjectVersion() + ".jar");
    }

    @After
    public void tearDown() throws Exception {
        if (miniCluster != null) {
            miniCluster.stop();
        }
    }

    @Test
    public void trainStream() throws Exception {
        runAndVerify(miniCluster, PyTorchRunDist.EnvMode.Stream, "greeter.py");
    }

    @Test
    public void trainTable() throws Exception {
        runAndVerify(miniCluster, PyTorchRunDist.EnvMode.Table, "greeter.py");
    }

    @Test
    public void allReduceStream() throws Exception {
        runAndVerify(miniCluster, PyTorchRunDist.EnvMode.Stream, "all_reduce_test.py");
    }

    @Test
    public void allReduceTable() throws Exception {
        runAndVerify(miniCluster, PyTorchRunDist.EnvMode.Table, "all_reduce_test.py");
    }

    private static String run(MiniCluster miniCluster2, PyTorchRunDist.EnvMode envMode, String str) throws IOException {
        try {
            return miniCluster2.flinkRun(PyTorchRunDist.class.getCanonicalName(), new String[]{"--zk-conn-str", miniCluster2.getZKContainer(), "--mode", envMode.toString(), "--script", str, "--envpath", miniCluster2.getVenvHdfsPath(), "--code-path", CodeUtil.copyCodeToHdfs(miniCluster2)});
        } catch (IOException e) {
            e.printStackTrace();
            throw e;
        }
    }

    static boolean runAndVerify(MiniCluster miniCluster2, PyTorchRunDist.EnvMode envMode, String str) throws IOException {
        String run = run(miniCluster2, envMode, str);
        System.out.println(run);
        if (run.contains("Program execution finished")) {
            File createTempDir = Files.createTempDir();
            miniCluster2.dumpFlinkLogs(createTempDir);
            System.out.println("logs in " + createTempDir.getAbsolutePath());
            return true;
        }
        File createTempDir2 = Files.createTempDir();
        miniCluster2.dumpFlinkLogs(createTempDir2);
        Assert.fail("run failed in mode " + envMode + ", check logs in " + createTempDir2.getAbsolutePath());
        return false;
    }
}
