package dev.langchain4j.store.embedding.cassandra;

import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.CqlSession;
import com.datastax.oss.driver.api.core.CqlSessionBuilder;
import com.dtsx.astra.sdk.cassio.AnnQuery;
import com.dtsx.astra.sdk.cassio.AnnResult;
import com.dtsx.astra.sdk.cassio.CassIO;
import com.dtsx.astra.sdk.cassio.CassandraSimilarityMetric;
import com.dtsx.astra.sdk.cassio.MetadataVectorRecord;
import com.dtsx.astra.sdk.cassio.MetadataVectorTable;
import com.dtsx.astra.sdk.utils.AstraEnvironment;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import java.net.InetSocketAddress;
import java.util.List;
import java.util.Objects;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.NonNull;

/* loaded from: input_file:dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStore.class */
public class CassandraEmbeddingStore implements EmbeddingStore<TextSegment> {
    protected MetadataVectorTable embeddingTable;
    protected CqlSession cassandraSession;

    /* loaded from: input_file:dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStore$Builder.class */
    public static class Builder {
        public static Integer DEFAULT_PORT = 9042;
        private List<String> contactPoints;
        private String localDataCenter;
        private String userName;
        private String password;
        protected String keyspace;
        protected String table;
        protected Integer dimension;
        private Integer port = DEFAULT_PORT;
        protected CassandraSimilarityMetric metric = CassandraSimilarityMetric.COSINE;

        public Builder contactPoints(List<String> list) {
            this.contactPoints = list;
            return this;
        }

        public Builder localDataCenter(String str) {
            this.localDataCenter = str;
            return this;
        }

        public Builder port(Integer num) {
            this.port = num;
            return this;
        }

        public Builder userName(String str) {
            this.userName = str;
            return this;
        }

        public Builder password(String str) {
            this.password = str;
            return this;
        }

        public Builder keyspace(String str) {
            this.keyspace = str;
            return this;
        }

        public Builder table(String str) {
            this.table = str;
            return this;
        }

        public Builder dimension(Integer num) {
            this.dimension = num;
            return this;
        }

        public Builder metric(CassandraSimilarityMetric cassandraSimilarityMetric) {
            this.metric = cassandraSimilarityMetric;
            return this;
        }

        public CassandraEmbeddingStore build() {
            CqlSessionBuilder withLocalDatacenter = CqlSession.builder().withKeyspace(this.keyspace).withLocalDatacenter(this.localDataCenter);
            if (this.userName != null && this.password != null) {
                withLocalDatacenter.withAuthCredentials(this.userName, this.password);
            }
            this.contactPoints.forEach(str -> {
                withLocalDatacenter.addContactPoint(new InetSocketAddress(str, this.port.intValue()));
            });
            return new CassandraEmbeddingStore((CqlSession) withLocalDatacenter.build(), this.table, this.dimension.intValue(), this.metric);
        }
    }

    /* loaded from: input_file:dev/langchain4j/store/embedding/cassandra/CassandraEmbeddingStore$BuilderAstra.class */
    public static class BuilderAstra {
        private String token;
        private UUID dbId;
        private String tableName;
        private int dimension;
        private String keyspaceName = "default_keyspace";
        private String dbRegion = "us-east1";
        private CassandraSimilarityMetric metric = CassandraSimilarityMetric.COSINE;
        private AstraEnvironment env = AstraEnvironment.PROD;

        public BuilderAstra token(String str) {
            this.token = str;
            return this;
        }

        public BuilderAstra env(AstraEnvironment astraEnvironment) {
            this.env = astraEnvironment;
            return this;
        }

        public BuilderAstra databaseId(UUID uuid) {
            this.dbId = uuid;
            return this;
        }

        public BuilderAstra databaseRegion(String str) {
            this.dbRegion = str;
            return this;
        }

        public BuilderAstra keyspace(String str) {
            this.keyspaceName = str;
            return this;
        }

        public BuilderAstra table(String str) {
            this.tableName = str;
            return this;
        }

        public BuilderAstra dimension(int i) {
            this.dimension = i;
            return this;
        }

        public BuilderAstra metric(CassandraSimilarityMetric cassandraSimilarityMetric) {
            this.metric = cassandraSimilarityMetric;
            return this;
        }

        public CassandraEmbeddingStore build() {
            return new CassandraEmbeddingStore(CassIO.init(this.token, this.dbId, this.dbRegion, this.keyspaceName, this.env), this.tableName, this.dimension, this.metric);
        }
    }

    public CassandraEmbeddingStore(CqlSession cqlSession, String str, int i) {
        this(cqlSession, str, i, CassandraSimilarityMetric.COSINE);
    }

    public CassandraEmbeddingStore(CqlSession cqlSession, String str, int i, CassandraSimilarityMetric cassandraSimilarityMetric) {
        this.cassandraSession = cqlSession;
        this.embeddingTable = new MetadataVectorTable(cqlSession, ((CqlIdentifier) cqlSession.getKeyspace().get()).asInternal(), str, i, cassandraSimilarityMetric);
        this.embeddingTable.create();
    }

    public void delete() {
        this.embeddingTable.delete();
    }

