package org.tribuo.classification.dtree.impl;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.NotSerializableException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.classification.Label;
import org.tribuo.classification.dtree.impurity.LabelImpurity;
import org.tribuo.common.tree.AbstractTrainingNode;
import org.tribuo.common.tree.LeafNode;
import org.tribuo.common.tree.Node;
import org.tribuo.common.tree.impl.IntArrayContainer;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/classification/dtree/impl/ClassifierTrainingNode.class */
public class ClassifierTrainingNode extends AbstractTrainingNode<Label> {
    private static final long serialVersionUID = 1;
    private static final Logger logger = Logger.getLogger(ClassifierTrainingNode.class.getName());
    private static final ThreadLocal<IntArrayContainer> mergeBufferOne = ThreadLocal.withInitial(() -> {
        return new IntArrayContainer(16);
    });
    private static final ThreadLocal<IntArrayContainer> mergeBufferTwo = ThreadLocal.withInitial(() -> {
        return new IntArrayContainer(16);
    });
    private static final ThreadLocal<IntArrayContainer> mergeBufferThree = ThreadLocal.withInitial(() -> {
        return new IntArrayContainer(16);
    });
    private transient ArrayList<TreeFeature> data;
    private final ImmutableOutputInfo<Label> labelIDMap;
    private final ImmutableFeatureMap featureIDMap;
    private final LabelImpurity impurity;
    private final float[] weightedLabelCounts;
    private final float weightSum;

    public ClassifierTrainingNode(LabelImpurity labelImpurity, Dataset<Label> dataset, AbstractTrainingNode.LeafDeterminer leafDeterminer) {
        this(labelImpurity, invertData(dataset), dataset.size(), 0, dataset.getFeatureIDMap(), dataset.getOutputIDInfo(), leafDeterminer);
    }

    private ClassifierTrainingNode(LabelImpurity labelImpurity, ArrayList<TreeFeature> arrayList, int i, int i2, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Label> immutableOutputInfo, AbstractTrainingNode.LeafDeterminer leafDeterminer) {
        super(i2, i, leafDeterminer);
        this.data = arrayList;
        this.featureIDMap = immutableFeatureMap;
        this.labelIDMap = immutableOutputInfo;
        this.impurity = labelImpurity;
        this.weightedLabelCounts = arrayList.get(0).getWeightedLabelCounts();
        this.weightSum = Util.sum(this.weightedLabelCounts);
        this.impurityScore = labelImpurity.impurity(this.weightedLabelCounts);
    }

    private ClassifierTrainingNode(LabelImpurity labelImpurity, ArrayList<TreeFeature> arrayList, int i, int i2, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Label> immutableOutputInfo, AbstractTrainingNode.LeafDeterminer leafDeterminer, float[] fArr, float f, double d) {
        super(i2, i, leafDeterminer);
        this.data = arrayList;
        this.featureIDMap = immutableFeatureMap;
        this.labelIDMap = immutableOutputInfo;
        this.impurity = labelImpurity;
        this.weightedLabelCounts = fArr;
        this.weightSum = f;
        this.impurityScore = d;
    }

    public float getWeightSum() {
        return this.weightSum;
    }

    public double getImpurity() {
        return this.impurityScore;
    }

    public List<AbstractTrainingNode<Label>> buildTree(int[] iArr, SplittableRandom splittableRandom, boolean z) {
        return z ? buildRandomTree(iArr, splittableRandom) : buildGreedyTree(iArr);
    }

