package org.ranksys.javafm.example;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URL;
import java.nio.channels.Channels;
import java.util.Arrays;
import java.util.Random;
import org.ranksys.javafm.BoundedFM;
import org.ranksys.javafm.FMInstance;
import org.ranksys.javafm.data.SimpleListWiseFMData;
import org.ranksys.javafm.learner.gd.PointWiseError;
import org.ranksys.javafm.learner.gd.PointWiseGradientDescent;

/* loaded from: input_file:org/ranksys/javafm/example/ML100kRatingPredictionExample.class */
public class ML100kRatingPredictionExample {
    private static final int NUM_USERS = 943;
    private static final int NUM_ITEMS = 1682;

    public static void main(String[] strArr) throws Exception {
        SimpleListWiseFMData recommendationDataset = getRecommendationDataset("u1.base");
        SimpleListWiseFMData recommendationDataset2 = getRecommendationDataset("u1.test");
        double[] dArr = new double[recommendationDataset.numFeatures()];
        Arrays.fill(dArr, 0.1d);
        double[] dArr2 = new double[recommendationDataset.numFeatures()];
        Arrays.fill(dArr2, 0.1d);
        new PointWiseGradientDescent(0.01d, 200, PointWiseError.rmse(), 0.1d, dArr, dArr2).learn(new BoundedFM(1.0d, 5.0d, recommendationDataset.numFeatures(), 100, new Random(), 0.1d), recommendationDataset, recommendationDataset2);
    }

    public static SimpleListWiseFMData getRecommendationDataset(String str) throws IOException {
        SimpleListWiseFMData simpleListWiseFMData = new SimpleListWiseFMData(2625);
        if (!new File(str).exists()) {
            new FileOutputStream(str).getChannel().transferFrom(Channels.newChannel(new URL("http://files.grouplens.org/datasets/movielens/ml-100k/" + str).openStream()), 0L, Long.MAX_VALUE);
        }
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(str)));
        Throwable th = null;
        try {
            try {
                bufferedReader.lines().forEach(str2 -> {
                    String[] split = str2.split("\t");
                    int parseInt = Integer.parseInt(split[0]) - 1;
                    simpleListWiseFMData.add(new FMInstance(Double.parseDouble(split[2]), new int[]{parseInt, (Integer.parseInt(split[1]) - 1) + NUM_USERS}, new double[]{1.0d, 1.0d}), parseInt);
                });
                if (bufferedReader != null) {
                    if (0 != 0) {
                        try {
                            bufferedReader.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        bufferedReader.close();
                    }
                }
                return simpleListWiseFMData;
            } finally {
            }
        } catch (Throwable th3) {
            if (bufferedReader != null) {
                if (th != null) {
                    try {
                        bufferedReader.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedReader.close();
                }
            }
            throw th3;
        }
    }
}
