package org.lenskit.predict.ordrec;

import it.unimi.dsi.fastutil.longs.Long2DoubleMap;
import it.unimi.dsi.fastutil.longs.LongIterator;
import it.unimi.dsi.fastutil.longs.LongIterators;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.inject.Inject;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealVector;
import org.grouplens.lenskit.iterative.IterationCount;
import org.grouplens.lenskit.iterative.LearningRate;
import org.grouplens.lenskit.iterative.RegularizationTerm;
import org.lenskit.api.ItemScorer;
import org.lenskit.api.Result;
import org.lenskit.api.ResultMap;
import org.lenskit.basic.AbstractRatingPredictor;
import org.lenskit.data.dao.DataAccessObject;
import org.lenskit.data.entities.CommonAttributes;
import org.lenskit.data.ratings.Rating;
import org.lenskit.data.ratings.Ratings;
import org.lenskit.results.AbstractResult;
import org.lenskit.results.Results;
import org.lenskit.transform.quantize.Quantizer;
import org.lenskit.util.math.Vectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/lenskit/predict/ordrec/OrdRecRatingPredictor.class */
public class OrdRecRatingPredictor extends AbstractRatingPredictor {
    private static final Logger logger = LoggerFactory.getLogger(OrdRecRatingPredictor.class);
    private ItemScorer itemScorer;
    private DataAccessObject dao;
    private Quantizer quantizer;
    private final double learningRate;
    private final double regTerm;
    private final int iterationCount;

    /* loaded from: input_file:org/lenskit/predict/ordrec/OrdRecRatingPredictor$FullResult.class */
    public static class FullResult extends AbstractResult implements Serializable {
        private static final long serialVersionUID = 1;
        private final Result original;
        private final RealVector distribution;

        FullResult(Result result, double d, RealVector realVector) {
            super(result.getId(), d);
            this.original = result;
            this.distribution = realVector;
        }

        public Result getOriginalResult() {
            return this.original;
        }

        public RealVector getDistribution() {
            return this.distribution;
        }
    }

    @Inject
    public OrdRecRatingPredictor(ItemScorer itemScorer, DataAccessObject dataAccessObject, Quantizer quantizer, @LearningRate double d, @RegularizationTerm double d2, @IterationCount int i) {
        this.dao = dataAccessObject;
        this.itemScorer = itemScorer;
        this.quantizer = quantizer;
        this.learningRate = d;
        this.regTerm = d2;
        this.iterationCount = i;
    }

    OrdRecRatingPredictor(ItemScorer itemScorer, DataAccessObject dataAccessObject, Quantizer quantizer) {
        this.dao = dataAccessObject;
        this.itemScorer = itemScorer;
        this.quantizer = quantizer;
        this.learningRate = 0.001d;
        this.regTerm = 0.015d;
        this.iterationCount = 1000;
    }

    private Long2DoubleMap makeUserVector(long j, DataAccessObject dataAccessObject) {
        List list = dataAccessObject.query(Rating.class).withAttribute(CommonAttributes.USER_ID, Long.valueOf(j)).get();
        Long2DoubleMap long2DoubleMap = null;
        if (!list.isEmpty()) {
            long2DoubleMap = Ratings.userRatingVector(list);
        }
        return long2DoubleMap;
    }

    private static double dBeta(int i, int i2, double d) {
        if (i >= 0 && i2 == 0) {
            return 1.0d;
        }
        if (i2 <= 0 || i < i2) {
            return 0.0d;
        }
        return Math.exp(d);
    }

    private void trainModel(OrdRecModel ordRecModel, Long2DoubleMap long2DoubleMap, Map<Long, Double> map) {
        RealVector beta = ordRecModel.getBeta();
        ArrayRealVector arrayRealVector = new ArrayRealVector(beta.getDimension());
        for (int i = 0; i < this.iterationCount; i++) {
            for (Long2DoubleMap.Entry entry : Vectors.fastEntries(long2DoubleMap)) {
                double doubleValue = map.get(Long.valueOf(entry.getLongKey())).doubleValue();
                int index = this.quantizer.index(entry.getDoubleValue());
                double probEQ = ordRecModel.getProbEQ(doubleValue, index);
                double probLE = ordRecModel.getProbLE(doubleValue, index);
                double probLE2 = ordRecModel.getProbLE(doubleValue, index - 1);
                double t1 = ordRecModel.getT1();
                double dBeta = (this.learningRate / probEQ) * ((((probLE * (1.0d - probLE)) * dBeta(index, 0, t1)) - ((probLE2 * (1.0d - probLE2)) * dBeta(index - 1, 0, t1))) - (this.regTerm * t1));
                for (int i2 = 0; i2 < beta.getDimension(); i2++) {
                    arrayRealVector.setEntry(i2, (this.learningRate / probEQ) * ((((probLE * (1.0d - probLE)) * dBeta(index, i2 + 1, beta.getEntry(i2))) - ((probLE2 * (1.0d - probLE2)) * dBeta(index - 1, i2 + 1, beta.getEntry(i2)))) - (this.regTerm * beta.getEntry(i2))));
                }
                ordRecModel.update(dBeta, arrayRealVector);
            }
        }
    }

    @Nonnull
    public Map<Long, Double> predict(long j, @Nonnull Collection<Long> collection) {
        return computePredictions(j, collection, false).scoreMap();
    }

    @Nonnull
    public ResultMap predictWithDetails(long j, @Nonnull Collection<Long> collection) {
        return computePredictions(j, collection, true);
    }

    @Nonnull
    private ResultMap computePredictions(long j, @Nonnull Collection<Long> collection, boolean z) {
        Map<Long, Double> score;
        logger.debug("predicting {} items for {}", Integer.valueOf(collection.size()), Long.valueOf(j));
        Long2DoubleMap makeUserVector = makeUserVector(j, this.dao);
        LongOpenHashSet longOpenHashSet = new LongOpenHashSet(makeUserVector.keySet());
        longOpenHashSet.addAll(collection);
        ResultMap resultMap = null;
        if (z) {
            resultMap = this.itemScorer.scoreWithDetails(j, longOpenHashSet);
            score = resultMap.scoreMap();
        } else {
            score = this.itemScorer.score(j, longOpenHashSet);
        }
        OrdRecModel ordRecModel = new OrdRecModel(this.quantizer);
        trainModel(ordRecModel, makeUserVector, score);
        logger.debug("trained parameters for {}: {}", Long.valueOf(j), ordRecModel);
        ArrayRealVector arrayRealVector = new ArrayRealVector(ordRecModel.getLevelCount());
        ArrayList arrayList = new ArrayList();
        LongIterator asLongIterator = LongIterators.asLongIterator(collection.iterator());
        while (asLongIterator.hasNext()) {
            long nextLong = asLongIterator.nextLong();
            Double d = score.get(Long.valueOf(nextLong));
            if (d != null) {
                ordRecModel.getProbDistribution(d.doubleValue(), arrayRealVector);
                double indexValue = this.quantizer.getIndexValue(arrayRealVector.getMaxIndex());
                if (z) {
                    arrayList.add(new FullResult(resultMap.get(nextLong), indexValue, new ArrayRealVector(arrayRealVector)));
                } else {
                    arrayList.add(Results.create(nextLong, indexValue));
                }
            }
        }
        return Results.newResultMap(arrayList);
    }
}
