package fi.evolver.ai.taskchain.step;

import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.embedding.EmbeddingApi;
import fi.evolver.ai.spring.embedding.EmbeddingCache;
import fi.evolver.ai.spring.prompt.template.TemplateUtils;
import fi.evolver.ai.spring.util.TokenUtils;
import fi.evolver.ai.taskchain.model.StepState;
import fi.evolver.ai.taskchain.model.Value;
import fi.evolver.ai.taskchain.model.value.ListValue;
import fi.evolver.ai.taskchain.step.EmbeddingStrategy;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.util.ProxyUtils;
import org.springframework.stereotype.Component;

@Component
/* loaded from: input_file:fi/evolver/ai/taskchain/step/CreateEmbeddingsStepRunner.class */
public class CreateEmbeddingsStepRunner implements StepRunner {
    private static final Logger LOG = LoggerFactory.getLogger(CreateEmbeddingsStepRunner.class);
    private static final String PARAM_VALUE = "value";
    private static final String PARAM_DATASET = "dataset";
    private static final String PARAM_MODEL = "model";
    private static final String PARAM_MAX_TOKEN = "max_tokens";
    private static final String PARAM_TOKEN_LIMIT = "token_limit";
    private static final String PARAM_TOKENIZER = "tokenizer";
    private static final String PARAM_PROVIDER = "provider";
    private static final String PARAM_STRATEGIES = "strategies";
    private final Map<String, EmbeddingStrategy> strategiesByName;
    private final EmbeddingApi embeddingApi;

    @Autowired
    public CreateEmbeddingsStepRunner(EmbeddingApi embeddingApi, List<EmbeddingStrategy> list) {
        this.embeddingApi = embeddingApi;
        this.strategiesByName = (Map) list.stream().collect(Collectors.toMap(embeddingStrategy -> {
            return ProxyUtils.getUserClass(embeddingStrategy.getClass()).getSimpleName().replaceFirst("Strategy$", "").toLowerCase();
        }, Function.identity(), (embeddingStrategy2, embeddingStrategy3) -> {
            return embeddingStrategy2;
        }, () -> {
            return new TreeMap(String.CASE_INSENSITIVE_ORDER);
        }));
    }

    @Override // fi.evolver.ai.taskchain.step.StepRunner
    public Value run(StepState stepState) {
        Value expectParameter = stepState.expectParameter("value");
        String asString = stepState.expectParameter(PARAM_DATASET).asString();
        String asString2 = stepState.expectParameter(PARAM_MODEL).asString();
        int asInt = stepState.expectParameter(PARAM_MAX_TOKEN).asInt();
        String str = (String) stepState.acceptParameter(PARAM_TOKEN_LIMIT).map((v0) -> {
            return v0.asString();
        }).orElse(null);
        String str2 = (String) stepState.acceptParameter(PARAM_TOKENIZER).map((v0) -> {
            return v0.asString();
        }).orElse("cl100k_base");
        String str3 = (String) stepState.acceptParameter(PARAM_PROVIDER).map((v0) -> {
            return v0.asString();
        }).orElse(null);
        Stream map = stepState.acceptParameter(PARAM_STRATEGIES).map((v0) -> {
            return v0.asList();
        }).stream().flatMap((v0) -> {
            return v0.stream();
        }).map((v0) -> {
            return v0.asString();
        });
        Map<String, EmbeddingStrategy> map2 = this.strategiesByName;
        Objects.requireNonNull(map2);
        ArrayList arrayList = new ArrayList(map.map((v1) -> {
            return r3.get(v1);
        }).toList());
        if (arrayList.isEmpty()) {
            arrayList.addAll(List.of(this.strategiesByName.get("sentence"), this.strategiesByName.get("line")));
        }
        Model<EmbeddingApi> model = new Model<>(asString2, TemplateUtils.inferTokenLimit(asString2, str), TemplateUtils.inferTokenizer(asString2, str2));
        EmbeddingCache embeddingCache = stepState.getEmbeddingCache(model, asString, this.embeddingApi);
        this.embeddingApi.createEmbeddings(str3, model, asString, (Map) createDataEntries(expectParameter, asInt, arrayList).entrySet().stream().filter(entry -> {
            return embeddingCache.hasChanged((String) entry.getKey(), (String) entry.getValue());
        }).collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        })), Duration.ofSeconds(60L));
        return ListValue.EMPTY;
    }

    private static Map<String, String> createDataEntries(Value value, int i, List<EmbeddingStrategy> list) {
        HashMap hashMap = new HashMap();
        List<String> createEntries = createEntries(value.asList().stream().map((v0) -> {
            return v0.asString();
        }).toList(), i, list);
        for (int i2 = 0; i2 < createEntries.size(); i2++) {
            hashMap.put("embedding-%s".formatted(Integer.valueOf(i2)), createEntries.get(i2));
        }
        return hashMap;
    }

    private static List<String> createEntries(List<String> list, int i, List<EmbeddingStrategy> list2) {
        ArrayList arrayList = new ArrayList();
        for (String str : list) {
            int calculateTokens = TokenUtils.calculateTokens(str);
            if (calculateTokens <= i) {
                arrayList.add(str);
            } else {
                EmbeddingStrategy.SplitResult split = list2.get(0).split(str, i);
                arrayList.addAll(split.splitted());
                if (!split.unchanged().isEmpty()) {
                    if (list2.size() > 1) {
                        arrayList.addAll(createEntries(split.unchanged(), i, list2.subList(1, list2.size())));
                    } else {
                        LOG.warn("Unable to split large entry (tokens: {}, max tokens: {})", Integer.valueOf(calculateTokens), Integer.valueOf(i));
                    }
                }
            }
        }
        return arrayList;
    }

    @Override // fi.evolver.ai.taskchain.step.StepRunner
    public Optional<String> getImplicitParameter() {
        return Optional.of("value");
    }
}
