package ml.shifu.shifu.core.yarn.container;

import java.io.File;
import java.io.IOException;
import java.net.ServerSocket;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import ml.shifu.guagua.coordinator.zk.GuaguaZooKeeper;
import ml.shifu.shifu.core.yarn.util.CommonUtils;
import ml.shifu.shifu.core.yarn.util.Constants;
import ml.shifu.shifu.core.yarn.util.GlobalConfigurationKeys;
import ml.shifu.shifu.core.yarn.util.HdfsUtils;
import ml.shifu.shifu.util.HDFSUtils;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.zookeeper.CreateMode;
import org.apache.zookeeper.KeeperException;
import org.apache.zookeeper.WatchedEvent;
import org.apache.zookeeper.Watcher;
import org.apache.zookeeper.ZooDefs;
import org.apache.zookeeper.data.Stat;

/* loaded from: input_file:ml/shifu/shifu/core/yarn/container/TensorflowTaskExecutor.class */
public class TensorflowTaskExecutor implements Watcher {
    private static final Log LOG = LogFactory.getLog(TensorflowTaskExecutor.class);
    private GuaguaZooKeeper zookeeper;
    private String containerId;
    private boolean isBackup;
    private String taskIndex;
    private String jobName;
    private String tensorflowCluster;
    private ServerSocket tensorflowSocket;
    private String tensorflowPort;
    private int tbPort;
    private SocketServer socketServer;
    private String pythonScript;
    private String pythonShell;
    final CountDownLatch latch = new CountDownLatch(1);
    public final CountDownLatch backupStartingLatch = new CountDownLatch(1);
    private Configuration globalConf = new Configuration();
    private Map<String, String> shellEnv = new HashMap();
    private Process backupProcess = null;

    public TensorflowTaskExecutor() {
        this.globalConf.addResource(new Path(Constants.GLOBAL_FINAL_XML));
    }

    public void registeryToCluster() throws KeeperException, InterruptedException, IOException {
        this.tensorflowPort = getTensorflowPort();
        if (StringUtils.isBlank(this.tensorflowPort)) {
            throw new RuntimeException("Given port on container is blank!");
        }
        this.zookeeper.exists(Constants.TENSORFLOW_FINAL_CLUSTER, true);
        this.zookeeper.exists(Constants.getTrainingDataZookeeperPath(this.containerId), true);
        this.zookeeper.createOrSetExt(Constants.TENSORFLOW_CLUSTER_ROOT_PATH + this.containerId, (CommonUtils.getCurrentHostIP() + ":" + this.tensorflowPort).getBytes(Charset.forName("UTF-8")), ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT, true, -1);
        this.latch.await();
        this.shellEnv.put("CLUSTER_SPEC", this.tensorflowCluster);
    }

    public void process(WatchedEvent watchedEvent) {
        LOG.info("event path:" + watchedEvent.getPath());
        if (Watcher.Event.EventType.NodeCreated != watchedEvent.getType() || !Constants.TENSORFLOW_FINAL_CLUSTER.equalsIgnoreCase(watchedEvent.getPath())) {
            if (Watcher.Event.EventType.NodeCreated == watchedEvent.getType() && Constants.getTrainingDataZookeeperPath(this.containerId).equalsIgnoreCase(watchedEvent.getPath())) {
                LOG.info("Weak up back host!!");
                try {
                    String str = new String(this.zookeeper.getData(Constants.getTrainingDataZookeeperPath(this.containerId), false, (Stat) null));
                    LOG.info("TRAINING_DATA_PATH: " + str);
                    this.shellEnv.put("TRAINING_DATA_PATH", str);
                    if (this.backupProcess != null) {
                        CommonUtils.killProcessByPort(this.tensorflowPort);
                        this.backupProcess.destroy();
                        LOG.info("Killed backup waiting program... " + this.backupProcess.isAlive());
                    }
                    this.backupStartingLatch.countDown();
                    return;
                } catch (Exception e) {
                    LOG.error("Error when getting backup training data path from zookeeper", e);
                    throw new RuntimeException(e);
                }
            }
            return;
        }
        try {
            this.tensorflowCluster = new String(this.zookeeper.getData(Constants.TENSORFLOW_FINAL_CLUSTER, false, (Stat) null));
            if (this.isBackup) {
                String[] split = this.tensorflowCluster.split(Constants.WORKER_JOB_NAME)[1].split(",");
                int i = 0;
                while (true) {
                    if (i >= split.length) {
                        break;
                    }
                    if (!split[i].contains(CommonUtils.getCurrentHostIP() + ":" + this.tensorflowPort)) {
                        i++;
                    } else if (Integer.valueOf(this.taskIndex).intValue() != i) {
                        LOG.info("change taskid from " + this.taskIndex + " to " + i);
                        this.shellEnv.put("TASK_ID", Integer.toString(i));
                        this.isBackup = false;
                    }
                }
            } else if (Constants.PS_JOB_NAME.equalsIgnoreCase(this.jobName)) {
                String[] split2 = this.tensorflowCluster.split(Constants.WORKER_JOB_NAME)[0].split(",");
                int i2 = 0;
                while (true) {
                    if (i2 >= split2.length) {
                        break;
                    }
                    if (!split2[i2].contains(CommonUtils.getCurrentHostIP() + ":" + this.tensorflowPort)) {
                        i2++;
                    } else if (Integer.valueOf(this.taskIndex).intValue() != i2) {
                        LOG.info("change taskid from " + this.taskIndex + " to " + i2);
                        this.shellEnv.put("TASK_ID", Integer.toString(i2));
                    }
                }
            }
            LOG.info("Cluster:" + this.tensorflowCluster);
            this.latch.countDown();
        } catch (Exception e2) {
            LOG.error("Error when getting final cluster from zookeeper", e2);
            throw new RuntimeException(e2);
        }
    }

