package edu.cmu.lti.oaqa.baseqa.eval.calculator;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.Range;
import com.google.common.collect.RangeMap;
import com.google.common.collect.RangeSet;
import com.google.common.collect.Sets;
import com.google.common.collect.TreeRangeMap;
import com.google.common.collect.TreeRangeSet;
import edu.cmu.lti.oaqa.baseqa.eval.EvalCalculator;
import edu.cmu.lti.oaqa.baseqa.eval.EvalCalculatorUtil;
import edu.cmu.lti.oaqa.baseqa.eval.Measure;
import edu.cmu.lti.oaqa.ecd.config.ConfigurableProvider;
import edu.cmu.lti.oaqa.type.retrieval.Passage;
import edu.cmu.lti.oaqa.util.TypeUtil;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.uima.jcas.JCas;

/* loaded from: input_file:edu/cmu/lti/oaqa/baseqa/eval/calculator/TrecPassageMapEvalCalculator.class */
public class TrecPassageMapEvalCalculator<T extends Passage> extends ConfigurableProvider implements EvalCalculator<T> {
    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.cmu.lti.oaqa.baseqa.eval.EvalCalculator
    public Map<Measure, Double> calculate(JCas jCas, Collection<T> collection, Collection<T> collection2, Comparator<T> comparator, Function<T, String> function) {
        return ImmutableMap.builder().put(TrecPassageMapEvalMeasure.TREC_PASSAGE_MAP_COUNT, Double.valueOf(1.0d)).put(TrecPassageMapEvalMeasure.TREC_DOCUMENT_AVERAGE_PRECISION, Double.valueOf(EvalCalculatorUtil.calculateAveragePrecision((List) collection.stream().sorted(comparator).map((v0) -> {
            return v0.getUri();
        }).distinct().collect(Collectors.toList()), (Set) collection2.parallelStream().map((v0) -> {
            return v0.getUri();
        }).collect(Collectors.toSet())))).put(TrecPassageMapEvalMeasure.TREC_PASSAGE_AVERAGE_PRECISION, Double.valueOf(calculatePassageAveragePrecision(collection, collection2))).put(TrecPassageMapEvalMeasure.TREC_PASSAGE2_AVERAGE_PRECISION, Double.valueOf(calculatePassage2AveragePrecision((Collection<? extends Passage>) collection, (Collection<? extends Passage>) collection2))).put(TrecPassageMapEvalMeasure.TREC_ASPECT_AVERAGE_PRECISION, Double.valueOf(calculateAspectAveragePrecision(collection, collection2))).build();
    }

    @Override // edu.cmu.lti.oaqa.baseqa.eval.EvalCalculator
    public Map<Measure, Double> accumulate(Map<Measure, ? extends Collection<Double>> map) {
        double sumMeasurementValues = EvalCalculatorUtil.sumMeasurementValues(map.get(TrecPassageMapEvalMeasure.TREC_PASSAGE_MAP_COUNT));
        return ImmutableMap.builder().put(TrecPassageMapEvalMeasure.TREC_PASSAGE_MAP_COUNT, Double.valueOf(sumMeasurementValues)).put(TrecPassageMapEvalMeasure.TREC_DOCUMENT_MAP, Double.valueOf(EvalCalculatorUtil.sumMeasurementValues(map.get(TrecPassageMapEvalMeasure.TREC_DOCUMENT_AVERAGE_PRECISION)) / sumMeasurementValues)).put(TrecPassageMapEvalMeasure.TREC_PASSAGE_MAP, Double.valueOf(EvalCalculatorUtil.sumMeasurementValues(map.get(TrecPassageMapEvalMeasure.TREC_PASSAGE_AVERAGE_PRECISION)) / sumMeasurementValues)).put(TrecPassageMapEvalMeasure.TREC_PASSAGE2_MAP, Double.valueOf(EvalCalculatorUtil.sumMeasurementValues(map.get(TrecPassageMapEvalMeasure.TREC_PASSAGE2_AVERAGE_PRECISION)) / sumMeasurementValues)).put(TrecPassageMapEvalMeasure.TREC_ASPECT_MAP, Double.valueOf(EvalCalculatorUtil.sumMeasurementValues(map.get(TrecPassageMapEvalMeasure.TREC_ASPECT_AVERAGE_PRECISION)) / sumMeasurementValues)).build();
    }

    public static <T extends Passage> double calculatePassageAveragePrecision(Collection<T> collection, Collection<T> collection2) {
        return calculatePassageAveragePrecision(collection, toUriPassages(collection2));
    }

    public static <T extends Passage> double calculatePassageAveragePrecision(Collection<T> collection, Multimap<String, T> multimap) {
        int i = 0;
        int i2 = 0;
        double d = 0.0d;
        int i3 = 0;
        HashSet newHashSet = Sets.newHashSet();
        for (T t : collection) {
            Range<Integer> spanRangeInSection = TypeUtil.spanRangeInSection(t);
            i += ((Integer) spanRangeInSection.upperEndpoint()).intValue() - ((Integer) spanRangeInSection.lowerEndpoint()).intValue();
            if (multimap.containsKey(t.getUri())) {
                Iterator it = multimap.get(t.getUri()).iterator();
                while (true) {
                    if (it.hasNext()) {
                        Passage passage = (Passage) it.next();
                        Range<Integer> spanRangeInSection2 = TypeUtil.spanRangeInSection(passage);
                        if (spanRangeInSection2.isConnected(spanRangeInSection)) {
                            Range intersection = spanRangeInSection.intersection(spanRangeInSection2);
                            if (!intersection.isEmpty()) {
                                i2 += ((Integer) intersection.upperEndpoint()).intValue() - ((Integer) intersection.lowerEndpoint()).intValue();
                                d += i2 / i;
                                i3++;
                                newHashSet.add(passage);
                                break;
                            }
                        }
                    }
                }
            }
        }
        return d / (i3 + multimap.values().stream().filter(passage2 -> {
            return !newHashSet.contains(passage2);
        }).count());
    }

