package org.lenskit.eval.traintest.recommend;

import com.fasterxml.jackson.annotation.JsonCreator;
import it.unimi.dsi.fastutil.longs.LongList;
import it.unimi.dsi.fastutil.longs.LongListIterator;
import it.unimi.dsi.fastutil.longs.LongSet;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.lenskit.api.Recommender;
import org.lenskit.api.RecommenderEngine;
import org.lenskit.eval.traintest.AlgorithmInstance;
import org.lenskit.eval.traintest.DataSet;
import org.lenskit.eval.traintest.TestUser;
import org.lenskit.eval.traintest.metrics.MetricColumn;
import org.lenskit.eval.traintest.metrics.MetricResult;
import org.lenskit.eval.traintest.metrics.TypedMetricResult;
import org.lenskit.util.math.MeanAccumulator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/lenskit/eval/traintest/recommend/TopNMAPMetric.class */
public class TopNMAPMetric extends ListOnlyTopNMetric<Context> {
    private static final Logger logger = LoggerFactory.getLogger(TopNMAPMetric.class);
    private final String suffix;
    private final ItemSelector goodItems;

    /* loaded from: input_file:org/lenskit/eval/traintest/recommend/TopNMAPMetric$AggregateResult.class */
    public static class AggregateResult extends TypedMetricResult {

        @MetricColumn("MAP")
        public final double map;

        public AggregateResult(Context context) {
            this.map = context.allMean.getMean();
        }
    }

    /* loaded from: input_file:org/lenskit/eval/traintest/recommend/TopNMAPMetric$Context.class */
    public static class Context {
        private final LongSet universe;
        private final RecommenderEngine recommenderEngine;
        private final MeanAccumulator allMean = new MeanAccumulator();

        Context(LongSet longSet, RecommenderEngine recommenderEngine) {
            this.universe = longSet;
            this.recommenderEngine = recommenderEngine;
        }

        synchronized void addUser(UserResult userResult) {
            this.allMean.add(userResult.avgPrecision);
        }
    }

    /* loaded from: input_file:org/lenskit/eval/traintest/recommend/TopNMAPMetric$UserResult.class */
    public static class UserResult extends TypedMetricResult {

        @MetricColumn("AvgPrec")
        public final double avgPrecision;

        public UserResult(double d) {
            this.avgPrecision = d;
        }
    }

    public TopNMAPMetric() {
        this(ItemSelector.userTestItems(), null);
    }

    @JsonCreator
    public TopNMAPMetric(PRMetricSpec pRMetricSpec) {
        this(ItemSelector.compileSelector(StringUtils.defaultString(pRMetricSpec.getGoodItems(), "user.testItems")), pRMetricSpec.getSuffix());
    }

    public TopNMAPMetric(ItemSelector itemSelector, String str) {
        super(UserResult.class, AggregateResult.class, str);
        this.suffix = str;
        this.goodItems = itemSelector;
    }

    @Override // org.lenskit.eval.traintest.metrics.Metric
    @Nullable
    public Context createContext(AlgorithmInstance algorithmInstance, DataSet dataSet, RecommenderEngine recommenderEngine) {
        return new Context(dataSet.getAllItems(), recommenderEngine);
    }

    @Override // org.lenskit.eval.traintest.metrics.Metric
    @Nonnull
    public MetricResult getAggregateMeasurements(Context context) {
        return new AggregateResult(context).withSuffix(this.suffix);
    }

    @Override // org.lenskit.eval.traintest.recommend.ListOnlyTopNMetric
    @Nonnull
    public MetricResult measureUser(Recommender recommender, TestUser testUser, int i, LongList longList, Context context) {
        LongSet selectItems = this.goodItems.selectItems(context.universe, recommender, testUser);
        if (selectItems.isEmpty()) {
            logger.warn("no good items for user {}", Long.valueOf(testUser.getUserId()));
            return new UserResult(0.0d);
        }
        if (longList == null || longList.isEmpty()) {
            return MetricResult.empty();
        }
        int i2 = 0;
        double d = 0.0d;
        double d2 = 0.0d;
        LongListIterator it = longList.iterator();
        while (it.hasNext()) {
            i2++;
            if (selectItems.contains(it.nextLong())) {
                d += 1.0d;
                d2 += d / i2;
            }
        }
        UserResult userResult = new UserResult(d > 0.0d ? d2 / d : 0.0d);
        context.addUser(userResult);
        return userResult.withSuffix(this.suffix);
    }
}
