package cn.xfyun.api;

import cn.xfyun.base.websocket.AbstractClient;
import cn.xfyun.config.SparkModel;
import cn.xfyun.exception.BusinessException;
import cn.xfyun.model.sparkmodel.FunctionCall;
import cn.xfyun.model.sparkmodel.SparkChatParam;
import cn.xfyun.model.sparkmodel.WebSearch;
import cn.xfyun.model.sparkmodel.request.SparkChatPostRequest;
import cn.xfyun.model.sparkmodel.request.SparkChatRequest;
import cn.xfyun.util.StringUtils;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.net.MalformedURLException;
import java.security.SignatureException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import okhttp3.Callback;
import okhttp3.HttpUrl;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okhttp3.WebSocket;
import okhttp3.WebSocketListener;
import okhttp3.internal.Util;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:cn/xfyun/api/SparkChatClient.class */
public class SparkChatClient extends AbstractClient {
    private static final Logger logger = LoggerFactory.getLogger(SparkChatClient.class);
    private final SparkModel sparkModel;
    private final float temperature;
    private final int maxTokens;
    private final int topK;
    private final WebSearch webSearch;
    private final List<FunctionCall> functions;
    private final int topP;
    private final float presencePenalty;
    private final float frequencyPenalty;
    private final boolean toolCallsSwitch;
    private final Object toolChoice;
    private final String responseType;
    private final List<String> suppressPlugin;
    private final boolean keepAlive;

    /* loaded from: input_file:cn/xfyun/api/SparkChatClient$Builder.class */
    public static final class Builder {
        private String appId;
        private String apiKey;
        private String apiSecret;
        private SparkModel sparkModel;
        private WebSearch webSearch;
        private List<FunctionCall> functions;
        private Object toolChoice;
        private String responseType;
        private List<String> suppressPlugin;
        public final String SPARK_X1_URL = "https://spark-api-open.xf-yun.com/v2/chat/completions";
        public final String SPARK_URL = "https://spark-api-open.xf-yun.com/v1/chat/completions";
        private boolean retryOnConnectionFailure = true;
        private int callTimeout = 0;
        private int connectTimeout = 30000;
        private int readTimeout = 60000;
        private int writeTimeout = 30000;
        private int pingInterval = 0;
        private String hostUrl = "https://spark-api-open.xf-yun.com/v1/chat/completions";
        private float temperature = 0.5f;
        private int maxTokens = 4096;
        private int topK = 4;
        private int topP = 1;
        private float presencePenalty = 0.0f;
        private float frequencyPenalty = 0.0f;
        private boolean toolCallsSwitch = false;
        private boolean keepAlive = false;

        public SparkChatClient build() {
            return new SparkChatClient(this);
        }

        public Builder signatureWs(String str, String str2, String str3, SparkModel sparkModel) {
            this.appId = str;
            this.apiKey = str2;
            this.apiSecret = str3;
            this.sparkModel = sparkModel;
            this.hostUrl = sparkModel.getUrl();
            if (SparkModel.SPARK_X1 == sparkModel) {
                this.temperature = 1.2f;
                this.frequencyPenalty = 0.001f;
                this.presencePenalty = 2.01f;
            }
            return this;
        }

        public Builder signatureHttp(String str, SparkModel sparkModel) {
            this.apiKey = str;
            this.sparkModel = sparkModel;
            this.temperature = 1.0f;
            if (SparkModel.SPARK_X1 == sparkModel) {
                this.temperature = 1.2f;
                this.frequencyPenalty = 0.001f;
                this.presencePenalty = 2.01f;
                this.hostUrl = "https://spark-api-open.xf-yun.com/v2/chat/completions";
            }
            return this;
        }

        public Builder callTimeout(long j, TimeUnit timeUnit) {
            this.callTimeout = Util.checkDuration("timeout", j, timeUnit);
            return this;
        }

        public Builder connectTimeout(long j, TimeUnit timeUnit) {
            this.connectTimeout = Util.checkDuration("timeout", j, timeUnit);
            return this;
        }

        public Builder readTimeout(long j, TimeUnit timeUnit) {
            this.readTimeout = Util.checkDuration("timeout", j, timeUnit);
            return this;
        }

        public Builder writeTimeout(long j, TimeUnit timeUnit) {
            this.writeTimeout = Util.checkDuration("timeout", j, timeUnit);
            return this;
        }

