package org.lenskit.eval.traintest.recommend;

import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import it.unimi.dsi.fastutil.longs.LongList;
import it.unimi.dsi.fastutil.longs.LongSet;
import java.io.IOException;
import java.net.URI;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.grouplens.grapht.util.ClassLoaders;
import org.grouplens.lenskit.util.io.CompressionMode;
import org.lenskit.api.ItemRecommender;
import org.lenskit.api.Recommender;
import org.lenskit.api.Result;
import org.lenskit.api.ResultList;
import org.lenskit.eval.traintest.AlgorithmInstance;
import org.lenskit.eval.traintest.ConditionEvaluator;
import org.lenskit.eval.traintest.DataSet;
import org.lenskit.eval.traintest.EvalTask;
import org.lenskit.eval.traintest.EvaluationException;
import org.lenskit.eval.traintest.ExperimentOutputLayout;
import org.lenskit.eval.traintest.TestUser;
import org.lenskit.eval.traintest.metrics.Metric;
import org.lenskit.eval.traintest.metrics.MetricLoaderHelper;
import org.lenskit.eval.traintest.metrics.MetricResult;
import org.lenskit.eval.traintest.predict.PredictEvalTask;
import org.lenskit.util.collections.LongUtils;
import org.lenskit.util.table.TableLayout;
import org.lenskit.util.table.TableLayoutBuilder;
import org.lenskit.util.table.writer.CSVWriter;
import org.lenskit.util.table.writer.TableWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX WARN: Classes with same name are omitted:
  
 */
/* loaded from: input_file:org/lenskit/eval/traintest/recommend/RecommendEvalTask.class */
public class RecommendEvalTask implements EvalTask {
    private static final Logger logger = LoggerFactory.getLogger(RecommendEvalTask.class);
    private static final TopNMetric<?>[] DEFAULT_METRICS = {new TopNLengthMetric(), new TopNNDCGMetric()};
    private Path outputFile;
    private String labelPrefix;
    private int listSize = -1;
    private List<TopNMetric<?>> topNMetrics = Lists.newArrayList(DEFAULT_METRICS);
    private volatile ItemSelector candidateSelector = ItemSelector.allItems();
    private volatile ItemSelector excludeSelector = ItemSelector.userTrainItems();
    private ExperimentOutputLayout experimentOutputLayout;
    private TableWriter outputTable;

    /* JADX WARN: Classes with same name are omitted:
      
     */
    /* loaded from: input_file:org/lenskit/eval/traintest/recommend/RecommendEvalTask$MetricContext.class */
    static class MetricContext<X> {
        final TopNMetric<X> metric;
        final X context;

        public MetricContext(TopNMetric<X> topNMetric, X x) {
            this.metric = topNMetric;
            this.context = x;
        }

        public boolean usesDetails() {
            return !(this.metric instanceof ListOnlyTopNMetric);
        }

        @Nonnull
        public MetricResult measureUser(TestUser testUser, int i, ResultList resultList) {
            return this.metric.measureUser(testUser, i, resultList, this.context);
        }

        @Nonnull
        public MetricResult measureUser(TestUser testUser, int i, LongList longList) {
            return ((ListOnlyTopNMetric) this.metric).measureUser(testUser, i, longList, (LongList) this.context);
        }

        @Nonnull
        public MetricResult getAggregateMeasurements() {
            return this.metric.getAggregateMeasurements(this.context);
        }

        public static <X> MetricContext<X> create(TopNMetric<X> topNMetric, AlgorithmInstance algorithmInstance, DataSet dataSet, Recommender recommender) {
            return new MetricContext<>(topNMetric, topNMetric.createContext(algorithmInstance, dataSet, recommender));
        }
    }

    /* JADX WARN: Classes with same name are omitted:
      
     */
    /* loaded from: input_file:org/lenskit/eval/traintest/recommend/RecommendEvalTask$TopNConditionEvaluator.class */
    class TopNConditionEvaluator implements ConditionEvaluator {
        private final TableWriter writer;
        private final Recommender recommender;
        private final ItemRecommender itemRecommender;
        private final List<MetricContext<?>> predictMetricContexts;
        private final LongSet allItems;
        private final boolean useDetails;
        static final /* synthetic */ boolean $assertionsDisabled;

        public TopNConditionEvaluator(TableWriter tableWriter, Recommender recommender, ItemRecommender itemRecommender, List<MetricContext<?>> list, LongSet longSet, boolean z) {
            this.writer = tableWriter;
            this.recommender = recommender;
            this.itemRecommender = itemRecommender;
            this.predictMetricContexts = list;
            this.allItems = longSet;
            this.useDetails = z;
        }

