package org.lenskit.eval.traintest.predict;

import it.unimi.dsi.fastutil.longs.Long2DoubleMap;
import java.util.Iterator;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.lenskit.api.Recommender;
import org.lenskit.api.Result;
import org.lenskit.api.ResultMap;
import org.lenskit.eval.traintest.AlgorithmInstance;
import org.lenskit.eval.traintest.DataSet;
import org.lenskit.eval.traintest.TestUser;
import org.lenskit.eval.traintest.metrics.MetricColumn;
import org.lenskit.eval.traintest.metrics.MetricResult;
import org.lenskit.eval.traintest.metrics.TypedMetricResult;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX WARN: Classes with same name are omitted:
  
 */
/* loaded from: input_file:org/lenskit/eval/traintest/predict/RMSEPredictMetric.class */
public class RMSEPredictMetric extends PredictMetric<Context> {
    private static final Logger logger = LoggerFactory.getLogger(RMSEPredictMetric.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Classes with same name are omitted:
      
     */
    /* loaded from: input_file:org/lenskit/eval/traintest/predict/RMSEPredictMetric$AggregateResult.class */
    public static class AggregateResult extends TypedMetricResult {

        @MetricColumn("RMSE.ByUser")
        public final double userRMSE;

        @MetricColumn("RMSE.ByRating")
        public final double globalRMSE;

        public AggregateResult(double d, double d2) {
            this.userRMSE = d;
            this.globalRMSE = d2;
        }
    }

    /* JADX WARN: Classes with same name are omitted:
      
     */
    /* loaded from: input_file:org/lenskit/eval/traintest/predict/RMSEPredictMetric$Context.class */
    public class Context {
        private double totalSSE = 0.0d;
        private double totalRMSE = 0.0d;
        private int nratings = 0;
        private int nusers = 0;

        public Context() {
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void addUser(int i, double d, double d2) {
            this.totalSSE += d;
            this.totalRMSE += d2;
            this.nratings += i;
            this.nusers++;
        }

        public MetricResult finish() {
            if (this.nratings <= 0) {
                return MetricResult.empty();
            }
            double sqrt = Math.sqrt(this.totalSSE / this.nratings);
            RMSEPredictMetric.logger.info("RMSE: {}", Double.valueOf(sqrt));
            return new AggregateResult(this.totalRMSE / this.nusers, sqrt);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Classes with same name are omitted:
      
     */
    /* loaded from: input_file:org/lenskit/eval/traintest/predict/RMSEPredictMetric$UserResult.class */
    public static class UserResult extends TypedMetricResult {

        @MetricColumn("RMSE")
        public final double rmse;

        public UserResult(double d) {
            this.rmse = d;
        }
    }

    public RMSEPredictMetric() {
        super((Class<? extends TypedMetricResult>) UserResult.class, (Class<? extends TypedMetricResult>) AggregateResult.class);
    }

    @Override // org.lenskit.eval.traintest.metrics.Metric
    @Nullable
    public Context createContext(AlgorithmInstance algorithmInstance, DataSet dataSet, Recommender recommender) {
        return new Context();
    }

    @Override // org.lenskit.eval.traintest.predict.PredictMetric
    @Nonnull
    public MetricResult measureUser(TestUser testUser, ResultMap resultMap, Context context) {
        Long2DoubleMap testRatings = testUser.getTestRatings();
        double d = 0.0d;
        int i = 0;
        Iterator it = resultMap.iterator();
        while (it.hasNext()) {
            Result result = (Result) it.next();
            if (result.hasScore()) {
                double score = result.getScore() - testRatings.get(result.getId());
                d += score * score;
                i++;
            }
        }
        if (i <= 0) {
            return MetricResult.empty();
        }
        double sqrt = Math.sqrt(d / i);
        context.addUser(i, d, sqrt);
        return new UserResult(sqrt);
    }

    @Override // org.lenskit.eval.traintest.metrics.Metric
    @Nonnull
    public MetricResult getAggregateMeasurements(Context context) {
        return context.finish();
    }
}
