package org.lenskit.eval.traintest.predict;

import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import it.unimi.dsi.fastutil.longs.Long2DoubleMap;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.io.IOException;
import java.net.URI;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.grouplens.grapht.util.ClassLoaders;
import org.grouplens.lenskit.util.io.CompressionMode;
import org.lenskit.api.RatingPredictor;
import org.lenskit.api.Recommender;
import org.lenskit.api.Result;
import org.lenskit.api.ResultMap;
import org.lenskit.eval.traintest.AlgorithmInstance;
import org.lenskit.eval.traintest.ConditionEvaluator;
import org.lenskit.eval.traintest.DataSet;
import org.lenskit.eval.traintest.EvalTask;
import org.lenskit.eval.traintest.EvaluationException;
import org.lenskit.eval.traintest.ExperimentOutputLayout;
import org.lenskit.eval.traintest.TestUser;
import org.lenskit.eval.traintest.metrics.Metric;
import org.lenskit.eval.traintest.metrics.MetricLoaderHelper;
import org.lenskit.eval.traintest.metrics.MetricResult;
import org.lenskit.util.table.TableLayout;
import org.lenskit.util.table.TableLayoutBuilder;
import org.lenskit.util.table.writer.CSVWriter;
import org.lenskit.util.table.writer.TableWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX WARN: Classes with same name are omitted:
  
 */
/* loaded from: input_file:org/lenskit/eval/traintest/predict/PredictEvalTask.class */
public class PredictEvalTask implements EvalTask {
    private static final Logger logger = LoggerFactory.getLogger(PredictEvalTask.class);
    static final PredictMetric<?>[] DEFAULT_METRICS = {new CoveragePredictMetric(), new MAEPredictMetric(), new RMSEPredictMetric()};
    private Path outputFile;
    private List<PredictMetric<?>> predictMetrics = Lists.newArrayList(DEFAULT_METRICS);
    private ExperimentOutputLayout experimentOutputLayout;
    private TableWriter outputTable;

    /* JADX WARN: Classes with same name are omitted:
      
     */
    /* loaded from: input_file:org/lenskit/eval/traintest/predict/PredictEvalTask$MetricContext.class */
    static class MetricContext<X> {
        final PredictMetric<X> metric;
        final X context;

        public MetricContext(PredictMetric<X> predictMetric, X x) {
            this.metric = predictMetric;
            this.context = x;
        }

        @Nonnull
        public MetricResult measureUser(TestUser testUser, ResultMap resultMap) {
            return this.metric.measureUser(testUser, resultMap, this.context);
        }

        @Nonnull
        public MetricResult getAggregateMeasurements() {
            return this.metric.getAggregateMeasurements(this.context);
        }

        public static <X> MetricContext<X> create(PredictMetric<X> predictMetric, AlgorithmInstance algorithmInstance, DataSet dataSet, Recommender recommender) {
            return new MetricContext<>(predictMetric, predictMetric.createContext(algorithmInstance, dataSet, recommender));
        }
    }

    /* JADX WARN: Classes with same name are omitted:
      
     */
    /* loaded from: input_file:org/lenskit/eval/traintest/predict/PredictEvalTask$PredictConditionEvaluator.class */
    class PredictConditionEvaluator implements ConditionEvaluator {
        private final TableWriter writer;
        private final RatingPredictor predictor;
        private final List<MetricContext<?>> predictMetricContexts;

        public PredictConditionEvaluator(TableWriter tableWriter, RatingPredictor ratingPredictor, List<MetricContext<?>> list) {
            this.writer = tableWriter;
            this.predictor = ratingPredictor;
            this.predictMetricContexts = list;
        }

        @Override // org.lenskit.eval.traintest.ConditionEvaluator
        @Nonnull
        public Map<String, Object> measureUser(TestUser testUser) {
            Long2DoubleMap testRatings = testUser.getTestRatings();
            ResultMap predictWithDetails = this.predictor.predictWithDetails(testUser.getUserId(), testRatings.keySet());
            HashMap hashMap = new HashMap();
            Iterator<MetricContext<?>> it = this.predictMetricContexts.iterator();
            while (it.hasNext()) {
                hashMap.putAll(it.next().measureUser(testUser, predictWithDetails).getValues());
            }
            ObjectIterator it2 = testRatings.long2DoubleEntrySet().iterator();
            while (it2.hasNext()) {
                Long2DoubleMap.Entry entry = (Long2DoubleMap.Entry) it2.next();
                Result result = predictWithDetails.get(entry.getLongKey());
                try {
                    if (this.writer != null) {
                        TableWriter tableWriter = this.writer;
                        Object[] objArr = new Object[4];
                        objArr[0] = Long.valueOf(testUser.getUserId());
                        objArr[1] = Long.valueOf(entry.getLongKey());
                        objArr[2] = Double.valueOf(entry.getDoubleValue());
                        objArr[3] = result != null ? Double.valueOf(result.getScore()) : null;
                        tableWriter.writeRow(objArr);
                    }
                } catch (IOException e) {
                    throw new EvaluationException("error writing prediction row", e);
                }
            }
            return hashMap;
        }

