package org.lenskit.eval.crossfold;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import com.google.common.base.Charsets;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import it.unimi.dsi.fastutil.longs.LongIterator;
import it.unimi.dsi.fastutil.longs.LongSet;
import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import javax.annotation.Nullable;
import org.lenskit.data.dao.DataAccessException;
import org.lenskit.data.dao.file.EntitySource;
import org.lenskit.data.dao.file.StaticDataSource;
import org.lenskit.data.dao.file.TextEntitySource;
import org.lenskit.data.entities.CommonAttributes;
import org.lenskit.data.entities.CommonTypes;
import org.lenskit.data.entities.EntityType;
import org.lenskit.data.output.OutputFormat;
import org.lenskit.data.output.RatingWriter;
import org.lenskit.data.output.RatingWriters;
import org.lenskit.eval.traintest.DataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/lenskit/eval/crossfold/Crossfolder.class */
public class Crossfolder {
    public static final String ITEM_FILE_NAME = "items.txt";
    private static final Logger logger = LoggerFactory.getLogger(Crossfolder.class);
    private Random rng;
    private String name;
    private StaticDataSource source;
    private EntityType entityType;
    private int partitionCount;
    private Path outputDir;
    private OutputFormat outputFormat;
    private CrossfoldMethod method;
    private boolean writeTimestamps;
    private boolean executed;

    public Crossfolder() {
        this(null);
    }

    public Crossfolder(String str) {
        this.entityType = CommonTypes.RATING;
        this.partitionCount = 5;
        this.outputFormat = OutputFormat.CSV;
        this.method = CrossfoldMethods.partitionUsers(SortOrder.RANDOM, HistoryPartitions.holdout(10));
        this.writeTimestamps = true;
        this.executed = false;
        this.name = str;
        this.rng = new Random();
    }

    public EntityType getEntityType() {
        return this.entityType;
    }

    public void setEntityType(EntityType entityType) {
        this.entityType = entityType;
    }

    public Crossfolder setPartitionCount(int i) {
        this.partitionCount = i;
        return this;
    }

    public int getPartitionCount() {
        return this.partitionCount;
    }

    public Crossfolder setOutputFormat(OutputFormat outputFormat) {
        this.outputFormat = outputFormat;
        return this;
    }

    public OutputFormat getOutputFormat() {
        return this.outputFormat;
    }

    public Crossfolder setOutputDir(Path path) {
        this.outputDir = path;
        return this;
    }

    public Crossfolder setOutputDir(File file) {
        return setOutputDir(file.toPath());
    }

    public Crossfolder setOutputDir(String str) {
        return setOutputDir(Paths.get(str, new String[0]));
    }

    public Path getOutputDir() {
        return this.outputDir != null ? this.outputDir : Paths.get(getName() + ".split", new String[0]);
    }

    public Crossfolder setSource(StaticDataSource staticDataSource) {
        this.source = staticDataSource;
        return this;
    }

    public Crossfolder setMethod(CrossfoldMethod crossfoldMethod) {
        this.method = crossfoldMethod;
        return this;
    }

    public CrossfoldMethod getMethod() {
        return this.method;
    }

    public Crossfolder setWriteTimestamps(boolean z) {
        this.writeTimestamps = z;
        return this;
    }

    public boolean getWriteTimestamps() {
        return this.writeTimestamps;
    }

    public String getName() {
        return this.name == null ? this.source.getName() : this.name;
    }

    public Crossfolder setName(String str) {
        this.name = str;
        return this;
    }

    public StaticDataSource getSource() {
        return this.source;
    }

