package org.openimaj.demos.sandbox.ml.regression;

import java.io.File;
import java.io.IOException;
import java.util.Iterator;
import javax.swing.JFrame;
import org.apache.commons.lang.StringUtils;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.data.time.Day;
import org.jfree.data.time.TimeSeries;
import org.joda.time.DateTime;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import org.openimaj.hadoop.tools.twitter.utils.WordDFIDFTimeSeriesCollection;
import org.openimaj.io.Cache;
import org.openimaj.io.IOUtils;
import org.openimaj.ml.timeseries.IncompatibleTimeSeriesException;
import org.openimaj.ml.timeseries.aggregator.MeanSquaredDifferenceAggregator;
import org.openimaj.ml.timeseries.aggregator.SquaredSummedDifferenceAggregator;
import org.openimaj.ml.timeseries.aggregator.WindowedLinearRegressionAggregator;
import org.openimaj.ml.timeseries.processor.IntervalSummationProcessor;
import org.openimaj.ml.timeseries.processor.MovingAverageProcessor;
import org.openimaj.ml.timeseries.processor.WindowedLinearRegressionProcessor;
import org.openimaj.ml.timeseries.series.DoubleSynchronisedTimeSeriesCollection;
import org.openimaj.ml.timeseries.series.DoubleTimeSeries;
import org.openimaj.twitter.finance.YahooFinanceData;
import org.openimaj.util.pair.IndependentPair;

/* loaded from: input_file:org/openimaj/demos/sandbox/ml/regression/MultipleLinearRegressionPlayground.class */
public class MultipleLinearRegressionPlayground {
    public static void main(String[] strArr) throws IOException, IncompatibleTimeSeriesException {
        linearRegressStocks("2010-01-01", "2010-12-31", "2010-01-01", "2010-05-01", "MSFT", "AAPL");
    }

    private static void linearRegressStocks(String str, String str2, String str3, String str4, String... strArr) throws IncompatibleTimeSeriesException, IOException {
        DoubleSynchronisedTimeSeriesCollection doubleSynchronisedTimeSeriesCollection = new DoubleSynchronisedTimeSeriesCollection();
        for (String str5 : strArr) {
            doubleSynchronisedTimeSeriesCollection.addTimeSeries(str5, (DoubleTimeSeries) ((YahooFinanceData) Cache.load(new YahooFinanceData(str5, str, str2, "YYYY-MM-dd"))).seriesMap().get("High"));
        }
        TSCollection tSCollection = new TSCollection();
        timeSeriesToChart(doubleSynchronisedTimeSeriesCollection, tSCollection, new String[0]);
        timeSeriesToChart(doubleSynchronisedTimeSeriesCollection.processInternal(new MovingAverageProcessor(2592000000L)), tSCollection, "-MA");
        displayTimeSeries(tSCollection, StringUtils.join(strArr, " & "), "Date", "Price");
        TSCollection tSCollection2 = new TSCollection();
        timeSeriesToChart("AAPL", doubleSynchronisedTimeSeriesCollection.series("AAPL"), tSCollection2);
        DoubleTimeSeries process = doubleSynchronisedTimeSeriesCollection.series("AAPL").process(new WindowedLinearRegressionProcessor(10, 7));
        timeSeriesToChart("AAPL-interp", process, tSCollection2);
        long[] times = process.getTimes();
        System.out.println("AAPL linear regression SSE: " + new SquaredSummedDifferenceAggregator().aggregate(new DoubleSynchronisedTimeSeriesCollection(new IndependentPair[]{IndependentPair.pair("AAPL", doubleSynchronisedTimeSeriesCollection.series("AAPL").get(times[0], times[times.length - 1])), IndependentPair.pair("AAPL-interp", process)})));
        DoubleTimeSeries aggregate = new WindowedLinearRegressionAggregator("AAPL", 10, 7, true).aggregate(doubleSynchronisedTimeSeriesCollection);
        timeSeriesToChart("AAPL-interpmstf", aggregate, tSCollection2);
        long[] times2 = aggregate.getTimes();
        System.out.println("AAPL+MSFT linear regression SSE: " + new SquaredSummedDifferenceAggregator().aggregate(new DoubleSynchronisedTimeSeriesCollection(new IndependentPair[]{IndependentPair.pair("AAPL", doubleSynchronisedTimeSeriesCollection.series("AAPL").get(times2[0], times2[times2.length - 1])), IndependentPair.pair("AAPLMSFT-interp", aggregate)})));
        displayTimeSeries(tSCollection2, StringUtils.join(strArr, " & ") + " Interp", "Date", "Price");
        TSCollection tSCollection3 = new TSCollection();
        DoubleTimeSeries series = doubleSynchronisedTimeSeriesCollection.series("AAPL");
        DateTimeFormatter forPattern = DateTimeFormat.forPattern("YYYY-MM-dd");
        long millis = forPattern.parseDateTime(str3).getMillis();
        long millis2 = forPattern.parseDateTime(str4).getMillis();
        DoubleSynchronisedTimeSeriesCollection loadwords = loadwords("AAPL", doubleSynchronisedTimeSeriesCollection.series("AAPL"));
        DoubleSynchronisedTimeSeriesCollection doubleSynchronisedTimeSeriesCollection2 = loadwords.get(millis, millis2);
        DoubleTimeSeries aggregate2 = new WindowedLinearRegressionAggregator("AAPL", 10, 7, true).aggregate(loadwords);
        DoubleTimeSeries aggregate3 = new WindowedLinearRegressionAggregator("AAPL", 3, 1, true).aggregate(loadwords);
        DoubleTimeSeries aggregate4 = new WindowedLinearRegressionAggregator("AAPL", 10, 7, true, doubleSynchronisedTimeSeriesCollection2).aggregate(loadwords);
        double doubleValue = MeanSquaredDifferenceAggregator.error(new DoubleTimeSeries[]{aggregate2, series}).doubleValue();
        double doubleValue2 = MeanSquaredDifferenceAggregator.error(new DoubleTimeSeries[]{aggregate3, series}).doubleValue();
        double doubleValue3 = MeanSquaredDifferenceAggregator.error(new DoubleTimeSeries[]{aggregate4, series}).doubleValue();
        timeSeriesToChart("High Value", series, tSCollection3);
        timeSeriesToChart(String.format("OLR (m=7,n=10) (MSE=%.2f)", Double.valueOf(doubleValue)), aggregate2, tSCollection3);
        timeSeriesToChart(String.format("OLR (m=1,n=3) (MSE=%.2f)", Double.valueOf(doubleValue2)), aggregate3, tSCollection3);
        timeSeriesToChart(String.format("OLR unseen (m=7,n=10) (MSE=%.2f)", Double.valueOf(doubleValue3)), aggregate4, tSCollection3);
        displayTimeSeries(tSCollection3, StringUtils.join(strArr, " & ") + " Interp", "Date", "Price");
    }

