package de.viadee.xai.anchor.algorithm;

import de.viadee.xai.anchor.algorithm.DataInstance;
import de.viadee.xai.anchor.algorithm.coverage.CoverageIdentification;
import de.viadee.xai.anchor.algorithm.execution.SamplingService;
import de.viadee.xai.anchor.algorithm.execution.SamplingSession;
import de.viadee.xai.anchor.algorithm.exploration.BestAnchorIdentification;
import de.viadee.xai.anchor.algorithm.util.KLBernoulliUtils;
import de.viadee.xai.anchor.algorithm.util.ParameterValidation;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/viadee/xai/anchor/algorithm/AnchorConstruction.class */
public class AnchorConstruction<T extends DataInstance<?>> implements Serializable {
    private static final long serialVersionUID = -478521750533925027L;
    private static final Logger LOGGER = LoggerFactory.getLogger(AnchorConstruction.class);
    private final BestAnchorIdentification bestAnchorIdentification;
    private final CoverageIdentification coverageIdentification;
    private final T explainedInstance;
    private final int explainedInstanceLabel;
    private final int maxAnchorSize;
    private final int beamSize;
    private final double delta;
    private final double epsilon;
    private final double tau;
    private final double tauDiscrepancy;
    private final int initSampleCount;
    private final boolean lazyCoverageEvaluation;
    private final boolean allowSuboptimalSteps;
    private final SamplingService samplingService;

    /* JADX INFO: Access modifiers changed from: package-private */
    public AnchorConstruction(BestAnchorIdentification bestAnchorIdentification, CoverageIdentification coverageIdentification, SamplingService samplingService, T t, int i, int i2, int i3, double d, double d2, double d3, double d4, int i4, boolean z, boolean z2) {
        if (bestAnchorIdentification == null) {
            throw new IllegalArgumentException("Best anchor identification must not be null");
        }
        if (coverageIdentification == null) {
            throw new IllegalArgumentException("Coverage identification must not be null");
        }
        if (samplingService == null) {
            throw new IllegalArgumentException("Sampling service must not be null");
        }
        if (t == null) {
            throw new IllegalArgumentException("Explained instance must not be null");
        }
        if (!ParameterValidation.isUnsigned(Integer.valueOf(i))) {
            throw new IllegalArgumentException("Explained instance label must not be negative");
        }
        if (!ParameterValidation.isUnsigned(Integer.valueOf(i2))) {
            throw new IllegalArgumentException("Max anchor size must not be negative");
        }
        if (!ParameterValidation.isUnsigned(Integer.valueOf(i3))) {
            throw new IllegalArgumentException("Beam size must not be negative");
        }
        if (!ParameterValidation.isPercentage(Double.valueOf(d))) {
            throw new IllegalArgumentException("Delta value must be a value between 0 and 1");
        }
        if (!ParameterValidation.isPercentage(Double.valueOf(d2))) {
            throw new IllegalArgumentException("Epsilon value must be a value between 0 and 1");
        }
        if (!ParameterValidation.isPercentage(Double.valueOf(d3))) {
            throw new IllegalArgumentException("Tau value must be a value between 0 and 1");
        }
        if (!ParameterValidation.isPercentage(Double.valueOf(d4))) {
            throw new IllegalArgumentException("Tau discrepancy value must be a value between 0 and 1");
        }
        if (!ParameterValidation.isUnsigned(Integer.valueOf(i4))) {
            throw new IllegalArgumentException("Initialization sample count must not be negative");
        }
        this.bestAnchorIdentification = bestAnchorIdentification;
        this.coverageIdentification = coverageIdentification;
        this.explainedInstance = t;
        this.explainedInstanceLabel = i;
        this.maxAnchorSize = i2;
        this.beamSize = i3;
        this.delta = d;
        this.epsilon = d2;
        this.tau = d3;
        this.tauDiscrepancy = d4;
        this.initSampleCount = i4;
        this.lazyCoverageEvaluation = z;
        this.allowSuboptimalSteps = z2;
        this.samplingService = samplingService;
    }

    public SamplingService getSamplingService() {
        return this.samplingService;
    }

    private static String createKeyValueMap(Object... objArr) {
        String[] strArr = new String[objArr.length / 2];
        Iterator it = Arrays.asList(objArr).iterator();
        int i = 0;
        while (it.hasNext()) {
            int i2 = i;
            i++;
            strArr[i2] = it.next() + "=" + it.next();
        }
        return String.join(", ", strArr);
    }

