package org.springframework.web.reactive.socket.server.support;

import io.netty.handler.codec.http.HttpHeaders;
import java.net.URI;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.context.Lifecycle;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
import org.springframework.web.reactive.socket.server.WebSocketService;
import org.springframework.web.server.MethodNotAllowedException;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.ServerWebInputException;
import reactor.core.publisher.Mono;

/* loaded from: input_file:BOOT-INF/lib/spring-webflux-5.3.18.jar:org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.class */
public class HandshakeWebSocketService implements WebSocketService, Lifecycle {
    private static final String SEC_WEBSOCKET_KEY = "Sec-WebSocket-Key";
    private static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol";
    private static final Mono<Map<String, Object>> EMPTY_ATTRIBUTES = Mono.just(Collections.emptyMap());
    private static final boolean tomcatPresent;
    private static final boolean jettyPresent;
    private static final boolean jetty10Present;
    private static final boolean undertowPresent;
    private static final boolean reactorNettyPresent;
    protected static final Log logger;
    private final RequestUpgradeStrategy upgradeStrategy;

    @Nullable
    private Predicate<String> sessionAttributePredicate;
    private volatile boolean running;

    public HandshakeWebSocketService() {
        this(initUpgradeStrategy());
    }

    public HandshakeWebSocketService(RequestUpgradeStrategy requestUpgradeStrategy) {
        Assert.notNull(requestUpgradeStrategy, "RequestUpgradeStrategy is required");
        this.upgradeStrategy = requestUpgradeStrategy;
    }

    private static RequestUpgradeStrategy initUpgradeStrategy() {
        String str;
        if (tomcatPresent) {
            str = "TomcatRequestUpgradeStrategy";
        } else if (jettyPresent) {
            str = "JettyRequestUpgradeStrategy";
        } else if (jetty10Present) {
            str = "Jetty10RequestUpgradeStrategy";
        } else if (undertowPresent) {
            str = "UndertowRequestUpgradeStrategy";
        } else {
            if (!reactorNettyPresent) {
                throw new IllegalStateException("No suitable default RequestUpgradeStrategy found");
            }
            str = "ReactorNettyRequestUpgradeStrategy";
        }
        try {
            str = "org.springframework.web.reactive.socket.server.upgrade." + str;
            return (RequestUpgradeStrategy) ReflectionUtils.accessibleConstructor(ClassUtils.forName(str, HandshakeWebSocketService.class.getClassLoader()), new Class[0]).newInstance(new Object[0]);
        } catch (Throwable th) {
            throw new IllegalStateException("Failed to instantiate RequestUpgradeStrategy: " + str, th);
        }
    }

    public RequestUpgradeStrategy getUpgradeStrategy() {
        return this.upgradeStrategy;
    }

    public void setSessionAttributePredicate(@Nullable Predicate<String> predicate) {
        this.sessionAttributePredicate = predicate;
    }

    @Nullable
    public Predicate<String> getSessionAttributePredicate() {
        return this.sessionAttributePredicate;
    }

    @Override // org.springframework.context.Lifecycle
    public void start() {
        if (isRunning()) {
            return;
        }
        this.running = true;
        doStart();
    }

    protected void doStart() {
        if (getUpgradeStrategy() instanceof Lifecycle) {
            ((Lifecycle) getUpgradeStrategy()).start();
        }
    }

    @Override // org.springframework.context.Lifecycle
    public void stop() {
        if (isRunning()) {
            this.running = false;
            doStop();
        }
    }

    protected void doStop() {
        if (getUpgradeStrategy() instanceof Lifecycle) {
            ((Lifecycle) getUpgradeStrategy()).stop();
        }
    }

    @Override // org.springframework.context.Lifecycle
    public boolean isRunning() {
        return this.running;
    }

