package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Accountables;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

/* loaded from: input_file:org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.class */
public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel, Accountable {
    private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Tree.class);
    public static final ParseField NAME = new ParseField("tree", new String[0]);
    public static final ParseField FEATURE_NAMES = new ParseField("feature_names", new String[0]);
    public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure", new String[0]);
    public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels", new String[0]);
    private static final ObjectParser<Builder, Void> LENIENT_PARSER = createParser(true);
    private static final ObjectParser<Builder, Void> STRICT_PARSER = createParser(false);
    private final List<String> featureNames;
    private final List<TreeNode> nodes;
    private final TargetType targetType;
    private final List<String> classificationLabels;

    /* loaded from: input_file:org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree$Builder.class */
    public static class Builder {
        private List<String> featureNames;
        private int numNodes;
        private List<String> classificationLabels;
        private TargetType targetType = TargetType.REGRESSION;
        private ArrayList<TreeNode.Builder> nodes = new ArrayList<>();

        public Builder() {
            this.nodes.add(null);
            addLeaf(0, 0.0d);
            this.numNodes = 1;
        }

        public Builder setFeatureNames(List<String> list) {
            this.featureNames = list;
            return this;
        }

        public Builder setRoot(TreeNode.Builder builder) {
            this.nodes.set(0, builder);
            return this;
        }

        public Builder addNode(TreeNode.Builder builder) {
            this.nodes.add(builder);
            return this;
        }

        public Builder setNodes(List<TreeNode.Builder> list) {
            this.nodes = new ArrayList<>((Collection) ExceptionsHelper.requireNonNull(list, Tree.TREE_STRUCTURE.getPreferredName()));
            return this;
        }

        public Builder setNodes(TreeNode.Builder... builderArr) {
            return setNodes(Arrays.asList(builderArr));
        }

        public Builder setTargetType(TargetType targetType) {
            this.targetType = targetType;
            return this;
        }

        public Builder setClassificationLabels(List<String> list) {
            this.classificationLabels = list;
            return this;
        }

        private void setTargetType(String str) {
            this.targetType = TargetType.fromString(str);
        }

        public TreeNode.Builder addJunction(int i, int i2, boolean z, double d) {
            int i3 = this.numNodes;
            this.numNodes = i3 + 1;
            int i4 = this.numNodes;
            this.numNodes = i4 + 1;
            this.nodes.ensureCapacity(i + 1);
            for (int size = this.nodes.size(); size < i + 1; size++) {
                this.nodes.add(null);
            }
            TreeNode.Builder threshold = TreeNode.builder(i).setDefaultLeft(Boolean.valueOf(z)).setLeftChild(Integer.valueOf(i3)).setRightChild(Integer.valueOf(i4)).setSplitFeature(Integer.valueOf(i2)).setThreshold(Double.valueOf(d));
            this.nodes.set(i, threshold);
            while (this.nodes.size() <= i4) {
                this.nodes.add(null);
            }
            return threshold;
        }

        public Builder addLeaf(int i, double d) {
            return addLeaf(i, Arrays.asList(Double.valueOf(d)));
        }

        public Builder addLeaf(int i, List<Double> list) {
            for (int size = this.nodes.size(); size < i + 1; size++) {
                this.nodes.add(null);
            }
            this.nodes.set(i, TreeNode.builder(i).setLeafValue(list));
            return this;
        }

        public Tree build() {
            if (this.nodes.stream().anyMatch((v0) -> {
                return Objects.isNull(v0);
            })) {
                throw ExceptionsHelper.badRequestException("[tree] cannot contain null nodes", new Object[0]);
            }
            return new Tree(this.featureNames, (List) this.nodes.stream().map((v0) -> {
                return v0.build();
            }).collect(Collectors.toList()), this.targetType, this.classificationLabels);
        }
    }

    private static ObjectParser<Builder, Void> createParser(boolean z) {
        ObjectParser<Builder, Void> objectParser = new ObjectParser<>(NAME.getPreferredName(), z, Builder::new);
        objectParser.declareStringArray((v0, v1) -> {
            v0.setFeatureNames(v1);
        }, FEATURE_NAMES);
        objectParser.declareObjectArray((v0, v1) -> {
            v0.setNodes(v1);
        }, (xContentParser, r5) -> {
            return TreeNode.fromXContent(xContentParser, z);
        }, TREE_STRUCTURE);
        objectParser.declareString((v0, v1) -> {
            v0.setTargetType(v1);
        }, TargetType.TARGET_TYPE);
        objectParser.declareStringArray((v0, v1) -> {
            v0.setClassificationLabels(v1);
        }, CLASSIFICATION_LABELS);
        return objectParser;
    }

    public static Tree fromXContentStrict(XContentParser xContentParser) {
        return ((Builder) STRICT_PARSER.apply(xContentParser, (Object) null)).build();
    }

    public static Tree fromXContentLenient(XContentParser xContentParser) {
        return ((Builder) LENIENT_PARSER.apply(xContentParser, (Object) null)).build();
    }

    Tree(List<String> list, List<TreeNode> list2, TargetType targetType, List<String> list3) {
        this.featureNames = Collections.unmodifiableList((List) ExceptionsHelper.requireNonNull(list, FEATURE_NAMES));
        if (((List) ExceptionsHelper.requireNonNull(list2, TREE_STRUCTURE)).size() == 0) {
            throw new IllegalArgumentException("[tree_structure] must not be empty");
        }
        this.nodes = Collections.unmodifiableList(list2);
        this.targetType = (TargetType) ExceptionsHelper.requireNonNull(targetType, TargetType.TARGET_TYPE);
        this.classificationLabels = list3 == null ? null : Collections.unmodifiableList(list3);
    }

    public Tree(StreamInput streamInput) throws IOException {
        this.featureNames = streamInput.readCollectionAsImmutableList((v0) -> {
            return v0.readString();
        });
        this.nodes = streamInput.readCollectionAsImmutableList(TreeNode::new);
        this.targetType = TargetType.fromStream(streamInput);
        if (streamInput.readBoolean()) {
            this.classificationLabels = streamInput.readCollectionAsImmutableList((v0) -> {
                return v0.readString();
            });
        } else {
            this.classificationLabels = null;
        }
    }

    @Override // org.elasticsearch.xpack.core.ml.utils.NamedXContentObject
    public String getName() {
        return NAME.getPreferredName();
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel
    public TargetType targetType() {
        return this.targetType;
    }

    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    public void writeTo(StreamOutput streamOutput) throws IOException {
        streamOutput.writeStringCollection(this.featureNames);
        streamOutput.writeCollection(this.nodes);
        this.targetType.writeTo(streamOutput);
        streamOutput.writeBoolean(this.classificationLabels != null);
        if (this.classificationLabels != null) {
            streamOutput.writeStringCollection(this.classificationLabels);
        }
    }

    public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject();
        xContentBuilder.field(FEATURE_NAMES.getPreferredName(), this.featureNames);
        xContentBuilder.field(TREE_STRUCTURE.getPreferredName(), this.nodes);
        xContentBuilder.field(TargetType.TARGET_TYPE.getPreferredName(), this.targetType.toString());
        if (this.classificationLabels != null) {
            xContentBuilder.field(CLASSIFICATION_LABELS.getPreferredName(), this.classificationLabels);
        }
        xContentBuilder.endObject();
        return xContentBuilder;
    }

    public String toString() {
        return Strings.toString(this);
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        Tree tree = (Tree) obj;
        return Objects.equals(this.featureNames, tree.featureNames) && Objects.equals(this.nodes, tree.nodes) && Objects.equals(this.targetType, tree.targetType) && Objects.equals(this.classificationLabels, tree.classificationLabels);
    }

    public int hashCode() {
        return Objects.hash(this.featureNames, this.nodes, this.targetType, this.classificationLabels);
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel
    public void validate() {
        int maxFeatureIndex = maxFeatureIndex();
        if (maxFeatureIndex >= this.featureNames.size()) {
            throw ExceptionsHelper.badRequestException("feature index [{}] is out of bounds for the [{}] array", Integer.valueOf(maxFeatureIndex), FEATURE_NAMES.getPreferredName());
        }
        if (this.nodes.size() > 1 && this.featureNames.isEmpty()) {
            throw ExceptionsHelper.badRequestException("[{}] is empty and the tree has > 1 nodes; num nodes [{}]. The model Must have features if tree is not a stump", FEATURE_NAMES.getPreferredName(), Integer.valueOf(this.nodes.size()));
        }
        checkTargetType();
        detectMissingNodes();
        detectCycle();
        verifyLeafNodeUniformity();
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel
    public long estimatedNumOperations() {
        return ((long) Math.ceil(Math.log(this.nodes.size()))) + this.featureNames.size();
    }

    int maxFeatureIndex() {
        int i = -1;
        Iterator<TreeNode> it = this.nodes.iterator();
        while (it.hasNext()) {
            i = Math.max(i, it.next().getSplitFeature());
        }
        return i;
    }

    private void checkTargetType() {
        if (this.classificationLabels != null && this.targetType != TargetType.CLASSIFICATION) {
            throw ExceptionsHelper.badRequestException("[target_type] should be [classification] if [classification_labels] are provided", new Object[0]);
        }
        if (this.targetType != TargetType.CLASSIFICATION && this.nodes.stream().anyMatch(treeNode -> {
            return treeNode.getLeafValue().length > 1;
        })) {
            throw ExceptionsHelper.badRequestException("[target_type] should be [classification] if leaf nodes have multiple values", new Object[0]);
        }
    }

    private void detectCycle() {
        Set newHashSetWithExpectedSize = Sets.newHashSetWithExpectedSize(this.nodes.size());
        ArrayDeque arrayDeque = new ArrayDeque(this.nodes.size());
        arrayDeque.add(0);
        while (!arrayDeque.isEmpty()) {
            Integer num = (Integer) arrayDeque.remove();
            if (newHashSetWithExpectedSize.contains(num)) {
                throw ExceptionsHelper.badRequestException("[tree] contains cycle at node {}", num);
            }
            newHashSetWithExpectedSize.add(num);
            TreeNode treeNode = this.nodes.get(num.intValue());
            if (treeNode.getLeftChild() >= 0) {
                arrayDeque.add(Integer.valueOf(treeNode.getLeftChild()));
            }
            if (treeNode.getRightChild() >= 0) {
                arrayDeque.add(Integer.valueOf(treeNode.getRightChild()));
            }
        }
    }

    private void detectMissingNodes() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.nodes.size(); i++) {
            TreeNode treeNode = this.nodes.get(i);
            if (treeNode != null) {
                if (nodeMissing(treeNode.getLeftChild(), this.nodes)) {
                    arrayList.add(Integer.valueOf(treeNode.getLeftChild()));
                }
                if (nodeMissing(treeNode.getRightChild(), this.nodes)) {
                    arrayList.add(Integer.valueOf(treeNode.getRightChild()));
                }
            }
        }
        if (!arrayList.isEmpty()) {
            throw ExceptionsHelper.badRequestException("[tree] contains missing nodes {}", arrayList);
        }
    }

    private void verifyLeafNodeUniformity() {
        Integer num = null;
        for (TreeNode treeNode : this.nodes) {
            if (treeNode.isLeaf()) {
                if (num == null) {
                    num = Integer.valueOf(treeNode.getLeafValue().length);
                } else if (num.intValue() != treeNode.getLeafValue().length) {
                    throw ExceptionsHelper.badRequestException("[tree.tree_structure] all leaf nodes must have the same number of values", new Object[0]);
                }
            }
        }
    }

    private static boolean nodeMissing(int i, List<TreeNode> list) {
        return i >= list.size();
    }

    public long ramBytesUsed() {
        return SHALLOW_SIZE + RamUsageEstimator.sizeOfCollection(this.classificationLabels) + RamUsageEstimator.sizeOfCollection(this.featureNames) + RamUsageEstimator.sizeOfCollection(this.nodes);
    }

    public Collection<Accountable> getChildResources() {
        ArrayList arrayList = new ArrayList(this.nodes.size());
        for (TreeNode treeNode : this.nodes) {
            arrayList.add(Accountables.namedAccountable("tree_node_" + treeNode.getNodeIndex(), treeNode));
        }
        return Collections.unmodifiableCollection(arrayList);
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel
    public TransportVersion getMinimalCompatibilityVersion() {
        return this.nodes.stream().filter((v0) -> {
            return v0.isLeaf();
        }).anyMatch(treeNode -> {
            return treeNode.getLeafValue().length > 1;
        }) ? TransportVersions.V_7_7_0 : TransportVersions.V_7_6_0;
    }
}
