package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.function.BiFunction;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
import org.elasticsearch.xpack.esql.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;

/* loaded from: input_file:org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateNullable.class */
public class PropagateNullable extends OptimizerRules.OptimizerExpressionRule<And> {
    public PropagateNullable() {
        super(OptimizerRules.TransformDirection.DOWN);
    }

    @Override // org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.OptimizerExpressionRule
    public Expression rule(And and, LogicalOptimizerContext logicalOptimizerContext) {
        List<Expression> splitAnd = Predicates.splitAnd(and);
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        LinkedHashSet linkedHashSet2 = new LinkedHashSet();
        LinkedList linkedList = new LinkedList();
        for (Expression expression : splitAnd) {
            if (expression instanceof IsNull) {
                linkedHashSet.add(((IsNull) expression).field());
            } else if (expression instanceof IsNotNull) {
                linkedHashSet2.add(((IsNotNull) expression).field());
            } else {
                linkedList.add(expression);
            }
        }
        return Sets.haveNonEmptyIntersection(linkedHashSet, linkedHashSet2) ? Literal.of(and, Boolean.FALSE) : replace(linkedHashSet, linkedList, splitAnd, this::nullify) | replace(linkedHashSet2, linkedList, splitAnd, this::nonNullify) ? Predicates.combineAnd(splitAnd) : and;
    }

    private static boolean replace(Iterable<Expression> iterable, List<Expression> list, List<Expression> list2, BiFunction<Expression, Expression, Expression> biFunction) {
        Expression apply;
        boolean z = false;
        for (Expression expression : iterable) {
            for (int i = 0; i < list.size(); i++) {
                Expression expression2 = list.get(i);
                Objects.requireNonNull(expression);
                if (expression2.anyMatch(expression::semanticEquals) && (apply = biFunction.apply(expression2, expression)) != expression2) {
                    z = true;
                    list.set(i, apply);
                    list2.replaceAll(expression3 -> {
                        return expression2.semanticEquals(expression3) ? apply : expression3;
                    });
                }
            }
        }
        return z;
    }

    protected Expression nonNullify(Expression expression, Expression expression2) {
        return expression;
    }

    protected Expression nullify(Expression expression, Expression expression2) {
        if (expression instanceof Coalesce) {
            ArrayList arrayList = new ArrayList(expression.children());
            arrayList.removeIf(expression3 -> {
                return expression3.semanticEquals(expression2);
            });
            if (arrayList.size() != expression.children().size() && arrayList.size() > 0) {
                return expression.replaceChildren(arrayList);
            }
        }
        return Literal.of(expression, (Object) null);
    }
}
