package edu.iu.dsc.tws.comms.dfw.io.join;

import edu.iu.dsc.tws.api.comms.BulkReceiver;
import edu.iu.dsc.tws.api.comms.CommunicationContext;
import edu.iu.dsc.tws.api.comms.messaging.types.MessageType;
import edu.iu.dsc.tws.api.comms.structs.Tuple;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.exceptions.Twister2RuntimeException;
import edu.iu.dsc.tws.comms.shuffle.ResettableIterator;
import edu.iu.dsc.tws.comms.shuffle.RestorableIterator;
import edu.iu.dsc.tws.comms.utils.HashJoinUtils;
import edu.iu.dsc.tws.comms.utils.JoinRelation;
import edu.iu.dsc.tws.comms.utils.KeyComparatorWrapper;
import edu.iu.dsc.tws.comms.utils.SortJoinUtils;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/comms/dfw/io/join/JoinBatchCombinedReceiver.class */
public class JoinBatchCombinedReceiver {
    private static final Logger LOG = Logger.getLogger(JoinBatchCombinedReceiver.class.getName());
    private Map<Integer, boolean[]> syncCounts = new HashMap();
    private Map<Integer, Object[]> joinRelations = new HashMap();
    private BulkReceiver rcvr;
    private final CommunicationContext.JoinAlgorithm algorithm;
    private final CommunicationContext.JoinType joinType;
    private KeyComparatorWrapper keyComparator;
    private MessageType keyType;
    private Config config;

    public JoinBatchCombinedReceiver(BulkReceiver bulkReceiver, CommunicationContext.JoinAlgorithm joinAlgorithm, CommunicationContext.JoinType joinType, KeyComparatorWrapper keyComparatorWrapper, MessageType messageType) {
        this.rcvr = bulkReceiver;
        this.algorithm = joinAlgorithm;
        this.joinType = joinType;
        this.keyComparator = keyComparatorWrapper;
        this.keyType = messageType;
    }

    public void init(Config config, Set<Integer> set) {
        this.config = config;
        for (Integer num : set) {
            boolean[] zArr = new boolean[JoinRelation.values().length];
            Arrays.fill(zArr, false);
            this.syncCounts.put(num, zArr);
            Object[] objArr = new Object[JoinRelation.values().length];
            Arrays.fill(objArr, (Object) null);
            this.joinRelations.put(num, objArr);
        }
    }

    private Iterator doJoin(Object obj, Object obj2) {
        if (this.algorithm.equals(CommunicationContext.JoinAlgorithm.SORT)) {
            if (obj instanceof RestorableIterator) {
                return this.config.getBooleanValue(SortJoinUtils.CONFIG_USE_SORT_JOIN_CACHE, true).booleanValue() ? SortJoinUtils.joinWithCache((RestorableIterator) obj, (RestorableIterator) obj2, this.keyComparator, this.joinType, this.config) : SortJoinUtils.join((RestorableIterator<Tuple<?, ?>>) obj, (RestorableIterator<Tuple<?, ?>>) obj2, this.keyComparator, this.joinType);
            }
            if (obj instanceof List) {
                return SortJoinUtils.join((List<Tuple>) obj, (List<Tuple>) obj2, this.keyComparator, this.joinType);
            }
            throw new Twister2RuntimeException("Unsupported data formats received from sources : " + obj.getClass());
        }
        if (obj instanceof ResettableIterator) {
            return HashJoinUtils.join((ResettableIterator<Tuple<?, ?>>) obj, (ResettableIterator<Tuple<?, ?>>) obj2, this.joinType, this.keyType);
        }
        if (obj instanceof List) {
            return HashJoinUtils.join((List<Tuple>) obj, (List<Tuple>) obj2, this.joinType, this.keyType);
        }
        throw new Twister2RuntimeException("Unsupported data formats received from sources");
    }

    public boolean receive(int i, Object obj, JoinRelation joinRelation) {
        Object[] objArr = this.joinRelations.get(Integer.valueOf(i));
        objArr[joinRelation.ordinal()] = obj;
        if (Arrays.stream(objArr).filter(Objects::nonNull).count() != JoinRelation.values().length) {
            return true;
        }
        long currentTimeMillis = System.currentTimeMillis();
        this.rcvr.receive(i, doJoin(objArr[JoinRelation.LEFT.ordinal()], objArr[JoinRelation.RIGHT.ordinal()]));
        LOG.info("Join time : " + (System.currentTimeMillis() - currentTimeMillis));
        Arrays.fill(objArr, (Object) null);
        return true;
    }

    private boolean isAllSynced(int i, JoinRelation joinRelation) {
        boolean[] zArr = this.syncCounts.get(Integer.valueOf(i));
        zArr[joinRelation.ordinal()] = true;
        for (JoinRelation joinRelation2 : JoinRelation.values()) {
            if (!zArr[joinRelation2.ordinal()]) {
                return false;
            }
        }
        Arrays.fill(zArr, false);
        return true;
    }

    public boolean sync(int i, byte[] bArr, JoinRelation joinRelation) {
        if (!isAllSynced(i, joinRelation)) {
            return true;
        }
        this.rcvr.sync(i, bArr);
        return true;
    }
}
