package dev.langchain4j.store.embedding.infinispan;

import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.infinispan.client.hotrod.RemoteCache;
import org.infinispan.client.hotrod.RemoteCacheManager;
import org.infinispan.client.hotrod.configuration.ConfigurationBuilder;
import org.infinispan.commons.marshall.ProtoStreamMarshaller;
import org.infinispan.protostream.FileDescriptorSource;
import org.infinispan.protostream.SerializationContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:dev/langchain4j/store/embedding/infinispan/InfinispanEmbeddingStore.class */
public class InfinispanEmbeddingStore implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(InfinispanEmbeddingStore.class);
    private final RemoteCache<String, LangChainInfinispanItem> remoteCache;
    private final LangChainItemMarshaller itemMarshaller;
    private static final String DEFAULT_CACHE_CONFIG = "<distributed-cache name=\"CACHE_NAME\">\n<indexing storage=\"local-heap\">\n<indexed-entities>\n<indexed-entity>LANGCHAINITEM</indexed-entity>\n</indexed-entities>\n</indexing>\n</distributed-cache>";
    private static final String PROTO = "syntax = \"proto2\";\n\n/**\n * @Indexed\n */\nmessage LangChainItemDIMENSION {\n   \n   /**\n    * @Keyword\n    */\n   optional string id = 1;\n   \n   /**\n    * @Vector(dimension=DIMENSION, similarity=COSINE)\n    */\n   repeated float embedding = 2;\n   \n   optional string text = 3;\n   \n   repeated string metadataKeys = 4;\n   \n   repeated string metadataValues = 5;\n}\n";

    /* loaded from: input_file:dev/langchain4j/store/embedding/infinispan/InfinispanEmbeddingStore$Builder.class */
    public static class Builder {
        private ConfigurationBuilder builder;
        private String name;
        private Integer dimension;

        public Builder cacheName(String str) {
            this.name = str;
            return this;
        }

        public Builder dimension(Integer num) {
            this.dimension = num;
            return this;
        }

        public Builder infinispanConfigBuilder(ConfigurationBuilder configurationBuilder) {
            this.builder = configurationBuilder;
            return this;
        }

        public InfinispanEmbeddingStore build() {
            return new InfinispanEmbeddingStore(this.builder, this.name, this.dimension);
        }
    }

    public InfinispanEmbeddingStore(ConfigurationBuilder configurationBuilder, String str, Integer num) {
        ValidationUtils.ensureNotNull(configurationBuilder, "builder");
        ValidationUtils.ensureNotBlank(str, "name");
        ValidationUtils.ensureNotNull(num, "dimension");
        this.itemMarshaller = new LangChainItemMarshaller(num);
        configurationBuilder.remoteCache(str).configuration(DEFAULT_CACHE_CONFIG.replace("CACHE_NAME", str).replace("LANGCHAINITEM", this.itemMarshaller.getTypeName()));
        ProtoStreamMarshaller protoStreamMarshaller = new ProtoStreamMarshaller();
        SerializationContext serializationContext = protoStreamMarshaller.getSerializationContext();
        FileDescriptorSource fileDescriptorSource = new FileDescriptorSource();
        String str2 = "langchain_dimension_" + num + ".proto";
        String replace = PROTO.replace("DIMENSION", num.toString());
        fileDescriptorSource.addProtoFile(str2, replace);
        serializationContext.registerProtoFiles(fileDescriptorSource);
        serializationContext.registerMarshaller(this.itemMarshaller);
        configurationBuilder.marshaller(protoStreamMarshaller);
        RemoteCacheManager remoteCacheManager = new RemoteCacheManager(configurationBuilder.build());
        remoteCacheManager.getCache("___protobuf_metadata").put(str2, replace);
        this.remoteCache = remoteCacheManager.getCache(str);
    }

    public String add(Embedding embedding) {
        String randomUUID = Utils.randomUUID();
        add(randomUUID, embedding);
        return randomUUID;
    }

    public void add(String str, Embedding embedding) {
        addInternal(str, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, textSegment);
        return randomUUID;
    }

    public List<String> addAll(List<Embedding> list) {
        List<String> list2 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAllInternal(list2, list, null);
        return list2;
    }

    public List<String> addAll(List<Embedding> list, List<TextSegment> list2) {
        List<String> list3 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAllInternal(list3, list, list2);
        return list3;
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int i, double d) {
        return (List) this.remoteCache.query("select i, score(i) from " + this.itemMarshaller.getTypeName() + " i where i.embedding <-> " + Arrays.toString(embedding.vector()) + "~3").maxResults(i).list().stream().map(objArr -> {
            LangChainInfinispanItem langChainInfinispanItem = (LangChainInfinispanItem) objArr[0];
            Float f = (Float) objArr[1];
            if (f.doubleValue() < d) {
                return null;
            }
            TextSegment textSegment = null;
            if (langChainInfinispanItem.getText() != null) {
                HashMap hashMap = new HashMap();
                List<String> metadataKeys = langChainInfinispanItem.getMetadataKeys();
                List<String> metadataValues = langChainInfinispanItem.getMetadataValues();
                for (int i2 = 0; i2 < metadataKeys.size(); i2++) {
                    hashMap.put(metadataKeys.get(i2), metadataValues.get(i2));
                }
                textSegment = new TextSegment(langChainInfinispanItem.getText(), new Metadata(hashMap));
            }
            return new EmbeddingMatch(Double.valueOf(f.doubleValue()), langChainInfinispanItem.getId(), new Embedding(langChainInfinispanItem.getEmbedding()), textSegment);
        }).filter((v0) -> {
            return Objects.nonNull(v0);
        }).collect(Collectors.toList());
    }

    private void addInternal(String str, Embedding embedding, TextSegment textSegment) {
        addAllInternal(Collections.singletonList(str), Collections.singletonList(embedding), textSegment == null ? null : Collections.singletonList(textSegment));
    }

    private void addAllInternal(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (Utils.isNullOrEmpty(list) || Utils.isNullOrEmpty(list2)) {
            log.info("do not add empty embeddings to infinispan");
            return;
        }
        ValidationUtils.ensureTrue(list.size() == list2.size(), "ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue(list3 == null || list2.size() == list3.size(), "embeddings size is not equal to embedded size");
        int size = list.size();
        HashMap hashMap = new HashMap(size);
        for (int i = 0; i < size; i++) {
            String str = list.get(i);
            Embedding embedding = list2.get(i);
            TextSegment textSegment = list3 == null ? null : list3.get(i);
            if (textSegment != null) {
                Map asMap = textSegment.metadata().asMap();
                ArrayList arrayList = new ArrayList(asMap.size());
                ArrayList arrayList2 = new ArrayList(asMap.size());
                asMap.entrySet().forEach(entry -> {
                    arrayList.add((String) entry.getKey());
                    arrayList2.add((String) entry.getValue());
                });
                hashMap.put(str, new LangChainInfinispanItem(str, embedding.vector(), textSegment.text(), arrayList, arrayList2));
            } else {
                hashMap.put(str, new LangChainInfinispanItem(str, embedding.vector(), null, null, null));
            }
        }
        this.remoteCache.putAll(hashMap);
    }

    public static Builder builder() {
        return new Builder();
    }

    public RemoteCache<String, LangChainInfinispanItem> remoteCache() {
        return this.remoteCache;
    }

    public void clearCache() {
        this.remoteCache.clear();
    }
}
