package io.trino.plugin.sqlserver;

import com.google.common.base.Enums;
import com.google.common.base.Joiner;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.MoreExecutors;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.plugin.jdbc.BaseJdbcClient;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.ColumnMapping;
import io.trino.plugin.jdbc.ConnectionFactory;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcErrorCode;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcJoinCondition;
import io.trino.plugin.jdbc.JdbcOutputTableHandle;
import io.trino.plugin.jdbc.JdbcSplit;
import io.trino.plugin.jdbc.JdbcTableHandle;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.PredicatePushdownController;
import io.trino.plugin.jdbc.RemoteTableName;
import io.trino.plugin.jdbc.SliceWriteFunction;
import io.trino.plugin.jdbc.StandardColumnMappings;
import io.trino.plugin.jdbc.WriteMapping;
import io.trino.plugin.jdbc.expression.AggregateFunctionRewriter;
import io.trino.plugin.jdbc.expression.ImplementAvgDecimal;
import io.trino.plugin.jdbc.expression.ImplementAvgFloatingPoint;
import io.trino.plugin.jdbc.expression.ImplementCount;
import io.trino.plugin.jdbc.expression.ImplementCountAll;
import io.trino.plugin.jdbc.expression.ImplementMinMax;
import io.trino.plugin.jdbc.expression.ImplementSum;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.CharType;
import io.trino.spi.type.DateType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.TimeType;
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarbinaryType;
import io.trino.spi.type.VarcharType;
import java.math.RoundingMode;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.inject.Inject;
import org.jdbi.v3.core.Handle;
import org.jdbi.v3.core.Jdbi;

/* loaded from: input_file:io/trino/plugin/sqlserver/SqlServerClient.class */
public class SqlServerClient extends BaseJdbcClient {
    public static final int SQL_SERVER_MAX_LIST_EXPRESSIONS = 500;
    private static final Joiner DOT_JOINER = Joiner.on(".");
    private final Cache<SnapshotIsolationEnabledCacheKey, Boolean> snapshotIsolationEnabled;
    private final AggregateFunctionRewriter aggregateFunctionRewriter;
    private static final int MAX_SUPPORTED_TIMESTAMP_PRECISION = 7;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/plugin/sqlserver/SqlServerClient$SnapshotIsolationEnabledCacheKey.class */
    public enum SnapshotIsolationEnabledCacheKey {
        INSTANCE
    }

