package io.activej.dataflow;

import io.activej.async.exception.AsyncCloseException;
import io.activej.bytebuf.ByteBuf;
import io.activej.common.ApplicationSettings;
import io.activej.common.MemSize;
import io.activej.csp.ChannelConsumer;
import io.activej.csp.binary.ByteBufsCodec;
import io.activej.csp.dsl.ChannelTransformer;
import io.activej.csp.net.Messaging;
import io.activej.csp.net.MessagingWithBinaryStreaming;
import io.activej.csp.queue.ChannelQueue;
import io.activej.csp.queue.ChannelZeroBuffer;
import io.activej.dataflow.command.DataflowCommand;
import io.activej.dataflow.command.DataflowCommandDownload;
import io.activej.dataflow.command.DataflowCommandExecute;
import io.activej.dataflow.command.DataflowCommandGetTasks;
import io.activej.dataflow.command.DataflowResponse;
import io.activej.dataflow.command.DataflowResponsePartitionData;
import io.activej.dataflow.command.DataflowResponseResult;
import io.activej.dataflow.command.DataflowResponseTaskData;
import io.activej.dataflow.graph.StreamId;
import io.activej.dataflow.graph.Task;
import io.activej.dataflow.inject.BinarySerializerModule;
import io.activej.datastream.StreamConsumer;
import io.activej.datastream.csp.ChannelSerializer;
import io.activej.eventloop.Eventloop;
import io.activej.inject.ResourceLocator;
import io.activej.jmx.api.attribute.JmxAttribute;
import io.activej.jmx.api.attribute.JmxOperation;
import io.activej.net.AbstractServer;
import io.activej.net.socket.tcp.AsyncTcpSocket;
import io.activej.promise.Promise;
import io.activej.promise.SettablePromise;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.net.InetAddress;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;
import org.jetbrains.annotations.Nullable;

/* loaded from: input_file:io/activej/dataflow/DataflowServer.class */
public final class DataflowServer extends AbstractServer<DataflowServer> {
    private static final int MAX_LAST_RAN_TASKS = ApplicationSettings.getInt(DataflowServer.class, "maxLastRanTasks", 1000);
    private final Map<StreamId, ChannelQueue<ByteBuf>> pendingStreams;
    private final Map<Class, BiConsumer<Messaging<DataflowCommand, DataflowResponse>, ?>> handlers;
    private final ByteBufsCodec<DataflowCommand, DataflowResponse> codec;
    private final BinarySerializerModule.BinarySerializerLocator serializers;
    private final Map<Long, Task> runningTasks;
    private final Map<Long, Task> lastTasks;
    private int succeededTasks;
    private int canceledTasks;
    private int failedTasks;

    private <T> void handleCommand(Class<T> cls, BiConsumer<Messaging<DataflowCommand, DataflowResponse>, T> biConsumer) {
        this.handlers.put(cls, biConsumer);
    }