        public Builder pingInterval(long j, TimeUnit timeUnit) {
            this.pingInterval = Util.checkDuration("interval", j, timeUnit);
            return this;
        }

        public Builder retryOnConnectionFailure(boolean z) {
            this.retryOnConnectionFailure = z;
            return this;
        }

        public Builder temperature(float f) {
            this.temperature = f;
            return this;
        }

        public Builder maxTokens(int i) {
            this.maxTokens = i;
            return this;
        }

        public Builder topK(int i) {
            this.topK = i;
            return this;
        }

        public Builder webSearch(WebSearch webSearch) {
            this.webSearch = webSearch;
            return this;
        }

        public Builder functions(List<FunctionCall> list) {
            this.functions = list;
            return this;
        }

        public Builder hostUrl(String str) {
            this.hostUrl = str;
            return this;
        }

        public Builder topP(int i) {
            this.topP = i;
            return this;
        }

        public Builder presencePenalty(float f) {
            this.presencePenalty = f;
            return this;
        }

        public Builder frequencyPenalty(float f) {
            this.frequencyPenalty = f;
            return this;
        }

        public Builder toolCallsSwitch(boolean z) {
            this.toolCallsSwitch = z;
            return this;
        }

        public Builder toolChoice(Object obj) {
            this.toolChoice = obj;
            return this;
        }

        public Builder responseType(String str) {
            this.responseType = str;
            return this;
        }

        public Builder suppressPlugin(List<String> list) {
            this.suppressPlugin = list;
            return this;
        }

        public Builder keepAlive(boolean z) {
            this.keepAlive = z;
            return this;
        }
    }

    public SparkChatClient(Builder builder) {
        this.okHttpClient = new OkHttpClient.Builder().callTimeout(builder.callTimeout, TimeUnit.MILLISECONDS).connectTimeout(builder.callTimeout, TimeUnit.MILLISECONDS).readTimeout(builder.readTimeout, TimeUnit.MILLISECONDS).writeTimeout(builder.writeTimeout, TimeUnit.MILLISECONDS).retryOnConnectionFailure(builder.retryOnConnectionFailure).build();
        this.appId = builder.appId;
        this.apiKey = builder.apiKey;
        this.apiSecret = builder.apiSecret;
        this.originHostUrl = builder.hostUrl;
        this.sparkModel = builder.sparkModel;
        this.temperature = builder.temperature;
        this.maxTokens = builder.maxTokens;
        this.topK = builder.topK;
        this.webSearch = builder.webSearch;
        this.functions = builder.functions;
        this.topP = builder.topP;
        this.presencePenalty = builder.presencePenalty;
        this.frequencyPenalty = builder.frequencyPenalty;
        this.toolCallsSwitch = builder.toolCallsSwitch;
        this.toolChoice = builder.toolChoice;
        this.responseType = builder.responseType;
        this.suppressPlugin = builder.suppressPlugin;
        this.keepAlive = builder.keepAlive;
        this.retryOnConnectionFailure = builder.retryOnConnectionFailure;
        this.callTimeout = builder.callTimeout;
        this.connectTimeout = builder.connectTimeout;
        this.readTimeout = builder.readTimeout;
        this.writeTimeout = builder.writeTimeout;
        this.pingInterval = builder.pingInterval;
    }

    public String getResponseType() {
        return this.responseType;
    }

    public List<String> getSuppressPlugin() {
        return this.suppressPlugin;
    }

    public Integer getTopP() {
        return Integer.valueOf(this.topP);
    }

    public Float getPresencePenalty() {
        return Float.valueOf(this.presencePenalty);
    }

    public Float getFrequencyPenalty() {
        return Float.valueOf(this.frequencyPenalty);
    }

    public Boolean getToolCallsSwitch() {
        return Boolean.valueOf(this.toolCallsSwitch);
    }

    public Object getToolChoice() {
        return this.toolChoice;
    }

    public SparkModel getSparkModel() {
        return this.sparkModel;
    }

    public float getTemperature() {
        return this.temperature;
    }

    public int getMaxTokens() {
        return this.maxTokens;
    }

    public int getTopK() {
        return this.topK;
    }

    public WebSearch getWebSearch() {
        return this.webSearch;
    }

    public List<FunctionCall> getFunctions() {
        return this.functions;
    }

