package org.flinkextended.flink.ml.tensorflow.ops;

import com.google.common.base.Preconditions;
import io.grpc.Server;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.FutureTask;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.flinkextended.flink.ml.cluster.ExecutionMode;
import org.flinkextended.flink.ml.cluster.MLConfig;
import org.flinkextended.flink.ml.cluster.node.MLContext;
import org.flinkextended.flink.ml.cluster.role.BaseRole;
import org.flinkextended.flink.ml.data.DataExchange;
import org.flinkextended.flink.ml.tensorflow.util.JavaInferenceUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/ops/TFJavaInferenceFlatMap.class */
public class TFJavaInferenceFlatMap extends RichFlatMapFunction<Row, Row> implements ResultTypeQueryable<Row>, ListCheckpointed<Row> {
    private static final Logger LOG = LoggerFactory.getLogger(TFJavaInferenceFlatMap.class);
    private final BaseRole role;
    private final MLConfig mlConfig;
    private final RowTypeInfo inTypeInfo;
    private final RowTypeInfo outTypeInfo;
    private transient DataExchange<Row, Row> dataExchange;
    private transient Deque<Row> rowCache;
    private transient Collector<Row> collector;
    private transient Server server;
    private transient MLContext mlContext;
    private transient FutureTask<Void> processFuture;

    public TFJavaInferenceFlatMap(BaseRole baseRole, MLConfig mLConfig, TypeInformation typeInformation, TypeInformation typeInformation2) {
        this.role = baseRole;
        this.mlConfig = mLConfig;
        this.inTypeInfo = (RowTypeInfo) typeInformation;
        this.outTypeInfo = (RowTypeInfo) typeInformation2;
    }

    public void open(Configuration configuration) throws Exception {
        super.open(configuration);
        this.mlContext = new MLContext(ExecutionMode.INFERENCE, this.mlConfig, this.role.toString(), getRuntimeContext().getIndexOfThisSubtask(), this.mlConfig.getEnvPath(), Collections.emptyMap());
        this.server = JavaInferenceUtil.startTFContextService(this.mlContext);
        this.dataExchange = new DataExchange<>(this.mlContext);
        this.processFuture = JavaInferenceUtil.startInferenceProcessWatcher(JavaInferenceUtil.launchInferenceProcess(this.mlContext, this.inTypeInfo, this.outTypeInfo), this.mlContext);
        if (this.rowCache == null) {
            this.rowCache = new ArrayDeque();
            return;
        }
        LOG.info("{} replaying {} rows", this.mlContext.getIdentity(), Integer.valueOf(this.rowCache.size()));
        Iterator it = new ArrayList(this.rowCache).iterator();
        while (it.hasNext()) {
            writeToJavaInference((Row) it.next());
        }
    }

    public void close() throws Exception {
        try {
            if (this.mlContext != null && this.mlContext.getOutputQueue() != null) {
                this.mlContext.getOutputQueue().markFinished();
            }
            if (this.processFuture != null) {
                while (!this.processFuture.isDone()) {
                    drainRead(false);
                }
                this.processFuture.get();
                drainRead(true);
            }
            if (this.rowCache != null) {
                Preconditions.checkState(this.rowCache.isEmpty(), String.format("Still got %d unprocessed rows", Integer.valueOf(this.rowCache.size())));
            }
        } finally {
            if (this.processFuture != null) {
                this.processFuture.cancel(true);
            }
            if (this.mlContext != null) {
                this.mlContext.close();
            }
            if (this.server != null) {
                this.server.shutdown();
            }
        }
    }

    public void flatMap(Row row, Collector<Row> collector) throws Exception {
        Preconditions.checkArgument(row.getArity() == this.inTypeInfo.getArity(), "Input fields length mismatch");
        Preconditions.checkState(!this.processFuture.isDone(), "Java inference process already finished");
        this.collector = collector;
        drainRead(false);
        writeToJavaInference(row);
        this.rowCache.add(row);
        if (this.rowCache.size() / 1000 != (this.rowCache.size() - 1) / 1000) {
            LOG.info("{} Caching {} rows", this.mlContext.getIdentity(), Integer.valueOf(this.rowCache.size()));
        }
    }

    public TypeInformation<Row> getProducedType() {
        return this.outTypeInfo;
    }

    public List<Row> snapshotState(long j, long j2) throws Exception {
        return this.rowCache == null ? Collections.emptyList() : new ArrayList(this.rowCache);
    }

    public void restoreState(List<Row> list) throws Exception {
        LOG.info("Restoring from state with {} cached records", Integer.valueOf(list.size()));
        if (this.rowCache == null) {
            this.rowCache = new ArrayDeque();
        }
        this.rowCache.addAll(list);
    }

    private void drainRead(boolean z) throws IOException {
        Object read = this.dataExchange.read(z);
        while (true) {
            Row row = (Row) read;
            if (row == null) {
                return;
            }
            this.collector.collect(row);
            this.rowCache.remove();
            read = this.dataExchange.read(z);
        }
    }

    private void writeToJavaInference(Row row) throws IOException, ExecutionException, InterruptedException {
        while (!this.dataExchange.write(row)) {
            Preconditions.checkState(!this.processFuture.isDone(), "Java inference process already finished");
            try {
                this.processFuture.get(1000L, TimeUnit.MILLISECONDS);
            } catch (TimeoutException e) {
            }
            drainRead(false);
        }
    }

    public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
        flatMap((Row) obj, (Collector<Row>) collector);
    }
}
