package io.activej.dataflow.dataset.impl;

import io.activej.dataflow.dataset.Dataset;
import io.activej.dataflow.dataset.DatasetUtils;
import io.activej.dataflow.graph.DataflowContext;
import io.activej.dataflow.graph.DataflowGraph;
import io.activej.dataflow.graph.Partition;
import io.activej.dataflow.graph.StreamId;
import io.activej.dataflow.node.NodeShard;
import io.activej.dataflow.node.NodeUnion;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import org.jetbrains.annotations.Nullable;

/* loaded from: input_file:io/activej/dataflow/dataset/impl/DatasetRepartition.class */
public final class DatasetRepartition<T, K> extends Dataset<T> {
    private final Dataset<T> input;
    private final Function<T, K> keyFunction;

    @Nullable
    private final List<Partition> partitions;

    public DatasetRepartition(Dataset<T> dataset, Function<T, K> function, @Nullable List<Partition> list) {
        super(dataset.valueType());
        this.input = dataset;
        this.keyFunction = function;
        this.partitions = list;
    }

    @Override // io.activej.dataflow.dataset.Dataset
    public List<StreamId> channels(DataflowContext dataflowContext) {
        DataflowGraph graph = dataflowContext.getGraph();
        List<Partition> availablePartitions = this.partitions == null ? graph.getAvailablePartitions() : this.partitions;
        int nonce = dataflowContext.getNonce();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int generateNodeIndex = dataflowContext.generateNodeIndex();
        for (StreamId streamId : this.input.channels(dataflowContext.withoutFixedNonce())) {
            Partition partition = graph.getPartition(streamId);
            NodeShard nodeShard = new NodeShard(generateNodeIndex, this.keyFunction, streamId, nonce);
            graph.addNode(partition, nodeShard);
            arrayList2.add(nodeShard);
        }
        int generateNodeIndex2 = dataflowContext.generateNodeIndex();
        int[] generateIndexes = DatasetUtils.generateIndexes(dataflowContext, arrayList2.size());
        int[] generateIndexes2 = DatasetUtils.generateIndexes(dataflowContext, availablePartitions.size());
        for (int i = 0; i < availablePartitions.size(); i++) {
            Partition partition2 = availablePartitions.get(i);
            ArrayList arrayList3 = new ArrayList();
            for (int i2 = 0; i2 < arrayList2.size(); i2++) {
                NodeShard nodeShard2 = (NodeShard) arrayList2.get(i2);
                StreamId newPartition = nodeShard2.newPartition();
                graph.addNodeStream(nodeShard2, newPartition);
                arrayList3.add(DatasetUtils.forwardChannel(dataflowContext, this.input.valueType(), newPartition, partition2, generateIndexes2[i], generateIndexes[i2]));
            }
            NodeUnion nodeUnion = new NodeUnion(generateNodeIndex2, arrayList3);
            graph.addNode(partition2, nodeUnion);
            arrayList.add(nodeUnion.getOutput());
        }
        return arrayList;
    }
}
