package ml.shifu.shifu.core.yarn.appmaster;

import com.google.common.base.Preconditions;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import ml.shifu.shifu.core.yarn.appmaster.TensorflowSession;
import ml.shifu.shifu.core.yarn.util.CommonUtils;
import ml.shifu.shifu.core.yarn.util.Constants;
import ml.shifu.shifu.core.yarn.util.GlobalConfigurationKeys;
import ml.shifu.shifu.util.HDFSUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.yarn.api.records.ApplicationAccessType;
import org.apache.hadoop.yarn.api.records.Container;
import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
import org.apache.hadoop.yarn.api.records.ContainerStatus;
import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.api.records.LocalResourceType;
import org.apache.hadoop.yarn.api.records.NodeReport;
import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
import org.apache.hadoop.yarn.client.api.async.NMClientAsync;
import org.apache.hadoop.yarn.security.AMRMTokenIdentifier;
import org.apache.hadoop.yarn.util.AbstractLivelinessMonitor;

/* loaded from: input_file:ml/shifu/shifu/core/yarn/appmaster/AMRMCallbackHandler.class */
public class AMRMCallbackHandler implements AMRMClientAsync.CallbackHandler {
    private TensorflowSession session;
    private NMClientAsync nmClientAsync;
    private final AbstractLivelinessMonitor<TensorflowTask> hbMonitor;
    private Map<String, String> containerEnv;
    private ByteBuffer allTokens;
    private Path appResourcesPath;
    private static final Log LOG = LogFactory.getLog(AMRMCallbackHandler.class);
    private static FileSystem hdfs = HDFSUtils.getFS();
    private int lastRunEpochs = -1;
    private Map<String, LocalResource> containerResources = new ConcurrentHashMap();

    public AMRMCallbackHandler(Configuration configuration, TensorflowSession tensorflowSession, NMClientAsync nMClientAsync, AbstractLivelinessMonitor<TensorflowTask> abstractLivelinessMonitor, Map<String, String> map, String str) {
        this.nmClientAsync = null;
        this.session = tensorflowSession;
        this.nmClientAsync = nMClientAsync;
        this.hbMonitor = abstractLivelinessMonitor;
        this.containerEnv = map;
        this.appResourcesPath = Constants.getAppResourcePath(str);
        String[] strings = configuration.getStrings(GlobalConfigurationKeys.getContainerResourcesKey());
        if (null != strings) {
            for (String str2 : strings) {
                CommonUtils.addResource(str2, this.containerResources, hdfs);
            }
        }
        CommonUtils.addResource(new Path(this.appResourcesPath, Constants.GLOBAL_FINAL_XML), this.containerResources, hdfs, LocalResourceType.FILE, Constants.GLOBAL_FINAL_XML);
        CommonUtils.addResource(new Path(this.appResourcesPath, Constants.JAR_LIB_ZIP), this.containerResources, hdfs, LocalResourceType.ARCHIVE, Constants.JAR_LIB_ROOT);
        getAllTokens();
    }

    public void onContainersCompleted(List<ContainerStatus> list) {
        LOG.info("Completed containers: " + list.size());
        for (ContainerStatus containerStatus : list) {
            int exitStatus = containerStatus.getExitStatus();
            LOG.info("ContainerID = " + containerStatus.getContainerId() + ", state = " + containerStatus.getState() + ", exitStatus = " + exitStatus);
            String diagnostics = containerStatus.getDiagnostics();
            if (0 != exitStatus) {
                LOG.error(diagnostics);
            } else {
                LOG.info(diagnostics);
            }
            TensorflowTask taskByContainerId = this.session.getTaskByContainerId(containerStatus.getContainerId());
            if (taskByContainerId != null) {
                LOG.warn("container : [" + containerStatus.getContainerId() + "] isregister!" + taskByContainerId.isRegister());
                if (taskByContainerId.isRegister()) {
                    this.session.onTaskCompleted(taskByContainerId.getJobName(), taskByContainerId.getTaskIndex(), exitStatus);
                    LOG.info("Unregister task [" + taskByContainerId.getId() + "] from Heartbeat monitor..");
                    this.hbMonitor.unregister(taskByContainerId);
                } else {
                    LOG.warn("container : [" + containerStatus.getContainerId() + "] does not register!");
                }
            } else {
                LOG.warn("No task found for container : [" + containerStatus.getContainerId() + "]!");
            }
        }
    }