        @Override // org.lenskit.eval.traintest.ConditionEvaluator
        @Nonnull
        public Map<String, Object> measureUser(TestUser testUser) {
            LongSet selectItems = RecommendEvalTask.this.getCandidateSelector().selectItems(this.allItems, this.recommender, testUser);
            LongSet selectItems2 = RecommendEvalTask.this.getExcludeSelector().selectItems(this.allItems, this.recommender, testUser);
            int listSize = RecommendEvalTask.this.getListSize();
            ResultList<Result> resultList = null;
            LongList longList = null;
            if (this.useDetails) {
                resultList = this.itemRecommender.recommendWithDetails(testUser.getUserId(), listSize, selectItems, selectItems2);
            } else {
                longList = LongUtils.asLongList(this.itemRecommender.recommend(testUser.getUserId(), listSize, selectItems, selectItems2));
            }
            HashMap hashMap = new HashMap();
            for (MetricContext<?> metricContext : this.predictMetricContexts) {
                hashMap.putAll((this.useDetails ? metricContext.measureUser(testUser, listSize, resultList) : metricContext.measureUser(testUser, listSize, longList)).withPrefix(RecommendEvalTask.this.getLabelPrefix()).getValues());
            }
            if (this.writer != null) {
                if (!$assertionsDisabled && resultList == null) {
                    throw new AssertionError();
                }
                int i = 0;
                for (Result result : resultList) {
                    try {
                        i++;
                        this.writer.writeRow(new Object[]{Long.valueOf(testUser.getUserId()), Integer.valueOf(i), Long.valueOf(result.getId()), Double.valueOf(result.getScore())});
                    } catch (IOException e) {
                        throw new EvaluationException("error writing prediction row", e);
                    }
                }
            }
            return hashMap;
        }

        @Override // org.lenskit.eval.traintest.ConditionEvaluator
        @Nonnull
        public Map<String, Object> finish() {
            HashMap hashMap = new HashMap();
            for (MetricContext<?> metricContext : this.predictMetricContexts) {
                RecommendEvalTask.logger.debug("finishing metric {}", metricContext.metric);
                hashMap.putAll(metricContext.getAggregateMeasurements().withPrefix(RecommendEvalTask.this.getLabelPrefix()).getValues());
            }
            return hashMap;
        }

        static {
            $assertionsDisabled = !RecommendEvalTask.class.desiredAssertionStatus();
        }
    }

    public static RecommendEvalTask fromJSON(JsonNode jsonNode, URI uri) throws IOException {
        RecommendEvalTask recommendEvalTask = new RecommendEvalTask();
        String asText = jsonNode.path("output_file").asText((String) null);
        if (asText != null) {
            recommendEvalTask.setOutputFile(Paths.get(uri.resolve(asText)));
        }
        recommendEvalTask.setLabelPrefix(jsonNode.path("label_prefix").asText((String) null));
        recommendEvalTask.setListSize(jsonNode.path("list_size").asInt(-1));
        String asText2 = jsonNode.path("candidates").asText((String) null);
        if (asText2 != null) {
            recommendEvalTask.setCandidateSelector(ItemSelector.compileSelector(asText2));
        }
        String asText3 = jsonNode.path("exclude").asText((String) null);
        if (asText3 != null) {
            recommendEvalTask.setExcludeSelector(ItemSelector.compileSelector(asText3));
        }
        JsonNode jsonNode2 = jsonNode.get("metrics");
        if (jsonNode2 != null && !jsonNode2.isNull()) {
            recommendEvalTask.topNMetrics.clear();
            MetricLoaderHelper metricLoaderHelper = new MetricLoaderHelper(ClassLoaders.inferDefault(PredictEvalTask.class), "topn-metrics");
            Iterator it = jsonNode2.iterator();
            while (it.hasNext()) {
                JsonNode jsonNode3 = (JsonNode) it.next();
                TopNMetric<?> topNMetric = (TopNMetric) metricLoaderHelper.createMetric(TopNMetric.class, jsonNode3);
                if (topNMetric == null) {
                    throw new RuntimeException("cannot build metric for " + jsonNode3.toString());
                }
                recommendEvalTask.addMetric(topNMetric);
            }
        }
        return recommendEvalTask;
    }

    public Path getOutputFile() {
        return this.outputFile;
    }

    public void setOutputFile(Path path) {
        this.outputFile = path;
    }

    public String getLabelPrefix() {
        return this.labelPrefix;
    }

    public void setLabelPrefix(String str) {
        this.labelPrefix = str;
    }

    public int getListSize() {
        return this.listSize;
    }

    public void setListSize(int i) {
        this.listSize = i;
    }

