package cn.hserver.modelcontextprotocol.spec;

import cn.hserver.modelcontextprotocol.server.McpAsyncServerExchange;
import cn.hserver.modelcontextprotocol.spec.McpSchema;
import com.fasterxml.jackson.core.type.TypeReference;
import java.time.Duration;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import lombok.Generated;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoSink;
import reactor.core.publisher.Sinks;

/* loaded from: input_file:cn/hserver/modelcontextprotocol/spec/McpServerSession.class */
public class McpServerSession implements McpSession {
    private static final Logger logger = LoggerFactory.getLogger(McpServerSession.class);
    private final String id;
    private final InitRequestHandler initRequestHandler;
    private final InitNotificationHandler initNotificationHandler;
    private final Map<String, RequestHandler<?>> requestHandlers;
    private final Map<String, NotificationHandler> notificationHandlers;
    private final McpServerTransport transport;
    private static final int STATE_UNINITIALIZED = 0;
    private static final int STATE_INITIALIZING = 1;
    private static final int STATE_INITIALIZED = 2;
    private final ConcurrentHashMap<Object, MonoSink<McpSchema.JSONRPCResponse>> pendingResponses = new ConcurrentHashMap<>();
    private final AtomicLong requestCounter = new AtomicLong(0);
    private final Sinks.One<McpAsyncServerExchange> exchangeSink = Sinks.one();
    private final AtomicReference<McpSchema.ClientCapabilities> clientCapabilities = new AtomicReference<>();
    private final AtomicReference<McpSchema.Implementation> clientInfo = new AtomicReference<>();
    private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED);

    @FunctionalInterface
    /* loaded from: input_file:cn/hserver/modelcontextprotocol/spec/McpServerSession$Factory.class */
    public interface Factory {
        McpServerSession create(McpServerTransport mcpServerTransport);
    }

    /* loaded from: input_file:cn/hserver/modelcontextprotocol/spec/McpServerSession$InitNotificationHandler.class */
    public interface InitNotificationHandler {
        Mono<Void> handle();
    }

    /* loaded from: input_file:cn/hserver/modelcontextprotocol/spec/McpServerSession$InitRequestHandler.class */
    public interface InitRequestHandler {
        Mono<McpSchema.InitializeResult> handle(McpSchema.InitializeRequest initializeRequest);
    }

    /* loaded from: input_file:cn/hserver/modelcontextprotocol/spec/McpServerSession$MethodNotFoundError.class */
    public static class MethodNotFoundError {
        String method;
        String message;
        Object data;

        @Generated
        public String getMethod() {
            return this.method;
        }

        @Generated
        public String getMessage() {
            return this.message;
        }

        @Generated
        public Object getData() {
            return this.data;
        }

        @Generated
        public void setMethod(String str) {
            this.method = str;
        }

        @Generated
        public void setMessage(String str) {
            this.message = str;
        }

        @Generated
        public void setData(Object obj) {
            this.data = obj;
        }

        @Generated
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof MethodNotFoundError)) {
                return false;
            }
            MethodNotFoundError methodNotFoundError = (MethodNotFoundError) obj;
            if (!methodNotFoundError.canEqual(this)) {
                return false;
            }
            String method = getMethod();
            String method2 = methodNotFoundError.getMethod();
            if (method == null) {
                if (method2 != null) {
                    return false;
                }
            } else if (!method.equals(method2)) {
                return false;
            }
            String message = getMessage();
            String message2 = methodNotFoundError.getMessage();
            if (message == null) {
                if (message2 != null) {
                    return false;
                }
            } else if (!message.equals(message2)) {
                return false;
            }
            Object data = getData();
            Object data2 = methodNotFoundError.getData();
            return data == null ? data2 == null : data.equals(data2);
        }

        @Generated
        protected boolean canEqual(Object obj) {
            return obj instanceof MethodNotFoundError;
        }

        @Generated
        public int hashCode() {
            String method = getMethod();
            int hashCode = (McpServerSession.STATE_INITIALIZING * 59) + (method == null ? 43 : method.hashCode());
            String message = getMessage();
            int hashCode2 = (hashCode * 59) + (message == null ? 43 : message.hashCode());
            Object data = getData();
            return (hashCode2 * 59) + (data == null ? 43 : data.hashCode());
        }

        @Generated
        public String toString() {
            return "McpServerSession.MethodNotFoundError(method=" + getMethod() + ", message=" + getMessage() + ", data=" + getData() + ")";
        }

        @Generated
        public MethodNotFoundError(String str, String str2, Object obj) {
            this.method = str;
            this.message = str2;
            this.data = obj;
        }

        @Generated
        public MethodNotFoundError() {
        }
    }

    /* loaded from: input_file:cn/hserver/modelcontextprotocol/spec/McpServerSession$NotificationHandler.class */
    public interface NotificationHandler {
        Mono<Void> handle(McpAsyncServerExchange mcpAsyncServerExchange, Object obj);
    }

    /* loaded from: input_file:cn/hserver/modelcontextprotocol/spec/McpServerSession$RequestHandler.class */
    public interface RequestHandler<T> {
        Mono<T> handle(McpAsyncServerExchange mcpAsyncServerExchange, Object obj);
    }

    public McpServerSession(String str, McpServerTransport mcpServerTransport, InitRequestHandler initRequestHandler, InitNotificationHandler initNotificationHandler, Map<String, RequestHandler<?>> map, Map<String, NotificationHandler> map2) {
        this.id = str;
        this.transport = mcpServerTransport;
        this.initRequestHandler = initRequestHandler;
        this.initNotificationHandler = initNotificationHandler;
        this.requestHandlers = map;
        this.notificationHandlers = map2;
    }

    public McpServerTransport getTransport() {
        return this.transport;
    }

    public String getId() {
        return this.id;
    }

    public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation implementation) {
        this.clientCapabilities.lazySet(clientCapabilities);
        this.clientInfo.lazySet(implementation);
    }

    private String generateRequestId() {
        return this.id + "-" + this.requestCounter.getAndIncrement();
    }

    @Override // cn.hserver.modelcontextprotocol.spec.McpSession
    public <T> Mono<T> sendRequest(String str, Object obj, TypeReference<T> typeReference) {
        String generateRequestId = generateRequestId();
        return Mono.create(monoSink -> {
            this.pendingResponses.put(generateRequestId, monoSink);
            this.transport.sendMessage(new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, str, generateRequestId, obj)).subscribe(r1 -> {
            }, th -> {
                this.pendingResponses.remove(generateRequestId);
                monoSink.error(th);
            });
        }).timeout(Duration.ofSeconds(10L)).handle((jSONRPCResponse, synchronousSink) -> {
            if (jSONRPCResponse.getError() != null) {
                synchronousSink.error(new McpError(jSONRPCResponse.getError()));
            } else if (typeReference.getType().equals(Void.class)) {
                synchronousSink.complete();
            } else {
                synchronousSink.next(this.transport.unmarshalFrom(jSONRPCResponse.getResult(), typeReference));
            }
        });
    }

    @Override // cn.hserver.modelcontextprotocol.spec.McpSession
    public Mono<Void> sendNotification(String str, Map<String, Object> map) {
        return this.transport.sendMessage(new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, str, map));
    }

    public Mono<Void> handle(McpSchema.JSONRPCMessage jSONRPCMessage) {
        return Mono.defer(() -> {
            if (jSONRPCMessage instanceof McpSchema.JSONRPCResponse) {
                McpSchema.JSONRPCResponse jSONRPCResponse = (McpSchema.JSONRPCResponse) jSONRPCMessage;
                logger.debug("Received Response: {}", jSONRPCResponse);
                MonoSink<McpSchema.JSONRPCResponse> remove = this.pendingResponses.remove(jSONRPCResponse.getId());
                if (remove == null) {
                    logger.warn("Unexpected response for unknown id {}", jSONRPCResponse.getId());
                } else {
                    remove.success(jSONRPCResponse);
                }
                return Mono.empty();
            }
            if (jSONRPCMessage instanceof McpSchema.JSONRPCRequest) {
                McpSchema.JSONRPCRequest jSONRPCRequest = (McpSchema.JSONRPCRequest) jSONRPCMessage;
                logger.debug("Received request: {}", jSONRPCRequest);
                Mono onErrorResume = handleIncomingRequest(jSONRPCRequest).onErrorResume(th -> {
                    return this.transport.sendMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jSONRPCRequest.getId(), null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, th.getMessage(), null))).then(Mono.empty());
                });
                McpServerTransport mcpServerTransport = this.transport;
                mcpServerTransport.getClass();
                return onErrorResume.flatMap((v1) -> {
                    return r1.sendMessage(v1);
                });
            }
            if (!(jSONRPCMessage instanceof McpSchema.JSONRPCNotification)) {
                logger.warn("Received unknown message type: {}", jSONRPCMessage);
                return Mono.empty();
            }
            McpSchema.JSONRPCNotification jSONRPCNotification = (McpSchema.JSONRPCNotification) jSONRPCMessage;
            logger.debug("Received notification: {}", jSONRPCNotification);
            return handleIncomingNotification(jSONRPCNotification).doOnError(th2 -> {
                logger.error("Error handling notification: {}", th2.getMessage());
            });
        });
    }

    private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCRequest jSONRPCRequest) {
        return Mono.defer(() -> {
            Mono<McpSchema.InitializeResult> flatMap;
            if (McpSchema.METHOD_INITIALIZE.equals(jSONRPCRequest.getMethod())) {
                McpSchema.InitializeRequest initializeRequest = (McpSchema.InitializeRequest) this.transport.unmarshalFrom(jSONRPCRequest.getParams(), new TypeReference<McpSchema.InitializeRequest>() { // from class: cn.hserver.modelcontextprotocol.spec.McpServerSession.1
                });
                this.state.lazySet(STATE_INITIALIZING);
                init(initializeRequest.getCapabilities(), initializeRequest.getClientInfo());
                flatMap = this.initRequestHandler.handle(initializeRequest);
            } else {
                RequestHandler<?> requestHandler = this.requestHandlers.get(jSONRPCRequest.getMethod());
                if (requestHandler == null) {
                    MethodNotFoundError methodNotFoundError = getMethodNotFoundError(jSONRPCRequest.getMethod());
                    return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jSONRPCRequest.getId(), null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, methodNotFoundError.getMessage(), methodNotFoundError.getData())));
                }
                flatMap = this.exchangeSink.asMono().flatMap(mcpAsyncServerExchange -> {
                    return requestHandler.handle(mcpAsyncServerExchange, jSONRPCRequest.getParams());
                });
            }
            return flatMap.map(obj -> {
                return new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jSONRPCRequest.getId(), obj, null);
            }).onErrorResume(th -> {
                return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jSONRPCRequest.getId(), null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, th.getMessage(), null)));
            });
        });
    }

    private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification jSONRPCNotification) {
        return Mono.defer(() -> {
            if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(jSONRPCNotification.getMethod())) {
                this.state.lazySet(STATE_INITIALIZED);
                this.exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, this.clientCapabilities.get(), this.clientInfo.get()));
                return this.initNotificationHandler.handle();
            }
            NotificationHandler notificationHandler = this.notificationHandlers.get(jSONRPCNotification.getMethod());
            if (notificationHandler != null) {
                return this.exchangeSink.asMono().flatMap(mcpAsyncServerExchange -> {
                    return notificationHandler.handle(mcpAsyncServerExchange, jSONRPCNotification.getParams());
                });
            }
            logger.error("No handler registered for notification method: {}", jSONRPCNotification.getMethod());
            return Mono.empty();
        });
    }

    static MethodNotFoundError getMethodNotFoundError(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -597942244:
                if (str.equals(McpSchema.METHOD_ROOTS_LIST)) {
                    z = STATE_UNINITIALIZED;
                    break;
                }
                break;
        }
        switch (z) {
            case STATE_UNINITIALIZED /* 0 */:
                return new MethodNotFoundError(str, "Roots not supported", Collections.singletonMap("reason", "Client does not have roots capability"));
            default:
                return new MethodNotFoundError(str, "Method not found: " + str, null);
        }
    }

    @Override // cn.hserver.modelcontextprotocol.spec.McpSession
    public Mono<Void> closeGracefully() {
        return this.transport.closeGracefully();
    }

    @Override // cn.hserver.modelcontextprotocol.spec.McpSession
    public void close() {
        this.transport.close();
    }
}
