package dev.langchain4j.store.embedding.coherence;

import com.oracle.coherence.ai.DocumentChunk;
import com.oracle.coherence.ai.Float32Vector;
import com.oracle.coherence.ai.QueryResult;
import com.oracle.coherence.ai.Vector;
import com.oracle.coherence.ai.VectorIndexExtractor;
import com.oracle.coherence.ai.search.SimilaritySearch;
import com.oracle.coherence.common.base.Logger;
import com.tangosol.internal.util.processor.CacheProcessors;
import com.tangosol.net.Coherence;
import com.tangosol.net.NamedMap;
import com.tangosol.net.Session;
import com.tangosol.util.UUID;
import com.tangosol.util.ValueExtractor;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import dev.langchain4j.store.embedding.filter.Filter;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:dev/langchain4j/store/embedding/coherence/CoherenceEmbeddingStore.class */
public class CoherenceEmbeddingStore implements EmbeddingStore<TextSegment> {
    private static final ValueExtractor<DocumentChunk, Vector<float[]>> EXTRACTOR = ValueExtractor.of((v0) -> {
        return v0.vector();
    });
    public static final String DEFAULT_MAP_NAME = "documentChunks";
    protected final NamedMap<DocumentChunk.Id, DocumentChunk> documentChunks;
    protected final boolean normalizeEmbeddings;

    /* loaded from: input_file:dev/langchain4j/store/embedding/coherence/CoherenceEmbeddingStore$Builder.class */
    public static class Builder {
        private String sessionName;
        private Session session;
        private VectorIndexExtractor<DocumentChunk, Vector<?>> extractor;
        private String name = CoherenceEmbeddingStore.DEFAULT_MAP_NAME;
        private boolean normalizeEmbeddings = false;

        protected Builder() {
        }

        public Builder name(String str) {
            this.name = (str == null || str.isEmpty()) ? CoherenceEmbeddingStore.DEFAULT_MAP_NAME : str;
            return this;
        }

        public Builder session(String str) {
            this.sessionName = str;
            this.session = null;
            return this;
        }

        public Builder session(Session session) {
            this.session = session;
            this.sessionName = null;
            return this;
        }

        public Builder index(VectorIndexExtractor<DocumentChunk, Vector<?>> vectorIndexExtractor) {
            this.extractor = vectorIndexExtractor;
            return this;
        }

        public Builder normalizeEmbeddings(boolean z) {
            this.normalizeEmbeddings = z;
            return this;
        }

        public CoherenceEmbeddingStore build() {
            Session session = this.session;
            if (session == null) {
                session = this.sessionName != null ? Coherence.getInstance().getSession(this.sessionName) : Coherence.getInstance().getSession();
            }
            NamedMap map = session.getMap(this.name, new NamedMap.Option[0]);
            if (this.extractor != null) {
                map.addIndex(this.extractor);
            }
            return new CoherenceEmbeddingStore(map, this.normalizeEmbeddings);
        }
    }

    protected CoherenceEmbeddingStore(NamedMap<DocumentChunk.Id, DocumentChunk> namedMap, boolean z) {
        this.documentChunks = namedMap;
        this.normalizeEmbeddings = z;
    }

    public String add(Embedding embedding) {
        DocumentChunk.Id id = DocumentChunk.id(new UUID().toString(), 0);
        addInternal(id, embedding, null);
        return id.toString();
    }