    private static <T extends Passage> Multimap<String, T> toUriPassages(Collection<T> collection) {
        return Multimaps.index(collection, (v0) -> {
            return v0.getUri();
        });
    }

    public static double calculatePassage2AveragePrecision(Collection<? extends Passage> collection, Collection<? extends Passage> collection2) {
        return calculatePassage2AveragePrecision(collection, toUriSpans(collection2));
    }

    public static double calculatePassage2AveragePrecision(Collection<? extends Passage> collection, Map<String, RangeSet<Integer>> map) {
        HashMap newHashMap = Maps.newHashMap(map);
        int i = 0;
        int i2 = 0;
        double d = 0.0d;
        for (Passage passage : collection) {
            Range<Integer> spanRangeInSection = TypeUtil.spanRangeInSection(passage);
            String uri = passage.getUri();
            if (map.containsKey(uri) && map.get(uri).encloses(spanRangeInSection)) {
                for (int intValue = ((Integer) spanRangeInSection.lowerEndpoint()).intValue(); intValue < ((Integer) spanRangeInSection.upperEndpoint()).intValue(); intValue++) {
                    if (!map.get(uri).contains(Integer.valueOf(intValue))) {
                        i++;
                    } else if (((RangeSet) newHashMap.get(uri)).contains(Integer.valueOf(intValue))) {
                        ((RangeSet) newHashMap.get(uri)).remove(Range.singleton(Integer.valueOf(intValue)));
                        i++;
                        i2++;
                        d += i2 / i;
                    }
                }
            } else {
                i += ((Integer) spanRangeInSection.upperEndpoint()).intValue() - ((Integer) spanRangeInSection.lowerEndpoint()).intValue();
            }
        }
        return d / map.values().stream().flatMap(rangeSet -> {
            return rangeSet.asRanges().stream();
        }).mapToInt(range -> {
            return ((Integer) range.upperEndpoint()).intValue() - ((Integer) range.lowerEndpoint()).intValue();
        }).sum();
    }

    private static <T extends Passage> Map<String, RangeSet<Integer>> toUriSpans(Collection<T> collection) {
        return (Map) collection.parallelStream().collect(Collectors.groupingBy((v0) -> {
            return v0.getUri();
        }, Collector.of(() -> {
            return TreeRangeSet.create();
        }, (rangeSet, passage) -> {
            rangeSet.add(TypeUtil.spanRangeInSection(passage));
        }, (rangeSet2, rangeSet3) -> {
            rangeSet2.addAll(rangeSet3);
            return rangeSet2;
        }, Collector.Characteristics.CONCURRENT, Collector.Characteristics.UNORDERED, Collector.Characteristics.IDENTITY_FINISH)));
    }

    public static <T extends Passage> double calculateAspectAveragePrecision(Collection<T> collection, Collection<T> collection2) {
        return calculateAspectAveragePrecision(collection, toUriSpanAspects(collection2));
    }

    public static <T extends Passage> double calculateAspectAveragePrecision(Collection<T> collection, Map<String, RangeMap<Integer, String>> map) {
        int i = 0;
        int i2 = 0;
        double d = 0.0d;
        HashSet hashSet = new HashSet();
        for (T t : collection) {
            Range<Integer> spanRangeInSection = TypeUtil.spanRangeInSection(t);
            if (map.containsKey(t.getUri())) {
                Set<String> splitAspects = toSplitAspects(map.get(t.getUri()).subRangeMap(spanRangeInSection).asMapOfRanges().values());
                if (splitAspects.isEmpty()) {
                    i2++;
                } else {
                    if (splitAspects.stream().filter(str -> {
                        return !hashSet.contains(str);
                    }).count() > 0) {
                        i++;
                        i2++;
                        d += (r0 * i) / i2;
                    }
                    hashSet.addAll(splitAspects);
                }
            }
        }
        return d / toSplitAspects((Set) map.values().stream().flatMap(rangeMap -> {
            return rangeMap.asMapOfRanges().values().stream();
        }).collect(Collectors.toSet())).size();
    }

    private static <T extends Passage> Map<String, RangeMap<Integer, String>> toUriSpanAspects(Collection<T> collection) {
        return (Map) collection.parallelStream().filter(passage -> {
            return passage.getAspects() != null;
        }).collect(Collectors.groupingBy((v0) -> {
            return v0.getUri();
        }, Collector.of(() -> {
            return TreeRangeMap.create();
        }, (rangeMap, passage2) -> {
            rangeMap.put(TypeUtil.spanRangeInSection(passage2), passage2.getAspects());
        }, (rangeMap2, rangeMap3) -> {
            rangeMap2.putAll(rangeMap3);
            return rangeMap2;
        }, Collector.Characteristics.CONCURRENT, Collector.Characteristics.UNORDERED, Collector.Characteristics.IDENTITY_FINISH)));
    }

    private static Set<String> toSplitAspects(Collection<String> collection) {
        return (Set) collection.parallelStream().flatMap(str -> {
            return Stream.of((Object[]) str.split("\\|"));
        }).collect(Collectors.toSet());
    }

    @Override // edu.cmu.lti.oaqa.baseqa.eval.EvalCalculator
    public String getName() {
        return "Trec";
    }
}
