package org.flinkextended.flink.ml.tensorflow.ops.table;

import com.google.common.base.Preconditions;
import io.grpc.Server;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.util.Collections;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.FutureTask;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.functions.FunctionContext;
import org.apache.flink.table.functions.TableFunction;
import org.apache.flink.types.Row;
import org.flinkextended.flink.ml.cluster.ExecutionMode;
import org.flinkextended.flink.ml.cluster.MLConfig;
import org.flinkextended.flink.ml.cluster.node.MLContext;
import org.flinkextended.flink.ml.cluster.role.BaseRole;
import org.flinkextended.flink.ml.data.DataExchange;
import org.flinkextended.flink.ml.operator.util.TypeUtil;
import org.flinkextended.flink.ml.tensorflow.util.JavaInferenceUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/ops/table/TFTableInferenceJavaFunction.class */
public class TFTableInferenceJavaFunction extends TableFunction<Row> {
    private static Logger LOG = LoggerFactory.getLogger(TFTableInferenceJavaFunction.class);
    private final BaseRole role;
    private final MLConfig mlConfig;
    private final RowTypeInfo inRowType;
    private final RowTypeInfo outRowType;
    private transient DataExchange<Row, Row> dataExchange;
    private transient Server server;
    private transient MLContext mlContext;
    private transient FutureTask<Void> processFuture;
    private transient long numWritten = 0;
    private transient long numRead = 0;

    public TFTableInferenceJavaFunction(BaseRole baseRole, MLConfig mLConfig, TableSchema tableSchema, TableSchema tableSchema2) {
        this.role = baseRole;
        this.mlConfig = mLConfig;
        this.inRowType = TypeUtil.schemaToRowTypeInfo(tableSchema);
        this.outRowType = TypeUtil.schemaToRowTypeInfo(tableSchema2);
    }

    public void open(FunctionContext functionContext) throws Exception {
        super.open(functionContext);
        this.mlContext = new MLContext(ExecutionMode.INFERENCE, this.mlConfig, this.role.toString(), -1, this.mlConfig.getEnvPath(), Collections.emptyMap());
        this.server = JavaInferenceUtil.startTFContextService(this.mlContext);
        this.dataExchange = new DataExchange<>(this.mlContext);
        this.processFuture = JavaInferenceUtil.startInferenceProcessWatcher(JavaInferenceUtil.launchInferenceProcess(this.mlContext, this.inRowType, this.outRowType), this.mlContext);
    }

    public void eval(Object... objArr) {
        Preconditions.checkArgument(objArr.length == this.inRowType.getArity(), "Input fields length mismatch");
        Preconditions.checkState(!this.processFuture.isDone(), "Java inference process already finished");
        Row row = new Row(this.inRowType.getArity());
        for (int i = 0; i < this.inRowType.getArity(); i++) {
            row.setField(i, objArr[i]);
        }
        try {
            drainRead(false);
            while (!this.dataExchange.write(row)) {
                Preconditions.checkState(!this.processFuture.isDone(), "Java inference process already finished");
                try {
                    this.processFuture.get(1000L, TimeUnit.MILLISECONDS);
                } catch (TimeoutException e) {
                }
                drainRead(false);
            }
            this.numWritten++;
        } catch (InterruptedIOException e2) {
            LOG.info("{} interrupted reading from inference process");
        } catch (IOException e3) {
            throw new RuntimeException("Error interacting with Java inference process", e3);
        } catch (InterruptedException e4) {
            LOG.info("{} interrupted evaluating rows", this.mlContext.getIdentity());
        } catch (ExecutionException e5) {
            throw new RuntimeException("Java inference process failed", e5);
        }
    }

    public void close() throws Exception {
        try {
            if (this.mlContext != null && this.mlContext.getOutputQueue() != null) {
                this.mlContext.getOutputQueue().markFinished();
            }
            if (this.processFuture != null) {
                while (!this.processFuture.isDone()) {
                    drainRead(false);
                }
                this.processFuture.get();
                drainRead(true);
            }
            Preconditions.checkState(this.numWritten == this.numRead, String.format("Wrote %d records to inference process but read %d from it", Long.valueOf(this.numWritten), Long.valueOf(this.numRead)));
        } finally {
            if (this.processFuture != null) {
                this.processFuture.cancel(true);
            }
            if (this.mlContext != null) {
                this.mlContext.close();
            }
            if (this.server != null) {
                this.server.shutdown();
            }
        }
    }

    public String toString() {
        return (String) this.mlConfig.getProperties().getOrDefault("flink.vertex.name", this.role.name());
    }

    private void drainRead(boolean z) throws IOException {
        Object read = this.dataExchange.read(z);
        while (true) {
            Row row = (Row) read;
            if (row == null) {
                return;
            }
            collect(row);
            this.numRead++;
            read = this.dataExchange.read(z);
        }
    }

    public TypeInformation<Row> getResultType() {
        return this.outRowType;
    }
}
