package io.trino.plugin.ai.functions;

import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import com.google.common.net.MediaType;
import com.google.inject.Inject;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.HttpUriBuilder;
import io.airlift.http.client.JsonBodyGenerator;
import io.airlift.http.client.JsonResponseHandler;
import io.airlift.http.client.Request;
import io.airlift.json.JsonCodec;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.SpanKind;
import io.opentelemetry.api.trace.StatusCode;
import io.opentelemetry.api.trace.Tracer;
import io.opentelemetry.context.Scope;
import io.opentelemetry.semconv.incubating.GenAiIncubatingAttributes;
import io.trino.spi.TrinoException;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.net.URI;
import java.util.List;
import java.util.Objects;

/* loaded from: input_file:io/trino/plugin/ai/functions/OpenAiClient.class */
public class OpenAiClient extends AbstractAiClient {
    private static final JsonCodec<ChatRequest> CHAT_REQUEST_CODEC = JsonCodec.jsonCodec(ChatRequest.class);
    private static final JsonCodec<ChatResponse> CHAT_RESPONSE_CODEC = JsonCodec.jsonCodec(ChatResponse.class);
    private final HttpClient httpClient;
    private final Tracer tracer;
    private final URI endpoint;
    private final String apiKey;

    /* loaded from: input_file:io/trino/plugin/ai/functions/OpenAiClient$ChatRequest.class */
    public static final class ChatRequest extends Record {
        private final String model;
        private final List<Message> messages;
        private final int seed;

        /* loaded from: input_file:io/trino/plugin/ai/functions/OpenAiClient$ChatRequest$Message.class */
        public static final class Message extends Record {
            private final String role;
            private final String content;

            public Message(String str, String str2) {
                this.role = str;
                this.content = str2;
            }

            @Override // java.lang.Record
            public final String toString() {
                return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, Message.class), Message.class, "role;content", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatRequest$Message;->role:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatRequest$Message;->content:Ljava/lang/String;").dynamicInvoker().invoke(this) /* invoke-custom */;
            }

            @Override // java.lang.Record
            public final int hashCode() {
                return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, Message.class), Message.class, "role;content", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatRequest$Message;->role:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatRequest$Message;->content:Ljava/lang/String;").dynamicInvoker().invoke(this) /* invoke-custom */;
            }