    public Boolean getKeepAlive() {
        return Boolean.valueOf(this.keepAlive);
    }

    public void send(SparkChatParam sparkChatParam, WebSocketListener webSocketListener) throws MalformedURLException, SignatureException {
        paramCheck(sparkChatParam, false);
        WebSocket newWebSocket = newWebSocket(webSocketListener);
        try {
            String buildParam = buildParam(sparkChatParam);
            logger.debug("星火文本大模型ws请求参数：{}", buildParam);
            newWebSocket.send(buildParam);
        } catch (Exception e) {
            logger.error("ws消息发送失败", e);
        }
    }

    public String send(SparkChatParam sparkChatParam) throws IOException {
        paramCheck(sparkChatParam, true);
        String buildPostParam = buildPostParam(sparkChatParam, false);
        logger.debug("{}post请求URL：{}，参数：{}", new Object[]{this.sparkModel.getDesc(), this.originHostUrl, buildPostParam});
        Response execute = this.okHttpClient.newCall(getRequest(buildPostParam, false)).execute();
        Throwable th = null;
        try {
            String string = ((ResponseBody) Objects.requireNonNull(execute.body(), this.sparkModel.getDesc() + "post请求返回结果为空")).string();
            if (execute != null) {
                if (0 != 0) {
                    try {
                        execute.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    execute.close();
                }
            }
            return string;
        } catch (Throwable th3) {
            if (execute != null) {
                if (0 != 0) {
                    try {
                        execute.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    execute.close();
                }
            }
            throw th3;
        }
    }

    public void send(SparkChatParam sparkChatParam, Callback callback) {
        paramCheck(sparkChatParam, true);
        String buildPostParam = buildPostParam(sparkChatParam, true);
        logger.debug("{}post请求URL：{}，参数：{}", new Object[]{this.sparkModel.getDesc(), this.originHostUrl, buildPostParam});
        this.okHttpClient.newCall(getRequest(buildPostParam, true)).enqueue(callback);
    }

    private void paramCheck(SparkChatParam sparkChatParam, boolean z) {
        if (sparkChatParam == null) {
            throw new BusinessException("参数不能为空");
        }
        if (sparkChatParam.getMessages() == null || sparkChatParam.getMessages().isEmpty()) {
            throw new BusinessException("文本内容不能为空");
        }
        if (this.sparkModel == SparkModel.CHAT_MULTILANG && z) {
            throw new BusinessException(this.sparkModel.getDesc() + "暂不支持post请求");
        }
        if (null != sparkChatParam.getWebSearch() && sparkChatParam.getWebSearch().isEnable() && !this.sparkModel.isWebSearchEnable()) {
            throw new BusinessException(this.sparkModel.getDesc() + "暂不支持联网搜索");
        }
        if (null != sparkChatParam.getFunctions() && !this.sparkModel.isFunctionEnable()) {
            throw new BusinessException(this.sparkModel.getDesc() + "暂不支持function调用");
        }
    }

    private Request getRequest(String str, boolean z) {
        Request.Builder post = new Request.Builder().url(((HttpUrl) Objects.requireNonNull(HttpUrl.parse(this.originHostUrl), "请求地址错误：" + this.originHostUrl)).newBuilder().build().toString()).post(RequestBody.create(MediaType.get("application/json; charset=utf-8"), str));
        post.addHeader("Authorization", "Bearer " + this.apiKey);
        post.addHeader("Content-type", "application/json");
        if (z) {
            post.addHeader("Accept", "text/event-stream");
        }
        return post.build();
    }

    private String buildPostParam(SparkChatParam sparkChatParam, boolean z) {
        SparkChatPostRequest sparkChatPostRequest = new SparkChatPostRequest(this);
        sparkChatPostRequest.setModel(this.sparkModel.getDomain());
        sparkChatPostRequest.setStream(Boolean.valueOf(z));
        sparkChatPostRequest.setUser(sparkChatParam.getUserId());
        sparkChatPostRequest.setMessages(sparkChatParam.getMessages());
        sparkChatPostRequest.setTopP(Integer.valueOf(this.topP));
        sparkChatPostRequest.setPresencePenalty(Float.valueOf(this.presencePenalty));
        sparkChatPostRequest.setFrequencyPenalty(Float.valueOf(this.frequencyPenalty));
        sparkChatPostRequest.setToolCallsSwitch(Boolean.valueOf(this.toolCallsSwitch));
        sparkChatPostRequest.setToolChoice(this.toolChoice);
        sparkChatPostRequest.setSuppressPlugin(this.suppressPlugin);
        if (SparkModel.SPARK_X1 == this.sparkModel) {
            sparkChatPostRequest.setKeepAlive(Boolean.valueOf(this.keepAlive));
        }
        ArrayList arrayList = new ArrayList();
        WebSearch webSearch = sparkChatParam.getWebSearch();
        WebSearch webSearch2 = null != webSearch ? webSearch : this.webSearch;
        if (null != webSearch2) {
            JsonObject jsonObject = new JsonObject();
            jsonObject.addProperty("type", "web_search");
            jsonObject.add("web_search", StringUtils.gson.toJsonTree(webSearch2));
            arrayList.add(jsonObject);
        }
        List<FunctionCall> functions = sparkChatParam.getFunctions();
        List<FunctionCall> list = (null == functions || functions.isEmpty()) ? this.functions : functions;
        if (null != list && !list.isEmpty()) {
            list.forEach(functionCall -> {
                JsonObject jsonObject2 = new JsonObject();
                jsonObject2.addProperty("type", "function");
                jsonObject2.add("function", StringUtils.gson.toJsonTree(functionCall));
                arrayList.add(jsonObject2);
            });
        }
        if (!arrayList.isEmpty()) {
            sparkChatPostRequest.setTools(arrayList);
        }
        if (!StringUtils.isNullOrEmpty(this.responseType)) {
            sparkChatPostRequest.setResponseFormat(new SparkChatPostRequest.ResponseFormat());
        }
        return StringUtils.gson.toJson(sparkChatPostRequest);
    }

    private String buildParam(SparkChatParam sparkChatParam) {
        SparkChatRequest sparkChatRequest = new SparkChatRequest();
        String userId = sparkChatParam.getUserId();
        if (null == userId) {
            userId = UUID.randomUUID().toString().substring(0, 10);
        }
        sparkChatRequest.setHeader(new SparkChatRequest.Header(this.appId, userId));
        SparkChatRequest.Parameter parameter = new SparkChatRequest.Parameter();
        SparkChatRequest.Parameter.Chat chat = new SparkChatRequest.Parameter.Chat();
        chat.setDomain(this.sparkModel.getDomain());
        chat.setTemperature(Float.valueOf(this.temperature));
        chat.setMaxTokens(Integer.valueOf(this.maxTokens));
        chat.setTopK(Integer.valueOf(this.topK));
        if (SparkModel.SPARK_X1 == this.sparkModel) {
            chat.setTopP(Integer.valueOf(this.topP));
            chat.setFrequencyPenalty(Float.valueOf(this.frequencyPenalty));
            chat.setPresencePenalty(Float.valueOf(this.presencePenalty));
        }
        WebSearch webSearch = sparkChatParam.getWebSearch();
        WebSearch webSearch2 = null != webSearch ? webSearch : this.webSearch;
        if (null != webSearch2 && this.sparkModel.isWebSearchEnable()) {
            JsonObject jsonObject = new JsonObject();
            jsonObject.addProperty("type", "web_search");
            jsonObject.add("web_search", StringUtils.gson.toJsonTree(webSearch2));
            chat.setTools(Collections.singletonList(jsonObject));
        }
        chat.setChatId(sparkChatParam.getChatId());
        parameter.setChat(chat);
        sparkChatRequest.setParameter(parameter);
        List<FunctionCall> functions = sparkChatParam.getFunctions();
        List<FunctionCall> list = (null == functions || functions.isEmpty()) ? this.functions : functions;
        SparkChatRequest.Payload payload = new SparkChatRequest.Payload();
        SparkChatRequest.Payload.Message message = new SparkChatRequest.Payload.Message();
        message.setText(sparkChatParam.getMessages());
        if (null != list && !list.isEmpty() && this.sparkModel.isFunctionEnable()) {
            SparkChatRequest.Payload.Function function = new SparkChatRequest.Payload.Function();
            function.setText(list);
            payload.setFunctions(function);
        }
        payload.setMessage(message);
        sparkChatRequest.setPayload(payload);
        return StringUtils.gson.toJson(sparkChatRequest);
    }
}
