package dev.langchain4j.store.embedding.qdrant;

import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import io.qdrant.client.PointIdFactory;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.QdrantGrpcClient;
import io.qdrant.client.ValueFactory;
import io.qdrant.client.VectorsFactory;
import io.qdrant.client.WithPayloadSelectorFactory;
import io.qdrant.client.WithVectorsSelectorFactory;
import io.qdrant.client.grpc.JsonWithInt;
import io.qdrant.client.grpc.Points;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

/* loaded from: input_file:dev/langchain4j/store/embedding/qdrant/QdrantEmbeddingStore.class */
public class QdrantEmbeddingStore implements EmbeddingStore<TextSegment> {
    private final QdrantClient client;
    private final String payloadTextKey;
    private final String collectionName;

    /* loaded from: input_file:dev/langchain4j/store/embedding/qdrant/QdrantEmbeddingStore$Builder.class */
    public static class Builder {
        private String collectionName;
        private String host = "localhost";
        private int port = 6334;
        private boolean useTls = false;
        private String payloadTextKey = "text_segment";
        private String apiKey = null;
        private QdrantClient client = null;

        public Builder host(String str) {
            this.host = str;
            return this;
        }

        public Builder collectionName(String str) {
            this.collectionName = str;
            return this;
        }

        public Builder port(int i) {
            this.port = i;
            return this;
        }

        public Builder useTls(boolean z) {
            this.useTls = z;
            return this;
        }

        public Builder payloadTextKey(String str) {
            this.payloadTextKey = str;
            return this;
        }

        public Builder apiKey(String str) {
            this.apiKey = str;
            return this;
        }

        public Builder client(QdrantClient qdrantClient) {
            this.client = qdrantClient;
            return this;
        }

        public QdrantEmbeddingStore build() {
            Objects.requireNonNull(this.collectionName, "collectionName cannot be null");
            return this.client != null ? new QdrantEmbeddingStore(this.client, this.collectionName, this.payloadTextKey) : new QdrantEmbeddingStore(this.collectionName, this.host, this.port, this.useTls, this.payloadTextKey, this.apiKey);
        }
    }

    public QdrantEmbeddingStore(String str, String str2, int i, boolean z, String str3, @Nullable String str4) {
        QdrantGrpcClient.Builder newBuilder = QdrantGrpcClient.newBuilder(str2, i, z);
        if (str4 != null) {
            newBuilder.withApiKey(str4);
        }
        this.client = new QdrantClient(newBuilder.build());
        this.collectionName = str;
        this.payloadTextKey = str3;
    }

    public QdrantEmbeddingStore(QdrantClient qdrantClient, String str, String str2) {
        this.client = qdrantClient;
        this.collectionName = str;
        this.payloadTextKey = str2;
    }

    public String add(Embedding embedding) {
        String randomUUID = Utils.randomUUID();
        add(randomUUID, embedding);
        return randomUUID;
    }

    public void add(String str, Embedding embedding) {
        addInternal(str, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, textSegment);
        return randomUUID;
    }

    public List<String> addAll(List<Embedding> list) {
        List<String> list2 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAllInternal(list2, list, null);
        return list2;
    }

    public List<String> addAll(List<Embedding> list, List<TextSegment> list2) {
        List<String> list3 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAllInternal(list3, list, list2);
        return list3;
    }

    private void addInternal(String str, Embedding embedding, TextSegment textSegment) {
        addAllInternal(Collections.singletonList(str), Collections.singletonList(embedding), textSegment == null ? null : Collections.singletonList(textSegment));
    }

