package org.tribuo.classification.dtree.impl;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.NotSerializableException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.classification.Label;
import org.tribuo.classification.dtree.impurity.LabelImpurity;
import org.tribuo.common.tree.AbstractTrainingNode;
import org.tribuo.common.tree.LeafNode;
import org.tribuo.common.tree.Node;
import org.tribuo.common.tree.SplitNode;
import org.tribuo.common.tree.impl.IntArrayContainer;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/classification/dtree/impl/ClassifierTrainingNode.class */
public class ClassifierTrainingNode extends AbstractTrainingNode<Label> {
    private static final long serialVersionUID = 1;
    private static final Logger logger = Logger.getLogger(ClassifierTrainingNode.class.getName());
    private static final ThreadLocal<IntArrayContainer> mergeBufferOne = ThreadLocal.withInitial(() -> {
        return new IntArrayContainer(16);
    });
    private static final ThreadLocal<IntArrayContainer> mergeBufferTwo = ThreadLocal.withInitial(() -> {
        return new IntArrayContainer(16);
    });
    private static final ThreadLocal<IntArrayContainer> mergeBufferThree = ThreadLocal.withInitial(() -> {
        return new IntArrayContainer(16);
    });
    private transient ArrayList<TreeFeature> data;
    private final ImmutableOutputInfo<Label> labelIDMap;
    private final ImmutableFeatureMap featureIDMap;
    private final LabelImpurity impurity;
    private final float[] labelCounts;

    public ClassifierTrainingNode(LabelImpurity labelImpurity, Dataset<Label> dataset) {
        this(labelImpurity, invertData(dataset), dataset.size(), 0, dataset.getFeatureIDMap(), dataset.getOutputIDInfo());
    }

    private ClassifierTrainingNode(LabelImpurity labelImpurity, ArrayList<TreeFeature> arrayList, int i, int i2, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Label> immutableOutputInfo) {
        super(i2, i);
        this.data = arrayList;
        this.featureIDMap = immutableFeatureMap;
        this.labelIDMap = immutableOutputInfo;
        this.impurity = labelImpurity;
        this.labelCounts = arrayList.get(0).getLabelCounts();
    }

    public List<AbstractTrainingNode<Label>> buildTree(int[] iArr) {
        List<AbstractTrainingNode<Label>> emptyList;
        int i = -1;
        double d = 0.0d;
        double impurity = this.impurity.impurity(this.labelCounts);
        float[] fArr = new float[this.labelCounts.length];
        float[] fArr2 = new float[this.labelCounts.length];
        double sum = Util.sum(this.labelCounts);
        for (int i2 = 0; i2 < iArr.length; i2++) {
            List<InvertedFeature> feature = this.data.get(iArr[i2]).getFeature();
            Arrays.fill(fArr, 0.0f);
            System.arraycopy(this.labelCounts, 0, fArr2, 0, this.labelCounts.length);
            for (int i3 = 0; i3 < feature.size() - 1; i3++) {
                InvertedFeature invertedFeature = feature.get(i3);
                float[] labelCounts = invertedFeature.getLabelCounts();
                Util.inPlaceAdd(fArr, labelCounts);
                Util.inPlaceSubtract(fArr2, labelCounts);
                double impurityWeighted = this.impurity.impurityWeighted(fArr);
                double impurityWeighted2 = this.impurity.impurityWeighted(fArr2);
                if (impurityWeighted > 1.0E-10d && impurityWeighted2 > 1.0E-10d) {
                    double d2 = (impurityWeighted + impurityWeighted2) / sum;
                    if (d2 < impurity) {
                        i = i2;
                        impurity = d2;
                        d = (invertedFeature.value + feature.get(i3 + 1).value) / 2.0d;
                    }
                }
            }
        }
        if (i != -1) {
            this.splitID = iArr[i];
            this.split = true;
            this.splitValue = d;
            IntArrayContainer intArrayContainer = mergeBufferOne.get();
            intArrayContainer.size = 0;
            IntArrayContainer intArrayContainer2 = mergeBufferTwo.get();
            intArrayContainer2.size = 0;
            Iterator<InvertedFeature> it = this.data.get(this.splitID).iterator();
            while (it.hasNext()) {
                InvertedFeature next = it.next();
                if (next.value >= this.splitValue) {
                    break;
                }
                IntArrayContainer.merge(intArrayContainer, next.indices(), intArrayContainer2);
                IntArrayContainer intArrayContainer3 = intArrayContainer;
                intArrayContainer = intArrayContainer2;
                intArrayContainer2 = intArrayContainer3;
            }
            IntArrayContainer intArrayContainer4 = mergeBufferThree.get();
            intArrayContainer4.grow(intArrayContainer.size);
            ArrayList arrayList = new ArrayList(this.data.size());
            ArrayList arrayList2 = new ArrayList(this.data.size());
            Iterator<TreeFeature> it2 = this.data.iterator();
            while (it2.hasNext()) {
                Pair<TreeFeature, TreeFeature> split = it2.next().split(intArrayContainer, intArrayContainer2, intArrayContainer4);
                arrayList.add(split.getA());
                arrayList2.add(split.getB());
            }
            this.lessThanOrEqual = new ClassifierTrainingNode(this.impurity, arrayList, intArrayContainer.size, this.depth + 1, this.featureIDMap, this.labelIDMap);
            this.greaterThan = new ClassifierTrainingNode(this.impurity, arrayList2, this.numExamples - intArrayContainer.size, this.depth + 1, this.featureIDMap, this.labelIDMap);
            emptyList = new ArrayList();
            emptyList.add(this.lessThanOrEqual);
            emptyList.add(this.greaterThan);
        } else {
            emptyList = Collections.emptyList();
        }
        this.data = null;
        return emptyList;
    }

