package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.BucketOrder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders;
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

/* loaded from: input_file:org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.class */
public class Precision implements EvaluationMetric {
    private static final String AGG_NAME_PREFIX = "classification_precision_";
    static final String ACTUAL_CLASSES_NAMES_AGG_NAME = "classification_precision_by_actual_class";
    static final String BY_PREDICTED_CLASS_AGG_NAME = "classification_precision_by_predicted_class";
    static final String PER_PREDICTED_CLASS_PRECISION_AGG_NAME = "classification_precision_per_predicted_class_precision";
    static final String AVG_PRECISION_AGG_NAME = "classification_precision_avg_precision";
    private static final int MAX_CLASSES_CARDINALITY = 1000;
    private final SetOnce<String> actualField = new SetOnce<>();
    private final SetOnce<List<String>> topActualClassNames = new SetOnce<>();
    private final SetOnce<Result> result = new SetOnce<>();
    public static final ParseField NAME = new ParseField("precision", new String[0]);
    private static final ObjectParser<Precision, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Precision::new);

    /* loaded from: input_file:org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision$Result.class */
    public static class Result implements EvaluationMetricResult {
        private static final ParseField CLASSES = new ParseField("classes", new String[0]);
        private static final ParseField AVG_PRECISION = new ParseField("avg_precision", new String[0]);
        private static final ConstructingObjectParser<Result, Void> PARSER = new ConstructingObjectParser<>("precision_result", true, objArr -> {
            return new Result((List) objArr[0], ((Double) objArr[1]).doubleValue());
        });
        private final List<PerClassSingleValue> classes;
        private final double avgPrecision;

        public static Result fromXContent(XContentParser xContentParser) {
            return (Result) PARSER.apply(xContentParser, (Object) null);
        }

        public Result(List<PerClassSingleValue> list, double d) {
            this.classes = Collections.unmodifiableList((List) ExceptionsHelper.requireNonNull(list, CLASSES));
            this.avgPrecision = d;
        }

        public Result(StreamInput streamInput) throws IOException {
            this.classes = Collections.unmodifiableList(streamInput.readList(PerClassSingleValue::new));
            this.avgPrecision = streamInput.readDouble();
        }

        public String getWriteableName() {
            return MlEvaluationNamedXContentProvider.registeredMetricName(Classification.NAME, Precision.NAME);
        }

        @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult
        public String getMetricName() {
            return Precision.NAME.getPreferredName();
        }

        public List<PerClassSingleValue> getClasses() {
            return this.classes;
        }

        public double getAvgPrecision() {
            return this.avgPrecision;
        }

        public void writeTo(StreamOutput streamOutput) throws IOException {
            streamOutput.writeList(this.classes);
            streamOutput.writeDouble(this.avgPrecision);
        }

        public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
            xContentBuilder.startObject();
            xContentBuilder.field(CLASSES.getPreferredName(), this.classes);
            xContentBuilder.field(AVG_PRECISION.getPreferredName(), this.avgPrecision);
            xContentBuilder.endObject();
            return xContentBuilder;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            Result result = (Result) obj;
            return Objects.equals(this.classes, result.classes) && this.avgPrecision == result.avgPrecision;
        }

        public int hashCode() {
            return Objects.hash(this.classes, Double.valueOf(this.avgPrecision));
        }

        static {
            PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), PerClassSingleValue.PARSER, CLASSES);
            PARSER.declareDouble(ConstructingObjectParser.constructorArg(), AVG_PRECISION);
        }
    }

    public static Precision fromXContent(XContentParser xContentParser) {
        return (Precision) PARSER.apply(xContentParser, (Object) null);
    }

    public Precision() {
    }

    public Precision(StreamInput streamInput) throws IOException {
    }

    public String getWriteableName() {
        return MlEvaluationNamedXContentProvider.registeredMetricName(Classification.NAME, NAME);
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric
    public String getName() {
        return NAME.getPreferredName();
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric
    public Set<String> getRequiredFields() {
        return Sets.newHashSet(new String[]{EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName()});
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric
    public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters evaluationParameters, EvaluationFields evaluationFields) {
        String actualField = evaluationFields.getActualField();
        String predictedField = evaluationFields.getPredictedField();
        this.actualField.trySet(actualField);
        if (this.topActualClassNames.get() == null) {
            return Tuple.tuple(List.of(AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME).field(actualField).order(List.of(BucketOrder.count(false), BucketOrder.key(true))).size(1000)), List.of());
        }
        if (this.result.get() != null) {
            return Tuple.tuple(List.of(), List.of());
        }
        return Tuple.tuple(List.of(AggregationBuilders.filters(BY_PREDICTED_CLASS_AGG_NAME, (FiltersAggregator.KeyedFilter[]) ((List) this.topActualClassNames.get()).stream().map(str -> {
            return new FiltersAggregator.KeyedFilter(str, QueryBuilders.matchQuery(predictedField, str).lenient(true));
        }).toArray(i -> {
            return new FiltersAggregator.KeyedFilter[i];
        })).subAggregation(AggregationBuilders.avg(PER_PREDICTED_CLASS_PRECISION_AGG_NAME).script(PainlessScripts.buildIsEqualScript(actualField, predictedField)))), List.of(PipelineAggregatorBuilders.avgBucket(AVG_PRECISION_AGG_NAME, "classification_precision_by_predicted_class>classification_precision_per_predicted_class_precision")));
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric
    public void process(Aggregations aggregations) {
        Terms terms = aggregations.get(ACTUAL_CLASSES_NAMES_AGG_NAME);
        if (this.topActualClassNames.get() == null && (terms instanceof Terms)) {
            Terms terms2 = terms;
            if (terms2.getSumOfOtherDocCounts() > 0) {
                throw ExceptionsHelper.badRequestException("Cannot calculate average precision. Cardinality of field [{}] is too high", this.actualField.get());
            }
            this.topActualClassNames.set((List) terms2.getBuckets().stream().map((v0) -> {
                return v0.getKeyAsString();
            }).sorted().collect(Collectors.toList()));
        }
        Filters filters = aggregations.get(BY_PREDICTED_CLASS_AGG_NAME);
        NumericMetricsAggregation.SingleValue singleValue = aggregations.get(AVG_PRECISION_AGG_NAME);
        if (this.result.get() == null && (filters instanceof Filters)) {
            Filters filters2 = filters;
            if (singleValue instanceof NumericMetricsAggregation.SingleValue) {
                NumericMetricsAggregation.SingleValue singleValue2 = singleValue;
                ArrayList arrayList = new ArrayList(filters2.getBuckets().size());
                for (Filters.Bucket bucket : filters2.getBuckets()) {
                    String keyAsString = bucket.getKeyAsString();
                    double value = bucket.getAggregations().get(PER_PREDICTED_CLASS_PRECISION_AGG_NAME).value();
                    if (Double.isFinite(value)) {
                        arrayList.add(new PerClassSingleValue(keyAsString, value));
                    }
                }
                this.result.set(new Result(arrayList, singleValue2.value()));
            }
        }
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric
    public Optional<Result> getResult() {
        return Optional.ofNullable((Result) this.result.get());
    }

    public void writeTo(StreamOutput streamOutput) throws IOException {
    }

    public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject();
        xContentBuilder.endObject();
        return xContentBuilder;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        return obj != null && getClass() == obj.getClass();
    }

    public int hashCode() {
        return Objects.hashCode(NAME.getPreferredName());
    }
}
