package org.tribuo.regression.rtree;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.ArrayDeque;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.SplittableRandom;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.SparseModel;
import org.tribuo.common.tree.AbstractCARTTrainer;
import org.tribuo.common.tree.AbstractTrainingNode;
import org.tribuo.common.tree.TreeModel;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.rtree.impl.RegressorTrainingNode;
import org.tribuo.regression.rtree.impurity.MeanSquaredError;
import org.tribuo.regression.rtree.impurity.RegressorImpurity;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/regression/rtree/CARTRegressionTrainer.class */
public final class CARTRegressionTrainer extends AbstractCARTTrainer<Regressor> {

    @Config(description = "Regression impurity measure used to determine split quality.")
    private RegressorImpurity impurity;

    public CARTRegressionTrainer(int i, float f, float f2, float f3, boolean z, RegressorImpurity regressorImpurity, long j) {
        super(i, f, f2, f3, z, j);
        this.impurity = new MeanSquaredError();
        this.impurity = regressorImpurity;
        postConfig();
    }

    public CARTRegressionTrainer(int i, float f, float f2, float f3, RegressorImpurity regressorImpurity, long j) {
        this(i, f, f2, f3, false, regressorImpurity, j);
    }

    public CARTRegressionTrainer() {
        this(Integer.MAX_VALUE);
    }

    public CARTRegressionTrainer(int i) {
        this(i, 5.0f, 0.0f, 1.0f, false, new MeanSquaredError(), 12345L);
    }

    protected AbstractTrainingNode<Regressor> mkTrainingNode(Dataset<Regressor> dataset, AbstractTrainingNode.LeafDeterminer leafDeterminer) {
        throw new IllegalStateException("Shouldn't reach here.");
    }

    public TreeModel<Regressor> train(Dataset<Regressor> dataset, Map<String, Provenance> map) {
        return train(dataset, map, -1);
    }

    public TreeModel<Regressor> train(Dataset<Regressor> dataset, Map<String, Provenance> map, int i) {
        SplittableRandom split;
        TrainerProvenance m5getProvenance;
        if (dataset.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        synchronized (this) {
            if (i != -1) {
                setInvocationCount(i);
            }
            split = this.rng.split();
            m5getProvenance = m5getProvenance();
            this.trainInvocationCounter++;
        }
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        ImmutableOutputInfo outputIDInfo = dataset.getOutputIDInfo();
        Set<Regressor> domain = outputIDInfo.getDomain();
        int min = Math.min(Math.round(this.fractionFeaturesInSplit * featureIDMap.size()), featureIDMap.size());
        int[] iArr = new int[featureIDMap.size()];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            iArr[i2] = i2;
        }
        int[] iArr2 = min != featureIDMap.size() ? new int[min] : iArr;
        float f = 0.0f;
        Iterator it = dataset.iterator();
        while (it.hasNext()) {
            f += ((Example) it.next()).getWeight();
        }
        AbstractTrainingNode.LeafDeterminer leafDeterminer = new AbstractTrainingNode.LeafDeterminer(this.maxDepth, this.minChildWeight, getMinImpurityDecrease() * f);
        RegressorTrainingNode.InvertedData invertData = RegressorTrainingNode.invertData(dataset);
        HashMap hashMap = new HashMap();
        for (Regressor regressor : domain) {
            String str = regressor.getNames()[0];
            RegressorTrainingNode regressorTrainingNode = new RegressorTrainingNode(this.impurity, invertData, outputIDInfo.getID(regressor), str, dataset.size(), featureIDMap, outputIDInfo, leafDeterminer);
            ArrayDeque arrayDeque = new ArrayDeque();
            arrayDeque.add(regressorTrainingNode);
            while (!arrayDeque.isEmpty()) {
                AbstractTrainingNode abstractTrainingNode = (AbstractTrainingNode) arrayDeque.poll();
                if (abstractTrainingNode.getImpurity() > 0.0d && abstractTrainingNode.getDepth() < this.maxDepth && abstractTrainingNode.getWeightSum() >= this.minChildWeight) {
                    if (min != featureIDMap.size()) {
                        Util.randpermInPlace(iArr, split);
                        System.arraycopy(iArr, 0, iArr2, 0, min);
                    }
                    Iterator it2 = abstractTrainingNode.buildTree(iArr2, split, getUseRandomSplitPoints()).iterator();
                    while (it2.hasNext()) {
                        arrayDeque.addFirst((AbstractTrainingNode) it2.next());
                    }
                }
            }
            hashMap.put(str, regressorTrainingNode.convertTree());
        }
        return new IndependentRegressionTreeModel("cart-tree", new ModelProvenance(TreeModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), m5getProvenance, map), featureIDMap, outputIDInfo, false, hashMap);
    }

    public String toString() {
        return "CARTRegressionTrainer(maxDepth=" + this.maxDepth + ",minChildWeight=" + this.minChildWeight + ",minImpurityDecrease=" + this.minImpurityDecrease + ",fractionFeaturesInSplit=" + this.fractionFeaturesInSplit + ",useRandomSplitPoints=" + this.useRandomSplitPoints + ",impurity=" + this.impurity.toString() + ",seed=" + this.seed + ")";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m5getProvenance() {
        return new TrainerProvenanceImpl(this);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ SparseModel m1train(Dataset dataset, Map map, int i) {
        return train((Dataset<Regressor>) dataset, (Map<String, Provenance>) map, i);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ SparseModel m2train(Dataset dataset, Map map) {
        return train((Dataset<Regressor>) dataset, (Map<String, Provenance>) map);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m3train(Dataset dataset, Map map, int i) {
        return train((Dataset<Regressor>) dataset, (Map<String, Provenance>) map, i);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m4train(Dataset dataset, Map map) {
        return train((Dataset<Regressor>) dataset, (Map<String, Provenance>) map);
    }
}