    private static DoubleSynchronisedTimeSeriesCollection loadwords(String str, DoubleTimeSeries doubleTimeSeries) throws IOException, IncompatibleTimeSeriesException {
        WordDFIDFTimeSeriesCollection read = IOUtils.read(new File("/Users/ss/Development/data/trendminer-data/datasets/sheffield/2010/part-r-00000"), WordDFIDFTimeSeriesCollection.class);
        read.processInternalInplace(new IntervalSummationProcessor(doubleTimeSeries.getTimes()));
        DoubleSynchronisedTimeSeriesCollection doubleSynchronisedTimeSeriesCollection = new DoubleSynchronisedTimeSeriesCollection();
        doubleSynchronisedTimeSeriesCollection.addTimeSeries(str, doubleTimeSeries);
        for (String str2 : read.getNames()) {
            doubleSynchronisedTimeSeriesCollection.addTimeSeries(str2, read.series(str2).doubleTimeSeries());
        }
        return doubleSynchronisedTimeSeriesCollection;
    }

    private static void displayTimeSeries(TSCollection tSCollection, String str, String str2, String str3) {
        ChartPanel chartPanel = new ChartPanel(ChartFactory.createTimeSeriesChart(str, str2, str3, tSCollection, true, false, false));
        chartPanel.setFillZoomRectangle(true);
        JFrame jFrame = new JFrame();
        jFrame.setContentPane(chartPanel);
        jFrame.pack();
        jFrame.setVisible(true);
        jFrame.setDefaultCloseOperation(3);
    }

    private static void timeSeriesToChart(DoubleSynchronisedTimeSeriesCollection doubleSynchronisedTimeSeriesCollection, TSCollection tSCollection, String... strArr) {
        for (String str : doubleSynchronisedTimeSeriesCollection.getNames()) {
            DoubleTimeSeries series = doubleSynchronisedTimeSeriesCollection.series(str);
            TimeSeries timeSeries = new TimeSeries(str + StringUtils.join(strArr, "-"));
            Iterator it = series.iterator();
            while (it.hasNext()) {
                IndependentPair independentPair = (IndependentPair) it.next();
                DateTime dateTime = new DateTime(independentPair.firstObject());
                timeSeries.add(new Day(dateTime.getDayOfMonth(), dateTime.getMonthOfYear(), dateTime.getYear()), (Number) independentPair.secondObject());
            }
            tSCollection.addSeries(timeSeries);
        }
    }

    private static void timeSeriesToChart(String str, DoubleTimeSeries doubleTimeSeries, TSCollection tSCollection) {
        TimeSeries timeSeries = new TimeSeries(str);
        Iterator it = doubleTimeSeries.iterator();
        while (it.hasNext()) {
            IndependentPair independentPair = (IndependentPair) it.next();
            DateTime dateTime = new DateTime(independentPair.firstObject());
            timeSeries.add(new Day(dateTime.getDayOfMonth(), dateTime.getMonthOfYear(), dateTime.getYear()), (Number) independentPair.secondObject());
        }
        tSCollection.addSeries(timeSeries);
    }
}
