package org.opensearch.ml.common.connector;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

/* loaded from: input_file:org/opensearch/ml/common/connector/MLPostProcessFunction.class */
public class MLPostProcessFunction {
    public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding";
    public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding";
    public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";
    private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap();
    private static final Map<String, Function<List<List<Float>>, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap();

    public static Function<List<List<Float>>, List<ModelTensor>> buildModelTensorList() {
        return list -> {
            ArrayList arrayList = new ArrayList();
            if (list == null) {
                throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
            }
            list.forEach(list -> {
                arrayList.add(ModelTensor.builder().name("sentence_embedding").dataType(MLResultDataType.FLOAT32).shape(new long[]{list.size()}).data((Number[]) list.toArray(new Number[0])).build());
            });
            return arrayList;
        };
    }

    public static String getResponseFilter(String str) {
        return JSON_PATH_EXPRESSION.get(str);
    }

    public static Function<List<List<Float>>, List<ModelTensor>> get(String str) {
        return POST_PROCESS_FUNCTIONS.get(str);
    }

    public static boolean contains(String str) {
        return POST_PROCESS_FUNCTIONS.containsKey(str);
    }

    static {
        JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
        JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
        JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
        POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorList());
        POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorList());
        POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorList());
    }
}
