package io.trino.aws.proxy.server.credentials;

import com.google.common.base.Preconditions;
import com.google.inject.Inject;
import io.airlift.log.Logger;
import io.trino.aws.proxy.server.remote.RemoteS3Facade;
import io.trino.aws.proxy.spi.credentials.Credential;
import io.trino.aws.proxy.spi.credentials.Credentials;
import io.trino.aws.proxy.spi.credentials.CredentialsProvider;
import io.trino.aws.proxy.spi.remote.RemoteSessionRole;
import jakarta.annotation.PreDestroy;
import java.io.Closeable;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.SwitchBootstraps;
import java.time.Instant;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.endpoints.Endpoint;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;

/* loaded from: input_file:io/trino/aws/proxy/server/credentials/CredentialsController.class */
public class CredentialsController {
    private static final Logger log = Logger.get(CredentialsController.class);
    private final RemoteS3Facade remoteS3Facade;
    private final CredentialsProvider credentialsProvider;
    private final Map<String, Session> remoteSessions = new ConcurrentHashMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/aws/proxy/server/credentials/CredentialsController$Session.class */
    public final class Session implements Closeable {
        private final String sessionName;
        private final StsClient stsClient;
        private final StsAssumeRoleCredentialsProvider credentialsProvider;
        private final AtomicLong useCount = new AtomicLong();
        private volatile Instant lastUsage = Instant.now();

        private Session(String str, StsClient stsClient, StsAssumeRoleCredentialsProvider stsAssumeRoleCredentialsProvider) {
            this.sessionName = (String) Objects.requireNonNull(str, "sessionName is null");
            this.stsClient = (StsClient) Objects.requireNonNull(stsClient, "stsClient is null");
            this.credentialsProvider = (StsAssumeRoleCredentialsProvider) Objects.requireNonNull(stsAssumeRoleCredentialsProvider, "credentialsProvider is null");
        }

        @Override // java.io.Closeable, java.lang.AutoCloseable
        public void close() {
            Preconditions.checkState(this.useCount.get() == 0, "Session is currently being used");
            Objects.requireNonNull(CredentialsController.this.remoteSessions.remove(this.sessionName), "Session was already closed");
            this.stsClient.close();
        }

        private <T> Optional<T> withUsage(Credentials credentials, Function<Credentials, Optional<T>> function) {
            incrementUsage();
            try {
                Optional<T> apply = function.apply(Credentials.build(credentials.emulated(), currentCredential()));
                decrementUsage();
                return apply;
            } catch (Throwable th) {
                decrementUsage();
                throw th;
            }
        }

        private void incrementUsage() {
            this.lastUsage = Instant.now();
            this.useCount.incrementAndGet();
        }

        private void decrementUsage() {
            if (this.useCount.decrementAndGet() < 0) {
                throw new IllegalStateException("Session useCount has gone negative");
            }
        }

        private Credential currentCredential() {
            Preconditions.checkState(CredentialsController.this.remoteSessions.containsKey(this.sessionName), "Session is closed");
            AwsSessionCredentials resolveCredentials = this.credentialsProvider.resolveCredentials();
            Objects.requireNonNull(resolveCredentials);
            switch ((int) SwitchBootstraps.typeSwitch(MethodHandles.lookup(), "typeSwitch", MethodType.methodType(Integer.TYPE, Object.class, Integer.TYPE), AwsSessionCredentials.class).dynamicInvoker().invoke(resolveCredentials, 0) /* invoke-custom */) {
                case 0:
                    AwsSessionCredentials awsSessionCredentials = resolveCredentials;
                    return new Credential(awsSessionCredentials.accessKeyId(), awsSessionCredentials.secretAccessKey(), Optional.of(awsSessionCredentials.sessionToken()));
                default:
                    return new Credential(resolveCredentials.accessKeyId(), resolveCredentials.secretAccessKey());
            }
        }
    }

    @Inject
    public CredentialsController(RemoteS3Facade remoteS3Facade, CredentialsProvider credentialsProvider) {
        this.remoteS3Facade = (RemoteS3Facade) Objects.requireNonNull(remoteS3Facade, "remoteS3Facade is null");
        this.credentialsProvider = (CredentialsProvider) Objects.requireNonNull(credentialsProvider, "credentialsProvider is null");
    }

    @PreDestroy
    public void shutdown() {
        this.remoteSessions.values().forEach((v0) -> {
            v0.close();
        });
    }

    public <T> Optional<T> withCredentials(String str, Optional<String> optional, Function<Credentials, Optional<T>> function) {
        Optional credentials = this.credentialsProvider.credentials(str, optional);
        credentials.ifPresentOrElse(credentials2 -> {
            log.debug("Credentials found. EmulatedAccessKey: %s", new Object[]{str});
        }, () -> {
            log.debug("Credentials not found. EmulatedAccessKey: %s", new Object[]{str});
        });
        return credentials.flatMap(credentials3 -> {
            return credentials3.remoteSessionRole().flatMap(remoteSessionRole -> {
                return internalRemoteSession(remoteSessionRole, credentials3).withUsage(credentials3, function);
            }).or(() -> {
                return (Optional) function.apply(credentials3);
            });
        });
    }

    private Session internalRemoteSession(RemoteSessionRole remoteSessionRole, Credentials credentials) {
        String accessKey = credentials.emulated().accessKey();
        return this.remoteSessions.computeIfAbsent(accessKey, str -> {
            return internalStartRemoteSession(remoteSessionRole, credentials.requiredRemoteCredential(), accessKey);
        });
    }

    private Session internalStartRemoteSession(RemoteSessionRole remoteSessionRole, Credential credential, String str) {
        StsClient stsClient = (StsClient) StsClient.builder().region(Region.of(remoteSessionRole.region())).credentialsProvider(StaticCredentialsProvider.create((AwsCredentials) credential.session().map(str2 -> {
            return AwsSessionCredentials.create(credential.accessKey(), credential.secretKey(), str2);
        }).orElseGet(() -> {
            return AwsBasicCredentials.create(credential.accessKey(), credential.secretKey());
        }))).endpointProvider(stsEndpointParams -> {
            return CompletableFuture.completedFuture(Endpoint.builder().url(this.remoteS3Facade.remoteUri(remoteSessionRole.region())).build());
        }).build();
        return new Session(str, stsClient, StsAssumeRoleCredentialsProvider.builder().refreshRequest(builder -> {
            builder.roleArn(remoteSessionRole.roleArn()).roleSessionName(str);
            Optional externalId = remoteSessionRole.externalId();
            Objects.requireNonNull(builder);
            externalId.ifPresent(builder::externalId);
        }).stsClient(stsClient).asyncCredentialUpdateEnabled(true).build());
    }
}
