package org.flinkextended.flink.ml.operator.ops.inputformat;

import java.io.IOException;
import java.lang.Thread;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.configuration.Configuration;
import org.flinkextended.flink.ml.cluster.ClusterConfig;
import org.flinkextended.flink.ml.cluster.rpc.AppMasterServer;
import org.flinkextended.flink.ml.util.MLException;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/flinkextended/flink/ml/operator/ops/inputformat/AMInputFormatTest.class */
public class AMInputFormatTest {
    private AMInputFormat amInputFormat;

    @Before
    public void setUp() throws Exception {
        this.amInputFormat = new AMInputFormat(ClusterConfig.newBuilder().addNodeType("worker", 2).setNodeEntry("entry.py", "main").build());
        this.amInputFormat.setRuntimeContext((RuntimeContext) Mockito.mock(RuntimeContext.class));
    }

    @Test
    public void testConfigure() {
        this.amInputFormat.configure(new Configuration());
    }

    @Test
    public void testCreateInputSplits() throws IOException {
        NodeInputSplit[] createInputSplits = this.amInputFormat.createInputSplits(1);
        Assert.assertEquals(1L, createInputSplits.length);
        Assert.assertEquals(new NodeInputSplit(1, 0), createInputSplits[0]);
    }

    @Test(expected = IllegalStateException.class)
    public void testCreateMoreThanOneInputSplitsThrowException() throws IOException {
        this.amInputFormat.createInputSplits(2);
    }

    @Test
    public void testPrepareMLContext() throws MLException {
        Assert.assertEquals("AM", this.amInputFormat.prepareMLContext(0).getRoleName());
        Assert.assertEquals(0L, r0.getIndex());
    }

    @Test(expected = IllegalStateException.class)
    public void testPrepareMLContextWithIndexGreaterThanZeroThrowException() throws MLException {
        this.amInputFormat.prepareMLContext(1);
    }

    @Test
    public void testGetNodeServerRunnable() throws MLException {
        MatcherAssert.assertThat(this.amInputFormat.getNodeServerRunnable(this.amInputFormat.prepareMLContext(0)), Matchers.isA(AppMasterServer.class));
    }

    @Test
    public void testReachedEnd() throws IOException, InterruptedException {
        this.amInputFormat.open(new NodeInputSplit(1, 0));
        Thread thread = new Thread(() -> {
            try {
                Assert.assertTrue(this.amInputFormat.reachedEnd());
            } catch (IOException e) {
                e.printStackTrace();
            }
        });
        thread.start();
        Thread.sleep(1000L);
        Assert.assertEquals(Thread.State.WAITING, thread.getState());
        this.amInputFormat.getServerFuture().cancel(true);
        thread.join();
    }
}
