package edu.iu.dsc.tws.comms.table;

import edu.iu.dsc.tws.api.comms.LogicalPlan;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.resource.IWorkerController;
import edu.iu.dsc.tws.common.table.ArrowColumn;
import edu.iu.dsc.tws.common.table.Table;
import edu.iu.dsc.tws.common.table.arrow.ArrowTable;
import edu.iu.dsc.tws.common.table.arrow.BinaryColumn;
import edu.iu.dsc.tws.common.table.arrow.Float4Column;
import edu.iu.dsc.tws.common.table.arrow.Float8Column;
import edu.iu.dsc.tws.common.table.arrow.Int4Column;
import edu.iu.dsc.tws.common.table.arrow.Int8Column;
import edu.iu.dsc.tws.common.table.arrow.StringColumn;
import edu.iu.dsc.tws.common.table.arrow.UInt2Column;
import edu.iu.dsc.tws.comms.table.channel.ChannelBuffer;
import edu.iu.dsc.tws.comms.utils.TaskPlanUtils;
import io.netty.buffer.ArrowBuf;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.BaseFixedWidthVector;
import org.apache.arrow.vector.BaseVariableWidthVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.TypeLayout;
import org.apache.arrow.vector.UInt2Vector;
import org.apache.arrow.vector.UInt8Vector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;

/* loaded from: input_file:edu/iu/dsc/tws/comms/table/ArrowAllToAll.class */
public class ArrowAllToAll implements ReceiveCallback {
    private static final Logger LOG;
    private static final int HEADER_SIZE = 6;
    private List<Integer> targets;
    private List<Integer> srcs;
    private List<Integer> sourceWorkerList;
    private SimpleAllToAll all;
    private ArrowCallback recvCallback;
    private int receivedBuffers;
    private int workerId;
    private Set<Integer> sourcesOfThisWorker;
    private Schema schema;
    private RootAllocator allocator;
    static final /* synthetic */ boolean $assertionsDisabled;
    private Map<Integer, PendingSendTable> inputs = new HashMap();
    private Map<Integer, PendingReceiveTable> receives = new HashMap();
    private boolean finished = false;
    private List<Integer> finishedSources = new ArrayList();
    private Map<Integer, Integer> targetToWorker = new HashMap();
    private List<Integer> finishedCalledSources = new ArrayList();
    private boolean completed = false;
    private boolean finishedSent = false;

    /* loaded from: input_file:edu/iu/dsc/tws/comms/table/ArrowAllToAll$ArrowHeader.class */
    private enum ArrowHeader {
        HEADER_INIT,
        COLUMN_CONTINUE
    }

    /* loaded from: input_file:edu/iu/dsc/tws/comms/table/ArrowAllToAll$PendingReceiveTable.class */
    private class PendingReceiveTable {
        private int source;
        private int columnIndex;
        private int bufferIndex;
        private int noBuffers;
        private int noArray;
        private int length;
        private int target;
        private VectorSchemaRoot root;
        private List<ArrowBuf> buffers;
        private List<ArrowFieldNode> fieldNodes;
        private List<FieldVector> arrays;

        private PendingReceiveTable() {
            this.buffers = new ArrayList();
            this.fieldNodes = new ArrayList();
            this.arrays = new ArrayList();
        }

        public void clear() {
            this.source = 0;
            this.columnIndex = 0;
            this.bufferIndex = 0;
            this.noBuffers = 0;
            this.noArray = 0;
            this.length = 0;
            this.target = 0;
            this.buffers.clear();
            this.fieldNodes.clear();
            this.arrays.clear();
            this.root = VectorSchemaRoot.create(ArrowAllToAll.this.schema, ArrowAllToAll.this.allocator);
        }
    }

    /* loaded from: input_file:edu/iu/dsc/tws/comms/table/ArrowAllToAll$PendingSendTable.class */
    private class PendingSendTable {
        private int source;
        private Queue<Table> pending;
        private Queue<Integer> target;
        private Table currentTable;
        private int currentTarget;
        private ArrowHeader status;
        private int columnIndex;
        private int bufferIndex;

        private PendingSendTable() {
            this.pending = new LinkedList();
            this.target = new LinkedList();
            this.status = ArrowHeader.HEADER_INIT;
        }

