package org.elasticsearch.xpack.esql.optimizer.rules.physical.local;

import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.geometry.Circle;
import org.elasticsearch.geometry.Geometry;
import org.elasticsearch.geometry.Point;
import org.elasticsearch.geometry.utils.WellKnownBinary;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.NameId;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialDisjoint;
import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialIntersects;
import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesUtils;
import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.StDistance;
import org.elasticsearch.xpack.esql.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison;
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules;
import org.elasticsearch.xpack.esql.parser.EsqlBaseParser;
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
import org.elasticsearch.xpack.esql.plan.physical.EvalExec;
import org.elasticsearch.xpack.esql.plan.physical.FilterExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;

/* loaded from: input_file:org/elasticsearch/xpack/esql/optimizer/rules/physical/local/EnableSpatialDistancePushdown.class */
public class EnableSpatialDistancePushdown extends PhysicalOptimizerRules.ParameterizedOptimizerRule<FilterExec, LocalPhysicalOptimizerContext> {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.elasticsearch.xpack.esql.optimizer.rules.physical.local.EnableSpatialDistancePushdown$1, reason: invalid class name */
    /* loaded from: input_file:org/elasticsearch/xpack/esql/optimizer/rules/physical/local/EnableSpatialDistancePushdown$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$elasticsearch$xpack$esql$core$type$DataType;

        static {
            try {
                $SwitchMap$org$elasticsearch$xpack$esql$expression$predicate$operator$comparison$EsqlBinaryComparison$BinaryComparisonOperation[EsqlBinaryComparison.BinaryComparisonOperation.LT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$elasticsearch$xpack$esql$expression$predicate$operator$comparison$EsqlBinaryComparison$BinaryComparisonOperation[EsqlBinaryComparison.BinaryComparisonOperation.LTE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$elasticsearch$xpack$esql$expression$predicate$operator$comparison$EsqlBinaryComparison$BinaryComparisonOperation[EsqlBinaryComparison.BinaryComparisonOperation.GT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$elasticsearch$xpack$esql$expression$predicate$operator$comparison$EsqlBinaryComparison$BinaryComparisonOperation[EsqlBinaryComparison.BinaryComparisonOperation.GTE.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            $SwitchMap$org$elasticsearch$xpack$esql$core$type$DataType = new int[DataType.values().length];
            try {
                $SwitchMap$org$elasticsearch$xpack$esql$core$type$DataType[DataType.GEO_POINT.ordinal()] = 1;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$elasticsearch$xpack$esql$core$type$DataType[DataType.GEO_SHAPE.ordinal()] = 2;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$elasticsearch$xpack$esql$core$type$DataType[DataType.CARTESIAN_POINT.ordinal()] = 3;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$elasticsearch$xpack$esql$core$type$DataType[DataType.CARTESIAN_SHAPE.ordinal()] = 4;
            } catch (NoSuchFieldError e8) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/elasticsearch/xpack/esql/optimizer/rules/physical/local/EnableSpatialDistancePushdown$ComparisonType.class */
    public enum ComparisonType {
        LTE(true, false, true),
        LT(true, false, false),
        GTE(false, true, true),
        GT(false, true, false),
        EQ(false, false, true);

        private final boolean lt;
        private final boolean gt;
        private final boolean eq;

        ComparisonType(boolean z, boolean z2, boolean z3) {
            this.lt = z;
            this.gt = z2;
            this.eq = z3;
        }

        static ComparisonType from(EsqlBinaryComparison.BinaryComparisonOperation binaryComparisonOperation) {
            switch (binaryComparisonOperation) {
                case LT:
                    return LT;
                case LTE:
                    return LTE;
                case GT:
                    return GT;
                case GTE:
                    return GTE;
                default:
                    return EQ;
            }
        }