    public void add(String str, Embedding embedding) {
        addInternal(DocumentChunk.Id.parse(str), embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        DocumentChunk.Id id = DocumentChunk.id(new UUID().toString(), 0);
        addInternal(id, embedding, textSegment);
        return id.toString();
    }

    public List<String> addAll(List<Embedding> list) {
        return addAll(list, null);
    }

    public void remove(String str) {
        if (str == null || str.isBlank()) {
            throw new IllegalArgumentException("id cannot be null or blank");
        }
        this.documentChunks.remove(DocumentChunk.Id.parse(str));
    }

    public void removeAll(Collection<String> collection) {
        if (collection == null || collection.isEmpty()) {
            throw new IllegalArgumentException("ids cannot be null or empty");
        }
        this.documentChunks.keySet().removeAll((Set) collection.stream().map(DocumentChunk.Id::parse).collect(Collectors.toSet()));
    }

    public void removeAll(Filter filter) {
        if (filter == null) {
            throw new IllegalArgumentException("filter cannot be null");
        }
        this.documentChunks.invokeAll(CoherenceMetadataFilterMapper.map(filter), CacheProcessors.removeBlind());
    }

    public void removeAll() {
        this.documentChunks.truncate();
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
        Embedding queryEmbedding = embeddingSearchRequest.queryEmbedding();
        if (this.normalizeEmbeddings) {
            queryEmbedding.normalize();
        }
        double minScore = embeddingSearchRequest.minScore();
        boolean z = minScore != 0.0d;
        com.tangosol.util.Filter<DocumentChunk> map = CoherenceMetadataFilterMapper.map(embeddingSearchRequest.filter());
        SimilaritySearch similaritySearch = new SimilaritySearch(EXTRACTOR, new Float32Vector(queryEmbedding.vector()), embeddingSearchRequest.maxResults());
        if (map != null) {
            similaritySearch = similaritySearch.filter(map);
        }
        List<QueryResult> list = (List) this.documentChunks.aggregate(similaritySearch);
        ArrayList arrayList = new ArrayList();
        for (QueryResult queryResult : list) {
            double fromCosineSimilarity = RelevanceScore.fromCosineSimilarity(1.0d - queryResult.getDistance());
            if (!z || fromCosineSimilarity >= minScore) {
                arrayList.add(embeddingMatch(fromCosineSimilarity, (DocumentChunk.Id) queryResult.getKey(), (DocumentChunk) queryResult.getValue()));
            }
        }
        return new EmbeddingSearchResult<>(arrayList);
    }

    private void addInternal(DocumentChunk.Id id, Embedding embedding, TextSegment textSegment) {
        HashMap hashMap = new HashMap();
        hashMap.put(id, createChunk(embedding, textSegment));
        this.documentChunks.putAll(hashMap);
    }

    public void addAll(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (Utils.isNullOrEmpty(list) || Utils.isNullOrEmpty(list2)) {
            Logger.info("Skipped adding empty embeddings");
            return;
        }
        boolean z = (list3 == null || list3.isEmpty()) ? false : true;
        ValidationUtils.ensureTrue(list.size() == list2.size(), "ids size is not equal to embeddings size");
        if (z) {
            ValidationUtils.ensureTrue(list2.size() == list3.size(), "embeddings size is not equal to embedded size");
        }
        HashMap hashMap = new HashMap();
        for (int i = 0; i < list2.size(); i++) {
            hashMap.put(new DocumentChunk.Id(list.get(i), 0), createChunk(list2.get(i), z ? list3.get(i) : null));
        }
        this.documentChunks.putAll(hashMap);
    }

    private EmbeddingMatch<TextSegment> embeddingMatch(double d, DocumentChunk.Id id, DocumentChunk documentChunk) {
        String id2 = id.toString();
        TextSegment textSegment = documentChunk.text() == null ? null : new TextSegment(documentChunk.text(), mapToMetadata(documentChunk.metadata()));
        Vector vector = documentChunk.vector();
        return new EmbeddingMatch<>(Double.valueOf(d), id2, vector == null ? null : new Embedding((float[]) vector.get()), textSegment);
    }

    private static Metadata mapToMetadata(Map<String, Object> map) {
        map.entrySet().removeIf(entry -> {
            return entry.getValue() == null;
        });
        return Metadata.from(map);
    }

    private DocumentChunk createChunk(Embedding embedding, TextSegment textSegment) {
        DocumentChunk documentChunk = new DocumentChunk(textSegment == null ? null : textSegment.text(), textSegment == null ? Collections.emptyMap() : textSegment.metadata().toMap());
        if (this.normalizeEmbeddings) {
            embedding.normalize();
        }
        documentChunk.setVector(new Float32Vector(embedding.vector()));
        return documentChunk;
    }

    public static CoherenceEmbeddingStore create() {
        return builder().build();
    }

    public static CoherenceEmbeddingStore create(String str) {
        return builder().name(str).build();
    }

    public static CoherenceEmbeddingStore create(NamedMap<DocumentChunk.Id, DocumentChunk> namedMap) {
        return new CoherenceEmbeddingStore(namedMap, false);
    }

    public static Builder builder() {
        return new Builder();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -820387517:
                if (implMethodName.equals("vector")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("com/tangosol/util/ValueExtractor") && serializedLambda.getFunctionalInterfaceMethodName().equals("extract") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/oracle/coherence/ai/DocumentChunk") && serializedLambda.getImplMethodSignature().equals("()Lcom/oracle/coherence/ai/Vector;")) {
                    return (v0) -> {
                        return v0.vector();
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
