package edu.cmu.lti.oaqa.baseqa.eval.calculator;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import edu.cmu.lti.oaqa.baseqa.eval.EvalCalculator;
import edu.cmu.lti.oaqa.baseqa.eval.EvalCalculatorUtil;
import edu.cmu.lti.oaqa.baseqa.eval.Measure;
import edu.cmu.lti.oaqa.ecd.config.ConfigurableProvider;
import edu.cmu.lti.oaqa.type.retrieval.SearchResult;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.uima.jcas.JCas;

/* loaded from: input_file:edu/cmu/lti/oaqa/baseqa/eval/calculator/RetrievalEvalCalculator.class */
public class RetrievalEvalCalculator<T extends SearchResult> extends ConfigurableProvider implements EvalCalculator<T> {
    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.cmu.lti.oaqa.baseqa.eval.EvalCalculator
    public Map<Measure, Double> calculate(JCas jCas, Collection<T> collection, Collection<T> collection2, Comparator<T> comparator, Function<T, String> function) {
        Set set = (Set) collection2.parallelStream().map(function).collect(Collectors.toSet());
        List list = (List) collection.stream().sorted(comparator).map(function).distinct().collect(Collectors.toList());
        HashSet hashSet = new HashSet(list);
        double size = hashSet.size();
        double size2 = set.size();
        double size3 = Sets.intersection(hashSet, set).size();
        double d = size3 > 0.0d ? 1.0d : 0.0d;
        double calculatePrecision = EvalCalculatorUtil.calculatePrecision(size, size3);
        double calculateRecall = EvalCalculatorUtil.calculateRecall(size2, size3);
        return ImmutableMap.builder().put(RetrievalEvalMeasure.RETRIEVAL_COUNT, Double.valueOf(1.0d)).put(RetrievalEvalMeasure.RETRIEVED, Double.valueOf(size)).put(RetrievalEvalMeasure.RELEVANT, Double.valueOf(size2)).put(RetrievalEvalMeasure.RELEVANT_RETRIEVED, Double.valueOf(size3)).put(RetrievalEvalMeasure.BINARY_RELEVANT, Double.valueOf(d)).put(RetrievalEvalMeasure.PRECISION, Double.valueOf(calculatePrecision)).put(RetrievalEvalMeasure.RECALL, Double.valueOf(calculateRecall)).put(RetrievalEvalMeasure.F1, Double.valueOf(EvalCalculatorUtil.calculateF1(calculatePrecision, calculateRecall))).put(RetrievalEvalMeasure.AVERAGE_PRECISION, Double.valueOf(EvalCalculatorUtil.calculateAveragePrecision(list, set))).build();
    }

    @Override // edu.cmu.lti.oaqa.baseqa.eval.EvalCalculator
    public Map<Measure, Double> accumulate(Map<Measure, ? extends Collection<Double>> map) {
        double sumMeasurementValues = EvalCalculatorUtil.sumMeasurementValues(map.get(RetrievalEvalMeasure.RETRIEVAL_COUNT));
        return ImmutableMap.builder().put(RetrievalEvalMeasure.RETRIEVAL_COUNT, Double.valueOf(sumMeasurementValues)).put(RetrievalEvalMeasure.MEAN_PRECISION, Double.valueOf(EvalCalculatorUtil.sumMeasurementValues(map.get(RetrievalEvalMeasure.PRECISION)) / sumMeasurementValues)).put(RetrievalEvalMeasure.MEAN_RECALL, Double.valueOf(EvalCalculatorUtil.sumMeasurementValues(map.get(RetrievalEvalMeasure.RECALL)) / sumMeasurementValues)).put(RetrievalEvalMeasure.MEAN_BINARY_RECALL, Double.valueOf(EvalCalculatorUtil.sumMeasurementValues(map.get(RetrievalEvalMeasure.BINARY_RELEVANT)) / sumMeasurementValues)).put(RetrievalEvalMeasure.MEAN_F1, Double.valueOf(EvalCalculatorUtil.sumMeasurementValues(map.get(RetrievalEvalMeasure.F1)) / sumMeasurementValues)).put(RetrievalEvalMeasure.MAP, Double.valueOf(EvalCalculatorUtil.sumMeasurementValues(map.get(RetrievalEvalMeasure.AVERAGE_PRECISION)) / sumMeasurementValues)).put(RetrievalEvalMeasure.GMAP, Double.valueOf(Math.exp(EvalCalculatorUtil.sumOfLogMeasurementValues(map.get(RetrievalEvalMeasure.AVERAGE_PRECISION)) / sumMeasurementValues))).build();
    }

    @Override // edu.cmu.lti.oaqa.baseqa.eval.EvalCalculator
    public String getName() {
        return "Retrieval";
    }
}
