package org.elasticsearch.xpack.esql.optimizer.rules.logical.local;

import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Filter;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.stats.SearchStats;

/* loaded from: input_file:org/elasticsearch/xpack/esql/optimizer/rules/logical/local/InferNonNullAggConstraint.class */
public class InferNonNullAggConstraint extends OptimizerRules.ParameterizedOptimizerRule<Aggregate, LocalLogicalOptimizerContext> {
    public InferNonNullAggConstraint() {
        super(OptimizerRules.TransformDirection.UP);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.ParameterizedOptimizerRule
    public LogicalPlan rule(Aggregate aggregate, LocalLogicalOptimizerContext localLogicalOptimizerContext) {
        if (aggregate.groupings().size() > 0) {
            return aggregate;
        }
        SearchStats searchStats = localLogicalOptimizerContext.searchStats();
        Aggregate aggregate2 = aggregate;
        List<? extends NamedExpression> aggregates = aggregate.aggregates();
        LinkedHashSet newLinkedHashSetWithExpectedSize = Sets.newLinkedHashSetWithExpectedSize(aggregates.size());
        Iterator<? extends NamedExpression> it = aggregates.iterator();
        while (it.hasNext()) {
            AggregateFunction unwrap = Alias.unwrap(it.next());
            if (unwrap instanceof AggregateFunction) {
                FieldAttribute field = unwrap.field();
                if (field.foldable() || !(field instanceof FieldAttribute) || !searchStats.isIndexed(field.name())) {
                    return aggregate2;
                }
                newLinkedHashSetWithExpectedSize.add(field);
            }
        }
        if (newLinkedHashSetWithExpectedSize.size() > 0) {
            aggregate2 = aggregate.replaceChild((LogicalPlan) new Filter(aggregate.source(), aggregate.child(), Predicates.combineOr(newLinkedHashSetWithExpectedSize.stream().map(expression -> {
                return new IsNotNull(aggregate.source(), expression);
            }).toList())));
        }
        return aggregate2;
    }
}
