package org.lenskit.eval.traintest.predict;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import it.unimi.dsi.fastutil.longs.Long2DoubleFunction;
import it.unimi.dsi.fastutil.longs.Long2DoubleMap;
import it.unimi.dsi.fastutil.longs.LongArrays;
import it.unimi.dsi.fastutil.longs.LongComparators;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.grouplens.lenskit.util.statistics.MeanAccumulator;
import org.lenskit.api.Recommender;
import org.lenskit.api.ResultMap;
import org.lenskit.eval.traintest.AlgorithmInstance;
import org.lenskit.eval.traintest.DataSet;
import org.lenskit.eval.traintest.TestUser;
import org.lenskit.eval.traintest.metrics.Discount;
import org.lenskit.eval.traintest.metrics.Discounts;
import org.lenskit.eval.traintest.metrics.MetricResult;
import org.lenskit.util.collections.LongUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX WARN: Classes with same name are omitted:
  
 */
/* loaded from: input_file:org/lenskit/eval/traintest/predict/NDCGPredictMetric.class */
public class NDCGPredictMetric extends PredictMetric<MeanAccumulator> {
    private static final Logger logger = LoggerFactory.getLogger(NDCGPredictMetric.class);
    public static final String DEFAULT_COLUMN = "Predict.nDCG";
    private final String columnName;
    private final Discount discount;

    /* JADX WARN: Classes with same name are omitted:
      
     */
    @JsonIgnoreProperties({"type"})
    /* loaded from: input_file:org/lenskit/eval/traintest/predict/NDCGPredictMetric$Spec.class */
    public static class Spec {
        private String name;
        private String discount;

        public String getColumnName() {
            return this.name;
        }

        public void setColumnName(String str) {
            this.name = str;
        }

        public String getDiscount() {
            return this.discount;
        }

        public void setDiscount(String str) {
            this.discount = str;
        }

        public Discount getParsedDiscount() {
            return this.discount == null ? Discounts.log2() : Discounts.parse(this.discount);
        }
    }

    public NDCGPredictMetric() {
        this(Discounts.log2(), DEFAULT_COLUMN);
    }

    public NDCGPredictMetric(Discount discount) {
        this(discount, DEFAULT_COLUMN);
    }

    @JsonCreator
    public NDCGPredictMetric(Spec spec) {
        this(spec.getParsedDiscount(), StringUtils.defaultString(spec.getColumnName(), DEFAULT_COLUMN));
    }

    public NDCGPredictMetric(Discount discount, String str) {
        super(Lists.newArrayList(new String[]{str, str + ".Raw"}), Lists.newArrayList(new String[]{str}));
        this.columnName = str;
        this.discount = discount;
    }

    @Override // org.lenskit.eval.traintest.metrics.Metric
    @Nullable
    public MeanAccumulator createContext(AlgorithmInstance algorithmInstance, DataSet dataSet, Recommender recommender) {
        return new MeanAccumulator();
    }

    @Override // org.lenskit.eval.traintest.metrics.Metric
    @Nonnull
    public MetricResult getAggregateMeasurements(MeanAccumulator meanAccumulator) {
        return MetricResult.singleton(this.columnName, Double.valueOf(meanAccumulator.getMean()));
    }

    @Override // org.lenskit.eval.traintest.predict.PredictMetric
    @Nonnull
    public MetricResult measureUser(TestUser testUser, ResultMap resultMap, MeanAccumulator meanAccumulator) {
        if (resultMap == null || resultMap.isEmpty()) {
            return MetricResult.empty();
        }
        Long2DoubleMap testRatings = testUser.getTestRatings();
        long[] longArray = testRatings.keySet().toLongArray();
        LongArrays.quickSort(longArray, LongComparators.oppositeComparator(LongUtils.keyValueComparator(testRatings)));
        long[] longArray2 = LongUtils.asLongSet(resultMap.keySet()).toLongArray();
        LongArrays.quickSort(longArray2, LongComparators.oppositeComparator(LongUtils.keyValueComparator(LongUtils.asLong2DoubleFunction(resultMap.scoreMap()))));
        double computeDCG = computeDCG(longArray, testRatings);
        double computeDCG2 = computeDCG(longArray2, testRatings);
        logger.debug("user {} has gain of {} (ideal {})", new Object[]{Long.valueOf(testUser.getUserId()), Double.valueOf(computeDCG2), Double.valueOf(computeDCG)});
        double d = computeDCG2 / computeDCG;
        meanAccumulator.add(d);
        return MetricResult.fromMap(ImmutableMap.builder().put(this.columnName, Double.valueOf(d)).put(this.columnName + ".Raw", Double.valueOf(computeDCG2)).build());
    }

    double computeDCG(long[] jArr, Long2DoubleFunction long2DoubleFunction) {
        double d = 0.0d;
        int i = 0;
        for (long j : jArr) {
            i++;
            d += long2DoubleFunction.get(j) * this.discount.discount(i);
        }
        return d;
    }
}