    public void execute() throws IOException {
        logger.info("ensuring output directory {} exists", this.outputDir);
        Files.createDirectories(this.outputDir, new FileAttribute[0]);
        logger.info("making sure item list is available");
        JsonNode writeItemFile = writeItemFile(this.source);
        logger.info("writing train-test split files");
        createTTFiles(this.source);
        logger.info("writing manifests and specs");
        HashMap hashMap = new HashMap();
        Iterator it = this.source.getSourcesForType(this.entityType).iterator();
        while (it.hasNext()) {
            hashMap.putAll(((EntitySource) it.next()).getMetadata());
        }
        writeManifests(this.source, hashMap, writeItemFile);
        this.executed = true;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public List<Path> getTrainingFiles() {
        return getFileList("part%02d.train." + this.outputFormat.getExtension());
    }

    List<Path> getTrainingManifestFiles() {
        return getFileList("part%02d.train.yaml");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public List<Path> getTestFiles() {
        return getFileList("part%02d.test." + this.outputFormat.getExtension());
    }

    List<Path> getTestManifestFiles() {
        return getFileList("part%02d.test.yaml");
    }

    List<Path> getSpecFiles() {
        return getFileList("part%02d.json");
    }

    private List<Path> getFileList(String str) {
        ArrayList arrayList = new ArrayList(this.partitionCount);
        for (int i = 1; i <= this.partitionCount; i++) {
            arrayList.add(getOutputDir().resolve(String.format(str, Integer.valueOf(i))));
        }
        return arrayList;
    }

    @Nullable
    private JsonNode writeItemFile(StaticDataSource staticDataSource) throws IOException {
        if (!staticDataSource.getSourcesForType(CommonTypes.ITEM).isEmpty()) {
            logger.info("input data specifies an item source, reusing that");
            return null;
        }
        logger.info("writing item IDs to {}", ITEM_FILE_NAME);
        Path resolve = this.outputDir.resolve(ITEM_FILE_NAME);
        LongSet entityIds = staticDataSource.get().getEntityIds(CommonTypes.ITEM);
        BufferedWriter newBufferedWriter = Files.newBufferedWriter(resolve, Charsets.UTF_8, new OpenOption[0]);
        Throwable th = null;
        try {
            try {
                LongIterator it = entityIds.iterator();
                while (it.hasNext()) {
                    newBufferedWriter.append((CharSequence) ((Long) it.next()).toString()).append((CharSequence) System.lineSeparator());
                }
                if (newBufferedWriter != null) {
                    if (0 != 0) {
                        try {
                            newBufferedWriter.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        newBufferedWriter.close();
                    }
                }
                logger.info("wrote {} item IDs", Integer.valueOf(entityIds.size()));
                JsonNodeFactory jsonNodeFactory = JsonNodeFactory.instance;
                ObjectNode objectNode = jsonNodeFactory.objectNode();
                objectNode.set("name", jsonNodeFactory.textNode(getName() + ".items"));
                objectNode.set("type", jsonNodeFactory.textNode("textfile"));
                objectNode.set("format", jsonNodeFactory.textNode("tsv"));
                objectNode.set("file", jsonNodeFactory.textNode(ITEM_FILE_NAME));
                objectNode.set("entity_type", jsonNodeFactory.textNode(CommonTypes.ITEM.getName()));
                ArrayNode arrayNode = jsonNodeFactory.arrayNode();
                arrayNode.add(CommonAttributes.ENTITY_ID.getName());
                objectNode.set("columns", arrayNode);
                return objectNode;
            } finally {
            }
        } catch (Throwable th3) {
            if (newBufferedWriter != null) {
                if (th != null) {
                    try {
                        newBufferedWriter.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    newBufferedWriter.close();
                }
            }
            throw th3;
        }
    }

    private void createTTFiles(StaticDataSource staticDataSource) throws IOException {
        if (this.entityType != CommonTypes.RATING) {
            logger.warn("entity type is not 'rating', crossfolding may not work correctly");
            logger.warn("crossfolding non-rating data is a work in progress");
        }
        List<EntitySource> sourcesForType = staticDataSource.getSourcesForType(this.entityType);
        logger.info("crossfolding {} data from {} sources", this.entityType, sourcesForType);
        for (EntitySource entitySource : sourcesForType) {
            Set types = entitySource.getTypes();
            if (types.size() > 1) {
                logger.warn("source {} has multiple entity types", entitySource);
                logger.warn("the following types will be ignored: {}", Sets.difference(types, ImmutableSet.of(this.entityType)));
            }
        }
        CrossfoldOutput crossfoldOutput = new CrossfoldOutput(this, this.rng);
        Throwable th = null;
        try {
            try {
                logger.info("running crossfold method {}", this.method);
                this.method.crossfold(staticDataSource.get(), crossfoldOutput, this.entityType);
                if (crossfoldOutput != null) {
                    if (0 == 0) {
                        crossfoldOutput.close();
                        return;
                    }
                    try {
                        crossfoldOutput.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (crossfoldOutput != null) {
                if (th != null) {
                    try {
                        crossfoldOutput.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    crossfoldOutput.close();
                }
            }
            throw th4;
        }
    }

    private void writeManifests(StaticDataSource staticDataSource, Map<String, Object> map, JsonNode jsonNode) throws IOException {
        logger.debug("writing manifests");
        ObjectMapper objectMapper = new ObjectMapper(new YAMLFactory());
        JsonNodeFactory jsonNodeFactory = JsonNodeFactory.instance;
        List<Path> trainingFiles = getTrainingFiles();
        List<Path> trainingManifestFiles = getTrainingManifestFiles();
        List<Path> testFiles = getTestFiles();
        List<Path> testManifestFiles = getTestManifestFiles();
        Path resolve = this.outputDir.resolve("datasets.yaml");
        ObjectNode objectNode = jsonNodeFactory.objectNode();
        objectNode.set("name", jsonNodeFactory.textNode(this.name));
        ArrayNode arrayNode = jsonNodeFactory.arrayNode();
        for (int i = 0; i < this.partitionCount; i++) {
            ObjectNode objectNode2 = jsonNodeFactory.objectNode();
            objectNode2.set("train", jsonNodeFactory.textNode(this.outputDir.relativize(trainingManifestFiles.get(i)).toString()));
            objectNode2.set("test", jsonNodeFactory.textNode(this.outputDir.relativize(testManifestFiles.get(i)).toString()));
            arrayNode.add(objectNode2);
            logger.debug("writing train manifest {}", Integer.valueOf(i));
            Path path = trainingManifestFiles.get(i);
            ArrayNode arrayNode2 = jsonNodeFactory.arrayNode();
            ObjectNode objectNode3 = jsonNodeFactory.objectNode();
            objectNode3.set("name", jsonNodeFactory.textNode(String.format("%s.%d.train", getName(), Integer.valueOf(i))));
            objectNode3.set("type", jsonNodeFactory.textNode("textfile"));
            objectNode3.set("file", jsonNodeFactory.textNode(this.outputDir.relativize(trainingFiles.get(i)).toString()));
            objectNode3.set("format", jsonNodeFactory.textNode("csv"));
            objectNode3.set("entity_type", jsonNodeFactory.textNode(this.entityType.getName()));
            objectNode3.set("metadata", objectMapper.valueToTree(map));
            arrayNode2.add(objectNode3);
            if (jsonNode != null) {
                arrayNode2.add(jsonNode);
            }
            for (TextEntitySource textEntitySource : staticDataSource.getSources()) {
                if (!textEntitySource.getTypes().contains(this.entityType)) {
                    if (textEntitySource instanceof TextEntitySource) {
                        arrayNode2.add(textEntitySource.toJSON(path.toUri()));
                    } else {
                        logger.warn("ignoring non-file data source {}", textEntitySource);
                    }
                }
            }
            objectMapper.writeValue(path.toFile(), arrayNode2);
            logger.debug("writing test manifest {}", Integer.valueOf(i));
            ObjectNode objectNode4 = jsonNodeFactory.objectNode();
            objectNode4.set("name", jsonNodeFactory.textNode(String.format("%s.%d.test", getName(), Integer.valueOf(i))));
            objectNode4.set("type", jsonNodeFactory.textNode("textfile"));
            objectNode4.set("file", jsonNodeFactory.textNode(this.outputDir.relativize(testFiles.get(i)).toString()));
            objectNode4.set("format", jsonNodeFactory.textNode("csv"));
            objectNode4.set("entity_type", jsonNodeFactory.textNode(this.entityType.getName()));
            objectNode4.set("metadata", objectMapper.valueToTree(map));
            objectMapper.writeValue(testManifestFiles.get(i).toFile(), objectNode4);
        }
        objectNode.set("datasets", arrayNode);
        objectMapper.writeValue(resolve.toFile(), objectNode);
    }

    public List<DataSet> getDataSets() {
        Preconditions.checkState(this.executed, "crossfolder has not been executed");
        try {
            return DataSet.load(this.outputDir.resolve("datasets.yaml"));
        } catch (IOException e) {
            throw new DataAccessException("cannot load data sets", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public RatingWriter openWriter(Path path) throws IOException {
        return RatingWriters.csv(path.toFile(), this.writeTimestamps);
    }

    public String toString() {
        return String.format("{CXManager %s}", this.source);
    }
}
