package io.kroxylicious.proxy.internal;

import io.kroxylicious.proxy.frame.BareSaslRequest;
import io.kroxylicious.proxy.frame.BareSaslResponse;
import io.kroxylicious.proxy.frame.DecodedRequestFrame;
import io.kroxylicious.proxy.frame.DecodedResponseFrame;
import io.kroxylicious.proxy.internal.codec.KafkaRequestEncoder;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import org.apache.kafka.common.errors.IllegalSaslStateException;
import org.apache.kafka.common.errors.InvalidRequestException;
import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.errors.UnsupportedSaslMechanismException;
import org.apache.kafka.common.message.ResponseHeaderData;
import org.apache.kafka.common.message.SaslAuthenticateRequestData;
import org.apache.kafka.common.message.SaslAuthenticateResponseData;
import org.apache.kafka.common.message.SaslHandshakeRequestData;
import org.apache.kafka.common.message.SaslHandshakeResponseData;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.protocol.Errors;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.plain.internals.PlainSaslServerProvider;
import org.apache.kafka.common.security.scram.internals.ScramMechanism;
import org.apache.kafka.common.security.scram.internals.ScramSaslServerProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/kroxylicious/proxy/internal/KafkaAuthnHandler.class */
public class KafkaAuthnHandler extends ChannelInboundHandlerAdapter {
    private static final IllegalSaslStateException NOT_AUTHENTICATED_EXCEPTION = new IllegalSaslStateException("Not authenticated");
    private static final Logger LOG;
    private final List<String> enabledMechanisms;
    SaslServer saslServer;
    private final Map<String, AuthenticateCallbackHandler> mechanismHandlers;
    State lastSeen;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: io.kroxylicious.proxy.internal.KafkaAuthnHandler$1, reason: invalid class name */
    /* loaded from: input_file:io/kroxylicious/proxy/internal/KafkaAuthnHandler$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$kafka$common$protocol$ApiKeys = new int[ApiKeys.values().length];

        static {
            try {
                $SwitchMap$org$apache$kafka$common$protocol$ApiKeys[ApiKeys.API_VERSIONS.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$kafka$common$protocol$ApiKeys[ApiKeys.SASL_HANDSHAKE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$apache$kafka$common$protocol$ApiKeys[ApiKeys.SASL_AUTHENTICATE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:io/kroxylicious/proxy/internal/KafkaAuthnHandler$SaslMechanism.class */
    public enum SaslMechanism {
        PLAIN("PLAIN", null) { // from class: io.kroxylicious.proxy.internal.KafkaAuthnHandler.SaslMechanism.1
            @Override // io.kroxylicious.proxy.internal.KafkaAuthnHandler.SaslMechanism
            public Map<String, Object> negotiatedProperties(SaslServer saslServer) {
                return Map.of();
            }
        },
        SCRAM_SHA_256("SCRAM-SHA-256", ScramMechanism.SCRAM_SHA_256) { // from class: io.kroxylicious.proxy.internal.KafkaAuthnHandler.SaslMechanism.2
            @Override // io.kroxylicious.proxy.internal.KafkaAuthnHandler.SaslMechanism
            public Map<String, Object> negotiatedProperties(SaslServer saslServer) {
                Object negotiatedProperty = saslServer.getNegotiatedProperty("CREDENTIAL.LIFETIME.MS");
                return negotiatedProperty == null ? Map.of() : Map.of("CREDENTIAL.LIFETIME.MS", negotiatedProperty);
            }
        },
        SCRAM_SHA_512("SCRAM-SHA-512", ScramMechanism.SCRAM_SHA_512) { // from class: io.kroxylicious.proxy.internal.KafkaAuthnHandler.SaslMechanism.3
            @Override // io.kroxylicious.proxy.internal.KafkaAuthnHandler.SaslMechanism
            public Map<String, Object> negotiatedProperties(SaslServer saslServer) {
                Object negotiatedProperty = saslServer.getNegotiatedProperty("CREDENTIAL.LIFETIME.MS");
                return negotiatedProperty == null ? Map.of() : Map.of("CREDENTIAL.LIFETIME.MS", negotiatedProperty);
            }
        };

        private final String name;
        private final ScramMechanism scramMechanism;

        SaslMechanism(String str, ScramMechanism scramMechanism) {
            this.name = str;
            this.scramMechanism = scramMechanism;
        }

        public String mechanismName() {
            return this.name;
        }

