package edu.iu.dsc.tws.rsched.schedulers.k8s.mpi;

import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.config.MPIContext;
import edu.iu.dsc.tws.api.config.SchedulerContext;
import edu.iu.dsc.tws.api.faulttolerance.FaultToleranceContext;
import edu.iu.dsc.tws.common.logging.LoggingHelper;
import edu.iu.dsc.tws.master.JobMasterContext;
import edu.iu.dsc.tws.proto.system.job.JobAPI;
import edu.iu.dsc.tws.rsched.schedulers.k8s.K8sEnvVariables;
import edu.iu.dsc.tws.rsched.schedulers.k8s.KubernetesContext;
import edu.iu.dsc.tws.rsched.schedulers.k8s.KubernetesUtils;
import edu.iu.dsc.tws.rsched.schedulers.k8s.PodWatchUtils;
import edu.iu.dsc.tws.rsched.schedulers.k8s.worker.K8sWorkerUtils;
import edu.iu.dsc.tws.rsched.utils.JobUtils;
import edu.iu.dsc.tws.rsched.utils.ProcessUtils;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStreamWriter;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.logging.Level;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/rsched/schedulers/k8s/mpi/MPIMasterStarter.class */
public final class MPIMasterStarter {
    private static final String HOSTFILE_NAME = "hostfile";
    private static final Logger LOG = Logger.getLogger(MPIMasterStarter.class.getName());
    private static Config config = null;
    private static String jobID = null;

    private MPIMasterStarter() {
    }

    public static void main(String[] strArr) {
        LoggingHelper.setLoggingFormat("[%1$tF %1$tT] [%4$s] [%7$s] %3$s: %5$s %6$s %n");
        String str = System.getenv(K8sEnvVariables.JOB_MASTER_IP.name());
        String str2 = System.getenv(K8sEnvVariables.POD_NAME.name());
        String str3 = System.getenv(K8sEnvVariables.JVM_MEMORY_MB.name());
        jobID = System.getenv(K8sEnvVariables.JOB_ID.name());
        if (jobID == null) {
            throw new RuntimeException("JobID is null");
        }
        String str4 = "/twister2-memory-dir/twister2-job/common/logger.properties";
        config = K8sWorkerUtils.loadConfig("/twister2-memory-dir/twister2-job");
        K8sWorkerUtils.initLogger(config, "mpiMaster", KubernetesContext.persistentVolumeRequested(config));
        String str5 = "/twister2-memory-dir/twister2-job/" + SchedulerContext.createJobDescriptionFileName(jobID);
        JobAPI.Job readJobFile = JobUtils.readJobFile(str5);
        LOG.info("Job description file is loaded: " + str5);
        config = JobUtils.overrideConfigs(readJobFile, config);
        config = JobUtils.updateConfigs(readJobFile, config);
        String namespace = KubernetesContext.namespace(config);
        int workersPerPod = readJobFile.getComputeResource(0).getWorkersPerPod();
        int numberOfWorkerPods = KubernetesUtils.numberOfWorkerPods(readJobFile);
        try {
            String hostAddress = InetAddress.getLocalHost().getHostAddress();
            LOG.info("MPIMaster information summary: \npodName: " + str2 + "\npodIP: " + hostAddress + "\njobID: " + jobID + "\nnamespace: " + namespace + "\nnumberOfWorkers: " + readJobFile.getNumberOfWorkers() + "\nnumberOfPods: " + numberOfWorkerPods);
            long currentTimeMillis = System.currentTimeMillis();
            if (!JobMasterContext.jobMasterRunsInClient(config)) {
                str = K8sWorkerUtils.getJobMasterServiceIP(KubernetesContext.namespace(config), jobID);
                if (str == null) {
                    str = PodWatchUtils.getJobMasterIpByWatchingPodToRunning(namespace, jobID, 100);
                }
                if (str == null) {
                    LOG.severe("Could not get job master IP by wathing job master pod to running. Aborting. You need to terminate this job and resubmit it....");
                    return;
                }
            }
            LOG.info("Job Master IP address: " + str);
            ArrayList<String> workerIPsByWatchingPodsToRunning = PodWatchUtils.getWorkerIPsByWatchingPodsToRunning(namespace, jobID, numberOfWorkerPods, 100);
            PodWatchUtils.close();
            if (workerIPsByWatchingPodsToRunning == null) {
                LOG.severe("Could not get IPs of all pods running. Aborting. You need to terminate this job and resubmit it....");
                return;
            }
            if (!createHostFile(workerIPsByWatchingPodsToRunning, workersPerPod)) {
                LOG.severe("hostfile can not be generated. Aborting. You need to terminate this job and resubmit it....");
                return;
            }
            LOG.info("Getting all pods running took: " + (System.currentTimeMillis() - currentTimeMillis) + " ms.");
            String[] generateMPIrunCommand = generateMPIrunCommand("edu.iu.dsc.tws.rsched.schedulers.k8s.mpi.MPIWorkerStarter", workersPerPod, str, str4, str3);
            if (KubernetesContext.checkPwdFreeSsh(config)) {
                long currentTimeMillis2 = System.currentTimeMillis();
                workerIPsByWatchingPodsToRunning.remove(hostAddress);
                boolean runScript = runScript(generateCheckSshCommand(workerIPsByWatchingPodsToRunning));
                LOG.info("Checking password free access took: " + (System.currentTimeMillis() - currentTimeMillis2) + " ms");
                if (!runScript) {
                    LOG.severe("Password free ssh can not be setup among pods. Not executing mpirun ...");
                    return;
                }
            }
            executeMpirun(generateMPIrunCommand);
        } catch (UnknownHostException e) {
            LOG.log(Level.SEVERE, "Cannot get localHost.", (Throwable) e);
            throw new RuntimeException("Cannot get localHost.", e);
        }
    }

