package org.flinkextended.flink.ml.operator.ops.inputformat;

import java.io.IOException;
import java.util.Arrays;
import org.apache.flink.api.common.JobID;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
import org.flinkextended.flink.ml.cluster.ClusterConfig;
import org.flinkextended.flink.ml.cluster.rpc.NodeServer;
import org.flinkextended.flink.ml.util.MLException;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/flinkextended/flink/ml/operator/ops/inputformat/NodeInputFormatTest.class */
public class NodeInputFormatTest {
    private NodeInputFormat<Integer> nodeInputFormat;

    @Before
    public void setUp() throws Exception {
        this.nodeInputFormat = new NodeInputFormat<>("worker", ClusterConfig.newBuilder().addNodeType("worker", 2).setNodeEntry("entry.py", "main").build());
        StreamingRuntimeContext streamingRuntimeContext = (StreamingRuntimeContext) Mockito.mock(StreamingRuntimeContext.class);
        TaskManagerRuntimeInfo taskManagerRuntimeInfo = (TaskManagerRuntimeInfo) Mockito.mock(TaskManagerRuntimeInfo.class);
        Mockito.when(taskManagerRuntimeInfo.getTmpDirectories()).thenReturn(new String[]{"/tmp"});
        Mockito.when(streamingRuntimeContext.getTaskManagerRuntimeInfo()).thenReturn(taskManagerRuntimeInfo);
        Mockito.when(streamingRuntimeContext.getJobId()).thenReturn(new JobID());
        this.nodeInputFormat.setRuntimeContext(streamingRuntimeContext);
    }

    @Test
    public void testCreateInputSplits() throws IOException {
        NodeInputSplit[] createInputSplits = this.nodeInputFormat.createInputSplits(2);
        Assert.assertEquals(2L, createInputSplits.length);
        MatcherAssert.assertThat(Arrays.asList(createInputSplits), Matchers.hasItems(new NodeInputSplit[]{new NodeInputSplit(2, 0), new NodeInputSplit(2, 1)}));
    }

    @Test(expected = IllegalStateException.class)
    public void testCreatInputSplitsThrowExceptionWhenMoreSplitAreDesired() throws IOException {
        this.nodeInputFormat.createInputSplits(3);
    }

    @Test
    public void testPrepareMLContext() throws MLException {
        Assert.assertEquals("", this.nodeInputFormat.prepareMLContext(0).getEnvProperty("gpu_info"));
    }

    @Test
    public void testGetNodeServerRunnable() throws MLException {
        MatcherAssert.assertThat(this.nodeInputFormat.getNodeServerRunnable(this.nodeInputFormat.prepareMLContext(0)), Matchers.isA(NodeServer.class));
    }

    @Test
    public void testConfigure() {
        this.nodeInputFormat.configure(new Configuration());
    }
}
