package edu.cmu.lti.oaqa.baseqa.eval.calculator;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Range;
import com.google.common.collect.RangeSet;
import com.google.common.collect.TreeRangeSet;
import edu.cmu.lti.oaqa.baseqa.eval.EvalCalculator;
import edu.cmu.lti.oaqa.baseqa.eval.EvalCalculatorUtil;
import edu.cmu.lti.oaqa.baseqa.eval.Measure;
import edu.cmu.lti.oaqa.ecd.config.ConfigurableProvider;
import edu.cmu.lti.oaqa.type.retrieval.Passage;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.uima.jcas.JCas;

/* loaded from: input_file:edu/cmu/lti/oaqa/baseqa/eval/calculator/PassageMapEvalCalculator.class */
public class PassageMapEvalCalculator<T extends Passage> extends ConfigurableProvider implements EvalCalculator<T> {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/cmu/lti/oaqa/baseqa/eval/calculator/PassageMapEvalCalculator$CharacterPosition.class */
    public static class CharacterPosition implements Comparable<CharacterPosition> {
        private String uri;
        private String section;
        private int offset;
        private static Comparator<CharacterPosition> passageSetComparator = Comparator.comparing((v0) -> {
            return v0.getUri();
        }).thenComparing((v0) -> {
            return v0.getSection();
        }).thenComparingInt((v0) -> {
            return v0.getOffset();
        });

