package dev.langchain4j.mcp.client;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.mcp.client.logging.DefaultMcpLogMessageHandler;
import dev.langchain4j.mcp.client.logging.McpLogMessageHandler;
import dev.langchain4j.mcp.client.protocol.CancellationNotification;
import dev.langchain4j.mcp.client.protocol.InitializeParams;
import dev.langchain4j.mcp.client.protocol.McpCallToolRequest;
import dev.langchain4j.mcp.client.protocol.McpGetPromptRequest;
import dev.langchain4j.mcp.client.protocol.McpInitializeRequest;
import dev.langchain4j.mcp.client.protocol.McpListPromptsRequest;
import dev.langchain4j.mcp.client.protocol.McpListResourceTemplatesRequest;
import dev.langchain4j.mcp.client.protocol.McpListResourcesRequest;
import dev.langchain4j.mcp.client.protocol.McpListToolsRequest;
import dev.langchain4j.mcp.client.protocol.McpPingRequest;
import dev.langchain4j.mcp.client.protocol.McpReadResourceRequest;
import dev.langchain4j.mcp.client.transport.McpOperationHandler;
import dev.langchain4j.mcp.client.transport.McpTransport;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:dev/langchain4j/mcp/client/DefaultMcpClient.class */
public class DefaultMcpClient implements McpClient {
    private final McpTransport transport;
    private final String key;
    private final String clientName;
    private final String clientVersion;
    private final String protocolVersion;
    private final Duration initializationTimeout;
    private final Duration toolExecutionTimeout;
    private final Duration resourcesTimeout;
    private final Duration promptsTimeout;
    private final Duration pingTimeout;
    private final String toolExecutionTimeoutErrorMessage;
    private final McpOperationHandler messageHandler;
    private final McpLogMessageHandler logHandler;
    private final Duration reconnectInterval;
    private static final Logger log = LoggerFactory.getLogger(DefaultMcpClient.class);
    static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
    private final AtomicLong idGenerator = new AtomicLong(0);
    private final Map<Long, CompletableFuture<JsonNode>> pendingOperations = new ConcurrentHashMap();
    private final AtomicReference<List<McpResource>> resourceRefs = new AtomicReference<>();
    private final AtomicReference<List<McpResourceTemplate>> resourceTemplateRefs = new AtomicReference<>();
    private final AtomicReference<List<McpPrompt>> promptRefs = new AtomicReference<>();
    private final AtomicReference<List<ToolSpecification>> toolListRefs = new AtomicReference<>();
    private final AtomicBoolean toolListOutOfDate = new AtomicBoolean(true);
    private final AtomicReference<CompletableFuture<Void>> toolListUpdateInProgress = new AtomicReference<>(null);
    private volatile boolean closed = false;
    private final JsonNode RESULT_TIMEOUT = JsonNodeFactory.instance.objectNode();

    /* loaded from: input_file:dev/langchain4j/mcp/client/DefaultMcpClient$Builder.class */
    public static class Builder {
        private String toolExecutionTimeoutErrorMessage;
        private McpTransport transport;
        private String key;
        private String clientName;
        private String clientVersion;
        private String protocolVersion;
        private Duration initializationTimeout;
        private Duration toolExecutionTimeout;
        private Duration resourcesTimeout;
        private Duration pingTimeout;
        private Duration promptsTimeout;
        private McpLogMessageHandler logHandler;
        private Duration reconnectInterval;

        public Builder transport(McpTransport mcpTransport) {
            this.transport = mcpTransport;
            return this;
        }

        public Builder key(String str) {
            this.key = str;
            return this;
        }

        public Builder clientName(String str) {
            this.clientName = str;
            return this;
        }

        public Builder clientVersion(String str) {
            this.clientVersion = str;
            return this;
        }

        public Builder protocolVersion(String str) {
            this.protocolVersion = str;
            return this;
        }

        public Builder initializationTimeout(Duration duration) {
            this.initializationTimeout = duration;
            return this;
        }

        public Builder toolExecutionTimeout(Duration duration) {
            this.toolExecutionTimeout = duration;
            return this;
        }

        public Builder resourcesTimeout(Duration duration) {
            this.resourcesTimeout = duration;
            return this;
        }

        public Builder promptsTimeout(Duration duration) {
            this.promptsTimeout = duration;
            return this;
        }

        public Builder toolExecutionTimeoutErrorMessage(String str) {
            this.toolExecutionTimeoutErrorMessage = str;
            return this;
        }