    public DataflowServer(Eventloop eventloop, ByteBufsCodec<DataflowCommand, DataflowResponse> byteBufsCodec, BinarySerializerModule.BinarySerializerLocator binarySerializerLocator, ResourceLocator resourceLocator) {
        super(eventloop);
        this.pendingStreams = new HashMap();
        this.handlers = new HashMap();
        this.runningTasks = new HashMap();
        this.lastTasks = new LinkedHashMap<Long, Task>() { // from class: io.activej.dataflow.DataflowServer.1
            @Override // java.util.LinkedHashMap
            protected boolean removeEldestEntry(Map.Entry<Long, Task> entry) {
                return size() > DataflowServer.MAX_LAST_RAN_TASKS;
            }
        };
        this.succeededTasks = 0;
        this.canceledTasks = 0;
        this.failedTasks = 0;
        this.codec = byteBufsCodec;
        this.serializers = binarySerializerLocator;
        handleCommand(DataflowCommandDownload.class, (messaging, dataflowCommandDownload) -> {
            if (this.logger.isTraceEnabled()) {
                this.logger.trace("Processing onDownload: {}, {}", dataflowCommandDownload, messaging);
            }
            StreamId streamId = dataflowCommandDownload.getStreamId();
            ChannelQueue<ByteBuf> remove = this.pendingStreams.remove(streamId);
            if (remove != null) {
                this.logger.info("onDownload: transferring {}, pending downloads: {}", streamId, Integer.valueOf(this.pendingStreams.size()));
            } else {
                remove = new ChannelZeroBuffer<>();
                this.pendingStreams.put(streamId, remove);
                this.logger.info("onDownload: waiting {}, pending downloads: {}", streamId, Integer.valueOf(this.pendingStreams.size()));
                messaging.receive().whenException(() -> {
                    if (this.pendingStreams.remove(streamId) != null) {
                        this.logger.info("onDownload: removing {}, pending downloads: {}", streamId, Integer.valueOf(this.pendingStreams.size()));
                    }
                });
            }
            ChannelConsumer sendBinaryStream = messaging.sendBinaryStream();
            remove.getSupplier().streamTo(sendBinaryStream);
            sendBinaryStream.withAcknowledgement(promise -> {
                return promise.whenComplete((r6, th) -> {
                    if (th != null) {
                        this.logger.warn("Exception occurred while trying to send data", th);
                    }
                    messaging.close();
                });
            });
        });
        handleCommand(DataflowCommandExecute.class, (messaging2, dataflowCommandExecute) -> {
            long taskId = dataflowCommandExecute.getTaskId();
            Task task = new Task(taskId, resourceLocator, dataflowCommandExecute.getNodes());
            try {
                task.bind();
                this.lastTasks.put(Long.valueOf(taskId), task);
                this.runningTasks.put(Long.valueOf(taskId), task);
                task.execute().whenComplete((r10, th) -> {
                    this.runningTasks.remove(Long.valueOf(taskId));
                    if (th == null) {
                        this.succeededTasks++;
                        this.logger.info("Task executed successfully: {}", dataflowCommandExecute);
                    } else if (th instanceof AsyncCloseException) {
                        this.canceledTasks++;
                        this.logger.error("Canceled task: {}", dataflowCommandExecute, th);
                    } else {
                        this.failedTasks++;
                        this.logger.error("Failed to execute task: {}", dataflowCommandExecute, th);
                    }
                    sendResponse(messaging2, th);
                });
                messaging2.receive().whenException(() -> {
                    if (task.isExecuted()) {
                        return;
                    }
                    this.logger.error("Client disconnected. Canceling task: {}", dataflowCommandExecute);
                    task.cancel();
                });
            } catch (Exception e) {
                this.logger.error("Failed to construct task: {}", dataflowCommandExecute, e);
                sendResponse(messaging2, e);
            }
        });
        handleCommand(DataflowCommandGetTasks.class, (messaging3, dataflowCommandGetTasks) -> {
            String str;
            Long taskId = dataflowCommandGetTasks.getTaskId();
            if (taskId == null) {
                messaging3.send(new DataflowResponsePartitionData(this.runningTasks.size(), this.succeededTasks, this.failedTasks, this.canceledTasks, (List) this.lastTasks.entrySet().stream().map(entry -> {
                    return new DataflowResponsePartitionData.TaskDesc(((Long) entry.getKey()).longValue(), ((Task) entry.getValue()).getStatus());
                }).collect(Collectors.toList()))).whenException(th -> {
                    this.logger.error("Failed to send answer for the partition data request", th);
                });
                return;
            }
            Task task = this.lastTasks.get(taskId);
            if (task == null) {
                messaging3.send(new DataflowResponseResult("No task found with id " + taskId));
                return;
            }
            if (task.getError() != null) {
                StringWriter stringWriter = new StringWriter();
                task.getError().printStackTrace(new PrintWriter(stringWriter));
                str = stringWriter.toString();
            } else {
                str = null;
            }
            messaging3.send(new DataflowResponseTaskData(task.getStatus(), task.getStartTime(), task.getFinishTime(), str, (Map) task.getNodes().stream().filter(node -> {
                return node.getStats() != null;
            }).collect(Collectors.toMap((v0) -> {
                return v0.getIndex();
            }, (v0) -> {
                return v0.getStats();
            })), task.getGraphViz())).whenException(th2 -> {
                this.logger.error("Failed to send answer for the task (" + taskId + ") data request", th2);
            });
        });
    }

