package org.apache.mahout.cf.taste.hadoop.als;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.mapreduce.MergeVectorsCombiner;
import org.apache.mahout.common.mapreduce.MergeVectorsReducer;
import org.apache.mahout.common.mapreduce.TransposeMapper;
import org.apache.mahout.common.mapreduce.VectorSumReducer;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.als.AlternatingLeastSquaresSolver;
import org.apache.mahout.math.als.ImplicitFeedbackAlternatingLeastSquaresSolver;
import org.apache.mahout.math.map.OpenIntObjectHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.class */
public class ParallelALSFactorizationJob extends AbstractJob {
    private static final Logger log = LoggerFactory.getLogger(ParallelALSFactorizationJob.class);
    static final String NUM_FEATURES = ParallelALSFactorizationJob.class.getName() + ".numFeatures";
    static final String LAMBDA = ParallelALSFactorizationJob.class.getName() + ".lambda";
    static final String ALPHA = ParallelALSFactorizationJob.class.getName() + ".alpha";
    static final String FEATURE_MATRIX = ParallelALSFactorizationJob.class.getName() + ".featureMatrix";
    private boolean implicitFeedback;
    private int numIterations;
    private int numFeatures;
    private double lambda;
    private double alpha;

    /* loaded from: input_file:org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob$AverageRatingMapper.class */
    static class AverageRatingMapper extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
        AverageRatingMapper() {
        }

        protected void map(IntWritable intWritable, VectorWritable vectorWritable, Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            FullRunningAverage fullRunningAverage = new FullRunningAverage();
            Iterator iterateNonZero = vectorWritable.get().iterateNonZero();
            while (iterateNonZero.hasNext()) {
                fullRunningAverage.addDatum(((Vector.Element) iterateNonZero.next()).get());
            }
            RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
            randomAccessSparseVector.setQuick(intWritable.get(), fullRunningAverage.getAverage());
            context.write(new IntWritable(0), new VectorWritable(randomAccessSparseVector));
        }

