package io.activej.dataflow.dataset;

import io.activej.common.Utils;
import io.activej.dataflow.graph.DataflowContext;
import io.activej.dataflow.graph.DataflowGraph;
import io.activej.dataflow.graph.Partition;
import io.activej.dataflow.graph.StreamId;
import io.activej.dataflow.graph.StreamSchema;
import io.activej.dataflow.node.Node;
import io.activej.dataflow.node.Nodes;
import io.activej.dataflow.node.impl.Reduce;
import io.activej.dataflow.node.impl.Shard;
import io.activej.datastream.processor.reducer.Reducer;
import io.activej.datastream.processor.reducer.Reducers;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.IntStream;

/* loaded from: input_file:io/activej/dataflow/dataset/DatasetUtils.class */
public class DatasetUtils {
    public static <K, I, O, A> List<StreamId> repartitionAndReduce(DataflowContext dataflowContext, List<StreamId> list, StreamSchema<I> streamSchema, Function<I, K> function, Comparator<K> comparator, Reducer<K, I, O, A> reducer, List<Partition> list2) {
        DataflowGraph graph = dataflowContext.getGraph();
        int nonce = dataflowContext.getNonce();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int generateNodeIndex = dataflowContext.generateNodeIndex();
        for (StreamId streamId : list) {
            Partition partition = graph.getPartition(streamId);
            Shard create = Shard.create(generateNodeIndex, function, streamId, nonce);
            graph.addNode(partition, create);
            arrayList2.add(create);
        }
        int generateNodeIndex2 = dataflowContext.generateNodeIndex();
        int[] generateIndexes = generateIndexes(dataflowContext, arrayList2.size());
        int[] generateIndexes2 = generateIndexes(dataflowContext, list2.size());
        for (int i = 0; i < list2.size(); i++) {
            Partition partition2 = list2.get(i);
            Reduce.Builder builder = Reduce.builder(generateNodeIndex2, comparator);
            for (int i2 = 0; i2 < arrayList2.size(); i2++) {
                Shard shard = (Shard) arrayList2.get(i2);
                StreamId newPartition = shard.newPartition();
                graph.addNodeStream(shard, newPartition);
                builder.withInput(forwardChannel(dataflowContext, streamSchema, newPartition, partition2, generateIndexes2[i], generateIndexes[i2]), function, reducer);
            }
            Reduce reduce = (Reduce) builder.build();
            graph.addNode(partition2, reduce);
            arrayList.add(reduce.output);
        }
        return arrayList;
    }

    public static <K, I, O, A> List<StreamId> repartitionAndReduce(DataflowContext dataflowContext, LocallySortedDataset<K, I> locallySortedDataset, Reducer<K, I, O, A> reducer, List<Partition> list) {
        return repartitionAndReduce(dataflowContext, locallySortedDataset.channels(dataflowContext.withoutFixedNonce()), locallySortedDataset.streamSchema(), locallySortedDataset.keyFunction(), locallySortedDataset.keyComparator(), reducer, list);
    }

    public static <K, T> List<StreamId> repartitionAndSort(DataflowContext dataflowContext, LocallySortedDataset<K, T> locallySortedDataset, List<Partition> list) {
        return repartitionAndReduce(dataflowContext, locallySortedDataset, Reducers.mergeReducer(), list);
    }

