package io.activej.dataflow.dataset.impl;

import io.activej.dataflow.dataset.Dataset;
import io.activej.dataflow.dataset.DatasetUtils;
import io.activej.dataflow.graph.DataflowContext;
import io.activej.dataflow.graph.DataflowGraph;
import io.activej.dataflow.graph.Partition;
import io.activej.dataflow.graph.StreamId;
import io.activej.dataflow.graph.StreamSchema;
import io.activej.dataflow.node.Node;
import io.activej.dataflow.node.Nodes;
import io.activej.dataflow.node.impl.Reduce;
import io.activej.dataflow.node.impl.ReduceSimple;
import io.activej.dataflow.node.impl.Shard;
import io.activej.datastream.processor.reducer.ReducerToResult;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;

/* loaded from: input_file:io/activej/dataflow/dataset/impl/SplitSortReduceRepartitionReduce.class */
public final class SplitSortReduceRepartitionReduce<K, I, O, A> extends Dataset<O> {
    public final Dataset<I> input;
    public final Function<I, K> inputKeyFunction;
    public final Function<A, K> accumulatorKeyFunction;
    public final Comparator<K> keyComparator;
    public final ReducerToResult<K, I, O, A> reducer;
    public final StreamSchema<A> accumulatorStreamSchema;
    public final int sortBufferSize;

    public SplitSortReduceRepartitionReduce(Dataset<I> dataset, Function<I, K> function, Function<A, K> function2, Comparator<K> comparator, ReducerToResult<K, I, O, A> reducerToResult, StreamSchema<O> streamSchema, StreamSchema<A> streamSchema2, int i) {
        super(streamSchema);
        this.input = dataset;
        this.inputKeyFunction = function;
        this.accumulatorKeyFunction = function2;
        this.keyComparator = comparator;
        this.reducer = reducerToResult;
        this.accumulatorStreamSchema = streamSchema2;
        this.sortBufferSize = i;
    }

    @Override // io.activej.dataflow.dataset.Dataset
    public List<StreamId> channels(DataflowContext dataflowContext) {
        DataflowGraph graph = dataflowContext.getGraph();
        int nonce = dataflowContext.getNonce();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int generateNodeIndex = dataflowContext.generateNodeIndex();
        for (StreamId streamId : this.input.channels(dataflowContext.withoutFixedNonce())) {
            Partition partition = graph.getPartition(streamId);
            Shard create = Shard.create(generateNodeIndex, this.inputKeyFunction, streamId, nonce);
            graph.addNode(partition, create);
            arrayList2.add(create);
        }
        int generateNodeIndex2 = dataflowContext.generateNodeIndex();
        List<Partition> availablePartitions = graph.getAvailablePartitions();
        int[] generateIndexes = DatasetUtils.generateIndexes(dataflowContext, arrayList2.size());
        int[] generateIndexes2 = DatasetUtils.generateIndexes(dataflowContext, availablePartitions.size());
        for (int i = 0; i < availablePartitions.size(); i++) {
            Partition partition2 = availablePartitions.get(i);
            Reduce.Builder builder = Reduce.builder(generateNodeIndex2, this.keyComparator);
            int generateNodeIndex3 = dataflowContext.generateNodeIndex();
            int generateNodeIndex4 = dataflowContext.generateNodeIndex();
            for (int i2 = 0; i2 < arrayList2.size(); i2++) {
                Shard shard = (Shard) arrayList2.get(i2);
                StreamId newPartition = shard.newPartition();
                graph.addNodeStream(shard, newPartition);
                builder.withInput(sortReduceForward(dataflowContext, newPartition, partition2, generateNodeIndex3, generateNodeIndex4, generateIndexes[i], generateIndexes2[i2]), this.accumulatorKeyFunction, this.reducer.accumulatorToOutput());
            }
            Reduce reduce = (Reduce) builder.build();
            graph.addNode(partition2, reduce);
            arrayList.add(reduce.output);
        }
        return arrayList;
    }

    private StreamId sortReduceForward(DataflowContext dataflowContext, StreamId streamId, Partition partition, int i, int i2, int i3, int i4) {
        DataflowGraph graph = dataflowContext.getGraph();
        Partition partition2 = graph.getPartition(streamId);
        Node sort = Nodes.sort(i, this.input.streamSchema(), this.inputKeyFunction, this.keyComparator, false, this.sortBufferSize, streamId);
        graph.addNode(partition2, sort);
        ReduceSimple.Builder builder = ReduceSimple.builder(i2, this.inputKeyFunction, this.keyComparator, this.reducer.inputToAccumulator());
        Iterator<StreamId> it = sort.getOutputs().iterator();
        while (it.hasNext()) {
            builder.withInput(it.next());
        }
        ReduceSimple reduceSimple = (ReduceSimple) builder.build();
        graph.addNode(partition2, reduceSimple);
        return DatasetUtils.forwardChannel(dataflowContext, this.accumulatorStreamSchema, reduceSimple.output, partition, i3, i4);
    }

    @Override // io.activej.dataflow.dataset.Dataset
    public Collection<Dataset<?>> getBases() {
        return List.of(this.input);
    }
}
