package org.ranksys.javafm.example;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URL;
import java.nio.channels.Channels;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.DoubleSummaryStatistics;
import java.util.List;
import java.util.Random;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.ranksys.javafm.BoundedFM;
import org.ranksys.javafm.FMInstance;
import org.ranksys.javafm.data.FMData;
import org.ranksys.javafm.data.SimpleFMData;
import org.ranksys.javafm.learner.gd.PointWiseError;
import org.ranksys.javafm.learner.gd.PointWiseGradientDescent;

/* loaded from: input_file:org/ranksys/javafm/example/WineQualityExample.class */
public class WineQualityExample {
    public static void main(String[] strArr) throws Exception {
        List<FMData> randomPartition = getRandomPartition(getWineQualityDataset(), 0.6d, new Random(1L));
        FMData fMData = randomPartition.get(0);
        FMData fMData2 = randomPartition.get(1);
        double[] dArr = new double[fMData.numFeatures()];
        Arrays.fill(dArr, 0.01d);
        double[] dArr2 = new double[fMData.numFeatures()];
        Arrays.fill(dArr2, 0.01d);
        new PointWiseGradientDescent(0.001d, 200, PointWiseError.rmse(), 0.01d, dArr, dArr2).learn(new BoundedFM(3.0d, 9.0d, fMData.numInstances(), 10, new Random(), 1.0d), fMData, fMData2);
    }

    private static FMData getWineQualityDataset() throws IOException {
        SimpleFMData simpleFMData = new SimpleFMData(11);
        if (!new File("winequality-white.csv").exists()) {
            new FileOutputStream("winequality-white.csv").getChannel().transferFrom(Channels.newChannel(new URL("https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv").openStream()), 0L, Long.MAX_VALUE);
        }
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream("winequality-white.csv")));
        Throwable th = null;
        try {
            try {
                bufferedReader.readLine();
                while (true) {
                    String readLine = bufferedReader.readLine();
                    if (readLine == null) {
                        break;
                    }
                    String[] split = readLine.split(";");
                    simpleFMData.add(new FMInstance(Double.parseDouble(split[11]), IntStream.range(0, 11).toArray(), Stream.of((Object[]) split).limit(11).mapToDouble(Double::parseDouble).toArray()));
                }
                if (bufferedReader != null) {
                    if (0 != 0) {
                        try {
                            bufferedReader.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        bufferedReader.close();
                    }
                }
                for (int i = 0; i < 11; i++) {
                    int i2 = i;
                    DoubleSummaryStatistics summaryStatistics = simpleFMData.stream().mapToDouble(fMInstance -> {
                        return fMInstance.get(i2);
                    }).summaryStatistics();
                    double max = summaryStatistics.getMax();
                    double min = summaryStatistics.getMin();
                    if (max == min) {
                        simpleFMData.stream().forEach(fMInstance2 -> {
                            fMInstance2.set(i2, 0.0d);
                        });
                    } else {
                        simpleFMData.stream().forEach(fMInstance3 -> {
                            fMInstance3.set(i2, (fMInstance3.get(i2) - min) / (max - min));
                        });
                    }
                }
                return simpleFMData;
            } finally {
            }
        } catch (Throwable th3) {
            if (bufferedReader != null) {
                if (th != null) {
                    try {
                        bufferedReader.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedReader.close();
                }
            }
            throw th3;
        }
    }

    private static List<FMData> getRandomPartition(FMData fMData, double d, Random random) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        fMData.shuffle();
        fMData.stream().forEach(fMInstance -> {
            if (random.nextDouble() < d) {
                arrayList.add(fMInstance);
            } else {
                arrayList2.add(fMInstance);
            }
        });
        return Arrays.asList(new SimpleFMData(fMData.numFeatures(), random, arrayList), new SimpleFMData(fMData.numFeatures(), random, arrayList2));
    }
}
