package org.lenskit.eval.traintest;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import org.lenskit.LenskitConfiguration;
import org.lenskit.data.dao.file.StaticDataSource;
import org.lenskit.eval.crossfold.CrossfoldMethods;
import org.lenskit.eval.crossfold.Crossfolder;
import org.lenskit.eval.crossfold.HistoryPartitions;
import org.lenskit.eval.crossfold.SortOrder;
import org.lenskit.eval.traintest.predict.PredictEvalTask;
import org.lenskit.eval.traintest.predict.PredictMetric;
import org.lenskit.util.table.Table;

/* loaded from: input_file:org/lenskit/eval/traintest/SimpleEvaluator.class */
public class SimpleEvaluator {
    private List<Crossfolder> crossfolders;
    private TrainTestExperiment experiment = new TrainTestExperiment();
    private Path workDir;

    public SimpleEvaluator() {
        this.experiment.addTask(new PredictEvalTask());
        this.crossfolders = new ArrayList();
    }

    public Path getWorkDir() {
        return this.workDir;
    }

    public SimpleEvaluator setWorkDir(Path path) {
        this.workDir = path;
        return this;
    }

    public SimpleEvaluator addAlgorithm(AlgorithmInstance algorithmInstance) {
        this.experiment.addAlgorithm(algorithmInstance);
        return this;
    }

    public SimpleEvaluator addAlgorithm(String str, LenskitConfiguration lenskitConfiguration) {
        this.experiment.addAlgorithm(new AlgorithmInstance(str, lenskitConfiguration));
        return this;
    }

    public SimpleEvaluator addDataSet(Crossfolder crossfolder) {
        this.crossfolders.add(crossfolder);
        return this;
    }

    public SimpleEvaluator addDataSet(String str, StaticDataSource staticDataSource, int i, double d) {
        addDataSet(new Crossfolder(str).setSource(staticDataSource).setPartitionCount(i).setMethod(CrossfoldMethods.partitionUsers(SortOrder.RANDOM, HistoryPartitions.holdoutFraction(d))).setOutputDir(this.workDir.resolve(str + ".split")));
        return this;
    }

    public SimpleEvaluator addDataSet(StaticDataSource staticDataSource, int i, double d) {
        return addDataSet(staticDataSource.getName(), staticDataSource, i, d);
    }

    public SimpleEvaluator addDataSet(String str, StaticDataSource staticDataSource, int i) {
        return addDataSet(new Crossfolder(str).setSource(staticDataSource).setPartitionCount(i).setOutputDir(this.workDir.resolve(str + ".split")));
    }

    public SimpleEvaluator addDataSet(StaticDataSource staticDataSource, int i) {
        return addDataSet(staticDataSource.getName(), staticDataSource, i);
    }

    public SimpleEvaluator addDataSet(DataSet dataSet) {
        this.experiment.addDataSet(dataSet);
        return this;
    }

    public SimpleEvaluator addDataSet(StaticDataSource staticDataSource, StaticDataSource staticDataSource2) {
        this.experiment.addDataSet(DataSet.newBuilder("generic-data-source").setTrain(staticDataSource).setTest(staticDataSource2).m18build());
        return this;
    }

    public SimpleEvaluator addMetric(PredictMetric<?> predictMetric) {
        this.experiment.getPredictionTask().addMetric(predictMetric);
        return this;
    }

    public SimpleEvaluator setOutput(Path path) {
        this.experiment.setOutputFile(path);
        return this;
    }

    public SimpleEvaluator setUserOutput(Path path) {
        this.experiment.setUserOutputFile(path);
        return this;
    }

    public TrainTestExperiment getExperiment() {
        return this.experiment;
    }

    public Table execute() {
        for (Crossfolder crossfolder : this.crossfolders) {
            try {
                crossfolder.execute();
                this.experiment.addDataSets(crossfolder.getDataSets());
            } catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        }
        return this.experiment.execute();
    }
}
