package org.flinkextended.flink.ml.pytorch;

import java.lang.invoke.SerializedLambda;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.iteration.DataStreamList;
import org.apache.flink.iteration.IterationBody;
import org.apache.flink.iteration.IterationBodyResult;
import org.apache.flink.iteration.IterationListener;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.util.Collector;
import org.apache.flink.util.OutputTag;
import org.flinkextended.flink.ml.operator.client.NodeUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Internal
/* loaded from: input_file:org/flinkextended/flink/ml/pytorch/PyTorchNodeIterationBody.class */
public class PyTorchNodeIterationBody implements IterationBody {
    private final StreamExecutionEnvironment env;
    private final PyTorchClusterConfig pyTorchClusterConfig;
    private final Integer maxEpoch;
    private final Configuration flinkConfig;

    /* loaded from: input_file:org/flinkextended/flink/ml/pytorch/PyTorchNodeIterationBody$TerminateOnEpoch.class */
    public static class TerminateOnEpoch implements IterationListener<Integer>, FlatMapFunction<Integer, Integer> {
        private static final Logger LOG = LoggerFactory.getLogger(TerminateOnEpoch.class);
        private final Integer maxEpoch;
        private boolean earlyTerminated = true;

        public TerminateOnEpoch(Integer num) {
            this.maxEpoch = num;
        }

        public void flatMap(Integer num, Collector<Integer> collector) {
            this.earlyTerminated = false;
        }

        public void onEpochWatermarkIncremented(int i, IterationListener.Context context, Collector<Integer> collector) {
            if (this.earlyTerminated) {
                LOG.info("Early Terminated at epoch {}", Integer.valueOf(i));
            } else if (i >= this.maxEpoch.intValue() - 1) {
                LOG.info("Terminate at epoch {}", Integer.valueOf(i));
            } else {
                collector.collect(Integer.valueOf(i));
                this.earlyTerminated = true;
            }
        }

        public void onIterationTerminated(IterationListener.Context context, Collector<Integer> collector) {
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Integer) obj, (Collector<Integer>) collector);
        }
    }

    public PyTorchNodeIterationBody(StreamExecutionEnvironment streamExecutionEnvironment, PyTorchClusterConfig pyTorchClusterConfig, Integer num, Configuration configuration) {
        this.env = streamExecutionEnvironment;
        this.pyTorchClusterConfig = pyTorchClusterConfig;
        this.maxEpoch = num;
        this.flinkConfig = configuration;
    }

    public IterationBodyResult process(DataStreamList dataStreamList, DataStreamList dataStreamList2) {
        DataStream scheduleNodes = NodeUtils.scheduleNodes(this.env, dataStreamList2.get(0), this.pyTorchClusterConfig, TypeInformation.of(Void.class), PyTorchClusterConfig.WORKER_NODE_TYPE, this.flinkConfig);
        return new IterationBodyResult(DataStreamList.of(new DataStream[]{dataStreamList.get(0).map(obj -> {
            return obj;
        }).setParallelism(1)}), DataStreamList.of(new DataStream[]{scheduleNodes}), scheduleNodes.getSideOutput(new OutputTag<Integer>("termination") { // from class: org.flinkextended.flink.ml.pytorch.PyTorchNodeIterationBody.1
        }).flatMap(new TerminateOnEpoch(this.maxEpoch)).name("TerminationDecider").setParallelism(1));
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -621975376:
                if (implMethodName.equals("lambda$process$76355c68$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/flinkextended/flink/ml/pytorch/PyTorchNodeIterationBody") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;")) {
                    return obj -> {
                        return obj;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
