package org.flinkextended.flink.ml.tensorflow.util;

import com.google.common.base.Joiner;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.util.Arrays;
import java.util.concurrent.FutureTask;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.types.Row;
import org.apache.flink.types.RowKind;
import org.apache.flink.util.Preconditions;
import org.flinkextended.flink.ml.cluster.node.MLContext;
import org.flinkextended.flink.ml.cluster.rpc.NodeServer;
import org.flinkextended.flink.ml.data.DataExchange;
import org.flinkextended.flink.ml.operator.coding.RowCSVCoding;
import org.flinkextended.flink.ml.util.DummyContext;
import org.flinkextended.flink.ml.util.MLException;
import org.flinkextended.flink.ml.util.ShellExec;
import org.flinkextended.flink.ml.util.TestUtil;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.mockito.internal.util.reflection.Whitebox;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

/* loaded from: input_file:org/flinkextended/flink/ml/tensorflow/util/JavaInferenceRunnerTest.class */
public class JavaInferenceRunnerTest {
    private static final String rootPath = TestUtil.getProjectRootPath() + "/dl-on-flink-tensorflow-2.x";
    private FutureTask<Void> nodeFuture;
    private MLContext mlContext;
    private NodeServer nodeServer;

    @Before
    public void setUp() throws Exception {
        Path path = Paths.get(Files.createTempDirectory("", new FileAttribute[0]).toUri().getPath(), "model");
        this.mlContext = DummyContext.createDummyMLContext();
        Preconditions.checkState(ShellExec.run(String.format("python %s %s", rootPath + "/src/test/python/mnist_model.py", path.toUri().getPath()), new ShellExec.StdOutConsumer()));
        this.mlContext.getProperties().put("tf.inference.export.path", path.toUri().getPath());
        this.mlContext.getProperties().put("tf.inference.input.tensor.names", "image");
        this.mlContext.getProperties().put("tf.inference.output.tensor.names", "prediction");
        this.mlContext.getProperties().put("tf.inference.output.row.fields", Joiner.on(",").join(new String[]{"prediction"}));
        setExampleCodingType(this.mlContext);
        startNodeServer();
    }

    @After
    public void tearDown() throws Exception {
        this.nodeServer.setAmCommand(NodeServer.AMCommand.STOP);
        this.nodeFuture.get();
    }

    @Test
    public void testJavaInferenceRunner() throws Exception {
        while (this.nodeServer.getPort() == null) {
            Thread.sleep(1000L);
        }
        JavaInferenceRunner javaInferenceRunner = new JavaInferenceRunner("localhost", this.nodeServer.getPort().intValue(), new RowTypeInfo(new TypeInformation[]{Types.PRIMITIVE_ARRAY(Types.FLOAT)}, new String[]{"image"}), new RowTypeInfo(new TypeInformation[]{Types.LONG}, new String[]{"prediction"}));
        DataExchange dataExchange = (DataExchange) Mockito.spy((DataExchange) Whitebox.getInternalState(javaInferenceRunner, "dataExchange"));
        Whitebox.setInternalState(javaInferenceRunner, "dataExchange", dataExchange);
        Mockito.when(dataExchange.read(Mockito.anyBoolean())).thenAnswer(new Answer<Row>() { // from class: org.flinkextended.flink.ml.tensorflow.util.JavaInferenceRunnerTest.1
            private int callCnt = 0;

            /* renamed from: answer, reason: merged with bridge method [inline-methods] */
            public Row m4answer(InvocationOnMock invocationOnMock) throws Throwable {
                if (this.callCnt > 0) {
                    return null;
                }
                this.callCnt++;
                Row row = new Row(RowKind.INSERT, 1);
                float[] fArr = new float[784];
                Arrays.fill(fArr, 0.1f);
                row.setField(0, fArr);
                return row;
            }
        });
        javaInferenceRunner.run();
        ((DataExchange) Mockito.verify(dataExchange)).write(ArgumentCaptor.forClass(Row.class).capture());
        Assert.assertEquals(1L, ((Row) r0.getValue()).getArity());
        javaInferenceRunner.close();
    }

    private void setExampleCodingType(MLContext mLContext) {
        mLContext.getProperties().put("sys:encoding_class", RowCSVCoding.class.getCanonicalName());
        mLContext.getProperties().put("sys:decoding_class", RowCSVCoding.class.getCanonicalName());
    }

    private FutureTask<Void> startNodeServer() throws MLException {
        this.nodeServer = new NodeServer(this.mlContext, "worker");
        this.nodeFuture = new FutureTask<>(this.nodeServer, null);
        Thread thread = new Thread(this.nodeFuture);
        thread.setDaemon(true);
        thread.start();
        return this.nodeFuture;
    }
}