        protected /* bridge */ /* synthetic */ void map(Object obj, Object obj2, Mapper.Context context) throws IOException, InterruptedException {
            map((IntWritable) obj, (VectorWritable) obj2, (Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context) context);
        }
    }

    /* loaded from: input_file:org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob$ItemRatingVectorsMapper.class */
    static class ItemRatingVectorsMapper extends Mapper<LongWritable, Text, IntWritable, VectorWritable> {
        ItemRatingVectorsMapper() {
        }

        protected void map(LongWritable longWritable, Text text, Mapper<LongWritable, Text, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            String[] splitPrefTokens = TasteHadoopUtils.splitPrefTokens(text.toString());
            int parseInt = Integer.parseInt(splitPrefTokens[0]);
            int parseInt2 = Integer.parseInt(splitPrefTokens[1]);
            float parseFloat = Float.parseFloat(splitPrefTokens[2]);
            RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
            randomAccessSparseVector.set(parseInt, parseFloat);
            context.write(new IntWritable(parseInt2), new VectorWritable(randomAccessSparseVector, true));
        }

        protected /* bridge */ /* synthetic */ void map(Object obj, Object obj2, Mapper.Context context) throws IOException, InterruptedException {
            map((LongWritable) obj, (Text) obj2, (Mapper<LongWritable, Text, IntWritable, VectorWritable>.Context) context);
        }
    }

    /* loaded from: input_file:org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob$SolveExplicitFeedbackMapper.class */
    static class SolveExplicitFeedbackMapper extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
        private double lambda;
        private int numFeatures;
        private OpenIntObjectHashMap<Vector> UorM;
        private AlternatingLeastSquaresSolver solver;

        SolveExplicitFeedbackMapper() {
        }

        protected void setup(Mapper.Context context) throws IOException, InterruptedException {
            this.lambda = Double.parseDouble(context.getConfiguration().get(ParallelALSFactorizationJob.LAMBDA));
            this.numFeatures = context.getConfiguration().getInt(ParallelALSFactorizationJob.NUM_FEATURES, -1);
            this.solver = new AlternatingLeastSquaresSolver();
            this.UorM = ALSUtils.readMatrixByRows(new Path(context.getConfiguration().get(ParallelALSFactorizationJob.FEATURE_MATRIX)), context.getConfiguration());
            Preconditions.checkArgument(this.numFeatures > 0, "numFeatures was not set correctly!");
        }

        protected void map(IntWritable intWritable, VectorWritable vectorWritable, Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            SequentialAccessSparseVector sequentialAccessSparseVector = new SequentialAccessSparseVector(vectorWritable.get());
            ArrayList newArrayList = Lists.newArrayList();
            Iterator iterateNonZero = sequentialAccessSparseVector.iterateNonZero();
            while (iterateNonZero.hasNext()) {
                newArrayList.add(this.UorM.get(((Vector.Element) iterateNonZero.next()).index()));
            }
            context.write(intWritable, new VectorWritable(this.solver.solve(newArrayList, sequentialAccessSparseVector, this.lambda, this.numFeatures)));
        }

        protected /* bridge */ /* synthetic */ void map(Object obj, Object obj2, Mapper.Context context) throws IOException, InterruptedException {
            map((IntWritable) obj, (VectorWritable) obj2, (Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context) context);
        }
    }

    /* loaded from: input_file:org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob$SolveImplicitFeedbackMapper.class */
    static class SolveImplicitFeedbackMapper extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
        private ImplicitFeedbackAlternatingLeastSquaresSolver solver;

        SolveImplicitFeedbackMapper() {
        }

        protected void setup(Mapper.Context context) throws IOException, InterruptedException {
            double parseDouble = Double.parseDouble(context.getConfiguration().get(ParallelALSFactorizationJob.LAMBDA));
            double parseDouble2 = Double.parseDouble(context.getConfiguration().get(ParallelALSFactorizationJob.ALPHA));
            int i = context.getConfiguration().getInt(ParallelALSFactorizationJob.NUM_FEATURES, -1);
            this.solver = new ImplicitFeedbackAlternatingLeastSquaresSolver(i, parseDouble, parseDouble2, ALSUtils.readMatrixByRows(new Path(context.getConfiguration().get(ParallelALSFactorizationJob.FEATURE_MATRIX)), context.getConfiguration()));
            Preconditions.checkArgument(i > 0, "numFeatures was not set correctly!");
        }

        protected void map(IntWritable intWritable, VectorWritable vectorWritable, Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            context.write(intWritable, new VectorWritable(this.solver.solve(new SequentialAccessSparseVector(vectorWritable.get()))));
        }

        protected /* bridge */ /* synthetic */ void map(Object obj, Object obj2, Mapper.Context context) throws IOException, InterruptedException {
            map((IntWritable) obj, (VectorWritable) obj2, (Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context) context);
        }
    }

    public static void main(String[] strArr) throws Exception {
        ToolRunner.run(new ParallelALSFactorizationJob(), strArr);
    }

    public int run(String[] strArr) throws Exception {
        addInputOption();
        addOutputOption();
        addOption("lambda", (String) null, "regularization parameter", true);
        addOption("implicitFeedback", (String) null, "data consists of implicit feedback?", String.valueOf(false));
        addOption("alpha", (String) null, "confidence parameter (only used on implicit feedback)", String.valueOf(40));
        addOption("numFeatures", (String) null, "dimension of the feature space", true);
        addOption("numIterations", (String) null, "number of iterations", true);
        if (parseArguments(strArr) == null) {
            return -1;
        }
        this.numFeatures = Integer.parseInt(getOption("numFeatures"));
        this.numIterations = Integer.parseInt(getOption("numIterations"));
        this.lambda = Double.parseDouble(getOption("lambda"));
        this.alpha = Double.parseDouble(getOption("alpha"));
        this.implicitFeedback = Boolean.parseBoolean(getOption("implicitFeedback"));
        Job prepareJob = prepareJob(getInputPath(), pathToItemRatings(), TextInputFormat.class, ItemRatingVectorsMapper.class, IntWritable.class, VectorWritable.class, VectorSumReducer.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class);
        prepareJob.setCombinerClass(VectorSumReducer.class);
        if (!prepareJob.waitForCompletion(true)) {
            return -1;
        }
        Job prepareJob2 = prepareJob(pathToItemRatings(), pathToUserRatings(), TransposeMapper.class, IntWritable.class, VectorWritable.class, MergeVectorsReducer.class, IntWritable.class, VectorWritable.class);
        prepareJob2.setCombinerClass(MergeVectorsCombiner.class);
        if (!prepareJob2.waitForCompletion(true)) {
            return -1;
        }
        Job prepareJob3 = prepareJob(pathToItemRatings(), getTempPath("averageRatings"), AverageRatingMapper.class, IntWritable.class, VectorWritable.class, MergeVectorsReducer.class, IntWritable.class, VectorWritable.class);
        prepareJob3.setCombinerClass(MergeVectorsCombiner.class);
        if (!prepareJob3.waitForCompletion(true)) {
            return -1;
        }
        initializeM(ALSUtils.readFirstRow(getTempPath("averageRatings"), getConf()));
        for (int i = 0; i < this.numIterations; i++) {
            log.info("Recomputing U (iteration {}/{})", Integer.valueOf(i), Integer.valueOf(this.numIterations));
            runSolver(pathToUserRatings(), pathToU(i), pathToM(i - 1));
            log.info("Recomputing M (iteration {}/{})", Integer.valueOf(i), Integer.valueOf(this.numIterations));
            runSolver(pathToItemRatings(), pathToM(i), pathToU(i));
        }
        return 0;
    }

    private void initializeM(Vector vector) throws IOException {
        Random random = RandomUtils.getRandom();
        SequenceFile.Writer writer = null;
        try {
            writer = new SequenceFile.Writer(FileSystem.get(pathToM(-1).toUri(), getConf()), getConf(), new Path(pathToM(-1), "part-m-00000"), IntWritable.class, VectorWritable.class);
            Iterator iterateNonZero = vector.iterateNonZero();
            while (iterateNonZero.hasNext()) {
                Vector.Element element = (Vector.Element) iterateNonZero.next();
                DenseVector denseVector = new DenseVector(this.numFeatures);
                denseVector.setQuick(0, element.get());
                for (int i = 1; i < this.numFeatures; i++) {
                    denseVector.setQuick(i, random.nextDouble());
                }
                writer.append(new IntWritable(element.index()), new VectorWritable(denseVector));
            }
            Closeables.closeQuietly(writer);
        } catch (Throwable th) {
            Closeables.closeQuietly(writer);
            throw th;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void runSolver(Path path, Path path2, Path path3) throws ClassNotFoundException, IOException, InterruptedException {
        Job prepareJob = prepareJob(path, path2, SequenceFileInputFormat.class, this.implicitFeedback ? SolveImplicitFeedbackMapper.class : SolveExplicitFeedbackMapper.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class);
        Configuration configuration = prepareJob.getConfiguration();
        configuration.set(LAMBDA, String.valueOf(this.lambda));
        configuration.set(ALPHA, String.valueOf(this.alpha));
        configuration.setInt(NUM_FEATURES, this.numFeatures);
        configuration.set(FEATURE_MATRIX, path3.toString());
        if (!prepareJob.waitForCompletion(true)) {
            throw new IllegalStateException("Job failed!");
        }
    }

    private Path pathToM(int i) {
        return i == this.numIterations - 1 ? getOutputPath("M") : getTempPath("M-" + i);
    }

    private Path pathToU(int i) {
        return i == this.numIterations - 1 ? getOutputPath("U") : getTempPath("U-" + i);
    }

    private Path pathToItemRatings() {
        return getTempPath("itemRatings");
    }

    private Path pathToUserRatings() {
        return getOutputPath("userRatings");
    }
}