    public static <T, K> List<StreamId> repartition(DataflowContext dataflowContext, List<StreamId> list, StreamSchema<T> streamSchema, Function<T, K> function, List<Partition> list2) {
        DataflowGraph graph = dataflowContext.getGraph();
        int nonce = dataflowContext.getNonce();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int generateNodeIndex = dataflowContext.generateNodeIndex();
        for (StreamId streamId : list) {
            Partition partition = graph.getPartition(streamId);
            Shard create = Shard.create(generateNodeIndex, function, streamId, nonce);
            graph.addNode(partition, create);
            arrayList2.add(create);
        }
        int generateNodeIndex2 = dataflowContext.generateNodeIndex();
        int[] generateIndexes = generateIndexes(dataflowContext, arrayList2.size());
        int[] generateIndexes2 = generateIndexes(dataflowContext, list2.size());
        for (int i = 0; i < list2.size(); i++) {
            Partition partition2 = list2.get(i);
            ArrayList arrayList3 = new ArrayList();
            for (int i2 = 0; i2 < arrayList2.size(); i2++) {
                Shard shard = (Shard) arrayList2.get(i2);
                StreamId newPartition = shard.newPartition();
                graph.addNodeStream(shard, newPartition);
                arrayList3.add(forwardChannel(dataflowContext, streamSchema, newPartition, partition2, generateIndexes2[i], generateIndexes[i2]));
            }
            Node union = Nodes.union(generateNodeIndex2, arrayList3);
            graph.addNode(partition2, union);
            arrayList.addAll(union.getOutputs());
        }
        return arrayList;
    }

    public static <T> StreamId forwardChannel(DataflowContext dataflowContext, StreamSchema<T> streamSchema, StreamId streamId, Partition partition, int i, int i2) {
        return forwardChannel(dataflowContext, streamSchema, dataflowContext.getGraph().getPartition(streamId), partition, streamId, i, i2);
    }

    private static <T> StreamId forwardChannel(DataflowContext dataflowContext, StreamSchema<T> streamSchema, Partition partition, Partition partition2, StreamId streamId, int i, int i2) {
        if (partition == partition2) {
            return streamId;
        }
        DataflowGraph graph = dataflowContext.getGraph();
        Node upload = Nodes.upload(i, streamSchema, streamId);
        Node download = Nodes.download(i2, streamSchema, partition.address(), streamId);
        graph.addNode(partition, upload);
        graph.addNode(partition2, download);
        return (StreamId) Utils.first(download.getOutputs());
    }

    public static int[] generateIndexes(DataflowContext dataflowContext, int i) {
        Objects.requireNonNull(dataflowContext);
        return IntStream.generate(dataflowContext::generateNodeIndex).limit(i).toArray();
    }

    public static Collection<StreamId> limitStream(DataflowGraph dataflowGraph, int i, long j, StreamId streamId) {
        Node offsetLimit = Nodes.offsetLimit(i, 0L, j, streamId);
        dataflowGraph.addNode(dataflowGraph.getPartition(streamId), offsetLimit);
        return offsetLimit.getOutputs();
    }

    public static List<StreamId> offsetLimit(DataflowContext dataflowContext, List<StreamId> list, long j, long j2, BiFunction<List<StreamId>, Partition, StreamId> biFunction) {
        if (j == 0 && j2 == -1) {
            return list;
        }
        DataflowGraph graph = dataflowContext.getGraph();
        if (list.isEmpty()) {
            return list;
        }
        if (list.size() == 1) {
            return toOutput(graph, dataflowContext.generateNodeIndex(), list.get(0), j, j2);
        }
        if (j2 != -1) {
            ArrayList arrayList = new ArrayList(list.size());
            Iterator<StreamId> it = list.iterator();
            while (it.hasNext()) {
                arrayList.addAll(limitStream(graph, dataflowContext.generateNodeIndex(), j + j2, it.next()));
            }
            list = arrayList;
        }
        return toOutput(graph, dataflowContext.generateNodeIndex(), biFunction.apply(list, graph.getPartition(list.get(Math.abs(dataflowContext.getNonce()) % list.size()))), j, j2);
    }

    private static List<StreamId> toOutput(DataflowGraph dataflowGraph, int i, StreamId streamId, long j, long j2) {
        Node offsetLimit = Nodes.offsetLimit(i, j, j2, streamId);
        dataflowGraph.addNode(dataflowGraph.getPartition(streamId), offsetLimit);
        return List.copyOf(offsetLimit.getOutputs());
    }
}
