package org.apache.pulsar.broker.authentication.oidc;

import com.auth0.jwk.InvalidPublicKeyException;
import com.auth0.jwk.Jwk;
import com.auth0.jwt.JWT;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.exceptions.AlgorithmMismatchException;
import com.auth0.jwt.exceptions.InvalidClaimException;
import com.auth0.jwt.exceptions.JWTDecodeException;
import com.auth0.jwt.exceptions.JWTVerificationException;
import com.auth0.jwt.exceptions.SignatureVerificationException;
import com.auth0.jwt.exceptions.TokenExpiredException;
import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.auth0.jwt.interfaces.Verification;
import io.kubernetes.client.openapi.ApiClient;
import io.kubernetes.client.util.Config;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.File;
import java.io.IOException;
import java.net.SocketAddress;
import java.security.PublicKey;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.Function;
import javax.naming.AuthenticationException;
import javax.net.ssl.SSLSession;
import org.apache.commons.lang.StringUtils;
import org.apache.pulsar.broker.ServiceConfiguration;
import org.apache.pulsar.broker.authentication.AuthenticationDataSource;
import org.apache.pulsar.broker.authentication.AuthenticationProvider;
import org.apache.pulsar.broker.authentication.AuthenticationProviderToken;
import org.apache.pulsar.broker.authentication.AuthenticationState;
import org.apache.pulsar.broker.authentication.metrics.AuthenticationMetrics;
import org.apache.pulsar.common.api.AuthData;
import org.asynchttpclient.AsyncHttpClient;
import org.asynchttpclient.DefaultAsyncHttpClient;
import org.asynchttpclient.DefaultAsyncHttpClientConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenID.class */
public class AuthenticationProviderOpenID implements AuthenticationProvider {
    private static final Logger log = LoggerFactory.getLogger(AuthenticationProviderOpenID.class);
    private static final String SIMPLE_NAME = AuthenticationProviderOpenID.class.getSimpleName();
    private static final String AUTH_METHOD_NAME = "token";
    private Set<String> issuers;
    private OpenIDProviderMetadataCache openIDProviderMetadataCache;
    private JwksCache jwksCache;
    private volatile AsyncHttpClient httpClient;
    private static final String ALG_RS256 = "RS256";
    private static final String ALG_RS384 = "RS384";
    private static final String ALG_RS512 = "RS512";
    private static final String ALG_ES256 = "ES256";
    private static final String ALG_ES384 = "ES384";
    private static final String ALG_ES512 = "ES512";
    private long acceptedTimeLeewaySeconds;
    private FallbackDiscoveryMode fallbackDiscoveryMode;
    private boolean isRoleClaimNotSubject;
    static final String ALLOWED_TOKEN_ISSUERS = "openIDAllowedTokenIssuers";
    static final String ISSUER_TRUST_CERTS_FILE_PATH = "openIDTokenIssuerTrustCertsFilePath";
    static final String FALLBACK_DISCOVERY_MODE = "openIDFallbackDiscoveryMode";
    static final String ALLOWED_AUDIENCES = "openIDAllowedAudiences";
    static final String ROLE_CLAIM = "openIDRoleClaim";
    static final String ROLE_CLAIM_DEFAULT = "sub";
    static final String ACCEPTED_TIME_LEEWAY_SECONDS = "openIDAcceptedTimeLeewaySeconds";
    static final int ACCEPTED_TIME_LEEWAY_SECONDS_DEFAULT = 0;
    static final String CACHE_SIZE = "openIDCacheSize";
    static final int CACHE_SIZE_DEFAULT = 5;
    static final String CACHE_REFRESH_AFTER_WRITE_SECONDS = "openIDCacheRefreshAfterWriteSeconds";
    static final int CACHE_REFRESH_AFTER_WRITE_SECONDS_DEFAULT = 64800;
    static final String CACHE_EXPIRATION_SECONDS = "openIDCacheExpirationSeconds";
    static final int CACHE_EXPIRATION_SECONDS_DEFAULT = 86400;
    static final String KEY_ID_CACHE_MISS_REFRESH_SECONDS = "openIDKeyIdCacheMissRefreshSeconds";
    static final int KEY_ID_CACHE_MISS_REFRESH_SECONDS_DEFAULT = 300;
    static final String HTTP_CONNECTION_TIMEOUT_MILLIS = "openIDHttpConnectionTimeoutMillis";
    static final int HTTP_CONNECTION_TIMEOUT_MILLIS_DEFAULT = 10000;
    static final String HTTP_READ_TIMEOUT_MILLIS = "openIDHttpReadTimeoutMillis";
    static final int HTTP_READ_TIMEOUT_MILLIS_DEFAULT = 10000;
    static final String REQUIRE_HTTPS = "openIDRequireIssuersUseHttps";
    static final boolean REQUIRE_HTTPS_DEFAULT = true;
    private String[] allowedAudiences;
    private final JWT jwtLibrary = new JWT();
    private String roleClaim = ROLE_CLAIM_DEFAULT;