        @Override // org.lenskit.eval.traintest.ConditionEvaluator
        @Nonnull
        public Map<String, Object> finish() {
            HashMap hashMap = new HashMap();
            for (MetricContext<?> metricContext : this.predictMetricContexts) {
                MetricResult aggregateMeasurements = metricContext.getAggregateMeasurements();
                if (aggregateMeasurements != null) {
                    hashMap.putAll(aggregateMeasurements.getValues());
                } else {
                    PredictEvalTask.logger.warn("Metric {} returned null results", metricContext.metric);
                }
            }
            return hashMap;
        }
    }

    public static PredictEvalTask fromJSON(JsonNode jsonNode, URI uri) throws IOException {
        PredictEvalTask predictEvalTask = new PredictEvalTask();
        String asText = jsonNode.path("output_file").asText((String) null);
        if (asText != null) {
            predictEvalTask.setOutputFile(Paths.get(uri.resolve(asText)));
        }
        MetricLoaderHelper metricLoaderHelper = new MetricLoaderHelper(ClassLoaders.inferDefault(PredictEvalTask.class), "predict-metrics");
        JsonNode jsonNode2 = jsonNode.get("metrics");
        if (jsonNode2 != null && !jsonNode2.isNull()) {
            predictEvalTask.predictMetrics.clear();
            Iterator it = jsonNode2.iterator();
            while (it.hasNext()) {
                predictEvalTask.addMetric((PredictMetric) metricLoaderHelper.createMetric(PredictMetric.class, (JsonNode) it.next()));
            }
        }
        return predictEvalTask;
    }

    public Path getOutputFile() {
        return this.outputFile;
    }

    public void setOutputFile(Path path) {
        this.outputFile = path;
    }

    public List<PredictMetric<?>> getPredictMetrics() {
        return this.predictMetrics;
    }

    public List<Metric<?>> getAllMetrics() {
        ImmutableList.Builder builder = ImmutableList.builder();
        builder.addAll(this.predictMetrics);
        return builder.build();
    }

    public void addMetric(PredictMetric<?> predictMetric) {
        this.predictMetrics.add(predictMetric);
    }

    @Override // org.lenskit.eval.traintest.EvalTask
    public Set<Class<?>> getRequiredRoots() {
        return FluentIterable.from(getAllMetrics()).transformAndConcat(new Function<Metric<?>, Iterable<Class<?>>>() { // from class: org.lenskit.eval.traintest.predict.PredictEvalTask.1
            @Nullable
            public Iterable<Class<?>> apply(Metric<?> metric) {
                return metric.getRequiredRoots();
            }
        }).toSet();
    }

    @Override // org.lenskit.eval.traintest.EvalTask
    public List<String> getGlobalColumns() {
        ImmutableList.Builder builder = ImmutableList.builder();
        Iterator<Metric<?>> it = getAllMetrics().iterator();
        while (it.hasNext()) {
            builder.addAll(it.next().getAggregateColumnLabels());
        }
        return builder.build();
    }

    @Override // org.lenskit.eval.traintest.EvalTask
    public List<String> getUserColumns() {
        ImmutableList.Builder builder = ImmutableList.builder();
        Iterator<PredictMetric<?>> it = getPredictMetrics().iterator();
        while (it.hasNext()) {
            builder.addAll(it.next().getColumnLabels());
        }
        return builder.build();
    }

    @Override // org.lenskit.eval.traintest.EvalTask
    public void start(ExperimentOutputLayout experimentOutputLayout) {
        this.experimentOutputLayout = experimentOutputLayout;
        Path outputFile = getOutputFile();
        if (outputFile == null) {
            return;
        }
        TableLayout build = TableLayoutBuilder.copy(experimentOutputLayout.getConditionLayout()).addColumn("User").addColumn("Item").addColumn("Rating").addColumn("Prediction").build();
        try {
            logger.info("writing predictions to {}", outputFile);
            this.outputTable = CSVWriter.open(outputFile.toFile(), build, CompressionMode.AUTO);
        } catch (IOException e) {
            throw new EvaluationException("error opening prediction output file", e);
        }
    }

    @Override // org.lenskit.eval.traintest.EvalTask
    public void finish() {
        this.experimentOutputLayout = null;
        if (this.outputTable != null) {
            try {
                this.outputTable.close();
                this.outputTable = null;
            } catch (IOException e) {
                throw new EvaluationException("error closing prediction output file", e);
            }
        }
    }

    @Override // org.lenskit.eval.traintest.EvalTask
    public ConditionEvaluator createConditionEvaluator(AlgorithmInstance algorithmInstance, DataSet dataSet, Recommender recommender) {
        Preconditions.checkState(this.experimentOutputLayout != null, "experiment not started");
        TableWriter prefixTable = this.experimentOutputLayout.prefixTable(this.outputTable, dataSet, algorithmInstance);
        RatingPredictor ratingPredictor = recommender.getRatingPredictor();
        if (ratingPredictor == null) {
            logger.warn("algorithm {} has no rating predictor", algorithmInstance);
            return null;
        }
        ArrayList arrayList = new ArrayList(this.predictMetrics.size());
        Iterator<PredictMetric<?>> it = this.predictMetrics.iterator();
        while (it.hasNext()) {
            arrayList.add(MetricContext.create(it.next(), algorithmInstance, dataSet, recommender));
        }
        return new PredictConditionEvaluator(prefixTable, ratingPredictor, arrayList);
    }
}
