package org.flinkextended.flink.ml.pytorch;

import java.lang.invoke.SerializedLambda;
import java.net.URL;
import java.util.concurrent.ExecutionException;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.TableDescriptor;
import org.apache.flink.table.api.bridge.java.StreamStatementSet;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/flinkextended/flink/ml/pytorch/PyTorchUtilsTest.class */
public class PyTorchUtilsTest {
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private StreamStatementSet statementSet;

    @Before
    public void setUp() throws Exception {
        this.env = StreamExecutionEnvironment.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.statementSet = this.tEnv.createStatementSet();
    }

    @Test
    public void testTrainWithoutInput() throws ExecutionException, InterruptedException {
        PyTorchUtils.train(this.statementSet, PyTorchClusterConfig.newBuilder().setNodeEntry(getScriptPathFromResources("all_gather.py"), "main").setWorldSize(3).build());
        this.statementSet.execute().await();
    }

    @Test
    public void testTrainWithInput() throws ExecutionException, InterruptedException {
        PyTorchUtils.train(this.statementSet, this.tEnv.fromDataStream(this.env.fromElements(new Integer[]{1, 2, 3, 4, 5, 6})), PyTorchClusterConfig.newBuilder().setNodeEntry(getScriptPathFromResources("with_input.py"), "main").setProperty("input_types", "INT_32").setWorldSize(3).build());
        this.statementSet.execute().await();
    }

    @Test
    public void testIterationTrain() throws ExecutionException, InterruptedException {
        PyTorchUtils.train(this.statementSet, this.tEnv.fromDataStream(this.env.fromElements(new Integer[]{1, 2, 3, 4}).map(num -> {
            return num;
        }).setParallelism(2)), PyTorchClusterConfig.newBuilder().setNodeEntry(getScriptPathFromResources("with_input_iter.py"), "main").setWorldSize(2).setProperty("input_types", "INT_32").build(), 4);
        this.statementSet.execute().await();
    }

    @Test
    public void testIterationTrainWithEarlyTermination() throws ExecutionException, InterruptedException {
        PyTorchUtils.train(this.statementSet, this.tEnv.fromDataStream(this.env.fromElements(new Integer[]{1, 2, 3, 4}).map(num -> {
            return num;
        }).setParallelism(2)), PyTorchClusterConfig.newBuilder().setNodeEntry(getScriptPathFromResources("with_input_iter.py"), "main").setWorldSize(2).setProperty("input_types", "INT_32").build(), Integer.MAX_VALUE);
        this.statementSet.execute().await();
    }

    @Test
    public void testInference() throws ExecutionException, InterruptedException {
        this.statementSet.addInsert(TableDescriptor.forConnector("print").build(), PyTorchUtils.inference(this.statementSet, this.tEnv.fromDataStream(this.env.fromElements(new Integer[]{1, 2, 3, 4, 5, 6})), PyTorchClusterConfig.newBuilder().setNodeEntry(getScriptPathFromResources("inference.py"), "main").setProperty("input_types", "INT_32").setProperty("output_types", "INT_32,INT_32").setWorldSize(3).build(), Schema.newBuilder().column("x", DataTypes.INT()).column("y", DataTypes.INT()).build()));
        this.statementSet.execute().await();
    }

    private String getScriptPathFromResources(String str) {
        URL resource = Thread.currentThread().getContextClassLoader().getResource(str);
        Assert.assertNotNull(resource);
        return resource.getPath();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1723351535:
                if (implMethodName.equals("lambda$testIterationTrainWithEarlyTermination$920cfdd9$1")) {
                    z = false;
                    break;
                }
                break;
            case 1935661046:
                if (implMethodName.equals("lambda$testIterationTrain$920cfdd9$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/flinkextended/flink/ml/pytorch/PyTorchUtilsTest") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Integer;)Ljava/lang/Integer;")) {
                    return num -> {
                        return num;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/flinkextended/flink/ml/pytorch/PyTorchUtilsTest") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Integer;)Ljava/lang/Integer;")) {
                    return num2 -> {
                        return num2;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