    public void initialize(ServiceConfiguration serviceConfiguration) throws IOException {
        this.allowedAudiences = validateAllowedAudiences(ConfigUtils.getConfigValueAsSet(serviceConfiguration, ALLOWED_AUDIENCES));
        this.roleClaim = ConfigUtils.getConfigValueAsString(serviceConfiguration, ROLE_CLAIM, ROLE_CLAIM_DEFAULT);
        this.isRoleClaimNotSubject = !ROLE_CLAIM_DEFAULT.equals(this.roleClaim);
        this.acceptedTimeLeewaySeconds = ConfigUtils.getConfigValueAsInt(serviceConfiguration, ACCEPTED_TIME_LEEWAY_SECONDS, ACCEPTED_TIME_LEEWAY_SECONDS_DEFAULT);
        boolean configValueAsBoolean = ConfigUtils.getConfigValueAsBoolean(serviceConfiguration, REQUIRE_HTTPS, true);
        this.fallbackDiscoveryMode = FallbackDiscoveryMode.valueOf(ConfigUtils.getConfigValueAsString(serviceConfiguration, FALLBACK_DISCOVERY_MODE, FallbackDiscoveryMode.DISABLED.name()));
        this.issuers = validateIssuers(ConfigUtils.getConfigValueAsSet(serviceConfiguration, ALLOWED_TOKEN_ISSUERS), configValueAsBoolean, this.fallbackDiscoveryMode != FallbackDiscoveryMode.DISABLED);
        int configValueAsInt = ConfigUtils.getConfigValueAsInt(serviceConfiguration, HTTP_CONNECTION_TIMEOUT_MILLIS, 10000);
        int configValueAsInt2 = ConfigUtils.getConfigValueAsInt(serviceConfiguration, HTTP_READ_TIMEOUT_MILLIS, 10000);
        String configValueAsString = ConfigUtils.getConfigValueAsString(serviceConfiguration, ISSUER_TRUST_CERTS_FILE_PATH, null);
        SslContext sslContext = ACCEPTED_TIME_LEEWAY_SECONDS_DEFAULT;
        if (StringUtils.isNotBlank(configValueAsString)) {
            sslContext = SslContextBuilder.forClient().trustManager(new File(configValueAsString)).build();
        }
        this.httpClient = new DefaultAsyncHttpClient(new DefaultAsyncHttpClientConfig.Builder().setConnectTimeout(configValueAsInt).setReadTimeout(configValueAsInt2).setSslContext(sslContext).build());
        ApiClient defaultClient = this.fallbackDiscoveryMode != FallbackDiscoveryMode.DISABLED ? Config.defaultClient() : null;
        this.openIDProviderMetadataCache = new OpenIDProviderMetadataCache(serviceConfiguration, this.httpClient, defaultClient);
        this.jwksCache = new JwksCache(serviceConfiguration, this.httpClient, defaultClient);
    }

    public String getAuthMethodName() {
        return AUTH_METHOD_NAME;
    }

