package org.tribuo.classification.mnb;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.util.ExpNormalizer;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/classification/mnb/MultinomialNaiveBayesModel.class */
public class MultinomialNaiveBayesModel extends Model<Label> {
    private static final long serialVersionUID = 1;
    private final DenseSparseMatrix labelWordProbs;
    private final double alpha;
    private static final VectorNormalizer normalizer = new ExpNormalizer();

    /* JADX INFO: Access modifiers changed from: package-private */
    public MultinomialNaiveBayesModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Label> immutableOutputInfo, DenseSparseMatrix denseSparseMatrix, double d) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, true);
        this.labelWordProbs = denseSparseMatrix;
        this.alpha = d;
    }

    public Prediction<Label> predict(Example<Label> example) {
        SparseVector createSparseVector = SparseVector.createSparseVector(example, this.featureIDMap, false);
        if (createSparseVector.minValue() < 0.0d) {
            throw new IllegalArgumentException("Example has negative feature values, example = " + example.toString());
        }
        if (createSparseVector.numActiveElements() == 0) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        double[] dArr = new double[this.outputIDInfo.size()];
        int dimension2Size = this.labelWordProbs.getDimension2Size();
        if (this.alpha > 0.0d) {
            for (int i = 0; i < this.outputIDInfo.size(); i++) {
                double log = Math.log(this.alpha / (this.labelWordProbs.getRow(i).oneNorm() + (dimension2Size * this.alpha)));
                double d = 0.0d;
                for (int i2 : createSparseVector.difference(this.labelWordProbs.getRow(i))) {
                    d += createSparseVector.get(i2) * log;
                }
                dArr[i] = d;
            }
        }
        DenseVector leftMultiply = this.labelWordProbs.leftMultiply(createSparseVector);
        leftMultiply.intersectAndAddInPlace(DenseVector.createDenseVector(dArr));
        leftMultiply.normalize(normalizer);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Label label = null;
        double d2 = Double.NEGATIVE_INFINITY;
        VectorIterator it = leftMultiply.iterator();
        while (it.hasNext()) {
            VectorTuple vectorTuple = (VectorTuple) it.next();
            String label2 = this.outputIDInfo.getOutput(vectorTuple.index).getLabel();
            Label label3 = new Label(label2, vectorTuple.value);
            if (vectorTuple.value > d2) {
                d2 = vectorTuple.value;
                label = label3;
            }
            linkedHashMap.put(label2, label3);
        }
        return new Prediction<>(label, linkedHashMap, createSparseVector.numActiveElements(), example, true);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v25, types: [java.util.List] */
    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        int size = i < 0 ? this.featureIDMap.size() : i;
        HashMap hashMap = new HashMap();
        for (Pair pair : this.outputIDInfo) {
            ArrayList arrayList = new ArrayList(this.labelWordProbs.numActiveElements(((Integer) pair.getA()).intValue()));
            VectorIterator it = this.labelWordProbs.getRow(((Integer) pair.getA()).intValue()).iterator();
            while (it.hasNext()) {
                VectorTuple vectorTuple = (VectorTuple) it.next();
                arrayList.add(new Pair(this.featureIDMap.get(vectorTuple.index).getName(), Double.valueOf(vectorTuple.value)));
            }
            arrayList.sort(Comparator.comparing(pair2 -> {
                return Double.valueOf(-((Double) pair2.getB()).doubleValue());
            }));
            if (size < this.featureIDMap.size()) {
                arrayList = arrayList.subList(0, size);
            }
            hashMap.put(((Label) pair.getB()).getLabel(), arrayList);
        }
        return hashMap;
    }

    public Optional<Excuse<Label>> getExcuse(Example<Label> example) {
        HashMap hashMap = new HashMap();
        for (Pair pair : this.outputIDInfo) {
            ArrayList arrayList = new ArrayList();
            Iterator it = example.iterator();
            while (it.hasNext()) {
                Feature feature = (Feature) it.next();
                int id = this.featureIDMap.getID(feature.getName());
                if (id > -1) {
                    arrayList.add(new Pair(feature.getName(), Double.valueOf(this.labelWordProbs.getRow(((Integer) pair.getA()).intValue()).get(id))));
                }
            }
            hashMap.put(((Label) pair.getB()).getLabel(), arrayList);
        }
        return Optional.of(new Excuse(example, predict(example), hashMap));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public MultinomialNaiveBayesModel m1copy(String str, ModelProvenance modelProvenance) {
        return new MultinomialNaiveBayesModel(str, modelProvenance, this.featureIDMap, this.outputIDInfo, new DenseSparseMatrix(this.labelWordProbs), this.alpha);
    }
}
