package org.tribuo.classification.libsvm;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Set;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.common.libsvm.LibSVMModel;
import org.tribuo.common.libsvm.LibSVMTrainer;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/classification/libsvm/LibSVMClassificationModel.class */
public class LibSVMClassificationModel extends LibSVMModel<Label> {
    private static final long serialVersionUID = 3;
    private final Set<Label> unobservedLabels;

    /* JADX INFO: Access modifiers changed from: package-private */
    public LibSVMClassificationModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Label> immutableOutputInfo, List<svm_model> list) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, list.get(0).param.probability == 1, list);
        int[] iArr = list.get(0).label;
        if (iArr.length == immutableOutputInfo.size()) {
            this.unobservedLabels = Collections.emptySet();
            return;
        }
        HashMap hashMap = new HashMap();
        Iterator it = immutableOutputInfo.iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            hashMap.put(pair.getA(), pair.getB());
        }
        for (int i = 0; i < iArr.length; i++) {
            hashMap.remove(Integer.valueOf(i));
        }
        HashSet hashSet = new HashSet(hashMap.values().size());
        Iterator it2 = hashMap.values().iterator();
        while (it2.hasNext()) {
            hashSet.add(new Label(((Label) it2.next()).getLabel(), 0.0d));
        }
        this.unobservedLabels = Collections.unmodifiableSet(hashSet);
    }

    public int getNumberOfSupportVectors() {
        return ((svm_model) this.models.get(0)).SV.length;
    }

    public Prediction<Label> predict(Example<Label> example) {
        svm_model svm_modelVar = (svm_model) this.models.get(0);
        svm_node[] exampleToNodes = LibSVMTrainer.exampleToNodes(example, this.featureIDMap, (List) null);
        if (exampleToNodes.length == 0) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        int[] iArr = svm_modelVar.label;
        double[] dArr = new double[iArr.length];
        if (this.generatesProbabilities) {
            svm.svm_predict_probability(svm_modelVar, exampleToNodes, dArr);
        } else {
            double[] dArr2 = new double[(iArr.length * (iArr.length - 1)) / 2];
            svm.svm_predict_values(svm_modelVar, exampleToNodes, dArr2);
            int i = 0;
            for (int i2 = 0; i2 < iArr.length; i2++) {
                for (int i3 = i2 + 1; i3 < iArr.length; i3++) {
                    if (dArr2[i] > 0.0d) {
                        int i4 = i2;
                        dArr[i4] = dArr[i4] + 1.0d;
                    } else {
                        int i5 = i3;
                        dArr[i5] = dArr[i5] + 1.0d;
                    }
                    i++;
                }
            }
        }
        double d = Double.NEGATIVE_INFINITY;
        Label label = null;
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i6 = 0; i6 < dArr.length; i6++) {
            String label2 = this.outputIDInfo.getOutput(iArr[i6]).getLabel();
            Label label3 = new Label(label2, dArr[i6]);
            linkedHashMap.put(label2, label3);
            if (label3.getScore() > d) {
                d = label3.getScore();
                label = label3;
            }
        }
        if (!this.unobservedLabels.isEmpty()) {
            for (Label label4 : this.unobservedLabels) {
                linkedHashMap.put(label4.getLabel(), label4);
            }
        }
        return new Prediction<>(label, linkedHashMap, exampleToNodes.length, example, this.generatesProbabilities);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public LibSVMClassificationModel m0copy(String str, ModelProvenance modelProvenance) {
        return new LibSVMClassificationModel(str, modelProvenance, this.featureIDMap, this.outputIDInfo, Collections.singletonList(LibSVMModel.copyModel((svm_model) this.models.get(0))));
    }
}
