package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;

/* loaded from: input_file:org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineProjections.class */
public final class CombineProjections extends OptimizerRules.OptimizerRule<UnaryPlan> {
    static final /* synthetic */ boolean $assertionsDisabled;

    public CombineProjections() {
        super(OptimizerRules.TransformDirection.UP);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.OptimizerRule
    public LogicalPlan rule(UnaryPlan unaryPlan) {
        Aggregate aggregate;
        List<? extends NamedExpression> aggregates;
        List<? extends NamedExpression> projectAggregations;
        LogicalPlan child = unaryPlan.child();
        if (unaryPlan instanceof Project) {
            Project project = (Project) unaryPlan;
            if (child instanceof Project) {
                Project project2 = (Project) child;
                project = project2.withProjections(combineProjections(project.projections(), project2.projections()));
                child = project.child();
                unaryPlan = project;
            }
            if ((child instanceof Aggregate) && (projectAggregations = projectAggregations(project.projections(), (aggregates = (aggregate = (Aggregate) child).aggregates()))) != null) {
                unaryPlan = new Aggregate(aggregate.source(), aggregate.child(), aggregate.aggregateType(), replacePrunedAliasesUsedInGroupBy(aggregate.groupings(), aggregates, projectAggregations), projectAggregations);
            }
            return unaryPlan;
        }
        if (unaryPlan instanceof Aggregate) {
            Aggregate aggregate2 = (Aggregate) unaryPlan;
            if (child instanceof Project) {
                Project project3 = (Project) child;
                List<Expression> groupings = aggregate2.groupings();
                ArrayList arrayList = new ArrayList(aggregate2.groupings().size());
                Iterator<Expression> it = groupings.iterator();
                while (it.hasNext()) {
                    Alias alias = (Expression) it.next();
                    if (!(alias instanceof Attribute)) {
                        if (alias instanceof Alias) {
                            Alias alias2 = alias;
                            if (alias2.child() instanceof Categorize) {
                                arrayList.add(alias2);
                            }
                        }
                        throw new EsqlIllegalArgumentException("Expected an Attribute, got {}", alias);
                    }
                    arrayList.add((Attribute) alias);
                }
                unaryPlan = new Aggregate(aggregate2.source(), project3.child(), aggregate2.aggregateType(), combineUpperGroupingsAndLowerProjections(arrayList, project3.projections()), combineProjections(aggregate2.aggregates(), project3.projections()));
            }
        }
        return unaryPlan;
    }

    private static List<? extends NamedExpression> projectAggregations(List<? extends NamedExpression> list, List<? extends NamedExpression> list2) {
        AttributeSet attributeSet = new AttributeSet();
        Iterator<? extends NamedExpression> it = list.iterator();
        while (it.hasNext()) {
            Expression unwrap = Alias.unwrap(it.next());
            if (attributeSet.contains(unwrap)) {
                return null;
            }
            attributeSet.add(Expressions.attribute(unwrap));
        }
        return combineProjections(list, list2);
    }

    private static List<NamedExpression> combineProjections(List<? extends NamedExpression> list, List<? extends NamedExpression> list2) {
        AttributeMap attributeMap = new AttributeMap();
        AttributeMap attributeMap2 = new AttributeMap();
        Iterator<? extends NamedExpression> it = list2.iterator();
        while (it.hasNext()) {
            Alias alias = (NamedExpression) it.next();
            attributeMap2.put(alias.toAttribute(), Alias.unwrap(alias));
            if (alias instanceof Alias) {
                Alias alias2 = alias;
                Expression child = alias2.child();
                attributeMap.put(alias.toAttribute(), alias2.replaceChild((Expression) attributeMap2.resolve(child, child)));
            }
        }
        ArrayList arrayList = new ArrayList();
        Iterator<? extends NamedExpression> it2 = list.iterator();
        while (it2.hasNext()) {
            arrayList.add(trimNonTopLevelAliases(it2.next().transformUp(Attribute.class, attribute -> {
                return (Expression) attributeMap.resolve(attribute, attribute);
            })));
        }
        return arrayList;
    }

    private static List<Expression> combineUpperGroupingsAndLowerProjections(List<? extends NamedExpression> list, List<? extends NamedExpression> list2) {
        if (!$assertionsDisabled && list.size() > 1 && list.stream().anyMatch(namedExpression -> {
            return namedExpression.anyMatch(expression -> {
                return expression instanceof Categorize;
            });
        })) {
            throw new AssertionError("CombineProjections only tested with a single CATEGORIZE with no additional groups");
        }
        AttributeMap attributeMap = new AttributeMap();
        for (NamedExpression namedExpression2 : list2) {
            attributeMap.put(namedExpression2.toAttribute(), Alias.unwrap(namedExpression2));
        }
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        Iterator<? extends NamedExpression> it = list.iterator();
        while (it.hasNext()) {
            linkedHashSet.add(it.next().transformUp(Attribute.class, attribute -> {
                return (Expression) attributeMap.resolve(attribute, attribute);
            }));
        }
        return new ArrayList(linkedHashSet);
    }

    private List<Expression> replacePrunedAliasesUsedInGroupBy(List<Expression> list, List<? extends NamedExpression> list2, List<? extends NamedExpression> list3) {
        AttributeMap attributeMap = new AttributeMap();
        AttributeSet attributeSet = new AttributeSet(Expressions.asAttributes(list3));
        Iterator<? extends NamedExpression> it = list2.iterator();
        while (it.hasNext()) {
            Alias alias = (NamedExpression) it.next();
            if (alias instanceof Alias) {
                Alias alias2 = alias;
                Attribute attribute = alias.toAttribute();
                if (!attributeSet.contains(attribute)) {
                    attributeMap.put(attribute, alias2.child());
                }
            }
        }
        if (attributeMap.isEmpty()) {
            return list;
        }
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<Expression> it2 = list.iterator();
        while (it2.hasNext()) {
            Expression transformUp = it2.next().transformUp(Attribute.class, attribute2 -> {
                return (Expression) attributeMap.resolve(attribute2, attribute2);
            });
            if (!Expressions.anyMatch(arrayList, expression -> {
                return Expressions.equalsAsAttribute(expression, transformUp);
            })) {
                arrayList.add(transformUp);
            }
        }
        return arrayList;
    }

    public static Expression trimNonTopLevelAliases(Expression expression) {
        if (!(expression instanceof Alias)) {
            return trimAliases(expression);
        }
        Alias alias = (Alias) expression;
        return alias.replaceChild(trimAliases(alias.child()));
    }

    private static Expression trimAliases(Expression expression) {
        return expression.transformDown(Alias.class, (v0) -> {
            return v0.child();
        });
    }

    static {
        $assertionsDisabled = !CombineProjections.class.desiredAssertionStatus();
    }
}
