package org.apache.mahout.clustering.classify;

import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.iterator.ClusterWritable;
import org.apache.mahout.clustering.iterator.ClusteringPolicy;
import org.apache.mahout.clustering.iterator.ClusteringPolicyWritable;
import org.apache.mahout.common.ClassUtils;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.stochasticsvd.YtYJob;

/* loaded from: input_file:WEB-INF/lib/mahout-core-0.7.jar:org/apache/mahout/clustering/classify/ClusterClassifier.class */
public class ClusterClassifier extends AbstractVectorClassifier implements OnlineLearner, Writable {
    private static final String POLICY_FILE_NAME = "_policy";
    private List<Cluster> models;
    private String modelClass;
    private ClusteringPolicy policy;

    public ClusterClassifier(List<Cluster> list, ClusteringPolicy clusteringPolicy) {
        this.models = list;
        this.modelClass = list.get(0).getClass().getName();
        this.policy = clusteringPolicy;
    }

    public ClusterClassifier() {
    }

    protected ClusterClassifier(ClusteringPolicy clusteringPolicy) {
        this.policy = clusteringPolicy;
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public Vector classify(Vector vector) {
        return this.policy.classify(vector, this);
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public double classifyScalar(Vector vector) {
        if (this.models.size() != 2) {
            throw new IllegalStateException();
        }
        double pdf = this.models.get(0).pdf(new VectorWritable(vector));
        return pdf / (pdf + this.models.get(1).pdf(new VectorWritable(vector)));
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public int numCategories() {
        return this.models.size();
    }

    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeInt(this.models.size());
        dataOutput.writeUTF(this.modelClass);
        new ClusteringPolicyWritable(this.policy).write(dataOutput);
        Iterator<Cluster> it = this.models.iterator();
        while (it.hasNext()) {
            it.next().write(dataOutput);
        }
    }

    public void readFields(DataInput dataInput) throws IOException {
        int readInt = dataInput.readInt();
        this.modelClass = dataInput.readUTF();
        this.models = Lists.newArrayList();
        ClusteringPolicyWritable clusteringPolicyWritable = new ClusteringPolicyWritable();
        clusteringPolicyWritable.readFields(dataInput);
        this.policy = clusteringPolicyWritable.getValue();
        for (int i = 0; i < readInt; i++) {
            Cluster cluster = (Cluster) ClassUtils.instantiateAs(this.modelClass, Cluster.class);
            cluster.readFields(dataInput);
            this.models.add(cluster);
        }
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(int i, Vector vector) {
        this.models.get(i).observe((Cluster) new VectorWritable(vector));
    }

    public void train(int i, Vector vector, double d) {
        this.models.get(i).observe(new VectorWritable(vector), d);
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(long j, String str, int i, Vector vector) {
        this.models.get(i).observe((Cluster) new VectorWritable(vector));
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(long j, int i, Vector vector) {
        this.models.get(i).observe((Cluster) new VectorWritable(vector));
    }

    @Override // org.apache.mahout.classifier.OnlineLearner, java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        this.policy.close(this);
    }

    public List<Cluster> getModels() {
        return this.models;
    }

    public ClusteringPolicy getPolicy() {
        return this.policy;
    }

    public void writeToSeqFiles(Path path) throws IOException {
        writePolicy(this.policy, path);
        Configuration configuration = new Configuration();
        FileSystem fileSystem = FileSystem.get(path.toUri(), configuration);
        SequenceFile.Writer writer = null;
        ClusterWritable clusterWritable = new ClusterWritable();
        for (int i = 0; i < this.models.size(); i++) {
            try {
                clusterWritable.setValue(this.models.get(i));
                writer = new SequenceFile.Writer(fileSystem, configuration, new Path(path, YtYJob.OUTPUT_YtY + String.format(Locale.ENGLISH, "%05d", Integer.valueOf(i))), IntWritable.class, ClusterWritable.class);
                writer.append(new IntWritable(i), clusterWritable);
                Closeables.closeQuietly(writer);
            } catch (Throwable th) {
                Closeables.closeQuietly(writer);
                throw th;
            }
        }
    }

    public void readFromSeqFiles(Configuration configuration, Path path) throws IOException {
        Configuration configuration2 = new Configuration();
        ArrayList newArrayList = Lists.newArrayList();
        Iterator it = new SequenceFileDirValueIterable(path, PathType.LIST, PathFilters.logsCRCFilter(), configuration2).iterator();
        while (it.hasNext()) {
            Cluster value = ((ClusterWritable) it.next()).getValue();
            value.configure(configuration);
            newArrayList.add(value);
        }
        this.models = newArrayList;
        this.modelClass = this.models.get(0).getClass().getName();
        this.policy = readPolicy(path);
    }

    public static ClusteringPolicy readPolicy(Path path) throws IOException {
        Path path2 = new Path(path, POLICY_FILE_NAME);
        Configuration configuration = new Configuration();
        SequenceFile.Reader reader = new SequenceFile.Reader(FileSystem.get(path2.toUri(), configuration), path2, configuration);
        Text text = new Text();
        ClusteringPolicyWritable clusteringPolicyWritable = new ClusteringPolicyWritable();
        reader.next(text, clusteringPolicyWritable);
        return clusteringPolicyWritable.getValue();
    }

    public static void writePolicy(ClusteringPolicy clusteringPolicy, Path path) throws IOException {
        Path path2 = new Path(path, POLICY_FILE_NAME);
        Configuration configuration = new Configuration();
        SequenceFile.Writer writer = new SequenceFile.Writer(FileSystem.get(path2.toUri(), configuration), configuration, path2, Text.class, ClusteringPolicyWritable.class);
        writer.append(new Text(), new ClusteringPolicyWritable(clusteringPolicy));
        writer.close();
    }
}