        static SaslMechanism fromMechanismName(String str) {
            boolean z = -1;
            switch (str.hashCode()) {
                case -1875511693:
                    if (str.equals("SCRAM-SHA-256")) {
                        z = true;
                        break;
                    }
                    break;
                case -1875508938:
                    if (str.equals("SCRAM-SHA-512")) {
                        z = 2;
                        break;
                    }
                    break;
                case 76210602:
                    if (str.equals("PLAIN")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    return PLAIN;
                case true:
                    return SCRAM_SHA_256;
                case true:
                    return SCRAM_SHA_512;
                default:
                    throw new UnsupportedSaslMechanismException(str);
            }
        }

        public ScramMechanism scramMechanism() {
            return this.scramMechanism;
        }

        public abstract Map<String, Object> negotiatedProperties(SaslServer saslServer);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/kroxylicious/proxy/internal/KafkaAuthnHandler$State.class */
    public enum State {
        START,
        API_VERSIONS,
        SASL_HANDSHAKE_v0,
        SASL_HANDSHAKE_v1_PLUS,
        UNFRAMED_SASL_AUTHENTICATE,
        FRAMED_SASL_AUTHENTICATE,
        FAILED,
        AUTHN_SUCCESS
    }

    public KafkaAuthnHandler(Channel channel, Map<SaslMechanism, AuthenticateCallbackHandler> map) {
        this(channel, State.START, map);
    }

    KafkaAuthnHandler(Channel channel, State state, Map<SaslMechanism, AuthenticateCallbackHandler> map) {
        this.lastSeen = state;
        LOG.debug("{}: Initial state {}", channel, this.lastSeen);
        this.mechanismHandlers = (Map) map.entrySet().stream().collect(Collectors.toMap(entry -> {
            return ((SaslMechanism) entry.getKey()).mechanismName();
        }, (v0) -> {
            return v0.getValue();
        }));
        this.enabledMechanisms = List.copyOf(this.mechanismHandlers.keySet());
    }

    private InvalidRequestException illegalTransition(State state) {
        InvalidRequestException invalidRequestException = new InvalidRequestException("Illegal state transition from " + String.valueOf(this.lastSeen) + " to " + String.valueOf(state));
        this.lastSeen = State.FAILED;
        return invalidRequestException;
    }

    private void doTransition(Channel channel, State state) {
        State state2 = this.lastSeen;
        switch (state.ordinal()) {
            case 1:
                if (state2 != State.START) {
                    throw illegalTransition(state);
                }
                break;
            case 2:
            case 3:
                if (state2 != State.START && state2 != State.API_VERSIONS) {
                    throw illegalTransition(state);
                }
                break;
            case KafkaRequestEncoder.LENGTH /* 4 */:
                if (state2 != State.START && state2 != State.SASL_HANDSHAKE_v0 && state2 != State.UNFRAMED_SASL_AUTHENTICATE) {
                    throw illegalTransition(state);
                }
                break;
            case 5:
                if (state2 != State.SASL_HANDSHAKE_v1_PLUS && state2 != State.FRAMED_SASL_AUTHENTICATE) {
                    throw illegalTransition(state);
                }
                break;
            case 6:
                break;
            case 7:
                if (state2 != State.FRAMED_SASL_AUTHENTICATE && state2 != State.UNFRAMED_SASL_AUTHENTICATE) {
                    throw illegalTransition(state);
                }
                break;
            default:
                throw illegalTransition(state);
        }
        LOG.debug("{}: Transition from {} to {}", new Object[]{channel, this.lastSeen, state});
        this.lastSeen = state;
    }

    public void channelRead(ChannelHandlerContext channelHandlerContext, Object obj) throws Exception {
        if (obj instanceof BareSaslRequest) {
            handleBareRequest(channelHandlerContext, (BareSaslRequest) obj);
        } else if (obj instanceof DecodedRequestFrame) {
            handleFramedRequest(channelHandlerContext, (DecodedRequestFrame) obj);
        } else {
            if (this.lastSeen != State.AUTHN_SUCCESS) {
                throw new IllegalStateException("Unexpected message " + String.valueOf(obj.getClass()));
            }
            channelHandlerContext.fireChannelRead(obj);
        }
    }

    private void handleFramedRequest(ChannelHandlerContext channelHandlerContext, DecodedRequestFrame<?> decodedRequestFrame) throws SaslException {
        switch (AnonymousClass1.$SwitchMap$org$apache$kafka$common$protocol$ApiKeys[decodedRequestFrame.apiKey().ordinal()]) {
            case 1:
                if (this.lastSeen != State.AUTHN_SUCCESS) {
                    doTransition(channelHandlerContext.channel(), State.API_VERSIONS);
                }
                channelHandlerContext.fireChannelRead(decodedRequestFrame);
                return;
            case 2:
                doTransition(channelHandlerContext.channel(), decodedRequestFrame.apiVersion() == 0 ? State.SASL_HANDSHAKE_v0 : State.SASL_HANDSHAKE_v1_PLUS);
                onSaslHandshakeRequest(channelHandlerContext, decodedRequestFrame);
                return;
            case 3:
                doTransition(channelHandlerContext.channel(), State.FRAMED_SASL_AUTHENTICATE);
                onSaslAuthenticateRequest(channelHandlerContext, decodedRequestFrame);
                return;
            default:
                if (this.lastSeen == State.AUTHN_SUCCESS) {
                    channelHandlerContext.fireChannelRead(decodedRequestFrame);
                    return;
                } else {
                    writeFramedResponse(channelHandlerContext, decodedRequestFrame, KafkaProxyExceptionMapper.errorResponseMessage(decodedRequestFrame, NOT_AUTHENTICATED_EXCEPTION));
                    return;
                }
        }
    }

    private void handleBareRequest(ChannelHandlerContext channelHandlerContext, BareSaslRequest bareSaslRequest) throws SaslException {
        if (this.lastSeen != State.SASL_HANDSHAKE_v0 && this.lastSeen != State.UNFRAMED_SASL_AUTHENTICATE) {
            this.lastSeen = State.FAILED;
            throw new InvalidRequestException("Bare SASL bytes without GSSAPI support or prior SaslHandshake");
        }
        doTransition(channelHandlerContext.channel(), State.UNFRAMED_SASL_AUTHENTICATE);
        writeBareResponse(channelHandlerContext, doEvaluateResponse(channelHandlerContext, bareSaslRequest.bytes()));
    }

    private void writeBareResponse(ChannelHandlerContext channelHandlerContext, byte[] bArr) throws SaslException {
        channelHandlerContext.writeAndFlush(new BareSaslResponse(bArr));
    }

    private void onSaslHandshakeRequest(ChannelHandlerContext channelHandlerContext, DecodedRequestFrame<SaslHandshakeRequestData> decodedRequestFrame) throws SaslException {
        Errors errors;
        String mechanism = decodedRequestFrame.body().mechanism();
        if (this.lastSeen == State.AUTHN_SUCCESS) {
            errors = Errors.ILLEGAL_SASL_STATE;
        } else if (this.enabledMechanisms.contains(mechanism)) {
            this.saslServer = Sasl.createSaslServer(mechanism, "kafka", (String) null, (Map) null, this.mechanismHandlers.get(mechanism));
            if (this.saslServer == null) {
                throw new IllegalStateException("SASL mechanism had no providers: " + mechanism);
            }
            errors = Errors.NONE;
        } else {
            errors = Errors.UNSUPPORTED_SASL_MECHANISM;
        }
        writeFramedResponse(channelHandlerContext, decodedRequestFrame, new SaslHandshakeResponseData().setMechanisms(this.enabledMechanisms).setErrorCode(errors.code()));
        channelHandlerContext.channel().read();
    }

    private void onSaslAuthenticateRequest(ChannelHandlerContext channelHandlerContext, DecodedRequestFrame<SaslAuthenticateRequestData> decodedRequestFrame) {
        Errors errors;
        String str;
        byte[] bArr = new byte[0];
        try {
            bArr = doEvaluateResponse(channelHandlerContext, decodedRequestFrame.body().authBytes());
            errors = Errors.NONE;
            str = null;
        } catch (SaslAuthenticationException e) {
            errors = Errors.SASL_AUTHENTICATION_FAILED;
            str = e.getMessage();
        } catch (SaslException e2) {
            errors = Errors.SASL_AUTHENTICATION_FAILED;
            str = "An error occurred";
        }
        writeFramedResponse(channelHandlerContext, decodedRequestFrame, new SaslAuthenticateResponseData().setErrorCode(errors.code()).setErrorMessage(str).setAuthBytes(bArr));
        channelHandlerContext.channel().read();
    }

    private static void writeFramedResponse(ChannelHandlerContext channelHandlerContext, DecodedRequestFrame<?> decodedRequestFrame, ApiMessage apiMessage) {
        channelHandlerContext.writeAndFlush(new DecodedResponseFrame(decodedRequestFrame.apiVersion(), decodedRequestFrame.correlationId(), new ResponseHeaderData().setCorrelationId(decodedRequestFrame.correlationId()), apiMessage));
    }

    private byte[] doEvaluateResponse(ChannelHandlerContext channelHandlerContext, byte[] bArr) throws SaslException {
        try {
            byte[] evaluateResponse = this.saslServer.evaluateResponse(bArr);
            if (this.saslServer.isComplete()) {
                try {
                    String authorizationID = this.saslServer.getAuthorizationID();
                    Map<String, Object> negotiatedProperties = SaslMechanism.fromMechanismName(this.saslServer.getMechanismName()).negotiatedProperties(this.saslServer);
                    doTransition(channelHandlerContext.channel(), State.AUTHN_SUCCESS);
                    LOG.debug("{}: Authentication successful, authorizationId={}, negotiatedProperties={}", new Object[]{channelHandlerContext.channel(), authorizationID, negotiatedProperties});
                    channelHandlerContext.fireUserEventTriggered(new AuthenticationEvent(authorizationID, negotiatedProperties));
                    this.saslServer.dispose();
                } catch (Throwable th) {
                    this.saslServer.dispose();
                    throw th;
                }
            }
            return evaluateResponse;
        } catch (SaslAuthenticationException e) {
            LOG.debug("{}: Authentication failed", channelHandlerContext.channel());
            doTransition(channelHandlerContext.channel(), State.FAILED);
            this.saslServer.dispose();
            throw e;
        } catch (Exception e2) {
            LOG.debug("{}: Authentication failed", channelHandlerContext.channel());
            doTransition(channelHandlerContext.channel(), State.FAILED);
            this.saslServer.dispose();
            throw new SaslAuthenticationException(e2.getMessage());
        }
    }

    static {
        PlainSaslServerProvider.initialize();
        ScramSaslServerProvider.initialize();
        LOG = LoggerFactory.getLogger(KafkaAuthnHandler.class);
    }
}
