package org.flinkextended.flink.ml.operator.ops;

import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.types.Row;
import org.flinkextended.flink.ml.cluster.ClusterConfig;
import org.flinkextended.flink.ml.cluster.node.MLContext;
import org.flinkextended.flink.ml.data.RecordWriter;
import org.flinkextended.flink.ml.operator.coding.RowCSVCoding;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/flinkextended/flink/ml/operator/ops/NodeOperatorTest.class */
public class NodeOperatorTest {
    private TestNodeOperator nodeOperator;
    static StringBuilder writtenSb;

    /* loaded from: input_file:org/flinkextended/flink/ml/operator/ops/NodeOperatorTest$TestNodeOperator.class */
    private static class TestNodeOperator extends NodeOperator<Integer> {
        private final AtomicBoolean finished;
        private final StreamingRuntimeContext runtimeContext;

        public TestNodeOperator(String str, ClusterConfig clusterConfig, StreamingRuntimeContext streamingRuntimeContext) {
            super(str, clusterConfig);
            this.finished = new AtomicBoolean(false);
            this.runtimeContext = streamingRuntimeContext;
            this.closeTimeoutMs = 1000L;
        }

        Runnable createNodeServerRunnable() {
            return () -> {
                while (!this.finished.get()) {
                    synchronized (this.finished) {
                        try {
                            if (this.finished.get()) {
                                return;
                            } else {
                                this.finished.wait();
                            }
                        } catch (InterruptedException e) {
                            return;
                        }
                    }
                }
            };
        }

        public void markFinish() {
            synchronized (this.finished) {
                this.finished.set(true);
                this.finished.notify();
            }
        }

        public StreamingRuntimeContext getRuntimeContext() {
            return this.runtimeContext;
        }
    }

    /* loaded from: input_file:org/flinkextended/flink/ml/operator/ops/NodeOperatorTest$TestRecordWriter.class */
    public static class TestRecordWriter implements RecordWriter {
        private final MLContext mlContext;

        public TestRecordWriter(MLContext mLContext) {
            this.mlContext = mLContext;
        }

        public boolean write(byte[] bArr, int i, int i2) throws IOException {
            throw new UnsupportedOperationException();
        }

        public boolean write(byte[] bArr) throws IOException {
            NodeOperatorTest.writtenSb.append(new String(bArr)).append("\n");
            return true;
        }

        public void close() throws IOException {
        }

        public MLContext getMlContext() {
            return this.mlContext;
        }
    }

    @Before
    public void setUp() throws Exception {
        ClusterConfig build = ClusterConfig.newBuilder().addNodeType("worker", 2).setNodeEntry("entry.py", "main").setProperty("sys:encoding_class", RowCSVCoding.class.getName()).setProperty("input_types", "INT_32").setProperty("sys:record_writer_class", TestRecordWriter.class.getName()).build();
        StreamingRuntimeContext streamingRuntimeContext = (StreamingRuntimeContext) Mockito.mock(StreamingRuntimeContext.class);
        TaskManagerRuntimeInfo taskManagerRuntimeInfo = (TaskManagerRuntimeInfo) Mockito.mock(TaskManagerRuntimeInfo.class);
        Mockito.when(taskManagerRuntimeInfo.getTmpDirectories()).thenReturn(new String[]{"/tmp"});
        Mockito.when(streamingRuntimeContext.getTaskManagerRuntimeInfo()).thenReturn(taskManagerRuntimeInfo);
        Mockito.when(streamingRuntimeContext.getJobId()).thenReturn(new JobID());
        this.nodeOperator = new TestNodeOperator("worker", build, streamingRuntimeContext);
        writtenSb = new StringBuilder();
    }

    @Test
    public void testOpen() throws Exception {
        Assert.assertNull(this.nodeOperator.getMlContext());
        Assert.assertNull(this.nodeOperator.getServerFuture());
        Assert.assertNull(this.nodeOperator.getDataExchange());
        this.nodeOperator.open();
        Assert.assertEquals("", this.nodeOperator.getMlContext().getEnvProperty("gpu_info"));
        Assert.assertNotNull(this.nodeOperator.getServerFuture());
        Assert.assertNotNull(this.nodeOperator.getDataExchange());
    }

    @Test
    public void testProcessElement() throws Exception {
        this.nodeOperator.open();
        Row row = new Row(1);
        row.setField(0, 0);
        this.nodeOperator.processElement(new StreamRecord(row));
        row.setField(0, 1);
        this.nodeOperator.processElement(new StreamRecord(row));
        Assert.assertEquals("0\n1\n", writtenSb.toString());
    }

    @Test
    public void testClose() throws Exception {
        this.nodeOperator.open();
        this.nodeOperator.markFinish();
        this.nodeOperator.getServerFuture().get();
        this.nodeOperator.getMlContext().getInputQueue().markFinished();
        this.nodeOperator.close();
        Assert.assertNull(this.nodeOperator.getServerFuture());
        Assert.assertNull(this.nodeOperator.getDataExchangeConsumerFuture());
    }

    @Test
    public void testCloseWithTimeout() throws Exception {
        this.nodeOperator.open();
        this.nodeOperator.close();
        Assert.assertNull(this.nodeOperator.getServerFuture());
        Assert.assertNull(this.nodeOperator.getDataExchangeConsumerFuture());
    }
}