    @Override // org.springframework.web.reactive.socket.server.WebSocketService
    public Mono<Void> handleRequest(ServerWebExchange serverWebExchange, WebSocketHandler webSocketHandler) {
        ServerHttpRequest request = serverWebExchange.getRequest();
        HttpMethod method = request.getMethod();
        HttpHeaders headers = request.getHeaders();
        if (HttpMethod.GET != method) {
            return Mono.error(new MethodNotAllowedException(request.getMethodValue(), Collections.singleton(HttpMethod.GET)));
        }
        if (!HttpHeaders.Values.WEBSOCKET.equalsIgnoreCase(headers.getUpgrade())) {
            return handleBadRequest(serverWebExchange, "Invalid 'Upgrade' header: " + headers);
        }
        List<String> connection = headers.getConnection();
        if (!connection.contains("Upgrade") && !connection.contains("upgrade")) {
            return handleBadRequest(serverWebExchange, "Invalid 'Connection' header: " + headers);
        }
        if (headers.getFirst("Sec-WebSocket-Key") == null) {
            return handleBadRequest(serverWebExchange, "Missing \"Sec-WebSocket-Key\" header");
        }
        String selectProtocol = selectProtocol(headers, webSocketHandler);
        return initAttributes(serverWebExchange).flatMap(map -> {
            return this.upgradeStrategy.upgrade(serverWebExchange, webSocketHandler, selectProtocol, () -> {
                return createHandshakeInfo(serverWebExchange, request, selectProtocol, map);
            });
        });
    }

    private Mono<Void> handleBadRequest(ServerWebExchange serverWebExchange, String str) {
        if (logger.isDebugEnabled()) {
            logger.debug(serverWebExchange.getLogPrefix() + str);
        }
        return Mono.error(new ServerWebInputException(str));
    }

    @Nullable
    private String selectProtocol(org.springframework.http.HttpHeaders httpHeaders, WebSocketHandler webSocketHandler) {
        String first = httpHeaders.getFirst("Sec-WebSocket-Protocol");
        if (first == null) {
            return null;
        }
        List<String> subProtocols = webSocketHandler.getSubProtocols();
        for (String str : StringUtils.commaDelimitedListToStringArray(first)) {
            if (subProtocols.contains(str)) {
                return str;
            }
        }
        return null;
    }

    private Mono<Map<String, Object>> initAttributes(ServerWebExchange serverWebExchange) {
        return this.sessionAttributePredicate == null ? EMPTY_ATTRIBUTES : serverWebExchange.getSession().map(webSession -> {
            return (Map) webSession.getAttributes().entrySet().stream().filter(entry -> {
                return this.sessionAttributePredicate.test(entry.getKey());
            }).collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, (v0) -> {
                return v0.getValue();
            }));
        });
    }

    private HandshakeInfo createHandshakeInfo(ServerWebExchange serverWebExchange, ServerHttpRequest serverHttpRequest, @Nullable String str, Map<String, Object> map) {
        URI uri = serverHttpRequest.getURI();
        org.springframework.http.HttpHeaders httpHeaders = new org.springframework.http.HttpHeaders();
        httpHeaders.addAll(serverHttpRequest.getHeaders());
        return new HandshakeInfo(uri, httpHeaders, serverHttpRequest.getCookies(), serverWebExchange.getPrincipal(), str, serverHttpRequest.getRemoteAddress(), map, serverWebExchange.getLogPrefix());
    }

    static {
        ClassLoader classLoader = HandshakeWebSocketService.class.getClassLoader();
        tomcatPresent = ClassUtils.isPresent("org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", classLoader);
        jettyPresent = ClassUtils.isPresent("org.eclipse.jetty.websocket.server.WebSocketServerFactory", classLoader);
        jetty10Present = ClassUtils.isPresent("org.eclipse.jetty.websocket.server.JettyWebSocketServerContainer", classLoader);
        undertowPresent = ClassUtils.isPresent("io.undertow.websockets.WebSocketProtocolHandshakeHandler", classLoader);
        reactorNettyPresent = ClassUtils.isPresent("reactor.netty.http.server.HttpServerResponse", classLoader);
        logger = LogFactory.getLog(HandshakeWebSocketService.class);
    }
}