    public void onContainersAllocated(List<Container> list) {
        LOG.info("Allocated: " + list.size() + " containers.");
        for (Container container : list) {
            LOG.info("Launching a task in container, 3 = " + container.getId() + ", containerNode = " + container.getNodeId().getHost() + ":" + container.getNodeId().getPort() + ", resourceRequest = " + container.getResource());
            TensorflowTask distributeTaskToContainer = this.session.distributeTaskToContainer(container);
            Preconditions.checkNotNull(distributeTaskToContainer, "Task was null! Nothing to schedule.");
            CommonUtils.printTaskUrl(distributeTaskToContainer.getTaskUrl(), LOG);
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList(5);
            arrayList2.add(distributeTaskToContainer.getTaskCommand());
            arrayList2.add("1><LOG_DIR>/stdout");
            arrayList2.add("2><LOG_DIR>/stderr");
            StringBuilder sb = new StringBuilder();
            Iterator it = arrayList2.iterator();
            while (it.hasNext()) {
                sb.append((CharSequence) it.next()).append(" ");
            }
            arrayList.add(sb.toString());
            LOG.info("Constructed command: " + arrayList);
            HashMap hashMap = new HashMap(2);
            hashMap.put(ApplicationAccessType.VIEW_APP, "*");
            hashMap.put(ApplicationAccessType.MODIFY_APP, " ");
            ContainerLaunchContext newInstance = ContainerLaunchContext.newInstance(this.containerResources, this.containerEnv, arrayList, (Map) null, (ByteBuffer) null, hashMap);
            newInstance.setTokens(this.allTokens.slice());
            this.nmClientAsync.startContainerAsync(container, newInstance);
        }
        if (this.session.isAllTaskAssignedContainer()) {
            this.session.setState(TensorflowSession.SessionState.REGESTERING_CLUSTER);
            this.session.setStartTimeOfRegisteringCluster(System.currentTimeMillis());
            LOG.info("Session goes to REGESTERING_CLUSTER");
        }
    }

    private void getAllTokens() {
        try {
            Credentials credentials = UserGroupInformation.getCurrentUser().getCredentials();
            DataOutputBuffer dataOutputBuffer = new DataOutputBuffer();
            credentials.writeTokenStorageToStream(dataOutputBuffer);
            Iterator it = credentials.getAllTokens().iterator();
            while (it.hasNext()) {
                Token token = (Token) it.next();
                LOG.info("Token type : " + token.getKind());
                if (token.getKind().equals(AMRMTokenIdentifier.KIND_NAME)) {
                    it.remove();
                }
            }
            this.allTokens = ByteBuffer.wrap(dataOutputBuffer.getData(), 0, dataOutputBuffer.getLength());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void onShutdownRequest() {
    }

    public void onNodesUpdated(List<NodeReport> list) {
    }

    public float getProgress() {
        if (this.lastRunEpochs == -1 && this.session.getGlobalEpoch().get() > 1) {
            this.lastRunEpochs = this.session.getGlobalEpoch().get();
        }
        if (this.lastRunEpochs == -1) {
            if (this.session.getGlobalEpoch().get() > this.session.getTotalEpochs()) {
                return 1.0f;
            }
            return this.session.getGlobalEpoch().get() / this.session.getTotalEpochs();
        }
        if (this.session.getGlobalEpoch().get() <= this.lastRunEpochs - 1) {
            return 0.0f;
        }
        float totalEpochs = ((this.session.getGlobalEpoch().get() - this.lastRunEpochs) + 1) / this.session.getTotalEpochs();
        if (totalEpochs > 1.0f) {
            return 1.0f;
        }
        return totalEpochs;
    }

    public void onError(Throwable th) {
        LOG.error("Error: stop nmClientAsync", th);
        this.nmClientAsync.stop();
    }
}