    public static boolean createHostFile(ArrayList<String> arrayList, int i) {
        try {
            StringBuffer stringBuffer = new StringBuffer();
            BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(HOSTFILE_NAME)));
            Iterator<String> it = arrayList.iterator();
            while (it.hasNext()) {
                String next = it.next();
                bufferedWriter.write(next + " slots=" + i + System.lineSeparator());
                stringBuffer.append(next + System.lineSeparator());
            }
            bufferedWriter.flush();
            bufferedWriter.close();
            LOG.info("File: hostfile is written with the content:\n" + stringBuffer.toString());
            return true;
        } catch (Exception e) {
            LOG.log(Level.SEVERE, "Exception when writing the file: hostfile", (Throwable) e);
            return false;
        }
    }

    public static String[] generateMPIrunCommand(String str, int i, String str2, String str3, String str4) {
        String str5 = System.getenv(K8sEnvVariables.JOB_SUBMISSION_TIME.name());
        String str6 = System.getenv(K8sEnvVariables.RESTORE_JOB.name());
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(Arrays.asList("mpirun", "--hostfile", HOSTFILE_NAME, "--allow-run-as-root", "-npernode", i + "", "-tag-output", "-x", "KUBERNETES_SERVICE_HOST=" + System.getenv("KUBERNETES_SERVICE_HOST"), "-x", "KUBERNETES_SERVICE_PORT=" + System.getenv("KUBERNETES_SERVICE_PORT"), "-x", K8sEnvVariables.JOB_ID.name() + "=" + jobID, "-x", K8sEnvVariables.JOB_MASTER_IP.name() + "=" + str2, "-x", K8sEnvVariables.JOB_SUBMISSION_TIME.name() + "=" + str5, "-x", K8sEnvVariables.RESTORE_JOB.name() + "=" + str6));
        String mpiParams = MPIContext.mpiParams(config);
        if (mpiParams != null && !mpiParams.trim().isEmpty()) {
            arrayList.addAll(Arrays.asList(mpiParams.split(" ")));
        }
        arrayList.add("java");
        arrayList.add("-Xms" + str4 + "m");
        arrayList.add("-Xmx" + str4 + "m");
        arrayList.add("-Djava.util.logging.config.file=" + str3);
        arrayList.add("-cp");
        arrayList.add(System.getenv("CLASSPATH"));
        arrayList.add(str);
        return (String[]) arrayList.toArray(new String[0]);
    }

    public static void executeMpirun(String[] strArr) {
        StringBuilder sb = new StringBuilder();
        int i = 0;
        while (true) {
            int i2 = i;
            i++;
            if (i2 >= FaultToleranceContext.maxMpiJobRestarts(config)) {
                LOG.severe(String.format("Failed to execute mpirun. Tried %s times. STDERR=%s", Integer.valueOf(i), sb));
                return;
            }
            LOG.info("mpirun will execute with the command: \n" + commandAsAString(strArr));
            if (ProcessUtils.runSyncProcess(false, strArr, sb, new File("."), true) == 0) {
                LOG.info("mpirun completed with success...");
                if (sb.length() != 0) {
                    LOG.info("The output:\n " + sb.toString());
                    return;
                }
                return;
            }
            if (i < FaultToleranceContext.maxMpiJobRestarts(config)) {
                LOG.severe(String.format("Failed to execute mpirun. Will try again. STDERR=%s", sb));
            }
            sb.setLength(0);
        }
    }

    public static String commandAsAString(String[] strArr) {
        String str = "";
        for (String str2 : strArr) {
            str = str + str2 + " ";
        }
        return str;
    }

    public static String[] generateCheckSshCommand(ArrayList<String> arrayList) {
        String[] strArr = new String[arrayList.size() + 1];
        strArr[0] = "./check_pwd_free_ssh.sh";
        int i = 1;
        Iterator<String> it = arrayList.iterator();
        while (it.hasNext()) {
            strArr[i] = it.next();
            i++;
        }
        return strArr;
    }

    public static boolean runScript(String[] strArr) {
        StringBuilder sb = new StringBuilder();
        String commandAsAString = commandAsAString(strArr);
        LOG.info("the script will be executed with the command: \n" + commandAsAString);
        int runSyncProcess = ProcessUtils.runSyncProcess(false, strArr, sb, new File("."), true);
        if (runSyncProcess != 0) {
            LOG.severe(String.format("Failed to execute the script file command=%s, STDERR=%s", commandAsAString, sb));
        } else {
            LOG.info("script: check_pwd_free_ssh.sh execution completed with success...");
            if (sb.length() != 0) {
                LOG.info("The error output:\n " + sb.toString());
            }
        }
        return runSyncProcess == 0;
    }
}