    private void addAllInternal(List<String> list, List<Embedding> list2, List<TextSegment> list3) throws RuntimeException {
        try {
            ArrayList arrayList = new ArrayList(list2.size());
            for (int i = 0; i < list2.size(); i++) {
                Points.PointStruct.Builder vectors = Points.PointStruct.newBuilder().setId(PointIdFactory.id(UUID.fromString(list.get(i)))).setVectors(VectorsFactory.vectors(list2.get(i).vector()));
                if (list3 != null) {
                    Map<String, JsonWithInt.Value> valueMap = ValueMapFactory.valueMap(list3.get(i).metadata().toMap());
                    valueMap.put(this.payloadTextKey, ValueFactory.value(list3.get(i).text()));
                    vectors.putAllPayload(valueMap);
                }
                arrayList.add(vectors.build());
            }
            this.client.upsertAsync(this.collectionName, arrayList).get();
        } catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException(e);
        }
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
        Points.SearchPoints.Builder limit = Points.SearchPoints.newBuilder().setCollectionName(this.collectionName).addAllVector(embeddingSearchRequest.queryEmbedding().vectorAsList()).setWithVectors(WithVectorsSelectorFactory.enable(true)).setWithPayload(WithPayloadSelectorFactory.enable(true)).setLimit(embeddingSearchRequest.maxResults());
        if (embeddingSearchRequest.filter() != null) {
            limit.setFilter(QdrantFilterConverter.convertExpression(embeddingSearchRequest.filter()));
        }
        try {
            List list = (List) this.client.searchAsync(limit.build()).get();
            if (list.isEmpty()) {
                return new EmbeddingSearchResult<>(Collections.emptyList());
            }
            List list2 = (List) list.stream().map(scoredPoint -> {
                return toEmbeddingMatch(scoredPoint, embeddingSearchRequest.queryEmbedding());
            }).filter(embeddingMatch -> {
                return embeddingMatch.score().doubleValue() >= embeddingSearchRequest.minScore();
            }).sorted(Comparator.comparingDouble((v0) -> {
                return v0.score();
            })).collect(Collectors.toList());
            Collections.reverse(list2);
            return new EmbeddingSearchResult<>(list2);
        } catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException(e);
        }
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int i, double d) {
        try {
            List list = (List) this.client.searchAsync(Points.SearchPoints.newBuilder().setCollectionName(this.collectionName).addAllVector(embedding.vectorAsList()).setWithVectors(WithVectorsSelectorFactory.enable(true)).setWithPayload(WithPayloadSelectorFactory.enable(true)).setLimit(i).build()).get();
            if (list.isEmpty()) {
                return Collections.emptyList();
            }
            List<EmbeddingMatch<TextSegment>> list2 = (List) list.stream().map(scoredPoint -> {
                return toEmbeddingMatch(scoredPoint, embedding);
            }).filter(embeddingMatch -> {
                return embeddingMatch.score().doubleValue() >= d;
            }).sorted(Comparator.comparingDouble((v0) -> {
                return v0.score();
            })).collect(Collectors.toList());
            Collections.reverse(list2);
            return list2;
        } catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException(e);
        }
    }

    public void clearStore() {
        try {
            this.client.deleteAsync(Points.DeletePoints.newBuilder().setCollectionName(this.collectionName).setPoints(Points.PointsSelector.newBuilder().setFilter(Points.Filter.newBuilder().build()).build()).build()).get();
        } catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException(e);
        }
    }

    public void close() {
        this.client.close();
    }

    private EmbeddingMatch<TextSegment> toEmbeddingMatch(Points.ScoredPoint scoredPoint, Embedding embedding) {
        Map payloadMap = scoredPoint.getPayloadMap();
        JsonWithInt.Value value = (JsonWithInt.Value) payloadMap.getOrDefault(this.payloadTextKey, null);
        Map map = (Map) payloadMap.entrySet().stream().filter(entry -> {
            return !((String) entry.getKey()).equals(this.payloadTextKey);
        }).collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry2 -> {
            return ObjectFactory.object((JsonWithInt.Value) entry2.getValue());
        }));
        Embedding from = Embedding.from(scoredPoint.getVectors().getVector().getDataList());
        return new EmbeddingMatch<>(Double.valueOf(RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(from, embedding))), scoredPoint.getId().getUuid(), from, value == null ? null : TextSegment.from(value.getStringValue(), new Metadata(map)));
    }

    public static Builder builder() {
        return new Builder();
    }
}