    private void sendResponse(Messaging<DataflowCommand, DataflowResponse> messaging, @Nullable Throwable th) {
        String str = null;
        if (th != null) {
            str = th.getClass().getSimpleName() + ": " + th.getMessage();
        }
        Promise send = messaging.send(new DataflowResponseResult(str));
        Objects.requireNonNull(messaging);
        send.whenComplete(messaging::close);
    }

    public <T> StreamConsumer<T> upload(StreamId streamId, Class<T> cls, ChannelTransformer<ByteBuf, ByteBuf> channelTransformer) {
        ChannelSerializer withExplicitEndOfStream = ChannelSerializer.create(this.serializers.get(cls)).withInitialBufferSize(MemSize.kilobytes(256L)).withAutoFlushInterval(Duration.ZERO).withExplicitEndOfStream();
        ChannelQueue<ByteBuf> remove = this.pendingStreams.remove(streamId);
        if (remove == null) {
            remove = new ChannelZeroBuffer<>();
            this.pendingStreams.put(streamId, remove);
            this.logger.info("onUpload: waiting {}, pending downloads: {}", streamId, Integer.valueOf(this.pendingStreams.size()));
        } else {
            this.logger.info("onUpload: transferring {}, pending downloads: {}", streamId, Integer.valueOf(this.pendingStreams.size()));
        }
        withExplicitEndOfStream.getOutput().set((ChannelConsumer) remove.getConsumer().transformWith(channelTransformer));
        withExplicitEndOfStream.getAcknowledgement().whenException(() -> {
            ChannelQueue<ByteBuf> remove2 = this.pendingStreams.remove(streamId);
            if (remove2 != null) {
                this.logger.info("onUpload: removing {}, pending downloads: {}", streamId, Integer.valueOf(this.pendingStreams.size()));
                remove2.close();
            }
        });
        return withExplicitEndOfStream;
    }

    public <T> StreamConsumer<T> upload(StreamId streamId, Class<T> cls) {
        return upload(streamId, cls, ChannelTransformer.identity());
    }

    protected void serve(AsyncTcpSocket asyncTcpSocket, InetAddress inetAddress) {
        MessagingWithBinaryStreaming create = MessagingWithBinaryStreaming.create(asyncTcpSocket, this.codec);
        create.receive().whenResult(dataflowCommand -> {
            if (dataflowCommand != null) {
                doRead(create, dataflowCommand);
            } else {
                this.logger.warn("unexpected end of stream");
                create.close();
            }
        }).whenException(th -> {
            this.logger.error("received error while trying to read", th);
            create.close();
        });
    }

    private void doRead(Messaging<DataflowCommand, DataflowResponse> messaging, DataflowCommand dataflowCommand) {
        BiConsumer<Messaging<DataflowCommand, DataflowResponse>, ?> biConsumer = this.handlers.get(dataflowCommand.getClass());
        if (biConsumer != null) {
            biConsumer.accept(messaging, dataflowCommand);
        } else {
            this.logger.error("missing handler for {}", dataflowCommand);
            messaging.close();
        }
    }

    protected void onClose(SettablePromise<Void> settablePromise) {
        ArrayList arrayList = new ArrayList(this.pendingStreams.values());
        this.pendingStreams.clear();
        arrayList.forEach((v0) -> {
            v0.close();
        });
        settablePromise.set((Object) null);
    }

    public Map<Long, Task> getLastTasks() {
        return this.lastTasks;
    }

    @JmxAttribute
    public int getRunningTasks() {
        return this.runningTasks.size();
    }

    @JmxAttribute
    public int getSucceededTasks() {
        return this.succeededTasks;
    }

    @JmxAttribute
    public int getFailedTasks() {
        return this.failedTasks;
    }

    @JmxAttribute
    public int getCanceledTasks() {
        return this.canceledTasks;
    }

    @JmxOperation
    public void cancelAll() {
        this.runningTasks.values().forEach((v0) -> {
            v0.cancel();
        });
    }

    @JmxOperation
    public boolean cancel(long j) {
        Task task = this.runningTasks.get(Long.valueOf(j));
        if (task == null) {
            return false;
        }
        task.cancel();
        return true;
    }

    @JmxOperation
    public void cancelTask(long j) {
        Task task = this.runningTasks.get(Long.valueOf(j));
        if (task != null) {
            task.cancel();
        }
    }
}
