package de.viadee.xai.anchor.algorithm.exploration;

import de.viadee.xai.anchor.algorithm.AnchorCandidate;
import de.viadee.xai.anchor.algorithm.execution.SamplingService;
import de.viadee.xai.anchor.algorithm.util.ParameterValidation;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/* loaded from: input_file:de/viadee/xai/anchor/algorithm/exploration/BatchSAR.class */
public class BatchSAR extends AbstractBRAlgorithm {
    private static final long serialVersionUID = 6756864957771948578L;
    private final int batchBudget;
    private final int nn;

    public BatchSAR(int i, int i2) {
        this(i, i, i2);
    }

    public BatchSAR(int i, int i2, int i3) {
        super(i, i2);
        if (!ParameterValidation.isUnsigned(Integer.valueOf(i3))) {
            throw new IllegalArgumentException("Batch budget must not be negative");
        }
        this.batchBudget = i3;
        this.nn = Math.max(ceil(i / i2), 2);
    }

    private static int ceil(double d) {
        return (int) Math.ceil(d);
    }

    @Override // de.viadee.xai.anchor.algorithm.exploration.BestAnchorIdentification
    public List<AnchorCandidate> identify(List<AnchorCandidate> list, SamplingService samplingService, int i, double d, double d2, int i2) {
        int size = list.size();
        int min = this.b + (this.nn * Math.min(this.r, ceil(this.b / 2.0d))) + size;
        HashSet hashSet = new HashSet(list);
        HashSet hashSet2 = new HashSet();
        int i3 = 1;
        while (true) {
            if (i3 > (size - this.nn) + 1) {
                break;
            }
            batchSample(hashSet, samplingService, i, calculateM(size, min, i3));
            int size2 = i2 - hashSet2.size();
            if (i3 <= size - this.nn) {
                List list2 = (List) hashSet.stream().sorted(Comparator.comparingDouble((v0) -> {
                    return v0.getPrecision();
                }).reversed()).collect(Collectors.toList());
                AnchorCandidate anchorCandidate = ((AnchorCandidate) list2.get(0)).getPrecision() - ((AnchorCandidate) list2.get(size2)).getPrecision() >= ((AnchorCandidate) list2.get(size2 - 1)).getPrecision() - ((AnchorCandidate) list2.get(size - i3)).getPrecision() ? (AnchorCandidate) list2.get(0) : (AnchorCandidate) list2.get(size - i3);
                hashSet.remove(anchorCandidate);
                if (anchorCandidate == list2.get(0)) {
                    hashSet2.add(anchorCandidate);
                }
                if (hashSet.size() == i2 - hashSet2.size()) {
                    hashSet2.addAll(hashSet);
                    break;
                }
                if (hashSet2.size() == i2) {
                    break;
                }
            } else {
                hashSet2.addAll((List) hashSet.stream().sorted(Comparator.comparingDouble((v0) -> {
                    return v0.getPrecision();
                }).reversed()).limit(size2).collect(Collectors.toList()));
            }
            i3++;
        }
        return new ArrayList(hashSet2);
    }

    private int calculateM(int i, int i2, int i3) {
        return ceil(((((this.b * this.batchBudget) - IntStream.rangeClosed(this.nn + 1, i).map(i4 -> {
            return ceil(this.b / i4);
        }).sum()) - i2) / ((this.nn / 2.0d) + IntStream.rangeClosed(this.nn + 1, i).mapToDouble(i5 -> {
            return 1.0d / i5;
        }).sum())) * (1.0d / (i3 <= i - this.nn ? (i - i3) + 1 : 2.0d)));
    }
}
