package schemacrawler.tools.command.aichat.langchain4j;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.service.tool.ToolExecutor;
import java.sql.Connection;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;
import java.util.logging.Level;
import java.util.logging.Logger;
import schemacrawler.schema.Catalog;
import schemacrawler.schemacrawler.exceptions.SchemaCrawlerException;
import schemacrawler.tools.command.aichat.ChatAssistant;
import schemacrawler.tools.command.aichat.langchain4j.AiModelFactoryUtility;
import schemacrawler.tools.command.aichat.options.AiChatCommandOptions;
import us.fatehi.utility.IOUtility;
import us.fatehi.utility.Utility;
import us.fatehi.utility.string.StringFormat;

/* loaded from: input_file:schemacrawler/tools/command/aichat/langchain4j/Langchain4JChatAssistant.class */
public class Langchain4JChatAssistant implements ChatAssistant {
    private static final AiMessage TOOL_CALL_MEMORY_MESSAGE = AiMessage.from("(Information on tool calls is redacted for security reasons.)");
    private static final Logger LOGGER = Logger.getLogger(Langchain4JChatAssistant.class.getCanonicalName());
    private final ChatModel model;
    private final ChatMemory chatMemory;
    private final List<ToolSpecification> toolSpecifications;
    private final Map<String, ToolExecutor> toolExecutors;
    private final ContentRetriever contentRetriever;
    private final String metadataPriming;
    private final int chatContextWindowSize;
    private boolean shouldExit;

    public Langchain4JChatAssistant(AiChatCommandOptions aiChatCommandOptions, Catalog catalog, Connection connection) {
        Objects.requireNonNull(aiChatCommandOptions, "AI chat options not provided");
        Objects.requireNonNull(catalog, "No catalog provided");
        Objects.requireNonNull(connection, "No connection provided");
        AiModelFactoryUtility.AiModelFactory chooseAiModelFactory = AiModelFactoryUtility.chooseAiModelFactory(aiChatCommandOptions);
        if (chooseAiModelFactory == null) {
            throw new SchemaCrawlerException("No models found");
        }
        this.chatContextWindowSize = aiChatCommandOptions.context();
        this.model = chooseAiModelFactory.newChatModel();
        this.chatMemory = chooseAiModelFactory.newChatMemory();
        if (aiChatCommandOptions.useMetadata()) {
            this.contentRetriever = new FullTextCatalogContentRetriever(chooseAiModelFactory.hasEmbeddingModel() ? chooseAiModelFactory.newEmbeddingModel() : null, catalog);
        } else {
            this.contentRetriever = query -> {
                return Collections.emptyList();
            };
        }
        this.toolSpecifications = Langchain4JUtility.tools();
        this.toolExecutors = Langchain4JUtility.toolExecutors(catalog, connection);
        this.metadataPriming = IOUtility.readResourceFully("/metadata-priming.txt");
    }

    public String chat(String str) {
        String text;
        try {
            if (Utility.isBlank(str)) {
                return "";
            }
            this.chatMemory.add(UserMessage.from(str));
            List<ChatMessage> chatContext = getChatContext();
            chatContext.add(0, createSystemMessage(str));
            ChatResponse chat = this.model.chat(ChatRequest.builder().messages(chatContext).toolSpecifications(this.toolSpecifications).build());
            LOGGER.log(Level.INFO, (Supplier<String>) new StringFormat("%s", new Object[]{chat.tokenUsage()}));
            AiMessage aiMessage = chat.aiMessage();
            if (aiMessage.hasToolExecutionRequests()) {
                StringBuilder sb = new StringBuilder();
                for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                    String name = toolExecutionRequest.name();
                    this.shouldExit = !this.shouldExit && name.startsWith("exit");
                    sb.append(this.toolExecutors.get(name).execute(toolExecutionRequest, (Object) null));
                }
                text = sb.toString();
                this.chatMemory.add(TOOL_CALL_MEMORY_MESSAGE);
            } else {
                text = aiMessage.text();
                this.chatMemory.add(aiMessage);
            }
            return text;
        } catch (Exception e) {
            LOGGER.log(Level.WARNING, e, (Supplier<String>) new StringFormat("Exception handling prompt:%n%s", new Object[]{str}));
            e.printStackTrace();
            return "There was a problem. Please try again.";
        }
    }

    public void close() {
    }

    public boolean shouldExit() {
        return this.shouldExit;
    }

    private SystemMessage createSystemMessage(String str) {
        StringBuilder sb = new StringBuilder();
        sb.append(this.metadataPriming).append("\n");
        Iterator it = this.contentRetriever.retrieve(Query.from(str)).iterator();
        while (it.hasNext()) {
            sb.append("\n").append(((Content) it.next()).textSegment().text());
        }
        return SystemMessage.from(sb.toString());
    }

    private List<ChatMessage> getChatContext() {
        ArrayList arrayList = new ArrayList(this.chatMemory.messages());
        int size = arrayList.size();
        return new ArrayList(arrayList.subList(Math.max(0, size - this.chatContextWindowSize), size));
    }
}