    public CompletableFuture<String> authenticateAsync(AuthenticationDataSource authenticationDataSource) {
        return authenticateTokenAsync(authenticationDataSource).thenApply(this::getRole);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public CompletableFuture<DecodedJWT> authenticateTokenAsync(AuthenticationDataSource authenticationDataSource) {
        try {
            return authenticateToken(AuthenticationProviderToken.getToken(authenticationDataSource)).whenComplete((decodedJWT, th) -> {
                if (decodedJWT != null) {
                    AuthenticationMetrics.authenticateSuccess(getClass().getSimpleName(), getAuthMethodName());
                }
            });
        } catch (AuthenticationException e) {
            incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
            return CompletableFuture.failedFuture(e);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public String getRole(DecodedJWT decodedJWT) {
        try {
            Claim claim = decodedJWT.getClaim(this.roleClaim);
            if (claim.isNull()) {
                return null;
            }
            String asString = claim.asString();
            if (asString != null) {
                return asString;
            }
            List asList = decodedJWT.getClaim(this.roleClaim).asList(String.class);
            if (asList == null || asList.size() == 0) {
                return null;
            }
            if (asList.size() == REQUIRE_HTTPS_DEFAULT) {
                return (String) asList.get(ACCEPTED_TIME_LEEWAY_SECONDS_DEFAULT);
            }
            log.debug("JWT for subject [{}] has multiple roles; using the first one.", decodedJWT.getSubject());
            return (String) asList.get(ACCEPTED_TIME_LEEWAY_SECONDS_DEFAULT);
        } catch (JWTDecodeException e) {
            log.error("Exception while retrieving role from JWT", e);
            return null;
        }
    }

    DecodedJWT decodeJWT(String str) throws AuthenticationException {
        if (str == null) {
            incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
            throw new AuthenticationException("Invalid token: cannot be null");
        }
        try {
            return this.jwtLibrary.decodeJwt(str);
        } catch (JWTDecodeException e) {
            incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
            throw new AuthenticationException("Unable to decode JWT: " + e.getMessage());
        }
    }

    private CompletableFuture<DecodedJWT> authenticateToken(String str) {
        if (str == null) {
            incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
            return CompletableFuture.failedFuture(new AuthenticationException("JWT cannot be null"));
        }
        try {
            DecodedJWT decodeJWT = decodeJWT(str);
            return verifyIssuerAndGetJwk(decodeJWT).thenCompose(jwk -> {
                try {
                    if (jwk.getAlgorithm() == null || decodeJWT.getAlgorithm().equals(jwk.getAlgorithm())) {
                        return CompletableFuture.completedFuture(verifyJWT(jwk.getPublicKey(), decodeJWT.getAlgorithm(), decodeJWT));
                    }
                    incrementFailureMetric(AuthenticationExceptionCode.ALGORITHM_MISMATCH);
                    return CompletableFuture.failedFuture(new AuthenticationException("JWK's alg [" + jwk.getAlgorithm() + "] does not match JWT's alg [" + decodeJWT.getAlgorithm() + "]"));
                } catch (InvalidPublicKeyException e) {
                    incrementFailureMetric(AuthenticationExceptionCode.INVALID_PUBLIC_KEY);
                    return CompletableFuture.failedFuture(new AuthenticationException("Invalid public key: " + e.getMessage()));
                } catch (AuthenticationException e2) {
                    return CompletableFuture.failedFuture(e2);
                }
            });
        } catch (AuthenticationException e) {
            incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
            return CompletableFuture.failedFuture(e);
        }
    }

    private CompletableFuture<Jwk> verifyIssuerAndGetJwk(DecodedJWT decodedJWT) {
        if (decodedJWT.getIssuer() == null) {
            incrementFailureMetric(AuthenticationExceptionCode.UNSUPPORTED_ISSUER);
            return CompletableFuture.failedFuture(new AuthenticationException("Issuer cannot be null"));
        }
        if (this.issuers.contains(decodedJWT.getIssuer())) {
            return this.openIDProviderMetadataCache.getOpenIDProviderMetadataForIssuer(decodedJWT.getIssuer()).thenCompose(openIDProviderMetadata -> {
                return this.jwksCache.getJwk(openIDProviderMetadata.getJwksUri(), decodedJWT.getKeyId());
            });
        }
        if (this.fallbackDiscoveryMode == FallbackDiscoveryMode.KUBERNETES_DISCOVER_TRUSTED_ISSUER) {
            return this.openIDProviderMetadataCache.getOpenIDProviderMetadataForKubernetesApiServer(decodedJWT.getIssuer()).thenCompose(openIDProviderMetadata2 -> {
                return this.openIDProviderMetadataCache.getOpenIDProviderMetadataForIssuer(openIDProviderMetadata2.getIssuer());
            }).thenCompose((Function<? super U, ? extends CompletionStage<U>>) openIDProviderMetadata3 -> {
                return this.jwksCache.getJwk(openIDProviderMetadata3.getJwksUri(), decodedJWT.getKeyId());
            });
        }
        if (this.fallbackDiscoveryMode == FallbackDiscoveryMode.KUBERNETES_DISCOVER_PUBLIC_KEYS) {
            return this.openIDProviderMetadataCache.getOpenIDProviderMetadataForKubernetesApiServer(decodedJWT.getIssuer()).thenCompose(openIDProviderMetadata4 -> {
                return this.jwksCache.getJwkFromKubernetesApiServer(decodedJWT.getKeyId());
            });
        }
        incrementFailureMetric(AuthenticationExceptionCode.UNSUPPORTED_ISSUER);
        return CompletableFuture.failedFuture(new AuthenticationException("Issuer not allowed: " + decodedJWT.getIssuer()));
    }

    public AuthenticationState newAuthState(AuthData authData, SocketAddress socketAddress, SSLSession sSLSession) throws AuthenticationException {
        return new AuthenticationStateOpenID(this, socketAddress, sSLSession);
    }

    public void close() throws IOException {
        this.httpClient.close();
    }

    DecodedJWT verifyJWT(PublicKey publicKey, String str, DecodedJWT decodedJWT) throws AuthenticationException {
        Algorithm ECDSA512;
        if (str == null) {
            incrementFailureMetric(AuthenticationExceptionCode.UNSUPPORTED_ALGORITHM);
            throw new AuthenticationException("PublicKey algorithm cannot be null");
        }
        try {
            boolean z = -1;
            switch (str.hashCode()) {
                case 66245349:
                    if (str.equals(ALG_ES256)) {
                        z = 3;
                        break;
                    }
                    break;
                case 66246401:
                    if (str.equals(ALG_ES384)) {
                        z = 4;
                        break;
                    }
                    break;
                case 66248104:
                    if (str.equals(ALG_ES512)) {
                        z = CACHE_SIZE_DEFAULT;
                        break;
                    }
                    break;
                case 78251122:
                    if (str.equals(ALG_RS256)) {
                        z = ACCEPTED_TIME_LEEWAY_SECONDS_DEFAULT;
                        break;
                    }
                    break;
                case 78252174:
                    if (str.equals(ALG_RS384)) {
                        z = REQUIRE_HTTPS_DEFAULT;
                        break;
                    }
                    break;
                case 78253877:
                    if (str.equals(ALG_RS512)) {
                        z = 2;
                        break;
                    }
                    break;
            }
            switch (z) {
                case ACCEPTED_TIME_LEEWAY_SECONDS_DEFAULT /* 0 */:
                    ECDSA512 = Algorithm.RSA256((RSAPublicKey) publicKey, (RSAPrivateKey) null);
                    break;
                case REQUIRE_HTTPS_DEFAULT /* 1 */:
                    ECDSA512 = Algorithm.RSA384((RSAPublicKey) publicKey, (RSAPrivateKey) null);
                    break;
                case true:
                    ECDSA512 = Algorithm.RSA512((RSAPublicKey) publicKey, (RSAPrivateKey) null);
                    break;
                case true:
                    ECDSA512 = Algorithm.ECDSA256((ECPublicKey) publicKey, (ECPrivateKey) null);
                    break;
                case true:
                    ECDSA512 = Algorithm.ECDSA384((ECPublicKey) publicKey, (ECPrivateKey) null);
                    break;
                case CACHE_SIZE_DEFAULT /* 5 */:
                    ECDSA512 = Algorithm.ECDSA512((ECPublicKey) publicKey, (ECPrivateKey) null);
                    break;
                default:
                    incrementFailureMetric(AuthenticationExceptionCode.UNSUPPORTED_ALGORITHM);
                    throw new AuthenticationException("Unsupported algorithm: " + str);
            }
            Verification withClaimPresence = JWT.require(ECDSA512).acceptLeeway(this.acceptedTimeLeewaySeconds).withAnyOfAudience(this.allowedAudiences).withClaimPresence("iat").withClaimPresence("exp").withClaimPresence("nbf").withClaimPresence(ROLE_CLAIM_DEFAULT);
            if (this.isRoleClaimNotSubject) {
                withClaimPresence = withClaimPresence.withClaimPresence(this.roleClaim);
            }
            try {
                return withClaimPresence.build().verify(decodedJWT);
            } catch (JWTDecodeException e) {
                incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
                throw new AuthenticationException("Error while decoding JWT: " + e.getMessage());
            } catch (InvalidClaimException e2) {
                incrementFailureMetric(AuthenticationExceptionCode.INVALID_JWT_CLAIM);
                throw new AuthenticationException("JWT contains invalid claim: " + e2.getMessage());
            } catch (AlgorithmMismatchException e3) {
                incrementFailureMetric(AuthenticationExceptionCode.ALGORITHM_MISMATCH);
                throw new AuthenticationException("JWT algorithm does not match Public Key algorithm: " + e3.getMessage());
            } catch (SignatureVerificationException e4) {
                incrementFailureMetric(AuthenticationExceptionCode.ERROR_VERIFYING_JWT_SIGNATURE);
                throw new AuthenticationException("JWT signature verification exception: " + e4.getMessage());
            } catch (JWTVerificationException | IllegalArgumentException e5) {
                incrementFailureMetric(AuthenticationExceptionCode.ERROR_VERIFYING_JWT);
                throw new AuthenticationException("JWT verification failed: " + e5.getMessage());
            } catch (TokenExpiredException e6) {
                incrementFailureMetric(AuthenticationExceptionCode.EXPIRED_JWT);
                throw new AuthenticationException("JWT expired: " + e6.getMessage());
            }
        } catch (ClassCastException e7) {
            incrementFailureMetric(AuthenticationExceptionCode.ALGORITHM_MISMATCH);
            throw new AuthenticationException("Expected PublicKey alg [" + str + "] does match actual alg.");
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void incrementFailureMetric(AuthenticationExceptionCode authenticationExceptionCode) {
        AuthenticationMetrics.authenticateFailure(SIMPLE_NAME, AUTH_METHOD_NAME, authenticationExceptionCode);
    }

    private Set<String> validateIssuers(Set<String> set, boolean z, boolean z2) {
        if (set == null || (set.isEmpty() && !z2)) {
            throw new IllegalArgumentException("Missing configured value for: openIDAllowedTokenIssuers");
        }
        for (String str : set) {
            if (!str.toLowerCase().startsWith("https://")) {
                log.warn("Allowed issuer is not using https scheme: {}", str);
                if (z) {
                    throw new IllegalArgumentException("Issuer URL does not use https, but must: " + str);
                }
            }
        }
        return set;
    }

    String[] validateAllowedAudiences(Set<String> set) {
        if (set == null || set.isEmpty()) {
            throw new IllegalArgumentException("Missing configured value for: openIDAllowedAudiences");
        }
        return (String[]) set.toArray(new String[ACCEPTED_TIME_LEEWAY_SECONDS_DEFAULT]);
    }
}