        private CharacterPosition(String str, String str2, int i) {
            this.uri = str;
            this.section = str2;
            this.offset = i;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public static Range<CharacterPosition> toRange(Passage passage) {
            return Range.closedOpen(new CharacterPosition(passage.getUri(), passage.getBeginSection(), passage.getOffsetInBeginSection()), new CharacterPosition(passage.getUri(), passage.getEndSection(), passage.getOffsetInEndSection()));
        }

        @Override // java.lang.Comparable
        public int compareTo(CharacterPosition characterPosition) {
            return passageSetComparator.compare(this, characterPosition);
        }

        public String toString() {
            return this.uri + ":" + this.section + ":" + this.offset;
        }

        public String getUri() {
            return this.uri;
        }

        public String getSection() {
            return this.section;
        }

        public int getOffset() {
            return this.offset;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.cmu.lti.oaqa.baseqa.eval.EvalCalculator
    public Map<Measure, Double> calculate(JCas jCas, Collection<T> collection, Collection<T> collection2, Comparator<T> comparator, Function<T, String> function) {
        List list = (List) collection.stream().sorted(comparator).map(passage -> {
            return CharacterPosition.toRange(passage);
        }).collect(Collectors.toList());
        RangeSet rangeSet = (RangeSet) list.stream().collect(() -> {
            return TreeRangeSet.create();
        }, (v0, v1) -> {
            v0.add(v1);
        }, (v0, v1) -> {
            v0.addAll(v1);
        });
        rangeSet.getClass();
        list.forEach(rangeSet::add);
        RangeSet rangeSet2 = (RangeSet) collection2.stream().map(passage2 -> {
            return CharacterPosition.toRange(passage2);
        }).collect(() -> {
            return TreeRangeSet.create();
        }, (v0, v1) -> {
            v0.add(v1);
        }, (v0, v1) -> {
            v0.addAll(v1);
        });
        int countCharacters = countCharacters((RangeSet<CharacterPosition>) rangeSet);
        int countCharacters2 = countCharacters((RangeSet<CharacterPosition>) rangeSet2);
        int countCharacters3 = countCharacters(intersection(rangeSet, rangeSet2));
        double calculatePrecision = EvalCalculatorUtil.calculatePrecision(countCharacters, countCharacters3);
        double calculateRecall = EvalCalculatorUtil.calculateRecall(countCharacters2, countCharacters3);
        return ImmutableMap.builder().put(PassageMapEvalMeasure.PASSAGE_MAP_COUNT, Double.valueOf(1.0d)).put(PassageMapEvalMeasure.PASSAGE_PRECISION, Double.valueOf(calculatePrecision)).put(PassageMapEvalMeasure.PASSAGE_RECALL, Double.valueOf(calculateRecall)).put(PassageMapEvalMeasure.PASSAGE_F1, Double.valueOf(EvalCalculatorUtil.calculateF1(calculatePrecision, calculateRecall))).put(PassageMapEvalMeasure.PASSAGE_AVERAGE_PRECISION, Double.valueOf(calculatePassageAveragePrecision(list, rangeSet2))).build();
    }

    @Override // edu.cmu.lti.oaqa.baseqa.eval.EvalCalculator
    public Map<Measure, Double> accumulate(Map<Measure, ? extends Collection<Double>> map) {
        double sumMeasurementValues = EvalCalculatorUtil.sumMeasurementValues(map.get(PassageMapEvalMeasure.PASSAGE_MAP_COUNT));
        return ImmutableMap.builder().put(PassageMapEvalMeasure.PASSAGE_MAP_COUNT, Double.valueOf(sumMeasurementValues)).put(PassageMapEvalMeasure.PASSAGE_MEAN_PRECISION, Double.valueOf(EvalCalculatorUtil.sumMeasurementValues(map.get(PassageMapEvalMeasure.PASSAGE_PRECISION)) / sumMeasurementValues)).put(PassageMapEvalMeasure.PASSAGE_MEAN_RECALL, Double.valueOf(EvalCalculatorUtil.sumMeasurementValues(map.get(PassageMapEvalMeasure.PASSAGE_RECALL)) / sumMeasurementValues)).put(PassageMapEvalMeasure.PASSAGE_MEAN_F1, Double.valueOf(EvalCalculatorUtil.sumMeasurementValues(map.get(PassageMapEvalMeasure.PASSAGE_F1)) / sumMeasurementValues)).put(PassageMapEvalMeasure.PASSAGE_MAP, Double.valueOf(EvalCalculatorUtil.sumMeasurementValues(map.get(PassageMapEvalMeasure.PASSAGE_AVERAGE_PRECISION)) / sumMeasurementValues)).put(PassageMapEvalMeasure.PASSAGE_GMAP, Double.valueOf(Math.exp(EvalCalculatorUtil.sumOfLogMeasurementValues(map.get(PassageMapEvalMeasure.PASSAGE_AVERAGE_PRECISION)) / sumMeasurementValues))).build();
    }

    public static int countCharacters(RangeSet<CharacterPosition> rangeSet) {
        return rangeSet.asRanges().stream().filter(PassageMapEvalCalculator::isSameSection).mapToInt(PassageMapEvalCalculator::countCharacters).sum();
    }

    public static boolean isSameSection(Range<CharacterPosition> range) {
        return ((CharacterPosition) range.lowerEndpoint()).section.equals(((CharacterPosition) range.upperEndpoint()).section);
    }

    public static int countCharacters(Range<CharacterPosition> range) {
        return ((CharacterPosition) range.upperEndpoint()).offset - ((CharacterPosition) range.lowerEndpoint()).offset;
    }

    public static RangeSet<CharacterPosition> intersection(RangeSet<CharacterPosition> rangeSet, RangeSet<CharacterPosition> rangeSet2) {
        TreeRangeSet create = TreeRangeSet.create(rangeSet.complement());
        create.addAll(rangeSet2.complement());
        return create.complement();
    }

    public static double calculatePassageAveragePrecision(List<Range<CharacterPosition>> list, RangeSet<CharacterPosition> rangeSet) {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        Iterator<Range<CharacterPosition>> it = list.iterator();
        while (it.hasNext()) {
            d2 += countCharacters(r0);
            if (rangeSet.encloses(it.next())) {
                d3 += countCharacters((RangeSet<CharacterPosition>) rangeSet.subRangeSet(r0));
                d += EvalCalculatorUtil.calculatePrecision(d2, d3);
            }
        }
        return d / rangeSet.asRanges().size();
    }

    @Override // edu.cmu.lti.oaqa.baseqa.eval.EvalCalculator
    public String getName() {
        return "Passage";
    }
}
