package org.tribuo.regression.rtree;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Set;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.common.tree.LeafNode;
import org.tribuo.common.tree.Node;
import org.tribuo.common.tree.SplitNode;
import org.tribuo.common.tree.TreeModel;
import org.tribuo.common.tree.protos.TreeNodeProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.la.SparseVector;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.rtree.protos.IndependentRegressionTreeModelProto;
import org.tribuo.regression.rtree.protos.TreeNodeListProto;

/* loaded from: input_file:org/tribuo/regression/rtree/IndependentRegressionTreeModel.class */
public final class IndependentRegressionTreeModel extends TreeModel<Regressor> {
    private static final long serialVersionUID = 1;
    public static final int CURRENT_VERSION = 0;
    private final Map<String, Node<Regressor>> roots;

    /* JADX INFO: Access modifiers changed from: package-private */
    public IndependentRegressionTreeModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, boolean z, Map<String, Node<Regressor>> map) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, z, gatherActiveFeatures(immutableFeatureMap, map));
        this.roots = map;
    }

    public static IndependentRegressionTreeModel deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        IndependentRegressionTreeModelProto unpack = any.unpack(IndependentRegressionTreeModelProto.class);
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        if (!deserialize.outputDomain().getOutput(0).getClass().equals(Regressor.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a regression domain, found " + deserialize.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = deserialize.outputDomain();
        if (unpack.getNodesCount() == 0) {
            throw new IllegalStateException("Invalid protobuf, tree must contain nodes");
        }
        if (unpack.getNodesCount() != outputDomain.size()) {
            throw new IllegalStateException("Invalid protobuf, must have one tree per output dimension, found " + unpack.getNodesCount());
        }
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, TreeNodeListProto> entry : unpack.getNodesMap().entrySet()) {
            List<TreeNodeProto> nodesList = entry.getValue().getNodesList();
            if (nodesList.size() == 0) {
                throw new IllegalStateException("Invalid protobuf, tree must contain nodes");
            }
            hashMap.put(entry.getKey(), (Node) deserializeFromProtos(nodesList, Regressor.class).get(0));
        }
        return new IndependentRegressionTreeModel(deserialize.name(), deserialize.provenance(), deserialize.featureDomain(), outputDomain, deserialize.generatesProbabilities(), hashMap);
    }

    private static Map<String, List<String>> gatherActiveFeatures(ImmutableFeatureMap immutableFeatureMap, Map<String, Node<Regressor>> map) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, Node<Regressor>> entry : map.entrySet()) {
            LinkedHashSet linkedHashSet = new LinkedHashSet();
            LinkedList linkedList = new LinkedList();
            linkedList.offer(entry.getValue());
            while (!linkedList.isEmpty()) {
                SplitNode splitNode = (Node) linkedList.poll();
                if (splitNode != null && !splitNode.isLeaf()) {
                    SplitNode splitNode2 = splitNode;
                    linkedHashSet.add(immutableFeatureMap.get(splitNode2.getFeatureID()).getName());
                    linkedList.offer(splitNode2.getGreaterThan());
                    linkedList.offer(splitNode2.getLessThanOrEqual());
                }
            }
            hashMap.put(entry.getKey(), new ArrayList(linkedHashSet));
        }
        return hashMap;
    }

    public int getDepth() {
        int i = 0;
        Iterator<Node<Regressor>> it = this.roots.values().iterator();
        while (it.hasNext()) {
            int computeDepth = computeDepth(0, it.next());
            if (i < computeDepth) {
                i = computeDepth;
            }
        }
        return i;
    }

    public Prediction<Regressor> predict(Example<Regressor> example) {
        SparseVector createSparseVector = SparseVector.createSparseVector(example, this.featureIDMap, false);
        if (createSparseVector.numActiveElements() == 0) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<String, Node<Regressor>> entry : this.roots.entrySet()) {
            Node value = entry.getValue();
            Node value2 = entry.getValue();
            while (true) {
                Node node = value2;
                if (node != null) {
                    value = node;
                    value2 = value.getNextNode(createSparseVector);
                }
            }
            arrayList.add(((LeafNode) value).getPrediction(createSparseVector.numActiveElements(), example));
        }
        return combine(arrayList);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        int size = i < 0 ? this.featureIDMap.size() : i;
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        LinkedList linkedList = new LinkedList();
        for (Map.Entry<String, Node<Regressor>> entry : this.roots.entrySet()) {
            hashMap2.clear();
            linkedList.clear();
            linkedList.offer(entry.getValue());
            while (!linkedList.isEmpty()) {
                SplitNode splitNode = (Node) linkedList.poll();
                if (splitNode != null && !splitNode.isLeaf()) {
                    SplitNode splitNode2 = splitNode;
                    String name = this.featureIDMap.get(splitNode2.getFeatureID()).getName();
                    hashMap2.put(name, Integer.valueOf(((Integer) hashMap2.getOrDefault(name, 0)).intValue() + 1));
                    linkedList.offer(splitNode2.getGreaterThan());
                    linkedList.offer(splitNode2.getLessThanOrEqual());
                }
            }
            Comparator comparingDouble = Comparator.comparingDouble(pair -> {
                return Math.abs(((Double) pair.getB()).doubleValue());
            });
            PriorityQueue priorityQueue = new PriorityQueue(size, comparingDouble);
            Iterator it = hashMap2.entrySet().iterator();
            while (it.hasNext()) {
                Pair pair2 = new Pair((String) ((Map.Entry) it.next()).getKey(), Double.valueOf(((Integer) r0.getValue()).intValue()));
                if (priorityQueue.size() < size) {
                    priorityQueue.offer(pair2);
                } else if (comparingDouble.compare(pair2, (Pair) priorityQueue.peek()) > 0) {
                    priorityQueue.poll();
                    priorityQueue.offer(pair2);
                }
            }
            ArrayList arrayList = new ArrayList();
            while (priorityQueue.size() > 0) {
                arrayList.add((Pair) priorityQueue.poll());
            }
            Collections.reverse(arrayList);
            hashMap.put(entry.getKey(), arrayList);
        }
        return hashMap;
    }

    public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
        SparseVector createSparseVector = SparseVector.createSparseVector(example, this.featureIDMap, false);
        if (createSparseVector.numActiveElements() == 0) {
            return Optional.empty();
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, Node<Regressor>> entry : this.roots.entrySet()) {
            arrayList.clear();
            SplitNode splitNode = (Node) entry.getValue();
            SplitNode splitNode2 = (Node) entry.getValue();
            while (true) {
                SplitNode splitNode3 = splitNode2;
                if (splitNode3 == null) {
                    break;
                }
                splitNode = splitNode3;
                if (splitNode instanceof SplitNode) {
                    arrayList.add(this.featureIDMap.get(splitNode3.getFeatureID()).getName());
                }
                splitNode2 = splitNode.getNextNode(createSparseVector);
            }
            arrayList2.add(((LeafNode) splitNode).getPrediction(createSparseVector.numActiveElements(), example));
            ArrayList arrayList3 = new ArrayList();
            int size = arrayList.size() + 1;
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                arrayList3.add(new Pair((String) it.next(), Double.valueOf(size + 0.0d)));
                size--;
            }
            hashMap.put(entry.getKey(), arrayList3);
        }
        return Optional.of(new Excuse(example, combine(arrayList2), hashMap));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public IndependentRegressionTreeModel m7copy(String str, ModelProvenance modelProvenance) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, Node<Regressor>> entry : this.roots.entrySet()) {
            hashMap.put(entry.getKey(), entry.getValue().copy());
        }
        return new IndependentRegressionTreeModel(str, modelProvenance, this.featureIDMap, this.outputIDInfo, this.generatesProbabilities, hashMap);
    }

    private Prediction<Regressor> combine(List<Prediction<Regressor>> list) {
        Regressor.DimensionTuple[] dimensionTupleArr = new Regressor.DimensionTuple[list.size()];
        int i = 0;
        int i2 = 0;
        for (Prediction<Regressor> prediction : list) {
            if (i < prediction.getNumActiveFeatures()) {
                i = prediction.getNumActiveFeatures();
            }
            Regressor output = prediction.getOutput();
            if (!(output instanceof Regressor.DimensionTuple)) {
                throw new IllegalStateException("All the leaves should contain DimensionTuple not Regressor");
            }
            dimensionTupleArr[i2] = (Regressor.DimensionTuple) output;
            i2++;
        }
        return new Prediction<>(new Regressor(dimensionTupleArr), i, list.get(0).getExample());
    }

    public Set<String> getFeatures() {
        HashSet hashSet = new HashSet();
        LinkedList linkedList = new LinkedList();
        Iterator<Map.Entry<String, Node<Regressor>>> it = this.roots.entrySet().iterator();
        while (it.hasNext()) {
            linkedList.offer(it.next().getValue());
            while (!linkedList.isEmpty()) {
                SplitNode splitNode = (Node) linkedList.poll();
                if (splitNode != null && !splitNode.isLeaf()) {
                    SplitNode splitNode2 = splitNode;
                    hashSet.add(this.featureIDMap.get(splitNode2.getFeatureID()).getName());
                    linkedList.offer(splitNode2.getGreaterThan());
                    linkedList.offer(splitNode2.getLessThanOrEqual());
                }
            }
        }
        return hashSet;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (Map.Entry<String, Node<Regressor>> entry : this.roots.entrySet()) {
            sb.append("Output '");
            sb.append(entry.getKey());
            sb.append("' - tree = ");
            sb.append(entry.getValue().toString());
            sb.append('\n');
        }
        return "IndependentTreeModel(description=" + this.provenance.toString() + ",\n" + sb.toString() + ")";
    }

    public Map<String, Node<Regressor>> getRoots() {
        return Collections.unmodifiableMap(this.roots);
    }

    public Node<Regressor> getRoot() {
        return null;
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ModelProto m8serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        IndependentRegressionTreeModelProto.Builder newBuilder = IndependentRegressionTreeModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        for (Map.Entry<String, Node<Regressor>> entry : this.roots.entrySet()) {
            newBuilder.putNodes(entry.getKey(), TreeNodeListProto.newBuilder().addAllNodes(serializeToNodes(entry.getValue())).m105build());
        }
        ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
        newBuilder2.setSerializedData(Any.pack(newBuilder.m57build()));
        newBuilder2.setClassName(IndependentRegressionTreeModel.class.getName());
        newBuilder2.setVersion(0);
        return newBuilder2.build();
    }
}