    public void clear() {
        this.embeddingTable.clear();
    }

    public CqlSession getCassandraSession() {
        return this.cassandraSession;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static BuilderAstra builderAstra() {
        return new BuilderAstra();
    }

    public String add(@NonNull Embedding embedding) {
        if (embedding == null) {
            throw new NullPointerException("embedding is marked non-null but is null");
        }
        return add(embedding, (TextSegment) null);
    }

    public String add(@NonNull Embedding embedding, TextSegment textSegment) {
        if (embedding == null) {
            throw new NullPointerException("embedding is marked non-null but is null");
        }
        return addInternal(Utils.randomUUID(), embedding, textSegment);
    }

    private String addInternal(@NonNull String str, @NonNull Embedding embedding, TextSegment textSegment) {
        if (str == null) {
            throw new NullPointerException("id is marked non-null but is null");
        }
        if (embedding == null) {
            throw new NullPointerException("embedding is marked non-null but is null");
        }
        MetadataVectorRecord metadataVectorRecord = new MetadataVectorRecord(str, embedding.vectorAsList());
        if (textSegment != null) {
            metadataVectorRecord.setBody(textSegment.text());
            metadataVectorRecord.setMetadata(Utils.toStringValueMap(textSegment.metadata().toMap()));
        }
        this.embeddingTable.put(metadataVectorRecord);
        return metadataVectorRecord.getRowId();
    }

    public void add(@NonNull String str, @NonNull Embedding embedding) {
        if (str == null) {
            throw new NullPointerException("rowId is marked non-null but is null");
        }
        if (embedding == null) {
            throw new NullPointerException("embedding is marked non-null but is null");
        }
        this.embeddingTable.put(new MetadataVectorRecord(str, embedding.vectorAsList()));
    }

    public List<String> addAll(List<Embedding> list) {
        Stream map = list.stream().map((v0) -> {
            return v0.vectorAsList();
        }).map(MetadataVectorRecord::new);
        MetadataVectorTable metadataVectorTable = this.embeddingTable;
        Objects.requireNonNull(metadataVectorTable);
        return (List) map.peek((v1) -> {
            r1.putAsync(v1);
        }).map((v0) -> {
            return v0.getRowId();
        }).collect(Collectors.toList());
    }

    public void addAll(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (list2 == null || list3 == null || list2.size() != list3.size()) {
            throw new IllegalArgumentException("embeddingList and textSegmentList must not be null and have the same size");
        }
        for (int i = 0; i < list2.size(); i++) {
            addInternal(list.get(i), list2.get(i), list3.get(i));
        }
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
        if (embeddingSearchRequest.filter() != null) {
            throw new UnsupportedOperationException("EmbeddingSearchRequest.Filter is not supported yet.");
        }
        return new EmbeddingSearchResult<>(findRelevant(embeddingSearchRequest.queryEmbedding(), embeddingSearchRequest.maxResults(), embeddingSearchRequest.minScore()));
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int i, double d) {
        return (List) this.embeddingTable.similaritySearch(AnnQuery.builder().embeddings(embedding.vectorAsList()).recordCount(ValidationUtils.ensureGreaterThanZero(Integer.valueOf(i), "maxResults")).threshold(CosineSimilarity.fromRelevanceScore(ValidationUtils.ensureBetween(Double.valueOf(d), 0.0d, 1.0d, "minScore"))).metric(CassandraSimilarityMetric.COSINE).build()).stream().map(CassandraEmbeddingStore::mapSearchResult).collect(Collectors.toList());
    }

    private static EmbeddingMatch<TextSegment> mapSearchResult(AnnResult<MetadataVectorRecord> annResult) {
        TextSegment textSegment = null;
        String body = ((MetadataVectorRecord) annResult.getEmbedded()).getBody();
        if (body != null && !body.isEmpty() && ((MetadataVectorRecord) annResult.getEmbedded()).getMetadata() != null) {
            textSegment = TextSegment.from(((MetadataVectorRecord) annResult.getEmbedded()).getBody(), new Metadata(((MetadataVectorRecord) annResult.getEmbedded()).getMetadata()));
        }
        return new EmbeddingMatch<>(Double.valueOf(RelevanceScore.fromCosineSimilarity(annResult.getSimilarity())), ((MetadataVectorRecord) annResult.getEmbedded()).getRowId(), Embedding.from(((MetadataVectorRecord) annResult.getEmbedded()).getVector()), textSegment);
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int i, double d, Metadata metadata) {
        AnnQuery.AnnQueryBuilder threshold = AnnQuery.builder().embeddings(embedding.vectorAsList()).metric(CassandraSimilarityMetric.COSINE).recordCount(ValidationUtils.ensureGreaterThanZero(Integer.valueOf(i), "maxResults")).threshold(CosineSimilarity.fromRelevanceScore(ValidationUtils.ensureBetween(Double.valueOf(d), 0.0d, 1.0d, "minScore")));
        if (metadata != null) {
            threshold.metaData(Utils.toStringValueMap(metadata.toMap()));
        }
        return (List) this.embeddingTable.similaritySearch(threshold.build()).stream().map(CassandraEmbeddingStore::mapSearchResult).collect(Collectors.toList());
    }
}