    @Inject
    public SqlServerClient(BaseJdbcConfig baseJdbcConfig, ConnectionFactory connectionFactory) {
        super(baseJdbcConfig, "\"", connectionFactory);
        this.snapshotIsolationEnabled = CacheBuilder.newBuilder().maximumSize(1L).expireAfterWrite(Duration.ofMinutes(5L)).build();
        JdbcTypeHandle jdbcTypeHandle = new JdbcTypeHandle(-5, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
        this.aggregateFunctionRewriter = new AggregateFunctionRewriter(this::quoted, ImmutableSet.builder().add(new ImplementCountAll(jdbcTypeHandle)).add(new ImplementCount(jdbcTypeHandle)).add(new ImplementMinMax()).add(new ImplementSum(SqlServerClient::toTypeHandle)).add(new ImplementAvgFloatingPoint()).add(new ImplementAvgDecimal()).add(new ImplementAvgBigint()).add(new ImplementSqlServerStdev()).add(new ImplementSqlServerStddevPop()).add(new ImplementSqlServerVariance()).add(new ImplementSqlServerVariancePop()).build());
    }

    protected void renameTable(ConnectorSession connectorSession, String str, String str2, String str3, SchemaTableName schemaTableName) {
        if (!str2.equals(schemaTableName.getSchemaName())) {
            throw new TrinoException(StandardErrorCode.NOT_SUPPORTED, "Table rename across schemas is not supported");
        }
        execute(connectorSession, String.format("sp_rename %s, %s", singleQuote(str, str2, str3), singleQuote(schemaTableName.getTableName())));
    }

    public void renameColumn(ConnectorSession connectorSession, JdbcTableHandle jdbcTableHandle, JdbcColumnHandle jdbcColumnHandle, String str) {
        execute(connectorSession, String.format("sp_rename %s, %s, 'COLUMN'", singleQuote(jdbcTableHandle.getCatalogName(), jdbcTableHandle.getSchemaName(), jdbcTableHandle.getTableName(), jdbcColumnHandle.getColumnName()), singleQuote(str)));
    }

    protected void copyTableSchema(Connection connection, String str, String str2, String str3, String str4, List<String> list) {
        execute(connection, String.format("SELECT %s INTO %s FROM %s WHERE 0 = 1", list.stream().map(this::quoted).collect(Collectors.joining(", ")), quoted(str, str2, str4), quoted(str, str2, str3)));
    }

    public Optional<ColumnMapping> toColumnMapping(ConnectorSession connectorSession, Connection connection, JdbcTypeHandle jdbcTypeHandle) {
        Optional<ColumnMapping> forcedMappingToVarchar = getForcedMappingToVarchar(jdbcTypeHandle);
        if (forcedMappingToVarchar.isPresent()) {
            return forcedMappingToVarchar;
        }
        String str = (String) jdbcTypeHandle.getJdbcTypeName().orElseThrow(() -> {
            return new TrinoException(JdbcErrorCode.JDBC_ERROR, "Type name is missing: " + jdbcTypeHandle);
        });
        boolean z = -1;
        switch (str.hashCode()) {
            case -275146264:
                if (str.equals("varbinary")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return Optional.of(varbinaryColumnMapping());
            default:
                switch (jdbcTypeHandle.getJdbcType()) {
                    case -15:
                    case 1:
                        return Optional.of(StandardColumnMappings.defaultCharColumnMapping(jdbcTypeHandle.getRequiredColumnSize(), false));
                    case -9:
                    case 12:
                        return Optional.of(StandardColumnMappings.defaultVarcharColumnMapping(jdbcTypeHandle.getRequiredColumnSize(), false));
                    case -7:
                        return Optional.of(StandardColumnMappings.booleanColumnMapping());
                    case -6:
                        return Optional.of(StandardColumnMappings.tinyintColumnMapping());
                    case -5:
                        return Optional.of(StandardColumnMappings.bigintColumnMapping());
                    case -4:
                    case -3:
                    case -2:
                        return Optional.of(varbinaryColumnMapping());
                    case 2:
                    case 3:
                        int requiredColumnSize = jdbcTypeHandle.getRequiredColumnSize();
                        int requiredDecimalDigits = jdbcTypeHandle.getRequiredDecimalDigits();
                        int max = requiredColumnSize + Math.max(-requiredDecimalDigits, 0);
                        if (max <= 38) {
                            return Optional.of(StandardColumnMappings.decimalColumnMapping(DecimalType.createDecimalType(max, Math.max(requiredDecimalDigits, 0)), RoundingMode.UNNECESSARY));
                        }
                        break;
                    case 4:
                        return Optional.of(StandardColumnMappings.integerColumnMapping());
                    case 5:
                        return Optional.of(StandardColumnMappings.smallintColumnMapping());
                    case MAX_SUPPORTED_TIMESTAMP_PRECISION /* 7 */:
                        return Optional.of(StandardColumnMappings.realColumnMapping());
                    case 8:
                        return Optional.of(StandardColumnMappings.doubleColumnMapping());
                    case 91:
                        return Optional.of(StandardColumnMappings.dateColumnMapping());
                    case 92:
                        return Optional.of(StandardColumnMappings.timeColumnMapping(TimeType.TIME));
                    case 93:
                        return Optional.of(StandardColumnMappings.timestampColumnMapping(TimestampType.createTimestampType(jdbcTypeHandle.getRequiredDecimalDigits())));
                }
                return legacyToPrestoType(connectorSession, connection, jdbcTypeHandle);
        }
    }

    public WriteMapping toWriteMapping(ConnectorSession connectorSession, Type type) {
        if (type == BooleanType.BOOLEAN) {
            return WriteMapping.booleanMapping("bit", StandardColumnMappings.booleanWriteFunction());
        }
        if (type == BigintType.BIGINT) {
            return WriteMapping.longMapping("bigint", StandardColumnMappings.bigintWriteFunction());
        }
        if (type == IntegerType.INTEGER) {
            return WriteMapping.longMapping("integer", StandardColumnMappings.integerWriteFunction());
        }
        if (type == SmallintType.SMALLINT) {
            return WriteMapping.longMapping("smallint", StandardColumnMappings.smallintWriteFunction());
        }
        if (type == TinyintType.TINYINT) {
            return WriteMapping.longMapping("tinyint", StandardColumnMappings.tinyintWriteFunction());
        }
        if (type == DoubleType.DOUBLE) {
            return WriteMapping.doubleMapping("double precision", StandardColumnMappings.doubleWriteFunction());
        }
        if (type instanceof VarcharType) {
            VarcharType varcharType = (VarcharType) type;
            return WriteMapping.sliceMapping((varcharType.isUnbounded() || varcharType.getBoundedLength() > 4000) ? "nvarchar(max)" : "nvarchar(" + varcharType.getBoundedLength() + ")", StandardColumnMappings.varcharWriteFunction());
        }
        if (type instanceof CharType) {
            CharType charType = (CharType) type;
            return WriteMapping.sliceMapping(charType.getLength() > 4000 ? "nvarchar(max)" : "nchar(" + charType.getLength() + ")", StandardColumnMappings.charWriteFunction());
        }
        if (type instanceof VarbinaryType) {
            return WriteMapping.sliceMapping("varbinary(max)", varbinaryWriteFunction());
        }
        if (type == DateType.DATE) {
            return WriteMapping.longMapping("date", StandardColumnMappings.dateWriteFunction());
        }
        if (!(type instanceof TimestampType)) {
            return legacyToWriteMapping(connectorSession, type);
        }
        TimestampType timestampType = (TimestampType) type;
        String format = String.format("datetime2(%d)", Integer.valueOf(Math.min(timestampType.getPrecision(), MAX_SUPPORTED_TIMESTAMP_PRECISION)));
        return timestampType.getPrecision() <= 6 ? WriteMapping.longMapping(format, StandardColumnMappings.timestampWriteFunction(timestampType)) : WriteMapping.objectMapping(format, StandardColumnMappings.longTimestampWriteFunction(timestampType));
    }

    public Optional<JdbcExpression> implementAggregation(ConnectorSession connectorSession, AggregateFunction aggregateFunction, Map<String, ColumnHandle> map) {
        return this.aggregateFunctionRewriter.rewrite(connectorSession, aggregateFunction, map);
    }

    private static Optional<JdbcTypeHandle> toTypeHandle(DecimalType decimalType) {
        return Optional.of(new JdbcTypeHandle(2, Optional.of("decimal"), Optional.of(Integer.valueOf(decimalType.getPrecision())), Optional.of(Integer.valueOf(decimalType.getScale())), Optional.empty(), Optional.empty()));
    }

    protected Optional<BiFunction<String, Long, String>> limitFunction() {
        return Optional.of((str, l) -> {
            return String.format("SELECT TOP %s * FROM (%s) o", l, str);
        });
    }

    public boolean isLimitGuaranteed(ConnectorSession connectorSession) {
        return true;
    }

    protected boolean isSupportedJoinCondition(JdbcJoinCondition jdbcJoinCondition) {
        if (jdbcJoinCondition.getOperator() == JoinCondition.Operator.IS_DISTINCT_FROM) {
            return false;
        }
        return Stream.of((Object[]) new JdbcColumnHandle[]{jdbcJoinCondition.getLeftColumn(), jdbcJoinCondition.getRightColumn()}).map((v0) -> {
            return v0.getColumnType();
        }).noneMatch(type -> {
            return (type instanceof CharType) || (type instanceof VarcharType);
        });
    }

    protected String createTableSql(RemoteTableName remoteTableName, List<String> list, ConnectorTableMetadata connectorTableMetadata) {
        return String.format("CREATE TABLE %s (%s) %s", quoted(remoteTableName), String.join(", ", list), SqlServerTableProperties.getDataCompression(connectorTableMetadata.getProperties()).map(dataCompression -> {
            return String.format("WITH (DATA_COMPRESSION = %s)", dataCompression);
        }).orElse(""));
    }

    public Map<String, Object> getTableProperties(ConnectorSession connectorSession, JdbcTableHandle jdbcTableHandle) {
        if (!jdbcTableHandle.isNamedRelation()) {
            return ImmutableMap.of();
        }
        try {
            Connection configureConnectionTransactionIsolation = configureConnectionTransactionIsolation(this.connectionFactory.openConnection(connectorSession));
            try {
                Handle open = Jdbi.open(configureConnectionTransactionIsolation);
                try {
                    Map<String, Object> map = (Map) getTableDataCompression(open, jdbcTableHandle).map(dataCompression -> {
                        return ImmutableMap.of(SqlServerTableProperties.DATA_COMPRESSION, dataCompression);
                    }).orElseGet(ImmutableMap::of);
                    if (open != null) {
                        open.close();
                    }
                    if (configureConnectionTransactionIsolation != null) {
                        configureConnectionTransactionIsolation.close();
                    }
                    return map;
                } catch (Throwable th) {
                    if (open != null) {
                        try {
                            open.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new TrinoException(JdbcErrorCode.JDBC_ERROR, e);
        }
    }

    public void abortReadConnection(Connection connection) throws SQLException {
        connection.abort(MoreExecutors.directExecutor());
    }

    public Connection getConnection(ConnectorSession connectorSession, JdbcOutputTableHandle jdbcOutputTableHandle) throws SQLException {
        return configureConnectionTransactionIsolation(super.getConnection(connectorSession, jdbcOutputTableHandle));
    }

    public Connection getConnection(ConnectorSession connectorSession, JdbcSplit jdbcSplit) throws SQLException {
        return configureConnectionTransactionIsolation(super.getConnection(connectorSession, jdbcSplit));
    }

    private Connection configureConnectionTransactionIsolation(Connection connection) throws SQLException {
        try {
            if (hasSnapshotIsolationEnabled(connection)) {
                connection.setTransactionIsolation(4096);
            }
            return connection;
        } catch (SQLException e) {
            connection.close();
            throw e;
        }
    }

    private boolean hasSnapshotIsolationEnabled(Connection connection) throws SQLException {
        try {
            return ((Boolean) this.snapshotIsolationEnabled.get(SnapshotIsolationEnabledCacheKey.INSTANCE, () -> {
                return (Boolean) Jdbi.open(connection).createQuery("SELECT is_read_committed_snapshot_on FROM sys.databases WHERE name = :name").bind("name", connection.getCatalog()).mapTo(Boolean.class).findOne().orElse(false);
            })).booleanValue();
        } catch (ExecutionException e) {
            throw new TrinoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, e);
        }
    }

    private static String singleQuote(String... strArr) {
        return singleQuote(DOT_JOINER.join(strArr));
    }

    private static String singleQuote(String str) {
        return "'" + str + "'";
    }

    public static ColumnMapping varbinaryColumnMapping() {
        return ColumnMapping.sliceMapping(VarbinaryType.VARBINARY, (resultSet, i) -> {
            return Slices.wrappedBuffer(resultSet.getBytes(i));
        }, varbinaryWriteFunction(), PredicatePushdownController.DISABLE_PUSHDOWN);
    }

    private static SliceWriteFunction varbinaryWriteFunction() {
        return new SliceWriteFunction() { // from class: io.trino.plugin.sqlserver.SqlServerClient.1
            public void set(PreparedStatement preparedStatement, int i, Slice slice) throws SQLException {
                preparedStatement.setBytes(i, slice.getBytes());
            }

            public void setNull(PreparedStatement preparedStatement, int i) throws SQLException {
                preparedStatement.setBytes(i, null);
            }
        };
    }

    private static Optional<DataCompression> getTableDataCompression(Handle handle, JdbcTableHandle jdbcTableHandle) {
        return handle.createQuery("SELECT data_compression_desc FROM sys.partitions p INNER JOIN sys.tables t ON p.object_id = t.object_id INNER JOIN sys.schemas s ON t.schema_id = s.schema_id INNER JOIN sys.indexes i ON t.object_id = i.object_id WHERE s.name = :schema AND t.name = :table_name AND p.index_id = 0 AND i.type = 0 AND i.data_space_id NOT IN (SELECT data_space_id FROM sys.partition_schemes)").bind("schema", jdbcTableHandle.getSchemaName()).bind("table_name", jdbcTableHandle.getTableName()).mapTo(String.class).findOne().flatMap(str -> {
            return Enums.getIfPresent(DataCompression.class, str).toJavaUtil();
        });
    }
}
