package fun.fengwk.chatjava.core.client.tool;

import fun.fengwk.chatjava.core.client.ChatClient;
import fun.fengwk.chatjava.core.client.ChatClientOptions;
import fun.fengwk.chatjava.core.client.ChatCompletionsResponse;
import fun.fengwk.chatjava.core.client.StreamChatListener;
import fun.fengwk.chatjava.core.client.request.ChatMessage;
import fun.fengwk.chatjava.core.client.request.ChatRequest;
import fun.fengwk.chatjava.core.client.response.ChatToolCall;
import fun.fengwk.chatjava.core.client.response.ChatToolCallFunction;
import fun.fengwk.chatjava.core.client.util.ChatUtils;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:fun/fengwk/chatjava/core/client/tool/ToolChatClient.class */
public class ToolChatClient implements ChatClient {
    private static final Logger log = LoggerFactory.getLogger(ToolChatClient.class);
    private final ChatClient delegate;
    private final ToolFunctionHandlerRegistryView registryView;

    public ToolChatClient(ChatClient chatClient, ToolFunctionHandlerRegistryView toolFunctionHandlerRegistryView) {
        this.delegate = (ChatClient) Objects.requireNonNull(chatClient);
        this.registryView = toolFunctionHandlerRegistryView;
    }

    @Override // fun.fengwk.chatjava.core.client.ChatClient
    public ChatClientOptions getClientOptions() {
        return this.delegate.getClientOptions();
    }

    @Override // fun.fengwk.chatjava.core.client.ChatClient
    public void setClientOptions(ChatClientOptions chatClientOptions) {
        this.delegate.setClientOptions(chatClientOptions);
    }

    @Override // fun.fengwk.chatjava.core.client.ChatClient
    public ChatCompletionsResponse chatCompletions(ChatRequest chatRequest, ChatClientOptions chatClientOptions) {
        ChatRequest copy = chatRequest.copy();
        setTools(copy);
        ChatCompletionsResponse chatCompletions = this.delegate.chatCompletions(copy, chatClientOptions);
        while (true) {
            if (!ChatUtils.isToolCalls(chatCompletions)) {
                break;
            }
            if (1 > chatClientOptions.getMaxFunctionCallTimes()) {
                chatCompletions = handleExceedsCallTimes(chatCompletions, chatClientOptions.getMaxFunctionCallTimes());
                break;
            }
            ChatMessage message = chatCompletions.getChatResponse().getChoices().get(0).getMessage();
            copy.getMessages().add(message);
            for (ChatToolCall chatToolCall : message.getTool_calls()) {
                if (ChatUtils.isFunctionCall(chatToolCall)) {
                    ChatToolCallFunction function = chatToolCall.getFunction();
                    copy.getMessages().add(ChatMessage.newToolMessage(chatToolCall.getId(), this.registryView.getHandlerRequired(function.getName()).call(function.getArguments())));
                }
            }
            chatCompletions = this.delegate.chatCompletions(copy, chatClientOptions);
        }
        return chatCompletions;
    }

    @Override // fun.fengwk.chatjava.core.client.ChatClient
    public CompletableFuture<ChatCompletionsResponse> streamChatCompletions(ChatRequest chatRequest, StreamChatListener streamChatListener, ChatClientOptions chatClientOptions) {
        ChatRequest copy = chatRequest.copy();
        setTools(copy);
        return doStreamChatCompletions(copy, streamChatListener, chatClientOptions, 1);
    }

    private CompletableFuture<ChatCompletionsResponse> doStreamChatCompletions(ChatRequest chatRequest, StreamChatListener streamChatListener, ChatClientOptions chatClientOptions, int i) {
        return this.delegate.streamChatCompletions(chatRequest, new ToolStreamChatListener(streamChatListener), chatClientOptions).thenCompose(chatCompletionsResponse -> {
            if (!ChatUtils.isToolCalls(chatCompletionsResponse)) {
                return CompletableFuture.completedStage(chatCompletionsResponse);
            }
            if (!chatCompletionsResponse.isSuccess()) {
                streamChatListener.onError(chatCompletionsResponse.getError());
                return CompletableFuture.completedStage(chatCompletionsResponse);
            }
            if (i > chatClientOptions.getMaxFunctionCallTimes()) {
                ChatCompletionsResponse handleExceedsCallTimes = handleExceedsCallTimes(chatCompletionsResponse, chatClientOptions.getMaxFunctionCallTimes());
                streamChatListener.onError(handleExceedsCallTimes.getError());
                return CompletableFuture.completedStage(handleExceedsCallTimes);
            }
            ChatMessage delta = chatCompletionsResponse.getChatResponse().getChoices().get(0).getDelta();
            chatRequest.getMessages().add(delta);
            for (ChatToolCall chatToolCall : delta.getTool_calls()) {
                if (ChatUtils.isFunctionCall(chatToolCall)) {
                    ChatToolCallFunction function = chatToolCall.getFunction();
                    chatRequest.getMessages().add(ChatMessage.newToolMessage(chatToolCall.getId(), this.registryView.getHandlerRequired(function.getName()).call(function.getArguments())));
                }
            }
            return doStreamChatCompletions(chatRequest, streamChatListener, chatClientOptions, i + 1);
        });
    }

    private void setTools(ChatRequest chatRequest) {
        chatRequest.setTools(this.registryView.getTools());
    }

    private ChatCompletionsResponse handleExceedsCallTimes(ChatCompletionsResponse chatCompletionsResponse, int i) {
        IllegalStateException illegalStateException = new IllegalStateException(String.format("The number of calls exceeds %d times", Integer.valueOf(i)));
        log.warn("{}, resp: {}", illegalStateException.getMessage(), chatCompletionsResponse);
        return new ChatCompletionsResponse(false, chatCompletionsResponse.getChatResponse(), illegalStateException.getMessage(), illegalStateException);
    }
}
