package io.trino.plugin.ai.functions;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.UncheckedExecutionException;
import io.airlift.json.JsonCodec;
import io.trino.cache.SafeCaches;
import io.trino.spi.TrinoException;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutionException;

/* loaded from: input_file:io/trino/plugin/ai/functions/AbstractAiClient.class */
public abstract class AbstractAiClient implements AiClient {
    protected static final JsonCodec<List<String>> LIST_CODEC = JsonCodec.listJsonCodec(String.class);
    protected static final JsonCodec<Map<String, String>> MAP_CODEC = JsonCodec.mapJsonCodec(String.class, String.class);
    protected static final JsonCodec<String> STRING_CODEC = JsonCodec.jsonCodec(String.class);
    protected final String analyzeSentimentModel;
    protected final String classifyModel;
    protected final String extractModel;
    protected final String fixGrammarModel;
    protected final String generateModel;
    protected final String maskModel;
    protected final String translateModel;
    private final Cache<String, String> completionCache = SafeCaches.buildNonEvictableCache(CacheBuilder.newBuilder().maximumSize(1000));

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractAiClient(AiConfig aiConfig) {
        this.analyzeSentimentModel = (String) Objects.requireNonNullElse(aiConfig.getAnalyzeSentimentModel(), aiConfig.getModel());
        this.classifyModel = (String) Objects.requireNonNullElse(aiConfig.getClassifyModel(), aiConfig.getModel());
        this.extractModel = (String) Objects.requireNonNullElse(aiConfig.getExtractModel(), aiConfig.getModel());
        this.fixGrammarModel = (String) Objects.requireNonNullElse(aiConfig.getFixGrammarModel(), aiConfig.getModel());
        this.generateModel = (String) Objects.requireNonNullElse(aiConfig.getGenerateModel(), aiConfig.getModel());
        this.maskModel = (String) Objects.requireNonNullElse(aiConfig.getMaskModel(), aiConfig.getModel());
        this.translateModel = (String) Objects.requireNonNullElse(aiConfig.getTranslateModel(), aiConfig.getModel());
    }

    @Override // io.trino.plugin.ai.functions.AiClient
    public String analyzeSentiment(String str) {
        return completion(this.analyzeSentimentModel, "Classify the text below into one of the following labels: [positive, negative, neutral, mixed]\nOutput only the label.\n=====\n%s\n".formatted(str)).toLowerCase(Locale.ROOT);
    }

    @Override // io.trino.plugin.ai.functions.AiClient
    public String classify(String str, List<String> list) {
        try {
            return (String) STRING_CODEC.fromJson(completion(this.classifyModel, "Classify the text below into one of the following JSON encoded labels: %s\nOutput the label as a JSON string (not a JSON object).\nOutput only the label.\n=====\n%s\n".formatted(LIST_CODEC.toJson(list), str)));
        } catch (IllegalArgumentException e) {
            throw new TrinoException(AiErrorCode.AI_ERROR, "Failed to parse AI response", e);
        }
    }

    @Override // io.trino.plugin.ai.functions.AiClient
    public Map<String, String> extract(String str, List<String> list) {
        try {
            return Maps.filterValues((Map) MAP_CODEC.fromJson(completion(this.extractModel, "Extract a value for each of the JSON encoded labels from the text below.\nFor each label, only extract a single value.\nLabels: %s\nOutput the extracted values as a JSON object.\nOutput only the JSON.\nDo not output a code block for the JSON.\n=====\n%s\n".formatted(LIST_CODEC.toJson(list), str))), (v0) -> {
                return Objects.nonNull(v0);
            });
        } catch (IllegalArgumentException e) {
            throw new TrinoException(AiErrorCode.AI_ERROR, "Failed to parse AI response", e);
        }
    }

    @Override // io.trino.plugin.ai.functions.AiClient
    public String fixGrammar(String str) {
        return completion(this.fixGrammarModel, "Fix the grammar in the text below.\nOutput only the text.\n=====\n%s\n".formatted(str));
    }

    @Override // io.trino.plugin.ai.functions.AiClient
    public String generate(String str) {
        return completion(this.generateModel, str);
    }

    @Override // io.trino.plugin.ai.functions.AiClient
    public String mask(String str, List<String> list) {
        return completion(this.maskModel, "Mask the values for each of the JSON encoded labels in the text below.\nLabels: %s\nReplace the values with the text \"[MASKED]\".\nOutput only the masked text.\nDo not output anything else.\n=====\n%s\n".formatted(LIST_CODEC.toJson(list), str));
    }

    @Override // io.trino.plugin.ai.functions.AiClient
    public String translate(String str, String str2) {
        return completion(this.translateModel, "Translate the text below to the language specified.\nThe language is encoded as a JSON string.\nOutput only the translated text.\nLanguage: %s\n=====\n%s\n".formatted(STRING_CODEC.toJson(str2), str));
    }

    private String completion(String str, String str2) {
        try {
            return (String) this.completionCache.get(str + "��" + str2, () -> {
                return generateCompletion(str, str2);
            });
        } catch (ExecutionException e) {
            throw new UncheckedExecutionException(e);
        } catch (UncheckedExecutionException e2) {
            TrinoException cause = e2.getCause();
            if (cause instanceof TrinoException) {
                throw cause;
            }
            throw e2;
        }
    }

    protected abstract String generateCompletion(String str, String str2);
}
