package org.flinkextended.flink.ml.tensorflow.util;

import com.google.common.base.Joiner;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.lang.ProcessBuilder;
import java.net.URI;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.concurrent.FutureTask;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.flinkextended.flink.ml.cluster.node.MLContext;
import org.flinkextended.flink.ml.util.ContextService;
import org.flinkextended.flink.ml.util.IpHostUtil;
import org.flinkextended.flink.ml.util.ShellExec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/util/JavaInferenceUtil.class */
public class JavaInferenceUtil {
    private static final Logger LOG = LoggerFactory.getLogger(JavaInferenceUtil.class);

    private JavaInferenceUtil() {
    }

    public static Server startTFContextService(MLContext mLContext) throws Exception {
        ContextService contextService = new ContextService();
        Server build = ServerBuilder.forPort(0).addService(contextService).build();
        build.start();
        mLContext.setNodeServerIP(IpHostUtil.getIpAddress());
        mLContext.setNodeServerPort(build.getPort());
        contextService.setMlContext(mLContext);
        return build;
    }

    public static FutureTask<Void> startInferenceProcessWatcher(Process process, MLContext mLContext) {
        Thread thread = new Thread((Runnable) new ShellExec.ProcessLogger(process.getInputStream(), new ShellExec.StdOutConsumer()));
        Thread thread2 = new Thread((Runnable) new ShellExec.ProcessLogger(process.getErrorStream(), new ShellExec.StdOutConsumer()));
        thread.setName(mLContext.getIdentity() + "-JavaInferenceProcess-in-logger");
        thread.setDaemon(true);
        thread2.setName(mLContext.getIdentity() + "-JavaInferenceProcess-err-logger");
        thread2.setDaemon(true);
        thread.start();
        thread2.start();
        FutureTask<Void> futureTask = new FutureTask<>(() -> {
            try {
                try {
                    int waitFor = process.waitFor();
                    thread.join();
                    thread2.join();
                    if (waitFor != 0) {
                        throw new RuntimeException("Java inference process exited with " + waitFor);
                    }
                    LOG.info("{} Java inference process finished successfully", mLContext.getIdentity());
                    process.destroyForcibly();
                } catch (InterruptedException e) {
                    LOG.info("{} Java inference process watcher interrupted, killing the process", mLContext.getIdentity());
                    process.destroyForcibly();
                }
            } catch (Throwable th) {
                process.destroyForcibly();
                throw th;
            }
        }, null);
        Thread thread3 = new Thread(futureTask);
        thread3.setName(mLContext.getIdentity() + "-JavaInferenceWatcher");
        thread3.setDaemon(true);
        thread3.start();
        return futureTask;
    }

    public static Process launchInferenceProcess(MLContext mLContext, RowTypeInfo rowTypeInfo, RowTypeInfo rowTypeInfo2) throws IOException {
        ArrayList arrayList = new ArrayList();
        arrayList.add(Joiner.on(File.separator).join(System.getProperty("java.home"), "bin", new Object[]{"java"}));
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(System.getProperty("java.class.path"));
        if (Thread.currentThread().getContextClassLoader() instanceof URLClassLoader) {
            for (URL url : ((URLClassLoader) Thread.currentThread().getContextClassLoader()).getURLs()) {
                arrayList2.add(url.toString());
            }
        }
        arrayList.add("-cp");
        arrayList.add(Joiner.on(File.pathSeparator).join(arrayList2));
        arrayList.add(JavaInferenceRunner.class.getCanonicalName());
        arrayList.add(String.format("%s:%d", mLContext.getNodeServerIP(), Integer.valueOf(mLContext.getNodeServerPort())));
        arrayList.add(serializeRowType(mLContext, rowTypeInfo).toString());
        arrayList.add(serializeRowType(mLContext, rowTypeInfo2).toString());
        LOG.info("Java Inference Cmd: " + Joiner.on(" ").join(arrayList));
        ProcessBuilder processBuilder = new ProcessBuilder(arrayList);
        processBuilder.redirectOutput(ProcessBuilder.Redirect.INHERIT);
        return processBuilder.start();
    }

    private static URI serializeRowType(MLContext mLContext, RowTypeInfo rowTypeInfo) throws IOException {
        File createTempFile = mLContext.createTempFile("RowType", (String) null);
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(createTempFile));
        Throwable th = null;
        try {
            try {
                objectOutputStream.writeObject(rowTypeInfo);
                if (objectOutputStream != null) {
                    if (0 != 0) {
                        try {
                            objectOutputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        objectOutputStream.close();
                    }
                }
                return createTempFile.toURI();
            } finally {
            }
        } catch (Throwable th3) {
            if (objectOutputStream != null) {
                if (th != null) {
                    try {
                        objectOutputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    objectOutputStream.close();
                }
            }
            throw th3;
        }
    }
}
