package io.activej.dataflow.dataset;

import io.activej.dataflow.graph.DataflowContext;
import io.activej.dataflow.graph.DataflowGraph;
import io.activej.dataflow.graph.Partition;
import io.activej.dataflow.graph.StreamId;
import io.activej.dataflow.node.NodeDownload;
import io.activej.dataflow.node.NodeReduce;
import io.activej.dataflow.node.NodeShard;
import io.activej.dataflow.node.NodeUpload;
import io.activej.datastream.processor.StreamReducers;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.IntStream;

/* loaded from: input_file:io/activej/dataflow/dataset/DatasetUtils.class */
public class DatasetUtils {
    public static <K, I, O, A> List<StreamId> repartitionAndReduce(DataflowContext dataflowContext, LocallySortedDataset<K, I> locallySortedDataset, StreamReducers.Reducer<K, I, O, A> reducer, List<Partition> list) {
        DataflowGraph graph = dataflowContext.getGraph();
        int nonce = dataflowContext.getNonce();
        Function<I, K> keyFunction = locallySortedDataset.keyFunction();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int generateNodeIndex = dataflowContext.generateNodeIndex();
        for (StreamId streamId : locallySortedDataset.channels(dataflowContext.withoutFixedNonce())) {
            Partition partition = graph.getPartition(streamId);
            NodeShard nodeShard = new NodeShard(generateNodeIndex, keyFunction, streamId, nonce);
            graph.addNode(partition, nodeShard);
            arrayList2.add(nodeShard);
        }
        int generateNodeIndex2 = dataflowContext.generateNodeIndex();
        int[] generateIndexes = generateIndexes(dataflowContext, arrayList2.size());
        int[] generateIndexes2 = generateIndexes(dataflowContext, list.size());
        for (int i = 0; i < list.size(); i++) {
            Partition partition2 = list.get(i);
            NodeReduce nodeReduce = new NodeReduce(generateNodeIndex2, locallySortedDataset.keyComparator());
            graph.addNode(partition2, nodeReduce);
            for (int i2 = 0; i2 < arrayList2.size(); i2++) {
                NodeShard nodeShard2 = (NodeShard) arrayList2.get(i2);
                StreamId newPartition = nodeShard2.newPartition();
                graph.addNodeStream(nodeShard2, newPartition);
                nodeReduce.addInput(forwardChannel(dataflowContext, locallySortedDataset.valueType(), newPartition, partition2, generateIndexes2[i], generateIndexes[i2]), keyFunction, reducer);
            }
            arrayList.add(nodeReduce.getOutput());
        }
        return arrayList;
    }

    public static <K, T> List<StreamId> repartitionAndSort(DataflowContext dataflowContext, LocallySortedDataset<K, T> locallySortedDataset, List<Partition> list) {
        return repartitionAndReduce(dataflowContext, locallySortedDataset, StreamReducers.mergeReducer(), list);
    }

    public static <T> StreamId forwardChannel(DataflowContext dataflowContext, Class<T> cls, StreamId streamId, Partition partition, int i, int i2) {
        return forwardChannel(dataflowContext, cls, dataflowContext.getGraph().getPartition(streamId), partition, streamId, i, i2);
    }

    private static <T> StreamId forwardChannel(DataflowContext dataflowContext, Class<T> cls, Partition partition, Partition partition2, StreamId streamId, int i, int i2) {
        DataflowGraph graph = dataflowContext.getGraph();
        NodeUpload nodeUpload = new NodeUpload(i, cls, streamId);
        NodeDownload nodeDownload = new NodeDownload(i2, cls, partition.getAddress(), streamId);
        graph.addNode(partition, nodeUpload);
        graph.addNode(partition2, nodeDownload);
        return nodeDownload.getOutput();
    }

    public static int[] generateIndexes(DataflowContext dataflowContext, int i) {
        Objects.requireNonNull(dataflowContext);
        return IntStream.generate(dataflowContext::generateNodeIndex).limit(i).toArray();
    }
}
