package org.elasticsearch.xpack.esql.optimizer.rules;

import java.time.DateTimeException;
import java.util.Arrays;
import java.util.List;
import java.util.function.BiFunction;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.ArithmeticOperation;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub;

/* loaded from: input_file:org/elasticsearch/xpack/esql/optimizer/rules/SimplifyComparisonsArithmetics.class */
public final class SimplifyComparisonsArithmetics extends OptimizerRules.OptimizerExpressionRule<BinaryComparison> {
    BiFunction<DataType, DataType, Boolean> typesCompatible;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/elasticsearch/xpack/esql/optimizer/rules/SimplifyComparisonsArithmetics$AddSubSimplifier.class */
    public static class AddSubSimplifier extends OperationSimplifier {
        AddSubSimplifier(BinaryComparison binaryComparison) {
            super(binaryComparison);
        }

        @Override // org.elasticsearch.xpack.esql.optimizer.rules.SimplifyComparisonsArithmetics.OperationSimplifier
        boolean isOpUnsafe() {
            if (this.operation.dataType().isRationalNumber()) {
                return true;
            }
            return this.operation.symbol().equals(DefaultBinaryArithmeticOperation.SUB.symbol()) && !(this.opRight instanceof Literal) && SimplifyComparisonsArithmetics.tryFolding(new Sub(Source.EMPTY, this.opLeft, this.bcLiteral)) == null;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/elasticsearch/xpack/esql/optimizer/rules/SimplifyComparisonsArithmetics$MulDivSimplifier.class */
    public static class MulDivSimplifier extends OperationSimplifier {
        private final boolean isDiv;
        private final int opRightSign;

        MulDivSimplifier(BinaryComparison binaryComparison) {
            super(binaryComparison);
            this.isDiv = this.operation.symbol().equals(DefaultBinaryArithmeticOperation.DIV.symbol());
            this.opRightSign = sign(this.opRight);
        }

        @Override // org.elasticsearch.xpack.esql.optimizer.rules.SimplifyComparisonsArithmetics.OperationSimplifier
        boolean isOpUnsafe() {
            if (this.operation.dataType().isWholeNumber() && this.isDiv) {
                return true;
            }
            if (this.isDiv || !this.opLeft.dataType().isWholeNumber()) {
                return this.opRightSign == 0;
            }
            long longValue = ((Number) this.opLiteral.value()).longValue();
            return longValue == 0 || ((Number) this.bcLiteral.value()).longValue() % longValue != 0;
        }

        @Override // org.elasticsearch.xpack.esql.optimizer.rules.SimplifyComparisonsArithmetics.OperationSimplifier
        Expression postProcess(BinaryComparison binaryComparison) {
            return this.opRightSign < 0 ? binaryComparison.reverse() : binaryComparison;
        }

        private static int sign(Object obj) {
            int i = 1;
            if (obj instanceof Number) {
                i = (int) Math.signum(((Number) obj).doubleValue());
            } else if (obj instanceof Literal) {
                i = sign(((Literal) obj).value());
            } else if (obj instanceof Neg) {
                i = -sign(((Neg) obj).field());
            } else if (obj instanceof ArithmeticOperation) {
                ArithmeticOperation arithmeticOperation = (ArithmeticOperation) obj;
                if (SimplifyComparisonsArithmetics.isMulOrDiv(arithmeticOperation.symbol())) {
                    i = sign(arithmeticOperation.left()) * sign(arithmeticOperation.right());
                }
            }
            return i;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/elasticsearch/xpack/esql/optimizer/rules/SimplifyComparisonsArithmetics$OperationSimplifier.class */
    public static abstract class OperationSimplifier {
        final BinaryComparison comparison;
        final Literal bcLiteral;
        final ArithmeticOperation operation;
        final Expression opLeft;
        final Expression opRight;
        final Literal opLiteral;

        OperationSimplifier(BinaryComparison binaryComparison) {
            this.comparison = binaryComparison;
            this.operation = binaryComparison.left();
            this.bcLiteral = binaryComparison.right();
            this.opLeft = this.operation.left();
            this.opRight = this.operation.right();
            if (this.opLeft instanceof Literal) {
                this.opLiteral = this.opLeft;
            } else if (this.opRight instanceof Literal) {
                this.opLiteral = this.opRight;
            } else {
                this.opLiteral = null;
            }
        }

        final boolean isUnsafe(BiFunction<DataType, DataType, Boolean> biFunction) {
            if (this.opLiteral == null || this.opLiteral.dataType().isRationalNumber() || this.bcLiteral.dataType().isRationalNumber() || !biFunction.apply(this.bcLiteral.dataType(), this.opLiteral.dataType()).booleanValue()) {
                return true;
            }
            return isOpUnsafe();
        }

        final Expression apply() {
            Literal literal = this.operation.dataType().isRationalNumber() ? new Literal(this.bcLiteral.source(), Double.valueOf(((Number) this.bcLiteral.value()).doubleValue()), DataType.DOUBLE) : this.bcLiteral;
            Expression tryFolding = SimplifyComparisonsArithmetics.tryFolding(this.operation.binaryComparisonInverse().create(literal.source(), literal, this.opRight));
            return tryFolding != null ? postProcess((BinaryComparison) this.comparison.replaceChildren(List.of(this.opLeft, tryFolding))) : this.comparison;
        }

        abstract boolean isOpUnsafe();

        Expression postProcess(BinaryComparison binaryComparison) {
            return binaryComparison;
        }
    }

    public SimplifyComparisonsArithmetics(BiFunction<DataType, DataType, Boolean> biFunction) {
        super(OptimizerRules.TransformDirection.UP);
        this.typesCompatible = biFunction;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Expression rule(BinaryComparison binaryComparison) {
        if (binaryComparison.right() instanceof Literal) {
            if (binaryComparison.left() instanceof ArithmeticOperation) {
                return simplifyBinaryComparison(binaryComparison);
            }
            if (binaryComparison.left() instanceof Neg) {
                return foldNegation(binaryComparison);
            }
        }
        return binaryComparison;
    }

    private Expression simplifyBinaryComparison(BinaryComparison binaryComparison) {
        String symbol = binaryComparison.left().symbol();
        if (symbol.equals(DefaultBinaryArithmeticOperation.MOD.symbol())) {
            return binaryComparison;
        }
        OperationSimplifier operationSimplifier = null;
        if (isMulOrDiv(symbol)) {
            operationSimplifier = new MulDivSimplifier(binaryComparison);
        } else if (symbol.equals(DefaultBinaryArithmeticOperation.ADD.symbol()) || symbol.equals(DefaultBinaryArithmeticOperation.SUB.symbol())) {
            operationSimplifier = new AddSubSimplifier(binaryComparison);
        }
        return (operationSimplifier == null || operationSimplifier.isUnsafe(this.typesCompatible)) ? binaryComparison : operationSimplifier.apply();
    }

    private static boolean isMulOrDiv(String str) {
        return str.equals(DefaultBinaryArithmeticOperation.MUL.symbol()) || str.equals(DefaultBinaryArithmeticOperation.DIV.symbol());
    }

    private static Expression foldNegation(BinaryComparison binaryComparison) {
        Literal right = binaryComparison.right();
        Expression tryFolding = tryFolding(new Neg(right.source(), right));
        return tryFolding == null ? binaryComparison : binaryComparison.reverse().replaceChildren(Arrays.asList(binaryComparison.left().field(), tryFolding));
    }

    private static Expression tryFolding(Expression expression) {
        if (expression.foldable()) {
            try {
                expression = new Literal(expression.source(), expression.fold(), expression.dataType());
            } catch (ArithmeticException | DateTimeException e) {
                expression = null;
            }
        }
        return expression;
    }
}