            @Override // java.lang.Record
            public final boolean equals(Object obj) {
                return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, Message.class, Object.class), Message.class, "role;content", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatRequest$Message;->role:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatRequest$Message;->content:Ljava/lang/String;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
            }

            public String role() {
                return this.role;
            }

            public String content() {
                return this.content;
            }
        }

        public ChatRequest(String str, List<Message> list, int i) {
            this.model = str;
            this.messages = list;
            this.seed = i;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, ChatRequest.class), ChatRequest.class, "model;messages;seed", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatRequest;->model:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatRequest;->messages:Ljava/util/List;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatRequest;->seed:I").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, ChatRequest.class), ChatRequest.class, "model;messages;seed", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatRequest;->model:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatRequest;->messages:Ljava/util/List;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatRequest;->seed:I").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, ChatRequest.class, Object.class), ChatRequest.class, "model;messages;seed", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatRequest;->model:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatRequest;->messages:Ljava/util/List;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatRequest;->seed:I").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public String model() {
            return this.model;
        }

        public List<Message> messages() {
            return this.messages;
        }

        public int seed() {
            return this.seed;
        }
    }

    @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
    /* loaded from: input_file:io/trino/plugin/ai/functions/OpenAiClient$ChatResponse.class */
    public static final class ChatResponse extends Record {
        private final String id;
        private final String model;
        private final List<Choice> choices;
        private final Usage usage;
        private final String serviceTier;
        private final String systemFingerprint;

        /* loaded from: input_file:io/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Choice.class */
        public static final class Choice extends Record {
            private final Message message;

            /* loaded from: input_file:io/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Choice$Message.class */
            public static final class Message extends Record {
                private final String content;
                private final String refusal;

                public Message(String str, String str2) {
                    this.content = str;
                    this.refusal = str2;
                }

                @Override // java.lang.Record
                public final String toString() {
                    return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, Message.class), Message.class, "content;refusal", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Choice$Message;->content:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Choice$Message;->refusal:Ljava/lang/String;").dynamicInvoker().invoke(this) /* invoke-custom */;
                }

                @Override // java.lang.Record
                public final int hashCode() {
                    return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, Message.class), Message.class, "content;refusal", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Choice$Message;->content:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Choice$Message;->refusal:Ljava/lang/String;").dynamicInvoker().invoke(this) /* invoke-custom */;
                }

                @Override // java.lang.Record
                public final boolean equals(Object obj) {
                    return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, Message.class, Object.class), Message.class, "content;refusal", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Choice$Message;->content:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Choice$Message;->refusal:Ljava/lang/String;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
                }

                public String content() {
                    return this.content;
                }

                public String refusal() {
                    return this.refusal;
                }
            }

            public Choice(Message message) {
                this.message = message;
            }

            @Override // java.lang.Record
            public final String toString() {
                return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, Choice.class), Choice.class, "message", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Choice;->message:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Choice$Message;").dynamicInvoker().invoke(this) /* invoke-custom */;
            }

            @Override // java.lang.Record
            public final int hashCode() {
                return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, Choice.class), Choice.class, "message", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Choice;->message:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Choice$Message;").dynamicInvoker().invoke(this) /* invoke-custom */;
            }

            @Override // java.lang.Record
            public final boolean equals(Object obj) {
                return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, Choice.class, Object.class), Choice.class, "message", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Choice;->message:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Choice$Message;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
            }

            public Message message() {
                return this.message;
            }
        }

        @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
        /* loaded from: input_file:io/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Usage.class */
        public static final class Usage extends Record {
            private final int promptTokens;
            private final int completionTokens;

            public Usage(int i, int i2) {
                this.promptTokens = i;
                this.completionTokens = i2;
            }

            @Override // java.lang.Record
            public final String toString() {
                return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, Usage.class), Usage.class, "promptTokens;completionTokens", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Usage;->promptTokens:I", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Usage;->completionTokens:I").dynamicInvoker().invoke(this) /* invoke-custom */;
            }

            @Override // java.lang.Record
            public final int hashCode() {
                return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, Usage.class), Usage.class, "promptTokens;completionTokens", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Usage;->promptTokens:I", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Usage;->completionTokens:I").dynamicInvoker().invoke(this) /* invoke-custom */;
            }

            @Override // java.lang.Record
            public final boolean equals(Object obj) {
                return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, Usage.class, Object.class), Usage.class, "promptTokens;completionTokens", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Usage;->promptTokens:I", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Usage;->completionTokens:I").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
            }

            public int promptTokens() {
                return this.promptTokens;
            }

            public int completionTokens() {
                return this.completionTokens;
            }
        }

        public ChatResponse(String str, String str2, List<Choice> list, Usage usage, String str3, String str4) {
            this.id = str;
            this.model = str2;
            this.choices = list;
            this.usage = usage;
            this.serviceTier = str3;
            this.systemFingerprint = str4;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, ChatResponse.class), ChatResponse.class, "id;model;choices;usage;serviceTier;systemFingerprint", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->id:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->model:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->choices:Ljava/util/List;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->usage:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Usage;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->serviceTier:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->systemFingerprint:Ljava/lang/String;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, ChatResponse.class), ChatResponse.class, "id;model;choices;usage;serviceTier;systemFingerprint", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->id:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->model:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->choices:Ljava/util/List;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->usage:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Usage;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->serviceTier:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->systemFingerprint:Ljava/lang/String;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, ChatResponse.class, Object.class), ChatResponse.class, "id;model;choices;usage;serviceTier;systemFingerprint", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->id:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->model:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->choices:Ljava/util/List;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->usage:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse$Usage;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->serviceTier:Ljava/lang/String;", "FIELD:Lio/trino/plugin/ai/functions/OpenAiClient$ChatResponse;->systemFingerprint:Ljava/lang/String;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public String id() {
            return this.id;
        }

        public String model() {
            return this.model;
        }

        public List<Choice> choices() {
            return this.choices;
        }

        public Usage usage() {
            return this.usage;
        }

        public String serviceTier() {
            return this.serviceTier;
        }

        public String systemFingerprint() {
            return this.systemFingerprint;
        }
    }

    @Inject
    public OpenAiClient(@ForAiClient HttpClient httpClient, Tracer tracer, OpenAiConfig openAiConfig, AiConfig aiConfig) {
        super(aiConfig);
        this.httpClient = (HttpClient) Objects.requireNonNull(httpClient, "httpClient is null");
        this.tracer = (Tracer) Objects.requireNonNull(tracer, "tracer is null");
        this.endpoint = openAiConfig.getEndpoint();
        this.apiKey = openAiConfig.getApiKey();
    }

    @Override // io.trino.plugin.ai.functions.AbstractAiClient
    protected String generateCompletion(String str, String str2) {
        URI build = HttpUriBuilder.uriBuilderFrom(this.endpoint).appendPath("/v1/chat/completions").build();
        ChatRequest chatRequest = new ChatRequest(str, List.of(new ChatRequest.Message("user", str2)), 0);
        Request build2 = Request.Builder.preparePost().setUri(build).setHeader("Authorization", "Bearer " + this.apiKey).setHeader("Content-Type", MediaType.JSON_UTF_8.toString()).setBodyGenerator(JsonBodyGenerator.jsonBodyGenerator(CHAT_REQUEST_CODEC, chatRequest)).build();
        Span startSpan = this.tracer.spanBuilder("chat " + str).setAttribute(GenAiIncubatingAttributes.GEN_AI_OPERATION_NAME, "chat").setAttribute(GenAiIncubatingAttributes.GEN_AI_SYSTEM, "openai").setAttribute(GenAiIncubatingAttributes.GEN_AI_REQUEST_MODEL, str).setAttribute(GenAiIncubatingAttributes.GEN_AI_REQUEST_SEED, chatRequest.seed()).setSpanKind(SpanKind.CLIENT).startSpan();
        try {
            try {
                Scope makeCurrent = startSpan.makeCurrent();
                try {
                    ChatResponse chatResponse = (ChatResponse) this.httpClient.execute(build2, JsonResponseHandler.createJsonResponseHandler(CHAT_RESPONSE_CODEC));
                    startSpan.setAttribute(GenAiIncubatingAttributes.GEN_AI_RESPONSE_ID, chatResponse.id());
                    startSpan.setAttribute(GenAiIncubatingAttributes.GEN_AI_RESPONSE_MODEL, chatResponse.model());
                    startSpan.setAttribute(GenAiIncubatingAttributes.GEN_AI_OPENAI_RESPONSE_SERVICE_TIER, chatResponse.serviceTier());
                    startSpan.setAttribute(GenAiIncubatingAttributes.GEN_AI_OPENAI_RESPONSE_SYSTEM_FINGERPRINT, chatResponse.systemFingerprint());
                    startSpan.setAttribute(GenAiIncubatingAttributes.GEN_AI_USAGE_INPUT_TOKENS, chatResponse.usage().promptTokens());
                    startSpan.setAttribute(GenAiIncubatingAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, chatResponse.usage().completionTokens());
                    if (makeCurrent != null) {
                        makeCurrent.close();
                    }
                    if (chatResponse.choices().isEmpty()) {
                        throw new TrinoException(AiErrorCode.AI_ERROR, "No response from AI provider at %s for model %s".formatted(build, str));
                    }
                    ChatResponse.Choice choice = (ChatResponse.Choice) chatResponse.choices().getFirst();
                    if (choice.message().refusal() != null) {
                        throw new TrinoException(AiErrorCode.AI_ERROR, "AI provider at %s for model %s refused to generate response: %s".formatted(build, str, choice.message().refusal()));
                    }
                    return choice.message().content();
                } catch (Throwable th) {
                    if (makeCurrent != null) {
                        try {
                            makeCurrent.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } catch (RuntimeException e) {
                startSpan.setStatus(StatusCode.ERROR, e.getMessage());
                startSpan.recordException(e);
                throw new TrinoException(AiErrorCode.AI_ERROR, "Request to AI provider at %s for model %s failed".formatted(build, str), e);
            }
        } finally {
            startSpan.end();
        }
    }
}