    private List<AnchorCandidate> generateCandidateSet(List<AnchorCandidate> list, int i, double d) {
        ArrayList arrayList = new ArrayList();
        HashSet<AnchorCandidate> hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        for (Integer num : (List) IntStream.range(0, i).boxed().collect(Collectors.toList())) {
            if (list == null || list.isEmpty()) {
                hashSet.add(new AnchorCandidate(new LinkedHashSet(Collections.singletonList(num)), null));
            } else {
                for (AnchorCandidate anchorCandidate : list) {
                    if (!anchorCandidate.getCanonicalFeatures().contains(num)) {
                        LinkedHashSet linkedHashSet = new LinkedHashSet(anchorCandidate.getOrderedFeatures());
                        linkedHashSet.add(num);
                        HashSet hashSet3 = new HashSet(linkedHashSet);
                        if (!hashSet2.contains(hashSet3)) {
                            hashSet2.add(hashSet3);
                            hashSet.add(new AnchorCandidate(linkedHashSet, anchorCandidate));
                        }
                    }
                }
            }
        }
        for (AnchorCandidate anchorCandidate2 : hashSet) {
            if (!this.lazyCoverageEvaluation) {
                calculateCandidateCoverage(anchorCandidate2);
            }
            if (d > 0.0d) {
                if (this.lazyCoverageEvaluation) {
                    calculateCandidateCoverage(anchorCandidate2);
                }
                if (anchorCandidate2.getCoverage().doubleValue() < d) {
                }
            }
            arrayList.add(anchorCandidate2);
        }
        return arrayList;
    }

    private List<AnchorCandidate> bestCandidate(List<AnchorCandidate> list, int i) {
        SamplingSession createSession = this.samplingService.createSession(this.explainedInstanceLabel);
        for (AnchorCandidate anchorCandidate : list) {
            if (anchorCandidate.getSampledSize() < this.initSampleCount) {
                createSession.registerCandidateEvaluation(anchorCandidate, this.initSampleCount - anchorCandidate.getSampledSize());
            }
        }
        createSession.run();
        if (list.size() <= i) {
            LOGGER.debug("Number of arms searched for less or equals total number of features. Returning all candidates.");
            return new ArrayList(list);
        }
        LOGGER.debug("Calling {} to identify top {} candidates with a significance level of {}", new Object[]{this.bestAnchorIdentification.getClass().getSimpleName(), Integer.valueOf(i), Double.valueOf(this.delta)});
        return this.bestAnchorIdentification.identify(list, this.samplingService, this.explainedInstanceLabel, this.delta, this.epsilon, i);
    }

    private boolean isValidCandidate(AnchorCandidate anchorCandidate, int i) {
        double log = Math.log(1.0d / (this.delta / (1 + ((i - 1) * this.explainedInstance.getFeatureCount()))));
        double precision = anchorCandidate.getPrecision();
        double dlowBernoulli = KLBernoulliUtils.dlowBernoulli(precision, log / anchorCandidate.getSampledSize());
        double dupBernoulli = KLBernoulliUtils.dupBernoulli(precision, log / anchorCandidate.getSampledSize());
        while (true) {
            double d = dupBernoulli;
            if ((precision < this.tau || dlowBernoulli >= this.tau - this.tauDiscrepancy) && (precision >= this.tau || d < this.tau + this.tauDiscrepancy)) {
                break;
            }
            LOGGER.debug("Cannot confirm or reject {} is an anchor. Taking more samples.", anchorCandidate.getCanonicalFeatures());
            this.samplingService.createSession(this.explainedInstanceLabel).registerCandidateEvaluation(anchorCandidate, this.initSampleCount).run();
            precision = anchorCandidate.getPrecision();
            dlowBernoulli = KLBernoulliUtils.dlowBernoulli(precision, log / anchorCandidate.getSampledSize());
            dupBernoulli = KLBernoulliUtils.dupBernoulli(precision, log / anchorCandidate.getSampledSize());
        }
        return precision >= this.tau && dlowBernoulli > this.tau - this.tauDiscrepancy;
    }

    private void calculateCandidateCoverage(AnchorCandidate anchorCandidate) {
        if (anchorCandidate.isCoverageUndefined()) {
            anchorCandidate.setCoverage(this.coverageIdentification.calculateCoverage(anchorCandidate.getCanonicalFeatures()));
        }
    }