        public Builder logHandler(McpLogMessageHandler mcpLogMessageHandler) {
            this.logHandler = mcpLogMessageHandler;
            return this;
        }

        public Builder pingTimeout(Duration duration) {
            this.pingTimeout = duration;
            return this;
        }

        public Builder reconnectInterval(Duration duration) {
            this.reconnectInterval = duration;
            return this;
        }

        public DefaultMcpClient build() {
            return new DefaultMcpClient(this);
        }
    }

    public DefaultMcpClient(Builder builder) {
        this.transport = (McpTransport) ValidationUtils.ensureNotNull(builder.transport, "transport");
        this.key = (String) Utils.getOrDefault(builder.key, () -> {
            return UUID.randomUUID().toString();
        });
        this.clientName = (String) Utils.getOrDefault(builder.clientName, "langchain4j");
        this.clientVersion = (String) Utils.getOrDefault(builder.clientVersion, "1.0");
        this.protocolVersion = (String) Utils.getOrDefault(builder.protocolVersion, "2024-11-05");
        this.initializationTimeout = (Duration) Utils.getOrDefault(builder.initializationTimeout, Duration.ofSeconds(30L));
        this.toolExecutionTimeout = (Duration) Utils.getOrDefault(builder.toolExecutionTimeout, Duration.ofSeconds(60L));
        this.resourcesTimeout = (Duration) Utils.getOrDefault(builder.resourcesTimeout, Duration.ofSeconds(60L));
        this.promptsTimeout = (Duration) Utils.getOrDefault(builder.promptsTimeout, Duration.ofSeconds(60L));
        this.logHandler = (McpLogMessageHandler) Utils.getOrDefault(builder.logHandler, new DefaultMcpLogMessageHandler());
        this.pingTimeout = (Duration) Utils.getOrDefault(builder.pingTimeout, Duration.ofSeconds(10L));
        this.reconnectInterval = (Duration) Utils.getOrDefault(builder.reconnectInterval, Duration.ofSeconds(5L));
        this.toolExecutionTimeoutErrorMessage = (String) Utils.getOrDefault(builder.toolExecutionTimeoutErrorMessage, "There was a timeout executing the tool");
        Map<Long, CompletableFuture<JsonNode>> map = this.pendingOperations;
        McpTransport mcpTransport = this.transport;
        McpLogMessageHandler mcpLogMessageHandler = this.logHandler;
        Objects.requireNonNull(mcpLogMessageHandler);
        this.messageHandler = new McpOperationHandler(map, mcpTransport, mcpLogMessageHandler::handleLogMessage, () -> {
            this.toolListOutOfDate.set(true);
        });
        this.RESULT_TIMEOUT.putObject("result").putArray("content").addObject().put("type", "text").put("text", this.toolExecutionTimeoutErrorMessage);
        this.transport.onFailure(() -> {
            if (this.closed) {
                return;
            }
            try {
                TimeUnit.MILLISECONDS.sleep(this.reconnectInterval.toMillis());
                log.info("Trying to reconnect...");
                initialize();
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        });
        initialize();
    }

    private void initialize() {
        this.transport.start(this.messageHandler);
        long andIncrement = this.idGenerator.getAndIncrement();
        McpInitializeRequest mcpInitializeRequest = new McpInitializeRequest(Long.valueOf(andIncrement));
        mcpInitializeRequest.setParams(createInitializeParams());
        try {
            try {
                log.debug("MCP server capabilities: {}", this.transport.initialize(mcpInitializeRequest).get(this.initializationTimeout.toMillis(), TimeUnit.MILLISECONDS).get("result"));
                this.pendingOperations.remove(Long.valueOf(andIncrement));
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            this.pendingOperations.remove(Long.valueOf(andIncrement));
            throw th;
        }
    }

    private InitializeParams createInitializeParams() {
        InitializeParams initializeParams = new InitializeParams();
        initializeParams.setProtocolVersion(this.protocolVersion);
        InitializeParams.ClientInfo clientInfo = new InitializeParams.ClientInfo();
        clientInfo.setName(this.clientName);
        clientInfo.setVersion(this.clientVersion);
        initializeParams.setClientInfo(clientInfo);
        InitializeParams.Capabilities capabilities = new InitializeParams.Capabilities();
        InitializeParams.Capabilities.Roots roots = new InitializeParams.Capabilities.Roots();
        roots.setListChanged(false);
        capabilities.setRoots(roots);
        initializeParams.setCapabilities(capabilities);
        return initializeParams;
    }

    @Override // dev.langchain4j.mcp.client.McpClient
    public String key() {
        return this.key;
    }

    @Override // dev.langchain4j.mcp.client.McpClient
    public List<ToolSpecification> listTools() {
        if (!this.toolListOutOfDate.get()) {
            return this.toolListRefs.get();
        }
        if (this.toolListUpdateInProgress.get() != null) {
            this.toolListUpdateInProgress.get();
            return this.toolListRefs.get();
        }
        CompletableFuture<Void> completableFuture = new CompletableFuture<>();
        this.toolListUpdateInProgress.set(completableFuture);
        try {
            obtainToolList();
            return this.toolListRefs.get();
        } finally {
            completableFuture.complete(null);
            this.toolListOutOfDate.set(false);
            this.toolListUpdateInProgress.set(null);
        }
    }

    @Override // dev.langchain4j.mcp.client.McpClient
    public String executeTool(ToolExecutionRequest toolExecutionRequest) {
        try {
            String arguments = toolExecutionRequest.arguments();
            if (Utils.isNullOrBlank(arguments)) {
                arguments = "{}";
            }
            ObjectNode objectNode = (ObjectNode) OBJECT_MAPPER.readValue(arguments, ObjectNode.class);
            long andIncrement = this.idGenerator.getAndIncrement();
            try {
                try {
                    JsonNode jsonNode = this.transport.executeOperationWithResponse(new McpCallToolRequest(Long.valueOf(andIncrement), toolExecutionRequest.name(), objectNode)).get(this.toolExecutionTimeout.toMillis() == 0 ? 2147483647L : this.toolExecutionTimeout.toMillis(), TimeUnit.MILLISECONDS);
                    this.pendingOperations.remove(Long.valueOf(andIncrement));
                    return ToolExecutionHelper.extractResult(jsonNode);
                } catch (Throwable th) {
                    this.pendingOperations.remove(Long.valueOf(andIncrement));
                    throw th;
                }
            } catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException(e);
            } catch (TimeoutException e2) {
                this.transport.executeOperationWithoutResponse(new CancellationNotification(Long.valueOf(andIncrement), "Timeout"));
                String extractResult = ToolExecutionHelper.extractResult(this.RESULT_TIMEOUT);
                this.pendingOperations.remove(Long.valueOf(andIncrement));
                return extractResult;
            }
        } catch (JsonProcessingException e3) {
            throw new RuntimeException((Throwable) e3);
        }
    }

    @Override // dev.langchain4j.mcp.client.McpClient
    public List<McpResource> listResources() {
        if (this.resourceRefs.get() == null) {
            obtainResourceList();
        }
        return this.resourceRefs.get();
    }

    @Override // dev.langchain4j.mcp.client.McpClient
    public McpReadResourceResult readResource(String str) {
        long andIncrement = this.idGenerator.getAndIncrement();
        try {
            try {
                McpReadResourceResult parseResourceContents = ResourcesHelper.parseResourceContents(this.transport.executeOperationWithResponse(new McpReadResourceRequest(Long.valueOf(andIncrement), str)).get(this.resourcesTimeout.toMillis() == 0 ? 2147483647L : this.resourcesTimeout.toMillis(), TimeUnit.MILLISECONDS));
                this.pendingOperations.remove(Long.valueOf(andIncrement));
                return parseResourceContents;
            } catch (InterruptedException | ExecutionException | TimeoutException e) {
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            this.pendingOperations.remove(Long.valueOf(andIncrement));
            throw th;
        }
    }

    @Override // dev.langchain4j.mcp.client.McpClient
    public List<McpPrompt> listPrompts() {
        if (this.promptRefs.get() == null) {
            obtainPromptList();
        }
        return this.promptRefs.get();
    }

    @Override // dev.langchain4j.mcp.client.McpClient
    public McpGetPromptResult getPrompt(String str, Map<String, Object> map) {
        long andIncrement = this.idGenerator.getAndIncrement();
        try {
            try {
                McpGetPromptResult parsePromptContents = PromptsHelper.parsePromptContents(this.transport.executeOperationWithResponse(new McpGetPromptRequest(Long.valueOf(andIncrement), str, map)).get(this.promptsTimeout.toMillis() == 0 ? 2147483647L : this.promptsTimeout.toMillis(), TimeUnit.MILLISECONDS));
                this.pendingOperations.remove(Long.valueOf(andIncrement));
                return parsePromptContents;
            } catch (InterruptedException | ExecutionException | TimeoutException e) {
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            this.pendingOperations.remove(Long.valueOf(andIncrement));
            throw th;
        }
    }

    @Override // dev.langchain4j.mcp.client.McpClient
    public void checkHealth() {
        this.transport.checkHealth();
        long andIncrement = this.idGenerator.getAndIncrement();
        try {
            try {
                this.transport.executeOperationWithResponse(new McpPingRequest(Long.valueOf(andIncrement))).get(this.pingTimeout.toMillis(), TimeUnit.MILLISECONDS);
                this.pendingOperations.remove(Long.valueOf(andIncrement));
            } catch (InterruptedException | ExecutionException | TimeoutException e) {
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            this.pendingOperations.remove(Long.valueOf(andIncrement));
            throw th;
        }
    }

    @Override // dev.langchain4j.mcp.client.McpClient
    public List<McpResourceTemplate> listResourceTemplates() {
        if (this.resourceTemplateRefs.get() == null) {
            obtainResourceTemplateList();
        }
        return this.resourceTemplateRefs.get();
    }

    private synchronized void obtainToolList() {
        McpListToolsRequest mcpListToolsRequest = new McpListToolsRequest(Long.valueOf(this.idGenerator.getAndIncrement()));
        try {
            try {
                JsonNode jsonNode = this.transport.executeOperationWithResponse(mcpListToolsRequest).get();
                this.pendingOperations.remove(mcpListToolsRequest.getId());
                this.toolListRefs.set(ToolSpecificationHelper.toolSpecificationListFromMcpResponse(jsonNode.get("result").get("tools")));
            } catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            this.pendingOperations.remove(mcpListToolsRequest.getId());
            throw th;
        }
    }

    private synchronized void obtainResourceList() {
        if (this.resourceRefs.get() != null) {
            return;
        }
        McpListResourcesRequest mcpListResourcesRequest = new McpListResourcesRequest(Long.valueOf(this.idGenerator.getAndIncrement()));
        try {
            try {
                this.resourceRefs.set(ResourcesHelper.parseResourceRefs(this.transport.executeOperationWithResponse(mcpListResourcesRequest).get(this.resourcesTimeout.toMillis() == 0 ? 2147483647L : this.resourcesTimeout.toMillis(), TimeUnit.MILLISECONDS)));
                this.pendingOperations.remove(mcpListResourcesRequest.getId());
            } catch (InterruptedException | ExecutionException | TimeoutException e) {
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            this.pendingOperations.remove(mcpListResourcesRequest.getId());
            throw th;
        }
    }

    private synchronized void obtainResourceTemplateList() {
        if (this.resourceTemplateRefs.get() != null) {
            return;
        }
        McpListResourceTemplatesRequest mcpListResourceTemplatesRequest = new McpListResourceTemplatesRequest(Long.valueOf(this.idGenerator.getAndIncrement()));
        try {
            try {
                this.resourceTemplateRefs.set(ResourcesHelper.parseResourceTemplateRefs(this.transport.executeOperationWithResponse(mcpListResourceTemplatesRequest).get(this.toolExecutionTimeout.toMillis() == 0 ? 2147483647L : this.toolExecutionTimeout.toMillis(), TimeUnit.MILLISECONDS)));
                this.pendingOperations.remove(mcpListResourceTemplatesRequest.getId());
            } catch (InterruptedException | ExecutionException | TimeoutException e) {
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            this.pendingOperations.remove(mcpListResourceTemplatesRequest.getId());
            throw th;
        }
    }

    private synchronized void obtainPromptList() {
        if (this.promptRefs.get() != null) {
            return;
        }
        McpListPromptsRequest mcpListPromptsRequest = new McpListPromptsRequest(Long.valueOf(this.idGenerator.getAndIncrement()));
        try {
            try {
                this.promptRefs.set(PromptsHelper.parsePromptRefs(this.transport.executeOperationWithResponse(mcpListPromptsRequest).get(this.promptsTimeout.toMillis() == 0 ? 2147483647L : this.promptsTimeout.toMillis(), TimeUnit.MILLISECONDS)));
                this.pendingOperations.remove(mcpListPromptsRequest.getId());
            } catch (InterruptedException | ExecutionException | TimeoutException e) {
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            this.pendingOperations.remove(mcpListPromptsRequest.getId());
            throw th;
        }
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        this.closed = true;
        try {
            this.transport.close();
        } catch (Exception e) {
            log.warn("Cannot close MCP transport", e);
        }
    }
}
