package org.fnlp.ml.classifier.hier.inf;

import gnu.trove.iterator.TIntIterator;
import gnu.trove.set.TIntSet;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.fnlp.ml.classifier.hier.Predict;
import org.fnlp.ml.classifier.hier.Tree;
import org.fnlp.ml.classifier.linear.inf.Inferencer;
import org.fnlp.ml.feature.Generator;
import org.fnlp.ml.types.Instance;
import org.fnlp.ml.types.alphabet.LabelAlphabet;
import org.fnlp.ml.types.sv.HashSparseVector;
import org.fnlp.ml.types.sv.ISparseVector;

/* loaded from: input_file:org/fnlp/ml/classifier/hier/inf/MultiLinearMax.class */
public class MultiLinearMax extends Inferencer implements Serializable {
    private static final long serialVersionUID = 460812009958228912L;
    private LabelAlphabet alphabet;
    private Tree tree;
    int numThread;
    private transient ExecutorService pool;
    private HashSparseVector[] weights;
    private Generator featureGen;
    private int numClass;
    transient TIntSet leafs;
    private boolean isUseTarget = true;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/fnlp/ml/classifier/hier/inf/MultiLinearMax$Multiplesolve.class */
    public class Multiplesolve implements Callable {
        ISparseVector fv;
        int idx;

        public Multiplesolve(ISparseVector iSparseVector, int i) {
            this.fv = iSparseVector;
            this.idx = i;
        }

        @Override // java.util.concurrent.Callable
        public Float call() {
            return Float.valueOf(this.fv.dotProduct(MultiLinearMax.this.weights[this.idx]));
        }
    }

    public MultiLinearMax(Generator generator, LabelAlphabet labelAlphabet, Tree tree, int i) {
        this.leafs = null;
        this.featureGen = generator;
        this.alphabet = labelAlphabet;
        this.numThread = i;
        this.tree = tree;
        this.pool = Executors.newFixedThreadPool(this.numThread);
        this.numClass = labelAlphabet.size();
        if (tree == null) {
            this.leafs = labelAlphabet.toTSet();
        } else {
            this.leafs = tree.getLeafs();
        }
    }

    @Override // org.fnlp.ml.classifier.linear.inf.Inferencer
    public Predict getBest(Instance instance) {
        return getBest(instance, 1);
    }

    @Override // org.fnlp.ml.classifier.linear.inf.Inferencer
    public Predict getBest(Instance instance, int i) {
        Integer num = this.isUseTarget ? (Integer) instance.getTarget() : null;
        ISparseVector vector = this.featureGen.getVector(instance);
        float[] fArr = new float[this.alphabet.size()];
        Multiplesolve[] multiplesolveArr = new Multiplesolve[this.numClass];
        Future[] futureArr = new Future[this.numClass];
        for (int i2 = 0; i2 < this.numClass; i2++) {
            multiplesolveArr[i2] = new Multiplesolve(vector, i2);
            futureArr[i2] = this.pool.submit(multiplesolveArr[i2]);
        }
        for (int i3 = 0; i3 < this.numClass; i3++) {
            try {
                fArr[i3] = ((Float) futureArr[i3].get()).floatValue();
            } catch (Exception e) {
                e.printStackTrace();
                return null;
            }
        }
        Predict predict = new Predict(i);
        Predict predict2 = num != null ? new Predict(i) : null;
        TIntIterator it = this.leafs.iterator();
        while (it.hasNext()) {
            float f = 0.0f;
            int next = it.next();
            if (this.tree != null) {
                for (int i4 : this.tree.getPath(next)) {
                    f += fArr[i4];
                }
            } else {
                f = fArr[next];
            }
            if (num == null || !num.equals(Integer.valueOf(next))) {
                predict.add(next, f);
            } else {
                predict2.add(next, f);
            }
        }
        if (num != null) {
            instance.setTempData(predict2);
        }
        return predict;
    }

    public void setWeight(HashSparseVector[] hashSparseVectorArr) {
        this.weights = hashSparseVectorArr;
    }

    @Override // org.fnlp.ml.classifier.linear.inf.Inferencer
    public void isUseTarget(boolean z) {
        this.isUseTarget = z;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.defaultWriteObject();
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        if (this.tree == null) {
            this.leafs = this.alphabet.toTSet();
        } else {
            this.leafs = this.tree.getLeafs();
        }
        this.pool = Executors.newFixedThreadPool(this.numThread);
    }
}
