package org.flinkextended.flink.ml.tensorflow.util;

import com.google.common.base.Preconditions;
import java.io.Closeable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.types.Row;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.flinkextended.flink.ml.cluster.node.MLContext;
import org.flinkextended.flink.ml.cluster.rpc.NodeClient;
import org.flinkextended.flink.ml.cluster.rpc.RpcCode;
import org.flinkextended.flink.ml.data.DataExchange;
import org.flinkextended.flink.ml.proto.ContextResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/util/JavaInferenceRunner.class */
public class JavaInferenceRunner implements Closeable {
    private static final Logger LOG = LoggerFactory.getLogger(JavaInferenceRunner.class);
    private static final Configuration HADOOP_CONF = new Configuration();
    private volatile boolean inputFinished;
    private final int batchSize;
    private final BlockingQueue<Row> batchCache;
    private final NodeClient nodeClient;
    private final MLContext mlContext;
    private final JavaInference javaInference;
    private transient DataExchange<Row, Row> dataExchange;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/util/JavaInferenceRunner$InputRowConsumer.class */
    public class InputRowConsumer implements Runnable {
        private long read;

        private InputRowConsumer() {
            this.read = 0L;
        }

        @Override // java.lang.Runnable
        public void run() {
            while (true) {
                try {
                    try {
                        Row row = (Row) JavaInferenceRunner.this.dataExchange.read(true);
                        if (row == null) {
                            JavaInferenceRunner.LOG.info("{} Input rows depleted", JavaInferenceRunner.this.mlContext.getIdentity());
                            JavaInferenceRunner.LOG.info("{} Read totally {} rows from flink", JavaInferenceRunner.this.mlContext.getIdentity(), Long.valueOf(this.read));
                            JavaInferenceRunner.this.inputFinished = true;
                            return;
                        } else {
                            this.read++;
                            if (this.read % 1000 == 0) {
                                JavaInferenceRunner.LOG.info("{} Read {} rows from flink", JavaInferenceRunner.this.mlContext.getIdentity(), Long.valueOf(this.read));
                            }
                            JavaInferenceRunner.this.batchCache.put(row);
                        }
                    } catch (IOException e) {
                        throw new RuntimeException(e);
                    } catch (InterruptedException e2) {
                        JavaInferenceRunner.LOG.info("{} interrupted", Thread.currentThread().getName());
                        JavaInferenceRunner.this.inputFinished = true;
                        return;
                    }
                } catch (Throwable th) {
                    JavaInferenceRunner.this.inputFinished = true;
                    throw th;
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/util/JavaInferenceRunner$OutputRowProducer.class */
    public class OutputRowProducer implements Runnable {
        private static final long INTERVAL = 1000;
        private final List<Object[]> batch;
        private long written;

        private OutputRowProducer() {
            this.batch = new ArrayList(JavaInferenceRunner.this.batchSize);
            this.written = 0L;
        }

        @Override // java.lang.Runnable
        public void run() {
            while (!JavaInferenceRunner.this.inputFinished) {
                try {
                    try {
                        Row row = (Row) JavaInferenceRunner.this.batchCache.poll(INTERVAL, TimeUnit.MILLISECONDS);
                        if (row != null) {
                            this.batch.add(rowToObjects(row));
                            while (this.batch.size() >= JavaInferenceRunner.this.batchSize) {
                                outputRows();
                            }
                        }
                    } catch (InterruptedException e) {
                        JavaInferenceRunner.LOG.info("{} interrupted", Thread.currentThread().getName());
                        JavaInferenceRunner.LOG.info("Closing output queue");
                        JavaInferenceRunner.this.mlContext.getOutputQueue().markFinished();
                        return;
                    } catch (Exception e2) {
                        JavaInferenceRunner.LOG.error("OutputRowProducer error", e2);
                        JavaInferenceRunner.LOG.info("Closing output queue");
                        JavaInferenceRunner.this.mlContext.getOutputQueue().markFinished();
                        return;
                    }
                } catch (Throwable th) {
                    JavaInferenceRunner.LOG.info("Closing output queue");
                    JavaInferenceRunner.this.mlContext.getOutputQueue().markFinished();
                    throw th;
                }
            }
            JavaInferenceRunner.LOG.info("{} Flush remaining {} records", JavaInferenceRunner.this.mlContext.getIdentity(), Integer.valueOf(JavaInferenceRunner.this.batchCache.size() + this.batch.size()));
            while (!JavaInferenceRunner.this.batchCache.isEmpty()) {
                this.batch.add(rowToObjects((Row) JavaInferenceRunner.this.batchCache.remove()));
            }
            while (!this.batch.isEmpty()) {
                outputRows();
            }
            JavaInferenceRunner.LOG.info("{} Written totally {} rows to flink", JavaInferenceRunner.this.mlContext.getIdentity(), Long.valueOf(this.written));
            JavaInferenceRunner.LOG.info("Closing output queue");
            JavaInferenceRunner.this.mlContext.getOutputQueue().markFinished();
        }

        private Object[] rowToObjects(Row row) {
            Object[] objArr = new Object[row.getArity()];
            for (int i = 0; i < objArr.length; i++) {
                objArr[i] = row.getField(i);
            }
            return objArr;
        }

        private void outputRows() throws IOException, InterruptedException {
            for (Row row : JavaInferenceRunner.this.javaInference.generateRowsOneBatch(this.batch, JavaInferenceRunner.this.batchSize)) {
                this.batch.remove(0);
                while (!JavaInferenceRunner.this.dataExchange.write(row)) {
                    Thread.sleep(INTERVAL);
                }
                this.written++;
                if (this.written % INTERVAL == 0) {
                    JavaInferenceRunner.LOG.info("{} Written {} rows to flink", JavaInferenceRunner.this.mlContext.getIdentity(), Long.valueOf(this.written));
                }
            }
        }
    }

    JavaInferenceRunner(String str, int i, String str2, String str3) throws Exception {
        this(str, i, readRowType(new Path(str2)), readRowType(new Path(str3)));
    }

    JavaInferenceRunner(String str, int i, RowTypeInfo rowTypeInfo, RowTypeInfo rowTypeInfo2) throws Exception {
        this.inputFinished = false;
        this.nodeClient = new NodeClient(str, i);
        ContextResponse mLContext = this.nodeClient.getMLContext();
        Preconditions.checkState(mLContext.getCode() == RpcCode.OK.ordinal(), "Failed to get TFContext");
        this.mlContext = MLContext.fromPB(mLContext.getContext());
        this.javaInference = new JavaInference((Map<String, String>) this.mlContext.getProperties(), rowTypeInfo.getFieldNames(), rowTypeInfo2.getFieldNames());
        this.batchSize = Integer.valueOf((String) this.mlContext.getProperties().getOrDefault("tf.inference.batch.size", "1")).intValue();
        LOG.info("{} java inference with batch size {}", this.mlContext.getIdentity(), Integer.valueOf(this.batchSize));
        this.batchCache = new ArrayBlockingQueue(this.batchSize);
        String str2 = (String) this.mlContext.getProperties().get("sys:input_tf_example_config");
        this.mlContext.getProperties().put("sys:input_tf_example_config", (String) this.mlContext.getProperties().get("sys:output_tf_example_config"));
        this.mlContext.getProperties().put("sys:output_tf_example_config", str2);
        this.dataExchange = new DataExchange<>(this.mlContext);
    }

    public void run() throws Exception {
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(1, runnable -> {
            Thread thread = new Thread(runnable);
            thread.setDaemon(true);
            thread.setName(this.mlContext.getIdentity() + "-" + InputRowConsumer.class.getSimpleName());
            return thread;
        });
        ExecutorService newFixedThreadPool2 = Executors.newFixedThreadPool(1, runnable2 -> {
            Thread thread = new Thread(runnable2);
            thread.setDaemon(true);
            thread.setName(this.mlContext.getIdentity() + "-" + OutputRowProducer.class.getSimpleName());
            return thread;
        });
        Future<?> submit = newFixedThreadPool.submit(new InputRowConsumer());
        Future<?> submit2 = newFixedThreadPool2.submit(new OutputRowProducer());
        try {
            try {
                submit.get();
                submit2.get();
                newFixedThreadPool.shutdownNow();
                if (!newFixedThreadPool.isTerminated()) {
                    newFixedThreadPool.awaitTermination(1L, TimeUnit.SECONDS);
                }
                newFixedThreadPool2.shutdownNow();
                if (newFixedThreadPool2.isTerminated()) {
                    return;
                }
                newFixedThreadPool2.awaitTermination(1L, TimeUnit.SECONDS);
            } catch (InterruptedException | ExecutionException e) {
                submit.cancel(true);
                submit2.cancel(true);
                newFixedThreadPool.shutdownNow();
                if (!newFixedThreadPool.isTerminated()) {
                    newFixedThreadPool.awaitTermination(1L, TimeUnit.SECONDS);
                }
                newFixedThreadPool2.shutdownNow();
                if (newFixedThreadPool2.isTerminated()) {
                    return;
                }
                newFixedThreadPool2.awaitTermination(1L, TimeUnit.SECONDS);
            }
        } catch (Throwable th) {
            newFixedThreadPool.shutdownNow();
            if (!newFixedThreadPool.isTerminated()) {
                newFixedThreadPool.awaitTermination(1L, TimeUnit.SECONDS);
            }
            newFixedThreadPool2.shutdownNow();
            if (!newFixedThreadPool2.isTerminated()) {
                newFixedThreadPool2.awaitTermination(1L, TimeUnit.SECONDS);
            }
            throw th;
        }
    }

    private static RowTypeInfo readRowType(Path path) throws IOException, ClassNotFoundException {
        ObjectInputStream objectInputStream = new ObjectInputStream(path.getFileSystem(HADOOP_CONF).open(path));
        Throwable th = null;
        try {
            try {
                RowTypeInfo rowTypeInfo = (RowTypeInfo) objectInputStream.readObject();
                if (objectInputStream != null) {
                    if (0 != 0) {
                        try {
                            objectInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        objectInputStream.close();
                    }
                }
                return rowTypeInfo;
            } finally {
            }
        } catch (Throwable th3) {
            if (objectInputStream != null) {
                if (th != null) {
                    try {
                        objectInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    objectInputStream.close();
                }
            }
            throw th3;
        }
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() throws IOException {
        if (this.nodeClient != null) {
            this.nodeClient.close();
        }
        if (this.mlContext != null) {
            this.mlContext.close();
        }
        if (this.javaInference != null) {
            this.javaInference.close();
        }
    }

    public static void main(String[] strArr) throws Exception {
        Preconditions.checkArgument(strArr.length == 3, "Takes three arguments, got " + Arrays.toString(strArr));
        String[] split = strArr[0].split(":");
        Preconditions.checkArgument(split.length == 2, String.format("Invalid tf node address %s, please specify in form <IP>:<PORT>", strArr[0]));
        JavaInferenceRunner javaInferenceRunner = new JavaInferenceRunner(split[0], Integer.valueOf(split[1]).intValue(), strArr[1], strArr[2]);
        javaInferenceRunner.run();
        javaInferenceRunner.close();
    }
}
