package io.squashql.query;

import io.squashql.query.database.AQueryEngine;
import io.squashql.query.database.DatabaseQuery;
import io.squashql.query.database.SQLTranslator;
import io.squashql.query.database.SqlUtils;
import io.squashql.query.dto.ConditionDto;
import io.squashql.query.dto.ConditionType;
import io.squashql.query.dto.CriteriaDto;
import io.squashql.query.dto.JoinDto;
import io.squashql.query.dto.JoinMappingDto;
import io.squashql.query.dto.JoinType;
import io.squashql.query.dto.TableDto;
import io.squashql.query.dto.VirtualTableDto;
import io.squashql.store.Field;
import io.squashql.store.Store;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/squashql/query/TestSQLTranslator.class */
public class TestSQLTranslator {
    private static final String BASE_STORE_NAME = "baseStore";
    private static final String BASE_STORE_NAME_ESCAPED = SqlUtils.backtickEscape(BASE_STORE_NAME);
    private static final Function<String, Field> fieldProvider = str -> {
        Function function = str -> {
            boolean z = -1;
            switch (str.hashCode()) {
                case -547448614:
                    if (str.equals("baseStore.delta")) {
                        z = 3;
                        break;
                    }
                    break;
                case 111150:
                    if (str.equals("pnl")) {
                        z = false;
                        break;
                    }
                    break;
                case 95468472:
                    if (str.equals("delta")) {
                        z = 2;
                        break;
                    }
                    break;
                case 1483239376:
                    if (str.equals("baseStore.pnl")) {
                        z = true;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                case true:
                    return Double.TYPE;
                case true:
                case true:
                    return Double.class;
                default:
                    return String.class;
            }
        };
        String[] split = str.split("\\.");
        if (split.length <= 1) {
            return new Field((String) null, str, (Class) function.apply(str));
        }
        String str2 = split[0];
        String str3 = split[1];
        return new Field(str2, str3, (Class) function.apply(str3));
    };

    @Test
    void testGrandTotal() {
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().aggregatedMeasure("pnl.sum", "pnl", "sum").aggregatedMeasure("delta.sum", "delta", "sum").aggregatedMeasure("pnl.avg", "pnl", "avg").aggregatedMeasure("mean pnl", "pnl", "avg").table(BASE_STORE_NAME), fieldProvider)).isEqualTo("select sum(`pnl`) as `pnl.sum`, sum(`delta`) as `delta.sum`, avg(`pnl`) as `pnl.avg`, avg(`pnl`) as `mean pnl` from " + BASE_STORE_NAME_ESCAPED);
    }

    @Test
    void testLimit() {
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().aggregatedMeasure("pnl.sum", "pnl", "sum").limit(8).table(BASE_STORE_NAME), fieldProvider)).isEqualTo("select sum(`pnl`) as `pnl.sum` from `baseStore` limit 8");
    }

    @Test
    void testGroupBy() {
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().withSelect(fieldProvider.apply("scenario")).withSelect(fieldProvider.apply("type")).aggregatedMeasure("pnl.sum", "pnl", "sum").aggregatedMeasure("delta.sum", "delta", "sum").aggregatedMeasure("pnl.avg", "pnl", "avg").table(BASE_STORE_NAME), fieldProvider)).isEqualTo("select `scenario`, `type`, sum(`pnl`) as `pnl.sum`, sum(`delta`) as `delta.sum`, avg(`pnl`) as `pnl.avg` from " + BASE_STORE_NAME_ESCAPED + " group by `scenario`, `type`");
    }

    @Test
    void testGroupByWithFullName() {
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().withSelect(fieldProvider.apply(SqlUtils.getFieldFullName(BASE_STORE_NAME, "scenario"))).withSelect(fieldProvider.apply(SqlUtils.getFieldFullName(BASE_STORE_NAME, "type"))).aggregatedMeasure("pnl.sum", SqlUtils.getFieldFullName(BASE_STORE_NAME, "pnl"), "sum").aggregatedMeasure("delta.sum", SqlUtils.getFieldFullName(BASE_STORE_NAME, "delta"), "sum").table(BASE_STORE_NAME), fieldProvider)).isEqualTo("select `baseStore`.`scenario`, `baseStore`.`type`, sum(`baseStore`.`pnl`) as `pnl.sum`, sum(`baseStore`.`delta`) as `delta.sum` from " + BASE_STORE_NAME_ESCAPED + " group by `baseStore`.`scenario`, `baseStore`.`type`");
    }

    @Test
    void testDifferentMeasures() {
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().table(BASE_STORE_NAME).aggregatedMeasure("pnl.sum", "pnl", "sum").expressionMeasure("indice", "100 * sum(`delta`) / sum(`pnl`)"), fieldProvider)).isEqualTo("select sum(`pnl`) as `pnl.sum`, 100 * sum(`delta`) / sum(`pnl`) as `indice` from " + BASE_STORE_NAME_ESCAPED);
    }

    @Test
    void testWithFullRollup() {
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().withSelect(fieldProvider.apply("scenario")).withSelect(fieldProvider.apply("type")).withRollup(fieldProvider.apply("scenario")).withRollup(fieldProvider.apply("type")).aggregatedMeasure("pnl.sum", "price", "sum").table(BASE_STORE_NAME), fieldProvider)).isEqualTo("select `scenario`, `type`,\n grouping(`scenario`), grouping(`type`),\n sum(`price`) as `pnl.sum`\n from `baseStore` group by rollup(`scenario`, `type`)\n".replaceAll(System.lineSeparator(), ""));
    }

    @Test
    void testWithFullRollupWithFullName() {
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().withSelect(fieldProvider.apply(SqlUtils.getFieldFullName(BASE_STORE_NAME, "scenario"))).withSelect(fieldProvider.apply(SqlUtils.getFieldFullName(BASE_STORE_NAME, "type"))).withRollup(fieldProvider.apply(SqlUtils.getFieldFullName(BASE_STORE_NAME, "scenario"))).withRollup(fieldProvider.apply(SqlUtils.getFieldFullName(BASE_STORE_NAME, "type"))).aggregatedMeasure("pnl.sum", "price", "sum").table(BASE_STORE_NAME), fieldProvider)).isEqualTo("select `baseStore`.`scenario`, `baseStore`.`type`,\n grouping(`baseStore`.`scenario`), grouping(`baseStore`.`type`),\n sum(`price`) as `pnl.sum`\n from `baseStore` group by rollup(`baseStore`.`scenario`, `baseStore`.`type`)\n".replaceAll(System.lineSeparator(), ""));
    }

    @Test
    void testWithPartialRollup() {
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().withSelect(fieldProvider.apply("scenario")).withSelect(fieldProvider.apply("type")).withRollup(fieldProvider.apply("scenario")).aggregatedMeasure("pnl.sum", "price", "sum").table(BASE_STORE_NAME), fieldProvider)).isEqualTo("select `scenario`, `type`, grouping(`scenario`), sum(`price`) as `pnl.sum` from `baseStore` group by `type`, rollup(`scenario`)");
    }

    @Test
    void testJoins() {
        TableDto tableDto = new TableDto(BASE_STORE_NAME);
        TableDto tableDto2 = new TableDto("table1");
        JoinMappingDto joinMappingDto = new JoinMappingDto(tableDto.name + ".id", tableDto2.name + ".table1_id");
        TableDto tableDto3 = new TableDto("table2");
        JoinMappingDto joinMappingDto2 = new JoinMappingDto(tableDto.name + ".id", tableDto3.name + ".table2_id");
        TableDto tableDto4 = new TableDto("table3");
        JoinMappingDto joinMappingDto3 = new JoinMappingDto(tableDto3.name + ".table2_field_1", tableDto4.name + ".table3_id");
        TableDto tableDto5 = new TableDto("table4");
        List of = List.of(new JoinMappingDto(tableDto2.name + ".table1_field_2", tableDto5.name + ".table4_id_1"), new JoinMappingDto(tableDto2.name + ".table1_field_3", tableDto5.name + ".table4_id_2"));
        tableDto.joins.add(new JoinDto(tableDto2, JoinType.INNER, joinMappingDto));
        tableDto.joins.add(new JoinDto(tableDto3, JoinType.LEFT, joinMappingDto2));
        tableDto2.joins.add(new JoinDto(tableDto5, JoinType.INNER, of));
        tableDto3.joins.add(new JoinDto(tableDto4, JoinType.INNER, joinMappingDto3));
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().table(tableDto).aggregatedMeasure("pnl.avg", "pnl", "avg"), fieldProvider)).isEqualTo("select avg(`pnl`) as `pnl.avg` from " + BASE_STORE_NAME_ESCAPED + " inner join `table1` on " + BASE_STORE_NAME_ESCAPED + ".`id` = `table1`.`table1_id` inner join `table4` on `table1`.`table1_field_2` = `table4`.`table4_id_1` and `table1`.`table1_field_3` = `table4`.`table4_id_2` left join `table2` on " + BASE_STORE_NAME_ESCAPED + ".`id` = `table2`.`table2_id` inner join `table3` on `table2`.`table2_field_1` = `table3`.`table3_id`");
    }

    @Test
    void testJoinsEquijoinsMultipleCondCrossTables() {
        TableDto tableDto = new TableDto("A");
        TableDto tableDto2 = new TableDto("B");
        JoinMappingDto joinMappingDto = new JoinMappingDto(tableDto.name + ".a_id", tableDto2.name + ".b_id");
        TableDto tableDto3 = new TableDto("C");
        List of = List.of(new JoinMappingDto(tableDto3.name + ".c_other_id", tableDto2.name + ".b_other_id"), new JoinMappingDto(tableDto3.name + ".c_f", tableDto.name + ".a_f"));
        tableDto.join(tableDto2, JoinType.INNER, joinMappingDto);
        tableDto.join(tableDto3, JoinType.LEFT, of);
        Function createFieldSupplier = AQueryEngine.createFieldSupplier(Map.of("A", new Store("A", List.of(new Field("A", "a_id", Integer.TYPE), new Field("A", "a_f", Integer.TYPE), new Field("A", "y", Integer.TYPE))), "B", new Store("B", List.of(new Field("B", "b_id", Integer.TYPE), new Field("B", "b_other_id", Integer.TYPE))), "C", new Store("C", List.of(new Field("c", "a_id", Integer.TYPE), new Field("C", "c_f", Integer.TYPE), new Field("C", "c_other_id", Integer.TYPE)))));
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().table(tableDto).withSelect((Field) createFieldSupplier.apply("A.y")), createFieldSupplier)).isEqualTo("select `A`.`y` from `A` inner join `B` on `A`.`a_id` = `B`.`b_id` left join `C` on `C`.`c_other_id` = `B`.`b_other_id` and `C`.`c_f` = `A`.`a_f` group by `A`.`y`");
    }

    @Test
    void testConditionsWithValue() {
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().withSelect(fieldProvider.apply("scenario")).withSelect(fieldProvider.apply("type")).aggregatedMeasure("pnl.sum", "pnl", "sum").whereCriteria(Functions.all(new CriteriaDto[]{Functions.criterion("scenario", Functions.or(Functions.eq("base"), Functions.eq("s1"), new ConditionDto[]{Functions.eq("s2")})), Functions.criterion("delta", Functions.ge(Double.valueOf(123.0d))), Functions.criterion("type", Functions.or(Functions.eq("A"), Functions.eq("B"), new ConditionDto[0])), Functions.criterion("pnl", Functions.lt(Double.valueOf(10.0d)))})).table(BASE_STORE_NAME), fieldProvider)).isEqualTo("select `scenario`, `type`, sum(`pnl`) as `pnl.sum` from " + BASE_STORE_NAME_ESCAPED + " where (((`scenario` = 'base' or `scenario` = 's1') or `scenario` = 's2') and `delta` >= 123.0 and (`type` = 'A' or `type` = 'B') and `pnl` < 10.0) group by `scenario`, `type`");
    }

    @Test
    void testConditionWithValueFullPath() {
        Field field = new Field(BASE_STORE_NAME, "scenario", String.class);
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().withSelect(field).aggregatedMeasure("pnl.sum", "pnl", "sum").whereCriteria(Functions.criterion(SqlUtils.getFieldFullName(field), Functions.and(Functions.eq("base"), Functions.eq("s2"), new ConditionDto[0]))).table(BASE_STORE_NAME), fieldProvider)).isEqualTo("select `baseStore`.`scenario`, sum(`pnl`) as `pnl.sum` from " + BASE_STORE_NAME_ESCAPED + " where (`baseStore`.`scenario` = 'base' and `baseStore`.`scenario` = 's2') group by `baseStore`.`scenario`");
    }

    @Test
    void testSelectFromSelect() {
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().subQuery(new DatabaseQuery().table(new TableDto("a")).withSelect(fieldProvider.apply("c1")).withSelect(fieldProvider.apply("c3")).withMeasure(Functions.avg("mean", "c2"))).withSelect(fieldProvider.apply("c3")).withMeasure(Functions.sum("sum GT", "mean")).whereCriteria(Functions.criterion("type", Functions.eq("myType"))), fieldProvider)).isEqualTo("select `c3`, sum(`mean`) as `sum GT` from (select `c1`, `c3`, avg(`c2`) as `mean` from `a` group by `c1`, `c3`) where `type` = 'myType' group by `c3`");
    }

    @Test
    void testBinaryOperationMeasure() {
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().table(new TableDto("a")).withMeasure(Functions.plus("plus", Functions.sum("pnl.sum", "pnl"), Functions.avg("delta.avg", "delta"))), fieldProvider)).isEqualTo("select sum(`pnl`)+avg(`delta`) as `plus` from `a`");
    }

    @Test
    void testAggregatedMeasures() {
        TableDto tableDto = new TableDto(BASE_STORE_NAME);
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().table(tableDto).withMeasure(Functions.sum("pnlSum", "pnl")).withMeasure(Functions.sumIf("pnlSumFiltered", "pnl", Functions.criterion("country", Functions.eq("france")))), fieldProvider)).isEqualTo("select sum(`pnl`) as `pnlSum`, sum(case when `country` = 'france' then `pnl` end) as `pnlSumFiltered` from `baseStore`");
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().table(tableDto).withMeasure(Functions.sum("pnlSum", tableDto.name + ".pnl")).withMeasure(Functions.sumIf("pnlSumFiltered", tableDto.name + ".pnl", Functions.criterion(tableDto.name + ".country", Functions.eq("france")))), fieldProvider)).isEqualTo(String.format("select sum(`%1$s`.`pnl`) as `pnlSum`, sum(case when `%1$s`.`country` = 'france' then `%1$s`.`pnl` end) as `pnlSumFiltered` from `%1$s`", BASE_STORE_NAME));
    }

    @Test
    void testVirtualTable() {
        TableDto tableDto = new TableDto(BASE_STORE_NAME);
        VirtualTableDto virtualTableDto = new VirtualTableDto("virtual", List.of("a", "b"), List.of(List.of(0, "0"), List.of(1, "1")));
        tableDto.join(new TableDto(virtualTableDto.name), JoinType.INNER, new JoinMappingDto("id", "a", ConditionType.EQ));
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().table(tableDto).virtualTable(virtualTableDto).withMeasure(Functions.sum("pnl.sum", "pnl")).withSelect(fieldProvider.apply("id")).withSelect(fieldProvider.apply("b")), fieldProvider)).isEqualTo(String.format("with %2$s as (select 0 as `a`, '0' as `b` union all select 1 as `a`, '1' as `b`) select `id`, `b`, sum(`pnl`) as `pnl.sum` from `%1$s` inner join %2$s on `id` = `a` group by `id`, `b`", BASE_STORE_NAME, virtualTableDto.name));
    }

    @Test
    void testVirtualTableFullName() {
        TableDto tableDto = new TableDto(BASE_STORE_NAME);
        VirtualTableDto virtualTableDto = new VirtualTableDto("virtual", List.of("a", "b"), List.of(List.of(0, "0"), List.of(1, "1")));
        tableDto.join(new TableDto(virtualTableDto.name), JoinType.INNER, new JoinMappingDto("baseStore.id", virtualTableDto.name + ".a", ConditionType.EQ));
        Assertions.assertThat(SQLTranslator.translate(new DatabaseQuery().table(tableDto).virtualTable(virtualTableDto).withMeasure(Functions.sum("pnl.sum", "pnl")).withSelect(fieldProvider.apply("id")).withSelect(fieldProvider.apply("b")), fieldProvider)).isEqualTo(String.format("with %2$s as (select 0 as `a`, '0' as `b` union all select 1 as `a`, '1' as `b`) select `id`, `b`, sum(`pnl`) as `pnl.sum` from `%1$s` inner join %2$s on `%1$s`.`id` = %2$s.`a` group by `id`, `b`", BASE_STORE_NAME, virtualTableDto.name));
    }
}