    public Node<Label> convertTree() {
        if (this.split) {
            return new SplitNode(this.splitValue, this.splitID, this.impurity.impurity(this.labelCounts), this.greaterThan.convertTree(), this.lessThanOrEqual.convertTree());
        }
        double[] normalizeToDistribution = Util.normalizeToDistribution(this.labelCounts);
        double d = Double.NEGATIVE_INFINITY;
        Label label = null;
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i = 0; i < this.labelCounts.length; i++) {
            String label2 = this.labelIDMap.getOutput(i).getLabel();
            Label label3 = new Label(label2, normalizeToDistribution[i]);
            linkedHashMap.put(label2, label3);
            if (label3.getScore() > d) {
                d = label3.getScore();
                label = label3;
            }
        }
        return new LeafNode(this.impurity.impurity(this.labelCounts), label, linkedHashMap, true);
    }

    public double getImpurity() {
        return this.impurity.impurity(this.labelCounts);
    }

    private static ArrayList<TreeFeature> invertData(Dataset<Label> dataset) {
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        ImmutableOutputInfo outputIDInfo = dataset.getOutputIDInfo();
        int size = outputIDInfo.size();
        int size2 = featureIDMap.size();
        int size3 = dataset.size();
        int[] iArr = new int[size3];
        float[] fArr = new float[size3];
        int i = 0;
        Iterator it = dataset.iterator();
        while (it.hasNext()) {
            Example example = (Example) it.next();
            fArr[i] = example.getWeight();
            iArr[i] = outputIDInfo.getID(example.getOutput());
            i++;
        }
        logger.fine("Building initial List<TreeFeature> for " + size2 + " features and " + size + " classes");
        ArrayList<TreeFeature> arrayList = new ArrayList<>(featureIDMap.size());
        for (int i2 = 0; i2 < featureIDMap.size(); i2++) {
            arrayList.add(new TreeFeature(i2, size, iArr, fArr));
        }
        for (int i3 = 0; i3 < dataset.size(); i3++) {
            Example example2 = dataset.getExample(i3);
            int i4 = 0;
            VectorIterator it2 = SparseVector.createSparseVector(example2, featureIDMap, false).iterator();
            while (it2.hasNext()) {
                VectorTuple vectorTuple = (VectorTuple) it2.next();
                int i5 = vectorTuple.index;
                for (int i6 = i4; i6 < i5; i6++) {
                    arrayList.get(i6).observeValue(0.0d, i3);
                }
                arrayList.get(i5).observeValue(vectorTuple.value, i3);
                if (i4 > i5) {
                    logger.severe("Example = " + example2.toString());
                    throw new IllegalStateException("Features aren't ordered. At id " + i3 + ", lastID = " + i4 + ", curID = " + i5);
                }
                if (i4 - 1 == i5) {
                    logger.severe("Example = " + example2.toString());
                    throw new IllegalStateException("Features are repeated. At id " + i3 + ", lastID = " + i4 + ", curID = " + i5);
                }
                i4 = i5 + 1;
            }
            for (int i7 = i4; i7 < size2; i7++) {
                arrayList.get(i7).observeValue(0.0d, i3);
            }
            if (i3 % 1000 == 0) {
                logger.fine("Processed example " + i3);
            }
        }
        logger.fine("Sorting features");
        arrayList.forEach((v0) -> {
            v0.sort();
        });
        logger.fine("Fixing InvertedFeature sizes");
        arrayList.forEach((v0) -> {
            v0.fixSize();
        });
        logger.fine("Built initial List<TreeFeature>");
        return arrayList;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        throw new NotSerializableException("ClassifierTrainingNode is a runtime class only, and should not be serialized.");
    }
}
