package org.lenskit.predict.ordrec;

import it.unimi.dsi.fastutil.longs.LongIterator;
import it.unimi.dsi.fastutil.longs.LongIterators;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.inject.Inject;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealVector;
import org.grouplens.lenskit.iterative.IterationCount;
import org.grouplens.lenskit.iterative.LearningRate;
import org.grouplens.lenskit.iterative.RegularizationTerm;
import org.grouplens.lenskit.vectors.ImmutableSparseVector;
import org.grouplens.lenskit.vectors.MutableSparseVector;
import org.grouplens.lenskit.vectors.SparseVector;
import org.grouplens.lenskit.vectors.VectorEntry;
import org.lenskit.api.ItemScorer;
import org.lenskit.api.Result;
import org.lenskit.api.ResultMap;
import org.lenskit.basic.AbstractRatingPredictor;
import org.lenskit.data.dao.UserEventDAO;
import org.lenskit.data.history.UserHistory;
import org.lenskit.data.ratings.Rating;
import org.lenskit.data.ratings.Ratings;
import org.lenskit.results.AbstractResult;
import org.lenskit.results.Results;
import org.lenskit.transform.quantize.Quantizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX WARN: Classes with same name are omitted:
  
 */
/* loaded from: input_file:org/lenskit/predict/ordrec/OrdRecRatingPredictor.class */
public class OrdRecRatingPredictor extends AbstractRatingPredictor {
    private static final Logger logger = LoggerFactory.getLogger(OrdRecRatingPredictor.class);
    private ItemScorer itemScorer;
    private UserEventDAO userEventDao;
    private Quantizer quantizer;
    private final double learningRate;
    private final double regTerm;
    private final int iterationCount;

    /* JADX WARN: Classes with same name are omitted:
      
     */
    /* loaded from: input_file:org/lenskit/predict/ordrec/OrdRecRatingPredictor$FullResult.class */
    public static class FullResult extends AbstractResult implements Serializable {
        private static final long serialVersionUID = 1;
        private final Result original;
        private final RealVector distribution;

        FullResult(Result result, double d, RealVector realVector) {
            super(result.getId(), d);
            this.original = result;
            this.distribution = realVector;
        }

        public Result getOriginalResult() {
            return this.original;
        }

        public RealVector getDistribution() {
            return this.distribution;
        }
    }

    @Inject
    public OrdRecRatingPredictor(ItemScorer itemScorer, UserEventDAO userEventDAO, Quantizer quantizer, @LearningRate double d, @RegularizationTerm double d2, @IterationCount int i) {
        this.userEventDao = userEventDAO;
        this.itemScorer = itemScorer;
        this.quantizer = quantizer;
        this.learningRate = d;
        this.regTerm = d2;
        this.iterationCount = i;
    }

    OrdRecRatingPredictor(ItemScorer itemScorer, UserEventDAO userEventDAO, Quantizer quantizer) {
        this.userEventDao = userEventDAO;
        this.itemScorer = itemScorer;
        this.quantizer = quantizer;
        this.learningRate = 0.001d;
        this.regTerm = 0.015d;
        this.iterationCount = 1000;
    }

    private SparseVector makeUserVector(long j, UserEventDAO userEventDAO) {
        UserHistory eventsForUser = userEventDAO.getEventsForUser(j, Rating.class);
        ImmutableSparseVector immutableSparseVector = null;
        if (eventsForUser != null) {
            immutableSparseVector = ImmutableSparseVector.create(Ratings.userRatingVector(eventsForUser));
        }
        return immutableSparseVector;
    }

    private static double dBeta(int i, int i2, double d) {
        if (i >= 0 && i2 == 0) {
            return 1.0d;
        }
        if (i2 <= 0 || i < i2) {
            return 0.0d;
        }
        return Math.exp(d);
    }

    private void trainModel(OrdRecModel ordRecModel, SparseVector sparseVector, MutableSparseVector mutableSparseVector) {
        RealVector beta = ordRecModel.getBeta();
        ArrayRealVector arrayRealVector = new ArrayRealVector(beta.getDimension());
        for (int i = 0; i < this.iterationCount; i++) {
            Iterator it = sparseVector.iterator();
            while (it.hasNext()) {
                VectorEntry vectorEntry = (VectorEntry) it.next();
                double d = mutableSparseVector.get(vectorEntry.getKey());
                int index = this.quantizer.index(vectorEntry.getValue());
                double probEQ = ordRecModel.getProbEQ(d, index);
                double probLE = ordRecModel.getProbLE(d, index);
                double probLE2 = ordRecModel.getProbLE(d, index - 1);
                double t1 = ordRecModel.getT1();
                double dBeta = (this.learningRate / probEQ) * ((((probLE * (1.0d - probLE)) * dBeta(index, 0, t1)) - ((probLE2 * (1.0d - probLE2)) * dBeta(index - 1, 0, t1))) - (this.regTerm * t1));
                for (int i2 = 0; i2 < beta.getDimension(); i2++) {
                    arrayRealVector.setEntry(i2, (this.learningRate / probEQ) * ((((probLE * (1.0d - probLE)) * dBeta(index, i2 + 1, beta.getEntry(i2))) - ((probLE2 * (1.0d - probLE2)) * dBeta(index - 1, i2 + 1, beta.getEntry(i2)))) - (this.regTerm * beta.getEntry(i2))));
                }
                ordRecModel.update(dBeta, arrayRealVector);
            }
        }
    }

    @Nonnull
    public Map<Long, Double> predict(long j, @Nonnull Collection<Long> collection) {
        return computePredictions(j, collection, false).scoreMap();
    }

    @Nonnull
    public ResultMap predictWithDetails(long j, @Nonnull Collection<Long> collection) {
        return computePredictions(j, collection, true);
    }

    @Nonnull
    private ResultMap computePredictions(long j, @Nonnull Collection<Long> collection, boolean z) {
        Map score;
        logger.debug("predicting {} items for {}", Integer.valueOf(collection.size()), Long.valueOf(j));
        SparseVector makeUserVector = makeUserVector(j, this.userEventDao);
        LongOpenHashSet longOpenHashSet = new LongOpenHashSet(makeUserVector.keySet());
        longOpenHashSet.addAll(collection);
        ResultMap resultMap = null;
        if (z) {
            resultMap = this.itemScorer.scoreWithDetails(j, longOpenHashSet);
            score = resultMap.scoreMap();
        } else {
            score = this.itemScorer.score(j, longOpenHashSet);
        }
        MutableSparseVector create = MutableSparseVector.create(score);
        OrdRecModel ordRecModel = new OrdRecModel(this.quantizer);
        trainModel(ordRecModel, makeUserVector, create);
        logger.debug("trained parameters for {}: {}", Long.valueOf(j), ordRecModel);
        ArrayRealVector arrayRealVector = new ArrayRealVector(ordRecModel.getLevelCount());
        ArrayList arrayList = new ArrayList();
        LongIterator asLongIterator = LongIterators.asLongIterator(collection.iterator());
        while (asLongIterator.hasNext()) {
            long nextLong = asLongIterator.nextLong();
            double d = create.get(nextLong, Double.NaN);
            if (!Double.isNaN(d)) {
                ordRecModel.getProbDistribution(d, arrayRealVector);
                double indexValue = this.quantizer.getIndexValue(arrayRealVector.getMaxIndex());
                if (z) {
                    arrayList.add(new FullResult(resultMap.get(nextLong), indexValue, new ArrayRealVector(arrayRealVector)));
                } else {
                    arrayList.add(Results.create(nextLong, indexValue));
                }
            }
        }
        return Results.newResultMap(arrayList);
    }
}
