package dev.langchain4j.store.embedding.neo4j;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Stream;
import org.neo4j.driver.AuthTokens;
import org.neo4j.driver.Driver;
import org.neo4j.driver.GraphDatabase;
import org.neo4j.driver.Record;
import org.neo4j.driver.Result;
import org.neo4j.driver.Session;
import org.neo4j.driver.SessionConfig;
import org.neo4j.driver.Value;
import org.neo4j.driver.Values;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:dev/langchain4j/store/embedding/neo4j/Neo4jEmbeddingStore.class */
public class Neo4jEmbeddingStore implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(Neo4jEmbeddingStore.class);
    private final Driver driver;
    private final SessionConfig config;
    private final int dimension;
    private final long awaitIndexTimeout;
    private final String indexName;
    private final String metadataPrefix;
    private final String embeddingProperty;
    private final String idProperty;
    private final String sanitizedEmbeddingProperty;
    private final String sanitizedIdProperty;
    private final String sanitizedText;
    private final String label;
    private final String sanitizedLabel;
    private final String textProperty;
    private final String databaseName;
    private final String retrievalQuery;
    private final Set<String> notMetaKeys;

    /* loaded from: input_file:dev/langchain4j/store/embedding/neo4j/Neo4jEmbeddingStore$Neo4jEmbeddingStoreBuilder.class */
    public static class Neo4jEmbeddingStoreBuilder {
        private SessionConfig config;
        private Driver driver;
        private int dimension;
        private String label;
        private String embeddingProperty;
        private String idProperty;
        private String metadataPrefix;
        private String textProperty;
        private String indexName;
        private String databaseName;
        private String retrievalQuery;
        private long awaitIndexTimeout;

        public Neo4jEmbeddingStoreBuilder withBasicAuth(String str, String str2, String str3) {
            return driver(GraphDatabase.driver(str, AuthTokens.basic(str2, str3)));
        }

        Neo4jEmbeddingStoreBuilder() {
        }

        public Neo4jEmbeddingStoreBuilder config(SessionConfig sessionConfig) {
            this.config = sessionConfig;
            return this;
        }

        public Neo4jEmbeddingStoreBuilder driver(Driver driver) {
            this.driver = driver;
            return this;
        }

        public Neo4jEmbeddingStoreBuilder dimension(int i) {
            this.dimension = i;
            return this;
        }

        public Neo4jEmbeddingStoreBuilder label(String str) {
            this.label = str;
            return this;
        }

        public Neo4jEmbeddingStoreBuilder embeddingProperty(String str) {
            this.embeddingProperty = str;
            return this;
        }

        public Neo4jEmbeddingStoreBuilder idProperty(String str) {
            this.idProperty = str;
            return this;
        }

        public Neo4jEmbeddingStoreBuilder metadataPrefix(String str) {
            this.metadataPrefix = str;
            return this;
        }

        public Neo4jEmbeddingStoreBuilder textProperty(String str) {
            this.textProperty = str;
            return this;
        }

        public Neo4jEmbeddingStoreBuilder indexName(String str) {
            this.indexName = str;
            return this;
        }

        public Neo4jEmbeddingStoreBuilder databaseName(String str) {
            this.databaseName = str;
            return this;
        }

        public Neo4jEmbeddingStoreBuilder retrievalQuery(String str) {
            this.retrievalQuery = str;
            return this;
        }

        public Neo4jEmbeddingStoreBuilder awaitIndexTimeout(long j) {
            this.awaitIndexTimeout = j;
            return this;
        }

        public Neo4jEmbeddingStore build() {
            return new Neo4jEmbeddingStore(this.config, this.driver, this.dimension, this.label, this.embeddingProperty, this.idProperty, this.metadataPrefix, this.textProperty, this.indexName, this.databaseName, this.retrievalQuery, this.awaitIndexTimeout);
        }

        public String toString() {
            return "Neo4jEmbeddingStore.Neo4jEmbeddingStoreBuilder(config=" + String.valueOf(this.config) + ", driver=" + String.valueOf(this.driver) + ", dimension=" + this.dimension + ", label=" + this.label + ", embeddingProperty=" + this.embeddingProperty + ", idProperty=" + this.idProperty + ", metadataPrefix=" + this.metadataPrefix + ", textProperty=" + this.textProperty + ", indexName=" + this.indexName + ", databaseName=" + this.databaseName + ", retrievalQuery=" + this.retrievalQuery + ", awaitIndexTimeout=" + this.awaitIndexTimeout + ")";
        }
    }

    public Neo4jEmbeddingStore(SessionConfig sessionConfig, Driver driver, int i, String str, String str2, String str3, String str4, String str5, String str6, String str7, String str8, long j) {
        this.driver = (Driver) ValidationUtils.ensureNotNull(driver, "driver");
        this.dimension = ValidationUtils.ensureBetween(Integer.valueOf(i), 0, 4096, "dimension");
        this.databaseName = (String) Utils.getOrDefault(str7, Neo4jEmbeddingUtils.DEFAULT_DATABASE_NAME);
        this.config = (SessionConfig) Utils.getOrDefault(sessionConfig, SessionConfig.forDatabase(this.databaseName));
        this.label = (String) Utils.getOrDefault(str, Neo4jEmbeddingUtils.DEFAULT_LABEL);
        this.embeddingProperty = (String) Utils.getOrDefault(str2, Neo4jEmbeddingUtils.DEFAULT_EMBEDDING_PROP);
        this.idProperty = (String) Utils.getOrDefault(str3, Neo4jEmbeddingUtils.DEFAULT_ID_PROP);
        this.indexName = (String) Utils.getOrDefault(str6, Neo4jEmbeddingUtils.DEFAULT_IDX_NAME);
        this.metadataPrefix = (String) Utils.getOrDefault(str4, "");
        this.textProperty = (String) Utils.getOrDefault(str5, Neo4jEmbeddingUtils.DEFAULT_TEXT_PROP);
        this.awaitIndexTimeout = ((Long) Utils.getOrDefault(Long.valueOf(j), 60L)).longValue();
        this.sanitizedLabel = Neo4jEmbeddingUtils.sanitizeOrThrows(this.label, "label");
        this.sanitizedEmbeddingProperty = Neo4jEmbeddingUtils.sanitizeOrThrows(this.embeddingProperty, "embeddingProperty");
        this.sanitizedIdProperty = Neo4jEmbeddingUtils.sanitizeOrThrows(this.idProperty, "idProperty");
        this.sanitizedText = Neo4jEmbeddingUtils.sanitizeOrThrows(this.textProperty, "textProperty");
        this.retrievalQuery = (String) Utils.getOrDefault(str8, String.format("RETURN properties(node) AS metadata, node.%1$s AS %1$s, node.%2$s AS %2$s, node.%3$s AS %3$s, score", this.sanitizedIdProperty, this.sanitizedText, this.sanitizedEmbeddingProperty));
        this.notMetaKeys = new HashSet(Arrays.asList(this.idProperty, this.embeddingProperty, this.textProperty));
        createSchema();
    }

    public String add(Embedding embedding) {
        String randomUUID = Utils.randomUUID();
        add(randomUUID, embedding);
        return randomUUID;
    }

    public void add(String str, Embedding embedding) {
        addInternal(str, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, textSegment);
        return randomUUID;
    }

    public List<String> addAll(List<Embedding> list) {
        return addAll(list, null);
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
        Value value = Values.value(embeddingSearchRequest.queryEmbedding().vector());
        Session session = session();
        try {
            EmbeddingSearchResult<TextSegment> embeddingSearchResult = new EmbeddingSearchResult<>(session.run("CALL db.index.vector.queryNodes($indexName, $maxResults, $embeddingValue)\nYIELD node, score\nWHERE score >= $minScore\n" + this.retrievalQuery, Map.of("indexName", this.indexName, "embeddingValue", value, "minScore", Double.valueOf(embeddingSearchRequest.minScore()), "maxResults", Integer.valueOf(embeddingSearchRequest.maxResults()))).list(record -> {
                return Neo4jEmbeddingUtils.toEmbeddingMatch(this, record);
            }));
            if (session != null) {
                session.close();
            }
            return embeddingSearchResult;
        } catch (Throwable th) {
            if (session != null) {
                try {
                    session.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private void addInternal(String str, Embedding embedding, TextSegment textSegment) {
        addAll(Collections.singletonList(str), Collections.singletonList(embedding), textSegment == null ? null : Collections.singletonList(textSegment));
    }

    public void addAll(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (Utils.isNullOrEmpty(list) || Utils.isNullOrEmpty(list2)) {
            log.info("[do not add empty embeddings to neo4j]");
            return;
        }
        ValidationUtils.ensureTrue(list.size() == list2.size(), "ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue(list3 == null || list2.size() == list3.size(), "embeddings size is not equal to embedded size");
        bulk(list, list2, list3);
    }

    private void bulk(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        Stream<List<Map<String, Object>>> rowsBatched = Neo4jEmbeddingUtils.getRowsBatched(this, list, list2, list3);
        Session session = session();
        try {
            rowsBatched.forEach(list4 -> {
                String formatted = "UNWIND $rows AS row\nMERGE (u:%1$s {%2$s: row.%2$s})\nSET u += row.%3$s\nWITH row, u\nCALL db.create.setNodeVectorProperty(u, $embeddingProperty, row.%4$s)\nRETURN count(*)".formatted(this.sanitizedLabel, this.sanitizedIdProperty, Neo4jEmbeddingUtils.PROPS, Neo4jEmbeddingUtils.EMBEDDINGS_ROW_KEY);
                Map of = Map.of("rows", list4, "embeddingProperty", this.embeddingProperty);
                session.executeWrite(transactionContext -> {
                    return transactionContext.run(formatted, of).consume();
                });
            });
            if (session != null) {
                session.close();
            }
        } catch (Throwable th) {
            if (session != null) {
                try {
                    session.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private void createSchema() {
        if (!indexExists()) {
            createIndex();
        }
        createUniqueConstraint();
    }

    private void createUniqueConstraint() {
        Session session = session();
        try {
            session.run(String.format("CREATE CONSTRAINT IF NOT EXISTS FOR (n:%s) REQUIRE n.%s IS UNIQUE", this.sanitizedLabel, this.sanitizedIdProperty));
            if (session != null) {
                session.close();
            }
        } catch (Throwable th) {
            if (session != null) {
                try {
                    session.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private boolean indexExists() {
        Session session = session();
        try {
            Result run = session.run("SHOW INDEX WHERE type = 'VECTOR' AND name = $name", Map.of("name", this.indexName));
            if (!run.hasNext()) {
                if (session != null) {
                    session.close();
                }
                return false;
            }
            Record single = run.single();
            List asList = single.get("labelsOrTypes").asList((v0) -> {
                return v0.asString();
            });
            List asList2 = single.get("properties").asList();
            if ((asList.equals(Collections.singletonList(this.label)) && asList2.equals(Collections.singletonList(this.embeddingProperty))) ? false : true) {
                throw new RuntimeException(String.format("It's not possible to create an index for the label `%s` and the property `%s`,\nas there is another index with name `%s` with different labels: `%s` and properties `%s`.\nPlease provide another indexName to create the vector index, or delete the existing one", this.label, this.embeddingProperty, this.indexName, asList, asList2));
            }
            if (session != null) {
                session.close();
            }
            return true;
        } catch (Throwable th) {
            if (session != null) {
                try {
                    session.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private void createIndex() {
        Map of = Map.of("indexName", this.indexName, "label", this.label, "embeddingProperty", this.embeddingProperty, "dimension", Integer.valueOf(this.dimension));
        Session session = session();
        try {
            session.run("CALL db.index.vector.createNodeIndex($indexName, $label, $embeddingProperty, $dimension, 'cosine')", of);
            session.run("CALL db.awaitIndexes($timeout)", Map.of("timeout", Long.valueOf(this.awaitIndexTimeout))).consume();
            if (session != null) {
                session.close();
            }
        } catch (Throwable th) {
            if (session != null) {
                try {
                    session.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private Session session() {
        return this.driver.session(this.config);
    }

    public static Neo4jEmbeddingStoreBuilder builder() {
        return new Neo4jEmbeddingStoreBuilder();
    }

    public Driver getDriver() {
        return this.driver;
    }

    public SessionConfig getConfig() {
        return this.config;
    }

    public int getDimension() {
        return this.dimension;
    }

    public long getAwaitIndexTimeout() {
        return this.awaitIndexTimeout;
    }

    public String getIndexName() {
        return this.indexName;
    }

    public String getMetadataPrefix() {
        return this.metadataPrefix;
    }

    public String getEmbeddingProperty() {
        return this.embeddingProperty;
    }

    public String getIdProperty() {
        return this.idProperty;
    }

    public String getSanitizedEmbeddingProperty() {
        return this.sanitizedEmbeddingProperty;
    }

    public String getSanitizedIdProperty() {
        return this.sanitizedIdProperty;
    }

    public String getSanitizedText() {
        return this.sanitizedText;
    }

    public String getLabel() {
        return this.label;
    }

    public String getSanitizedLabel() {
        return this.sanitizedLabel;
    }

    public String getTextProperty() {
        return this.textProperty;
    }

    public String getDatabaseName() {
        return this.databaseName;
    }

    public String getRetrievalQuery() {
        return this.retrievalQuery;
    }

    public Set<String> getNotMetaKeys() {
        return this.notMetaKeys;
    }
}