    public String getTensorflowPort() throws IOException {
        this.tensorflowSocket = new ServerSocket(CommonUtils.getValidTensorflowPort());
        return Integer.toString(this.tensorflowSocket.getLocalPort());
    }

    public static CommandLine initOpts(String[] strArr) throws ParseException {
        Options options = new Options();
        options.addOption("job_name", true, "");
        options.addOption("task_id", true, "");
        options.addOption("container_id", true, "");
        options.addOption("training_data_path", true, "");
        options.addOption("zookeeper_server", true, "");
        options.addOption("is_backup", true, "");
        return new GnuParser().parse(options, strArr);
    }

    public void init(CommandLine commandLine) throws ParseException, IOException, InterruptedException {
        this.taskIndex = commandLine.getOptionValue("task_id");
        this.jobName = commandLine.getOptionValue("job_name");
        this.shellEnv.put("JOB_NAME", this.jobName);
        this.shellEnv.put("TASK_ID", this.taskIndex);
        this.shellEnv.put("TRAINING_DATA_PATH", commandLine.getOptionValue("training_data_path"));
        this.shellEnv.put("TOTAL_TRAINING_DATA_NUMBER", this.globalConf.get(GlobalConfigurationKeys.TOTAL_TRAINING_DATA_NUM));
        this.shellEnv.put("WEIGHT_COLUMN_NUM", this.globalConf.get(GlobalConfigurationKeys.WEIGHT_COLUMN_NUM, "-1"));
        this.shellEnv.put("TARGET_COLUMN_NUM", this.globalConf.get(GlobalConfigurationKeys.TARGET_COLUMN_NUM, GlobalConfigurationKeys.DEFAULT_TARGET_COLUMN_NUM));
        if (StringUtils.isNotBlank(this.globalConf.get(GlobalConfigurationKeys.SELECTED_COLUMN_NUMS))) {
            this.shellEnv.put("SELECTED_COLUMN_NUMS", this.globalConf.get(GlobalConfigurationKeys.SELECTED_COLUMN_NUMS, "-1"));
        } else {
            this.shellEnv.put("SELECTED_NUMERIC_COLUMN_NUMS", this.globalConf.get(GlobalConfigurationKeys.SELECTED_NUMERIC_COLUMN_NUMS, "-1"));
            this.shellEnv.put("SELECTED_CATEGORY_COLUMN_NUMS", this.globalConf.get(GlobalConfigurationKeys.SELECTED_CATEGORY_COLUMN_NUMS, "-1"));
        }
        this.shellEnv.put("TMP_MODEL_PATH", this.globalConf.get(GlobalConfigurationKeys.TMP_MODEL_PATH));
        this.shellEnv.put("FINAL_MODEL_PATH", this.globalConf.get(GlobalConfigurationKeys.FINAL_MODEL_PATH));
        this.containerId = commandLine.getOptionValue("container_id").trim();
        this.isBackup = Boolean.valueOf(commandLine.getOptionValue("is_backup")).booleanValue();
        this.zookeeper = new GuaguaZooKeeper(commandLine.getOptionValue("zookeeper_server"), 3000000, 5, GlobalConfigurationKeys.DEFAULT_TASK_HEARTBEAT_INTERVAL_MS, this);
        this.socketServer = new SocketServer(this.zookeeper, this.containerId);
        this.socketServer.start();
        this.shellEnv.put("SOCKET_SERVER_PORT", Integer.toString(this.socketServer.getServerPort()));
    }