    public ItemSelector getCandidateSelector() {
        return this.candidateSelector;
    }

    public void setCandidateSelector(ItemSelector itemSelector) {
        this.candidateSelector = itemSelector;
    }

    public ItemSelector getExcludeSelector() {
        return this.excludeSelector;
    }

    public void setExcludeSelector(ItemSelector itemSelector) {
        this.excludeSelector = itemSelector;
    }

    public List<TopNMetric<?>> getTopNMetrics() {
        return this.topNMetrics;
    }

    public List<Metric<?>> getAllMetrics() {
        ImmutableList.Builder builder = ImmutableList.builder();
        builder.addAll(this.topNMetrics);
        return builder.build();
    }

    public void addMetric(TopNMetric<?> topNMetric) {
        this.topNMetrics.add(topNMetric);
    }

    @Override // org.lenskit.eval.traintest.EvalTask
    public Set<Class<?>> getRequiredRoots() {
        return FluentIterable.from(getAllMetrics()).transformAndConcat(new Function<Metric<?>, Iterable<Class<?>>>() { // from class: org.lenskit.eval.traintest.recommend.RecommendEvalTask.1
            @Nullable
            public Iterable<Class<?>> apply(Metric<?> metric) {
                return metric.getRequiredRoots();
            }
        }).toSet();
    }

    @Override // org.lenskit.eval.traintest.EvalTask
    public List<String> getGlobalColumns() {
        ImmutableList.Builder builder = ImmutableList.builder();
        Iterator<Metric<?>> it = getAllMetrics().iterator();
        while (it.hasNext()) {
            Iterator<String> it2 = it.next().getAggregateColumnLabels().iterator();
            while (it2.hasNext()) {
                builder.add(prefixColumn(it2.next()));
            }
        }
        return builder.build();
    }

    @Override // org.lenskit.eval.traintest.EvalTask
    public List<String> getUserColumns() {
        ImmutableList.Builder builder = ImmutableList.builder();
        Iterator<TopNMetric<?>> it = getTopNMetrics().iterator();
        while (it.hasNext()) {
            Iterator<String> it2 = it.next().getColumnLabels().iterator();
            while (it2.hasNext()) {
                builder.add(prefixColumn(it2.next()));
            }
        }
        return builder.build();
    }

    private String prefixColumn(String str) {
        String labelPrefix = getLabelPrefix();
        return labelPrefix == null ? str : labelPrefix + "." + str;
    }

    @Override // org.lenskit.eval.traintest.EvalTask
    public void start(ExperimentOutputLayout experimentOutputLayout) {
        this.experimentOutputLayout = experimentOutputLayout;
        Path outputFile = getOutputFile();
        if (outputFile == null) {
            return;
        }
        TableLayout build = TableLayoutBuilder.copy(experimentOutputLayout.getConditionLayout()).addColumn("User").addColumn("Rank").addColumn("Item").addColumn("Score").build();
        try {
            logger.info("writing recommendations to {}", outputFile);
            this.outputTable = CSVWriter.open(outputFile.toFile(), build, CompressionMode.AUTO);
        } catch (IOException e) {
            throw new EvaluationException("error opening prediction output file", e);
        }
    }

    @Override // org.lenskit.eval.traintest.EvalTask
    public void finish() {
        this.experimentOutputLayout = null;
        if (this.outputTable != null) {
            try {
                this.outputTable.close();
                this.outputTable = null;
            } catch (IOException e) {
                throw new EvaluationException("error closing prediction output file", e);
            }
        }
    }

    @Override // org.lenskit.eval.traintest.EvalTask
    public ConditionEvaluator createConditionEvaluator(AlgorithmInstance algorithmInstance, DataSet dataSet, Recommender recommender) {
        Preconditions.checkState(this.experimentOutputLayout != null, "experiment not started");
        TableWriter prefixTable = this.experimentOutputLayout.prefixTable(this.outputTable, dataSet, algorithmInstance);
        LongSet allItems = dataSet.getAllItems();
        ItemRecommender itemRecommender = recommender.getItemRecommender();
        if (itemRecommender == null) {
            logger.warn("algorithm {} has no item recommender", algorithmInstance);
            return null;
        }
        boolean z = prefixTable != null;
        ArrayList arrayList = new ArrayList(this.topNMetrics.size());
        for (TopNMetric<?> topNMetric : this.topNMetrics) {
            logger.debug("setting up metric {}", topNMetric);
            MetricContext create = MetricContext.create(topNMetric, algorithmInstance, dataSet, recommender);
            arrayList.add(create);
            z |= create.usesDetails();
        }
        return new TopNConditionEvaluator(prefixTable, recommender, itemRecommender, arrayList, allItems, z);
    }
}