    private AnchorResult<T> beamSearch() throws NoCandidateFoundException {
        double currentTimeMillis = System.currentTimeMillis();
        int i = 1;
        HashMap hashMap = new HashMap();
        AnchorCandidate anchorCandidate = null;
        boolean z = false;
        while (true) {
            if (i > this.maxAnchorSize || z) {
                break;
            }
            LOGGER.debug("Adding feature {} of {}", Integer.valueOf(i), Integer.valueOf(this.maxAnchorSize));
            List<AnchorCandidate> generateCandidateSet = generateCandidateSet((List) hashMap.get(Integer.valueOf(i - 1)), this.explainedInstance.getFeatureCount(), anchorCandidate != null ? anchorCandidate.getCoverage().doubleValue() : 0.0d);
            if (generateCandidateSet.size() == 0) {
                break;
            }
            int min = Math.min(generateCandidateSet.size(), this.beamSize);
            List<AnchorCandidate> bestCandidate = bestCandidate(generateCandidateSet, min);
            Iterator<AnchorCandidate> it = bestCandidate.iterator();
            while (it.hasNext()) {
                AnchorCandidate next = it.next();
                if (generateCandidateSet.size() > min && next.getPrecision() <= 0.0d) {
                    LOGGER.debug("Removing candidate {} as its precision is 0", next.getOrderedFeatures());
                    it.remove();
                } else if (!this.allowSuboptimalSteps && next.getAddedPrecision() <= 0.0d) {
                    LOGGER.debug("Removing candidate {} as it decreases its parent's precision", next.getOrderedFeatures());
                    it.remove();
                }
            }
            if (bestCandidate.isEmpty()) {
                LOGGER.warn("No valid candidates found during best arm identification. Stopping search.");
                break;
            }
            hashMap.put(Integer.valueOf(i), bestCandidate);
            for (AnchorCandidate anchorCandidate2 : bestCandidate) {
                boolean isValidCandidate = isValidCandidate(anchorCandidate2, min);
                Logger logger = LOGGER;
                Object[] objArr = new Object[3];
                objArr[0] = anchorCandidate2.getCanonicalFeatures();
                objArr[1] = isValidCandidate ? "" : " not";
                objArr[2] = Double.valueOf(anchorCandidate2.getPrecision());
                logger.debug("Top candidate {} is{} a valid anchor with precision {}", objArr);
                if (isValidCandidate) {
                    calculateCandidateCoverage(anchorCandidate2);
                    if (anchorCandidate == null || anchorCandidate2.getCoverage().doubleValue() > anchorCandidate.getCoverage().doubleValue()) {
                        LOGGER.debug("Found a new best anchor ({}) with a coverage of {}", anchorCandidate2.getCanonicalFeatures(), anchorCandidate2.getCoverage());
                        anchorCandidate = anchorCandidate2;
                        if (anchorCandidate2.getCoverage().doubleValue() == 1.0d) {
                            LOGGER.info("Found an anchor with a coverage of 1. Stopping search prematurely.");
                            z = true;
                        }
                    }
                }
            }
            i++;
        }
        boolean z2 = anchorCandidate != null;
        if (anchorCandidate == null) {
            LOGGER.warn("Could not identify an anchor satisfying the parameters.Searching for best candidate.");
            List<AnchorCandidate> bestCandidate2 = bestCandidate((List) hashMap.values().stream().flatMap((v0) -> {
                return v0.stream();
            }).collect(Collectors.toList()), 1);
            if (bestCandidate2 == null || bestCandidate2.isEmpty()) {
                LOGGER.warn("Could not find an Anchor or any candidate with a precision > 0. Throwing NoCandidateFoundException.");
                throw new NoCandidateFoundException();
            }
            anchorCandidate = bestCandidate2.get(0);
            calculateCandidateCoverage(anchorCandidate);
            LOGGER.warn("No anchor found, returning best candidate");
        }
        double currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
        LOGGER.info("Found result {} in {}ms", anchorCandidate, Double.valueOf(currentTimeMillis2));
        return new AnchorResult<>(anchorCandidate, this.explainedInstance, this.explainedInstanceLabel, z2, currentTimeMillis2, this.samplingService.getTimeSpentSampling());
    }

    public AnchorResult<T> constructAnchor() throws NoCandidateFoundException {
        return constructAnchor(true);
    }

    public AnchorResult<T> constructAnchor(boolean z) throws NoCandidateFoundException {
        LOGGER.info("Starting Anchor Construction for instance {} and label {} with params: {}", new Object[]{this.explainedInstance, Integer.valueOf(this.explainedInstanceLabel), createKeyValueMap("maxAnchorSize", Integer.valueOf(this.maxAnchorSize), "beamSize", Integer.valueOf(this.beamSize), "delta", Double.valueOf(this.delta), "epsilon", Double.valueOf(this.epsilon), "tau", Double.valueOf(this.tau), "tauDiscrepancy", Double.valueOf(this.tauDiscrepancy), "initSampleCount", Integer.valueOf(this.initSampleCount), "lazyCoverageEvaluation", Boolean.valueOf(this.lazyCoverageEvaluation), "allowSuboptimalSteps", Boolean.valueOf(this.allowSuboptimalSteps))});
        return beamSearch();
    }
}