    public void prepare() throws IOException {
        if (new File(Constants.PYTHON_VENV_ZIP).exists()) {
            LOG.info("Unpacking Python virtual environment.. ");
            CommonUtils.unzipArchive(Constants.PYTHON_VENV_ZIP, ".");
        } else {
            LOG.info("No virtual environment uploaded.");
        }
        if (new File(Constants.GLIBC_VENV_ZIP).exists()) {
            LOG.info("Unpacking Python virtual environment.. ");
            CommonUtils.unzipArchive(Constants.GLIBC_VENV_ZIP, ".");
        } else {
            LOG.info("No virtual environment uploaded.");
        }
        String str = this.globalConf.get(GlobalConfigurationKeys.PYTHON_BINARY_PATH);
        this.shellEnv.put("GLIBC_HOME", "." + this.globalConf.get(GlobalConfigurationKeys.GLIBC_BINARY_PATH));
        this.shellEnv.put("PYTHON_HOME", "." + str);
        Files.copy(getClass().getResourceAsStream(Constants.BACKUP_SCRIPT), Paths.get("./backup.py", new String[0]), StandardCopyOption.REPLACE_EXISTING);
        this.pythonScript = "./" + new Path(this.globalConf.get(GlobalConfigurationKeys.PYTHON_SCRIPT_PATH)).getName();
        this.pythonShell = "./" + new Path(this.globalConf.get(GlobalConfigurationKeys.PYTHON_SHELL_PATH)).getName();
        HdfsUtils.givePerms(HDFSUtils.getLocalFS(), new File(this.pythonShell), true);
        this.shellEnv.put("WORKER_CNT", Integer.toString(this.globalConf.getInt(GlobalConfigurationKeys.getInstancesKey(Constants.WORKER_JOB_NAME), GlobalConfigurationKeys.getDefaultInstances(Constants.WORKER_JOB_NAME))));
    }

    private int executeBackupExecutor() throws IOException, InterruptedException {
        this.shellEnv.put("TRAIN_SCRIPT_PATH", "./backup.py");
        this.shellEnv.put("IS_BACKUP", "True");
        this.tensorflowSocket.close();
        this.backupProcess = CommonUtils.executeShellAndGetProcess(this.pythonShell, this.shellEnv);
        this.backupProcess.waitFor();
        if (this.backupProcess.exitValue() != 0 && this.backupProcess.exitValue() != 137) {
            LOG.info("backup task waiting process fails");
        }
        return this.backupProcess.exitValue();
    }

    public int run() throws IOException, InterruptedException {
        this.shellEnv.put("TRAIN_SCRIPT_PATH", this.pythonScript);
        if (Constants.WORKER_JOB_NAME.equalsIgnoreCase(this.jobName)) {
            while (true) {
                if (this.shellEnv.containsKey("TRAINING_DATA_PATH") && !StringUtils.isBlank(this.shellEnv.get("TRAINING_DATA_PATH"))) {
                    break;
                }
                Thread.sleep(5000L);
            }
        }
        if (!this.tensorflowSocket.isClosed()) {
            this.tensorflowSocket.close();
        }
        return CommonUtils.executeShell(this.pythonShell, 0L, this.shellEnv);
    }

    public static void main(String[] strArr) throws Exception {
        LOG.info("TaskExecutor is running..");
        TensorflowTaskExecutor tensorflowTaskExecutor = new TensorflowTaskExecutor();
        tensorflowTaskExecutor.init(initOpts(strArr));
        tensorflowTaskExecutor.prepare();
        tensorflowTaskExecutor.registeryToCluster();
        if (tensorflowTaskExecutor.isBackup()) {
            LOG.info("This is backup host..");
            LOG.info("backup task exit value: " + tensorflowTaskExecutor.executeBackupExecutor());
            tensorflowTaskExecutor.backupStartingLatch.await();
        }
        int run = tensorflowTaskExecutor.run();
        LOG.info("current worker finish..");
        System.exit(run);
    }

    public boolean isBackup() {
        return this.isBackup;
    }
}