        static /* synthetic */ int access$1108(PendingSendTable pendingSendTable) {
            int i = pendingSendTable.bufferIndex;
            pendingSendTable.bufferIndex = i + 1;
            return i;
        }

        static /* synthetic */ int access$1008(PendingSendTable pendingSendTable) {
            int i = pendingSendTable.columnIndex;
            pendingSendTable.columnIndex = i + 1;
            return i;
        }
    }

    public ArrowAllToAll(Config config, IWorkerController iWorkerController, Set<Integer> set, Set<Integer> set2, LogicalPlan logicalPlan, int i, ArrowCallback arrowCallback, Schema schema, RootAllocator rootAllocator) {
        this.targets = new ArrayList(set2);
        this.srcs = new ArrayList(set);
        this.workerId = iWorkerController.getWorkerInfo().getWorkerID();
        this.recvCallback = arrowCallback;
        HashSet hashSet = new HashSet();
        Iterator<Integer> it = this.targets.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            int workerForForLogicalId = logicalPlan.getWorkerForForLogicalId(intValue);
            this.targetToWorker.put(Integer.valueOf(intValue), Integer.valueOf(workerForForLogicalId));
            hashSet.add(Integer.valueOf(workerForForLogicalId));
        }
        ArrayList arrayList = new ArrayList(hashSet);
        Iterator it2 = hashSet.iterator();
        while (it2.hasNext()) {
            this.inputs.put(Integer.valueOf(((Integer) it2.next()).intValue()), new PendingSendTable());
        }
        HashSet hashSet2 = new HashSet();
        Iterator<Integer> it3 = this.srcs.iterator();
        while (it3.hasNext()) {
            hashSet2.add(Integer.valueOf(logicalPlan.getWorkerForForLogicalId(it3.next().intValue())));
        }
        this.sourceWorkerList = new ArrayList(hashSet2);
        Iterator it4 = hashSet2.iterator();
        while (it4.hasNext()) {
            int intValue2 = ((Integer) it4.next()).intValue();
            PendingReceiveTable pendingReceiveTable = new PendingReceiveTable();
            this.receives.put(Integer.valueOf(intValue2), pendingReceiveTable);
            pendingReceiveTable.root = VectorSchemaRoot.create(schema, rootAllocator);
        }
        this.sourcesOfThisWorker = TaskPlanUtils.getTasksOfThisWorker(logicalPlan, set);
        this.schema = schema;
        this.allocator = rootAllocator;
        this.all = new SimpleAllToAll(config, iWorkerController, this.sourceWorkerList, arrayList, i, this, new ArrowAllocator(rootAllocator));
    }

    public boolean insert(Table table, int i) {
        PendingSendTable pendingSendTable = this.inputs.get(Integer.valueOf(this.targetToWorker.get(Integer.valueOf(i)).intValue()));
        pendingSendTable.pending.offer(table);
        pendingSendTable.target.offer(Integer.valueOf(i));
        return true;
    }

    public boolean isComplete() {
        if (this.completed) {
            return true;
        }
        boolean z = true;
        for (Map.Entry<Integer, PendingSendTable> entry : this.inputs.entrySet()) {
            PendingSendTable value = entry.getValue();
            if (value.status == ArrowHeader.HEADER_INIT && !value.pending.isEmpty()) {
                value.currentTable = (Table) value.pending.poll();
                if (!$assertionsDisabled && value.target.isEmpty()) {
                    throw new AssertionError();
                }
                value.currentTarget = ((Integer) value.target.poll()).intValue();
                value.status = ArrowHeader.COLUMN_CONTINUE;
            }
            if (value.status == ArrowHeader.COLUMN_CONTINUE) {
                int size = value.currentTable.getColumns().size();
                boolean z2 = true;
                while (value.columnIndex < size && z2) {
                    FieldVector vector = ((ArrowColumn) value.currentTable.getColumns().get(value.columnIndex)).getVector();
                    ArrayList arrayList = new ArrayList();
                    ArrayList arrayList2 = new ArrayList();
                    appendNodes(vector, arrayList, arrayList2);
                    while (true) {
                        if (value.bufferIndex >= arrayList2.size()) {
                            break;
                        }
                        ArrowBuf arrowBuf = arrayList2.get(value.bufferIndex);
                        int capacity = (int) arrowBuf.capacity();
                        if (!this.all.insert(arrowBuf.nioBuffer(), capacity, new int[]{value.columnIndex, value.bufferIndex, arrayList2.size(), vector.getValueCount(), capacity, value.currentTarget}, 6, entry.getKey().intValue())) {
                            z2 = false;
                            break;
                        }
                        PendingSendTable.access$1108(value);
                    }
                    if (z2) {
                        value.bufferIndex = 0;
                        PendingSendTable.access$1008(value);
                    }
                }
                if (z2) {
                    value.columnIndex = 0;
                    value.bufferIndex = 0;
                    value.status = ArrowHeader.HEADER_INIT;
                }
            }
            if (!value.pending.isEmpty() || value.status == ArrowHeader.COLUMN_CONTINUE) {
                z = false;
            }
        }
        if (z && this.finished && !this.finishedSent) {
            this.all.finish();
            this.finishedSent = true;
        }
        boolean z3 = z && this.all.isComplete() && this.finishedSources.size() == this.sourceWorkerList.size();
        if (z3) {
            this.completed = true;
        }
        return z3;
    }

    public void finish() {
        this.finished = true;
    }

    public void finish(int i) {
        this.finishedCalledSources.add(Integer.valueOf(i));
        if (this.finishedCalledSources.size() == this.sourcesOfThisWorker.size()) {
            this.finished = true;
        }
    }

    public void close() {
        this.inputs.clear();
        this.all.close();
    }

    @Override // edu.iu.dsc.tws.comms.table.ReceiveCallback
    public void onReceive(int i, ChannelBuffer channelBuffer, int i2) {
        Int4Column binaryColumn;
        PendingReceiveTable pendingReceiveTable = this.receives.get(Integer.valueOf(i));
        this.receivedBuffers++;
        pendingReceiveTable.buffers.add(((ArrowChannelBuffer) channelBuffer).getArrowBuf());
        if (pendingReceiveTable.bufferIndex == 0) {
            pendingReceiveTable.fieldNodes.add(new ArrowFieldNode(pendingReceiveTable.noArray, 0L));
        }
        VectorSchemaRoot vectorSchemaRoot = pendingReceiveTable.root;
        List<IntVector> fieldVectors = vectorSchemaRoot.getFieldVectors();
        if (pendingReceiveTable.noBuffers == pendingReceiveTable.bufferIndex + 1) {
            FieldVector fieldVector = (FieldVector) fieldVectors.get(pendingReceiveTable.columnIndex);
            loadBuffers(fieldVector, fieldVector.getField(), pendingReceiveTable.buffers.iterator(), pendingReceiveTable.fieldNodes.iterator());
            pendingReceiveTable.arrays.add(fieldVector);
            pendingReceiveTable.buffers.clear();
            if (pendingReceiveTable.arrays.size() == vectorSchemaRoot.getFieldVectors().size()) {
                ArrayList arrayList = new ArrayList();
                for (IntVector intVector : fieldVectors) {
                    if (!(intVector instanceof BaseFixedWidthVector)) {
                        if (!(intVector instanceof BaseVariableWidthVector)) {
                            throw new RuntimeException("Un-supported type : " + intVector.getClass().getName());
                        }
                        if (intVector instanceof VarCharVector) {
                            binaryColumn = new StringColumn((VarCharVector) intVector);
                        } else {
                            if (!(intVector instanceof VarBinaryVector)) {
                                throw new RuntimeException("Un-supported type : " + intVector.getClass().getName());
                            }
                            binaryColumn = new BinaryColumn((VarBinaryVector) intVector);
                        }
                    } else if (intVector instanceof IntVector) {
                        binaryColumn = new Int4Column(intVector);
                    } else if (intVector instanceof Float4Vector) {
                        binaryColumn = new Float4Column((Float4Vector) intVector);
                    } else if (intVector instanceof Float8Vector) {
                        binaryColumn = new Float8Column((Float8Vector) intVector);
                    } else if (intVector instanceof UInt8Vector) {
                        binaryColumn = new Int8Column((UInt8Vector) intVector);
                    } else {
                        if (!(intVector instanceof UInt2Vector)) {
                            throw new RuntimeException("Un-supported type : " + intVector.getClass().getName());
                        }
                        binaryColumn = new UInt2Column((UInt2Vector) intVector);
                    }
                    arrayList.add(binaryColumn);
                }
                Table arrowTable = new ArrowTable(vectorSchemaRoot.getSchema(), pendingReceiveTable.noArray, arrayList);
                LOG.info("Received table from source " + i + " to " + pendingReceiveTable.target + " count" + arrowTable.rowCount());
                this.recvCallback.onReceive(i, pendingReceiveTable.target, arrowTable);
                pendingReceiveTable.clear();
            }
        }
    }

    @Override // edu.iu.dsc.tws.comms.table.ReceiveCallback
    public void onReceiveHeader(int i, boolean z, int[] iArr, int i2) {
        if (z) {
            this.finishedSources.add(Integer.valueOf(i));
            return;
        }
        if (i2 != 6) {
            String str = "Incorrect length on header, expected 6 ints got " + i2;
            LOG.log(Level.SEVERE, str);
            throw new RuntimeException(str);
        }
        PendingReceiveTable pendingReceiveTable = this.receives.get(Integer.valueOf(i));
        pendingReceiveTable.columnIndex = iArr[0];
        pendingReceiveTable.bufferIndex = iArr[1];
        pendingReceiveTable.noBuffers = iArr[2];
        pendingReceiveTable.noArray = iArr[3];
        pendingReceiveTable.length = iArr[4];
        pendingReceiveTable.target = iArr[5];
    }

    @Override // edu.iu.dsc.tws.comms.table.ReceiveCallback
    public boolean onSendComplete(int i, ByteBuffer byteBuffer, int i2) {
        return false;
    }

    private void loadBuffers(FieldVector fieldVector, Field field, Iterator<ArrowBuf> it, Iterator<ArrowFieldNode> it2) {
        Preconditions.checkArgument(it2.hasNext(), "no more field nodes for for field %s and vector %s", field, fieldVector);
        ArrowFieldNode next = it2.next();
        int typeBufferCount = TypeLayout.getTypeBufferCount(field.getType());
        ArrayList arrayList = new ArrayList(typeBufferCount);
        for (int i = 0; i < typeBufferCount; i++) {
            arrayList.add(it.next());
        }
        try {
            fieldVector.loadFieldBuffers(next, arrayList);
            List children = field.getChildren();
            if (children.size() > 0) {
                List childrenFromFields = fieldVector.getChildrenFromFields();
                Preconditions.checkArgument(children.size() == childrenFromFields.size(), "should have as many children as in the schema: found %s expected %s", childrenFromFields.size(), children.size());
                for (int i2 = 0; i2 < childrenFromFields.size(); i2++) {
                    loadBuffers((FieldVector) childrenFromFields.get(i2), (Field) children.get(i2), it, it2);
                }
            }
        } catch (RuntimeException e) {
            throw new IllegalArgumentException("Could not load buffers for field " + field + ". error message: " + e.getMessage(), e);
        }
    }

    private void appendNodes(FieldVector fieldVector, List<ArrowFieldNode> list, List<ArrowBuf> list2) {
        list.add(new ArrowFieldNode(fieldVector.getValueCount(), 0L));
        List fieldBuffers = fieldVector.getFieldBuffers();
        if (fieldBuffers.size() != TypeLayout.getTypeBufferCount(fieldVector.getField().getType())) {
            throw new IllegalArgumentException(String.format("wrong number of buffers for field %s in vector %s. found: %s", fieldVector.getField(), fieldVector.getClass().getSimpleName(), fieldBuffers));
        }
        list2.addAll(fieldBuffers);
        Iterator it = fieldVector.getChildrenFromFields().iterator();
        while (it.hasNext()) {
            appendNodes((FieldVector) it.next(), list, list2);
        }
    }

    static {
        $assertionsDisabled = !ArrowAllToAll.class.desiredAssertionStatus();
        LOG = Logger.getLogger(ArrowAllToAll.class.getName());
    }
}