        static ComparisonType invert(ComparisonType comparisonType) {
            switch (comparisonType.ordinal()) {
                case EsqlBaseParser.RULE_singleStatement /* 0 */:
                    return GTE;
                case 1:
                    return GT;
                case 2:
                    return LTE;
                case 3:
                    return LT;
                default:
                    return EQ;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules.ParameterizedOptimizerRule
    public PhysicalPlan rule(FilterExec filterExec, LocalPhysicalOptimizerContext localPhysicalOptimizerContext) {
        FilterExec filterExec2 = filterExec;
        PhysicalPlan child = filterExec.child();
        if (child instanceof EsQueryExec) {
            filterExec2 = rewrite(localPhysicalOptimizerContext.foldCtx(), filterExec, (EsQueryExec) child, LucenePushdownPredicates.from(localPhysicalOptimizerContext.searchStats()));
        } else {
            PhysicalPlan child2 = filterExec.child();
            if (child2 instanceof EvalExec) {
                EvalExec evalExec = (EvalExec) child2;
                PhysicalPlan child3 = evalExec.child();
                if (child3 instanceof EsQueryExec) {
                    filterExec2 = rewriteBySplittingFilter(localPhysicalOptimizerContext.foldCtx(), filterExec, evalExec, (EsQueryExec) child3, LucenePushdownPredicates.from(localPhysicalOptimizerContext.searchStats()));
                }
            }
        }
        return filterExec2;
    }

    private FilterExec rewrite(FoldContext foldContext, FilterExec filterExec, EsQueryExec esQueryExec, LucenePushdownPredicates lucenePushdownPredicates) {
        Expression transformDown = filterExec.condition().transformDown(EsqlBinaryComparison.class, esqlBinaryComparison -> {
            ComparisonType from = ComparisonType.from(esqlBinaryComparison.getFunctionType());
            Expression left = esqlBinaryComparison.left();
            if (left instanceof StDistance) {
                StDistance stDistance = (StDistance) left;
                if (esqlBinaryComparison.right().foldable()) {
                    return rewriteComparison(foldContext, esqlBinaryComparison, stDistance, esqlBinaryComparison.right(), from);
                }
            }
            Expression right = esqlBinaryComparison.right();
            if (right instanceof StDistance) {
                StDistance stDistance2 = (StDistance) right;
                if (esqlBinaryComparison.left().foldable()) {
                    return rewriteComparison(foldContext, esqlBinaryComparison, stDistance2, esqlBinaryComparison.left(), ComparisonType.invert(from));
                }
            }
            return esqlBinaryComparison;
        });
        return (transformDown.equals(filterExec.condition()) || !PushFiltersToSource.canPushToSource(transformDown, lucenePushdownPredicates)) ? filterExec : new FilterExec(filterExec.source(), esQueryExec, transformDown);
    }

    private PhysicalPlan rewriteBySplittingFilter(FoldContext foldContext, FilterExec filterExec, EvalExec evalExec, EsQueryExec esQueryExec, LucenePushdownPredicates lucenePushdownPredicates) {
        Map<NameId, StDistance> pushableDistances = getPushableDistances(evalExec.fields(), lucenePushdownPredicates);
        if (pushableDistances.isEmpty()) {
            return filterExec;
        }
        AttributeMap<Attribute> aliasReplacedBy = PushFiltersToSource.getAliasReplacedBy(evalExec);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (Expression expression : Predicates.splitAnd(filterExec.condition())) {
            Expression expression2 = (Expression) expression.transformUp(ReferenceAttribute.class, referenceAttribute -> {
                return (Expression) aliasReplacedBy.resolve(referenceAttribute, referenceAttribute);
            });
            Expression rewriteDistanceFilters = rewriteDistanceFilters(foldContext, expression2, pushableDistances);
            if (rewriteDistanceFilters.equals(expression2) || !PushFiltersToSource.canPushToSource(rewriteDistanceFilters, lucenePushdownPredicates)) {
                arrayList2.add(expression);
            } else {
                arrayList.add(rewriteDistanceFilters);
            }
        }
        if (arrayList.isEmpty()) {
            return filterExec;
        }
        EvalExec evalExec2 = new EvalExec(evalExec.source(), new FilterExec(filterExec.source(), esQueryExec, Predicates.combineAnd(arrayList)), evalExec.fields());
        return arrayList2.isEmpty() ? evalExec2 : new FilterExec(filterExec.source(), evalExec2, Predicates.combineAnd(arrayList2));
    }

    private Map<NameId, StDistance> getPushableDistances(List<Alias> list, LucenePushdownPredicates lucenePushdownPredicates) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        list.forEach(alias -> {
            StDistance child = alias.child();
            if (child instanceof StDistance) {
                StDistance stDistance = child;
                if (stDistance.translatable(lucenePushdownPredicates)) {
                    linkedHashMap.put(alias.id(), stDistance);
                    return;
                }
            }
            ReferenceAttribute child2 = alias.child();
            if (child2 instanceof ReferenceAttribute) {
                ReferenceAttribute referenceAttribute = child2;
                if (linkedHashMap.containsKey(referenceAttribute.id())) {
                    linkedHashMap.put(alias.id(), (StDistance) linkedHashMap.get(referenceAttribute.id()));
                }
            }
        });
        return linkedHashMap;
    }

    private Expression rewriteDistanceFilters(FoldContext foldContext, Expression expression, Map<NameId, StDistance> map) {
        return expression.transformDown(EsqlBinaryComparison.class, esqlBinaryComparison -> {
            ComparisonType from = ComparisonType.from(esqlBinaryComparison.getFunctionType());
            ReferenceAttribute left = esqlBinaryComparison.left();
            if (left instanceof ReferenceAttribute) {
                ReferenceAttribute referenceAttribute = left;
                if (map.containsKey(referenceAttribute.id()) && esqlBinaryComparison.right().foldable()) {
                    return rewriteComparison(foldContext, esqlBinaryComparison, (StDistance) map.get(referenceAttribute.id()), esqlBinaryComparison.right(), from);
                }
            }
            ReferenceAttribute right = esqlBinaryComparison.right();
            if (right instanceof ReferenceAttribute) {
                ReferenceAttribute referenceAttribute2 = right;
                if (map.containsKey(referenceAttribute2.id()) && esqlBinaryComparison.left().foldable()) {
                    return rewriteComparison(foldContext, esqlBinaryComparison, (StDistance) map.get(referenceAttribute2.id()), esqlBinaryComparison.left(), ComparisonType.invert(from));
                }
            }
            return esqlBinaryComparison;
        });
    }

    private Expression rewriteComparison(FoldContext foldContext, EsqlBinaryComparison esqlBinaryComparison, StDistance stDistance, Expression expression, ComparisonType comparisonType) {
        Object fold = expression.fold(foldContext);
        if (fold instanceof Number) {
            Number number = (Number) fold;
            if (stDistance.right().foldable()) {
                return rewriteDistanceFilter(foldContext, esqlBinaryComparison, stDistance.left(), stDistance.right(), number, comparisonType);
            }
            if (stDistance.left().foldable()) {
                return rewriteDistanceFilter(foldContext, esqlBinaryComparison, stDistance.right(), stDistance.left(), number, comparisonType);
            }
        }
        return esqlBinaryComparison;
    }

    private Expression rewriteDistanceFilter(FoldContext foldContext, EsqlBinaryComparison esqlBinaryComparison, Expression expression, Expression expression2, Number number, ComparisonType comparisonType) {
        DataType shapeDataType = getShapeDataType(expression);
        Geometry makeGeometryFromLiteral = SpatialRelatesUtils.makeGeometryFromLiteral(foldContext, expression2);
        if (makeGeometryFromLiteral instanceof Point) {
            Point point = (Point) makeGeometryFromLiteral;
            double doubleValue = number.doubleValue();
            Source source = esqlBinaryComparison.source();
            if (comparisonType.lt) {
                return new SpatialIntersects(source, expression, makeCircleLiteral(point, comparisonType.eq ? doubleValue : Math.nextDown(doubleValue), expression2, shapeDataType));
            }
            if (comparisonType.gt) {
                return new SpatialDisjoint(source, expression, makeCircleLiteral(point, comparisonType.eq ? doubleValue : Math.nextUp(doubleValue), expression2, shapeDataType));
            }
            if (comparisonType.eq) {
                return new And(source, new SpatialIntersects(source, expression, makeCircleLiteral(point, doubleValue, expression2, shapeDataType)), new SpatialDisjoint(source, expression, makeCircleLiteral(point, Math.nextDown(doubleValue), expression2, shapeDataType)));
            }
        }
        return esqlBinaryComparison;
    }

    private Literal makeCircleLiteral(Point point, double d, Expression expression, DataType dataType) {
        return new Literal(expression.source(), new BytesRef(WellKnownBinary.toWKB(new Circle(point.getX(), point.getY(), d), ByteOrder.LITTLE_ENDIAN)), dataType);
    }

    private DataType getShapeDataType(Expression expression) {
        switch (AnonymousClass1.$SwitchMap$org$elasticsearch$xpack$esql$core$type$DataType[expression.dataType().ordinal()]) {
            case 1:
            case 2:
                return DataType.GEO_SHAPE;
            case 3:
            case 4:
                return DataType.CARTESIAN_SHAPE;
            default:
                throw new IllegalArgumentException("Unsupported spatial data type: " + String.valueOf(expression.dataType()));
        }
    }
}