    private List<AbstractTrainingNode<Label>> buildGreedyTree(int[] iArr) {
        int i = -1;
        double d = 0.0d;
        double impurity = getImpurity();
        float[] fArr = new float[this.weightedLabelCounts.length];
        float[] fArr2 = new float[this.weightedLabelCounts.length];
        float[] fArr3 = new float[this.weightedLabelCounts.length];
        float[] fArr4 = new float[this.weightedLabelCounts.length];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            List<InvertedFeature> feature = this.data.get(iArr[i2]).getFeature();
            Arrays.fill(fArr3, 0.0f);
            System.arraycopy(this.weightedLabelCounts, 0, fArr4, 0, this.weightedLabelCounts.length);
            for (int i3 = 0; i3 < feature.size() - 1; i3++) {
                InvertedFeature invertedFeature = feature.get(i3);
                float[] weightedLabelCounts = invertedFeature.getWeightedLabelCounts();
                Util.inPlaceAdd(fArr3, weightedLabelCounts);
                Util.inPlaceSubtract(fArr4, weightedLabelCounts);
                double impurityWeighted = (this.impurity.impurityWeighted(fArr3) + this.impurity.impurityWeighted(fArr4)) / this.weightSum;
                if (impurityWeighted < impurity) {
                    i = i2;
                    impurity = impurityWeighted;
                    System.arraycopy(fArr3, 0, fArr, 0, fArr3.length);
                    System.arraycopy(fArr4, 0, fArr2, 0, fArr4.length);
                    d = (invertedFeature.value + feature.get(i3 + 1).value) / 2.0d;
                }
            }
        }
        List<AbstractTrainingNode<Label>> emptyList = (i == -1 || ((double) this.weightSum) * (getImpurity() - impurity) < ((double) this.leafDeterminer.getScaledMinImpurityDecrease())) ? Collections.emptyList() : splitAtBest(iArr, i, d, fArr, fArr2);
        this.data = null;
        return emptyList;
    }

    public List<AbstractTrainingNode<Label>> buildRandomTree(int[] iArr, SplittableRandom splittableRandom) {
        int i = -1;
        double d = 0.0d;
        double impurity = getImpurity();
        float[] fArr = new float[this.weightedLabelCounts.length];
        float[] fArr2 = new float[this.weightedLabelCounts.length];
        float[] fArr3 = new float[this.weightedLabelCounts.length];
        float[] fArr4 = new float[this.weightedLabelCounts.length];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            List<InvertedFeature> feature = this.data.get(iArr[i2]).getFeature();
            if (feature.size() != 1) {
                Arrays.fill(fArr3, 0.0f);
                System.arraycopy(this.weightedLabelCounts, 0, fArr4, 0, this.weightedLabelCounts.length);
                int nextInt = splittableRandom.nextInt(feature.size() - 1);
                for (int i3 = 0; i3 < nextInt + 1; i3++) {
                    float[] weightedLabelCounts = feature.get(i3).getWeightedLabelCounts();
                    Util.inPlaceAdd(fArr3, weightedLabelCounts);
                    Util.inPlaceSubtract(fArr4, weightedLabelCounts);
                }
                double impurityWeighted = (this.impurity.impurityWeighted(fArr3) + this.impurity.impurityWeighted(fArr4)) / this.weightSum;
                if (impurityWeighted < impurity) {
                    i = i2;
                    impurity = impurityWeighted;
                    System.arraycopy(fArr3, 0, fArr, 0, fArr3.length);
                    System.arraycopy(fArr4, 0, fArr2, 0, fArr4.length);
                    d = (feature.get(nextInt).value + feature.get(nextInt + 1).value) / 2.0d;
                }
            }
        }
        List<AbstractTrainingNode<Label>> emptyList = (i == -1 || ((double) this.weightSum) * (getImpurity() - impurity) < ((double) this.leafDeterminer.getScaledMinImpurityDecrease())) ? Collections.emptyList() : splitAtBest(iArr, i, d, fArr, fArr2);
        this.data = null;
        return emptyList;
    }

    private List<AbstractTrainingNode<Label>> splitAtBest(int[] iArr, int i, double d, float[] fArr, float[] fArr2) {
        this.splitID = iArr[i];
        this.split = true;
        this.splitValue = d;
        float sum = Util.sum(fArr);
        double impurity = this.impurity.impurity(fArr);
        float sum2 = Util.sum(fArr2);
        double impurity2 = this.impurity.impurity(fArr2);
        boolean shouldMakeLeaf = shouldMakeLeaf(impurity, sum);
        boolean shouldMakeLeaf2 = shouldMakeLeaf(impurity2, sum2);
        if (shouldMakeLeaf && shouldMakeLeaf2) {
            this.lessThanOrEqual = createLeaf(impurity, fArr);
            this.greaterThan = createLeaf(impurity2, fArr2);
            return Collections.emptyList();
        }
        IntArrayContainer intArrayContainer = mergeBufferOne.get();
        intArrayContainer.size = 0;
        IntArrayContainer intArrayContainer2 = mergeBufferTwo.get();
        intArrayContainer2.size = 0;
        Iterator<InvertedFeature> it = this.data.get(this.splitID).iterator();
        while (it.hasNext()) {
            InvertedFeature next = it.next();
            if (next.value >= this.splitValue) {
                break;
            }
            IntArrayContainer.merge(intArrayContainer, next.indices(), intArrayContainer2);
            IntArrayContainer intArrayContainer3 = intArrayContainer;
            intArrayContainer = intArrayContainer2;
            intArrayContainer2 = intArrayContainer3;
        }
        IntArrayContainer intArrayContainer4 = mergeBufferThree.get();
        intArrayContainer4.grow(intArrayContainer.size);
        ArrayList arrayList = new ArrayList(this.data.size());
        ArrayList arrayList2 = new ArrayList(this.data.size());
        Iterator<TreeFeature> it2 = this.data.iterator();
        while (it2.hasNext()) {
            Pair<TreeFeature, TreeFeature> split = it2.next().split(intArrayContainer, intArrayContainer2, intArrayContainer4);
            arrayList.add((TreeFeature) split.getA());
            arrayList2.add((TreeFeature) split.getB());
        }
        ArrayList arrayList3 = new ArrayList(2);
        if (shouldMakeLeaf) {
            this.lessThanOrEqual = createLeaf(impurity, fArr);
        } else {
            ClassifierTrainingNode classifierTrainingNode = new ClassifierTrainingNode(this.impurity, arrayList, intArrayContainer.size, this.depth + 1, this.featureIDMap, this.labelIDMap, this.leafDeterminer, fArr, sum, impurity);
            this.lessThanOrEqual = classifierTrainingNode;
            arrayList3.add(classifierTrainingNode);
        }
        if (shouldMakeLeaf2) {
            this.greaterThan = createLeaf(impurity2, fArr2);
        } else {
            ClassifierTrainingNode classifierTrainingNode2 = new ClassifierTrainingNode(this.impurity, arrayList2, this.numExamples - intArrayContainer.size, this.depth + 1, this.featureIDMap, this.labelIDMap, this.leafDeterminer, fArr2, sum2, impurity2);
            this.greaterThan = classifierTrainingNode2;
            arrayList3.add(classifierTrainingNode2);
        }
        return arrayList3;
    }

    private LeafNode<Label> createLeaf(double d, float[] fArr) {
        double[] normalizeToDistribution = Util.normalizeToDistribution(fArr);
        double d2 = Double.NEGATIVE_INFINITY;
        Label label = null;
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i = 0; i < fArr.length; i++) {
            double d3 = normalizeToDistribution[i];
            String label2 = this.labelIDMap.getOutput(i).getLabel();
            Label label3 = new Label(label2, d3);
            linkedHashMap.put(label2, label3);
            if (d3 > d2) {
                d2 = d3;
                label = label3;
            }
        }
        return new LeafNode<>(d, label, linkedHashMap, true);
    }

    public Node<Label> convertTree() {
        return this.split ? createSplitNode() : createLeaf(getImpurity(), this.weightedLabelCounts);
    }

    private static ArrayList<TreeFeature> invertData(Dataset<Label> dataset) {
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        ImmutableOutputInfo outputIDInfo = dataset.getOutputIDInfo();
        int size = outputIDInfo.size();
        int size2 = featureIDMap.size();
        int size3 = dataset.size();
        int[] iArr = new int[size3];
        float[] fArr = new float[size3];
        int i = 0;
        Iterator it = dataset.iterator();
        while (it.hasNext()) {
            Example example = (Example) it.next();
            fArr[i] = example.getWeight();
            iArr[i] = outputIDInfo.getID(example.getOutput());
            i++;
        }
        logger.fine("Building initial List<TreeFeature> for " + size2 + " features and " + size + " classes");
        ArrayList<TreeFeature> arrayList = new ArrayList<>(featureIDMap.size());
        for (int i2 = 0; i2 < featureIDMap.size(); i2++) {
            arrayList.add(new TreeFeature(i2, size, iArr, fArr));
        }
        for (int i3 = 0; i3 < dataset.size(); i3++) {
            Example example2 = dataset.getExample(i3);
            int i4 = 0;
            VectorIterator it2 = SparseVector.createSparseVector(example2, featureIDMap, false).iterator();
            while (it2.hasNext()) {
                VectorTuple vectorTuple = (VectorTuple) it2.next();
                int i5 = vectorTuple.index;
                for (int i6 = i4; i6 < i5; i6++) {
                    arrayList.get(i6).observeValue(0.0d, i3);
                }
                arrayList.get(i5).observeValue(vectorTuple.value, i3);
                if (i4 > i5) {
                    logger.severe("Example = " + example2);
                    throw new IllegalStateException("Features aren't ordered. At id " + i3 + ", lastID = " + i4 + ", curID = " + i5);
                }
                if (i4 - 1 == i5) {
                    logger.severe("Example = " + example2);
                    throw new IllegalStateException("Features are repeated. At id " + i3 + ", lastID = " + i4 + ", curID = " + i5);
                }
                i4 = i5 + 1;
            }
            for (int i7 = i4; i7 < size2; i7++) {
                arrayList.get(i7).observeValue(0.0d, i3);
            }
            if (i3 % 1000 == 0) {
                logger.fine("Processed example " + i3);
            }
        }
        logger.fine("Sorting features");
        arrayList.forEach((v0) -> {
            v0.sort();
        });
        logger.fine("Fixing InvertedFeature sizes");
        arrayList.forEach((v0) -> {
            v0.fixSize();
        });
        logger.fine("Built initial List<TreeFeature>");
        return arrayList;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        throw new NotSerializableException("ClassifierTrainingNode is a runtime class only, and should not be serialized.");
    }
}
