package dev.langchain4j.store.embedding.vespa;

import ai.vespa.client.dsl.A;
import ai.vespa.client.dsl.Annotation;
import ai.vespa.client.dsl.NearestNeighbor;
import ai.vespa.client.dsl.Q;
import ai.vespa.feed.client.DocumentId;
import ai.vespa.feed.client.FeedClientBuilder;
import ai.vespa.feed.client.FeedException;
import ai.vespa.feed.client.JsonFeeder;
import ai.vespa.feed.client.Result;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Json;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URI;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import retrofit2.Response;

/* loaded from: input_file:dev/langchain4j/store/embedding/vespa/VespaEmbeddingStore.class */
public class VespaEmbeddingStore implements EmbeddingStore<TextSegment> {
    private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(5);
    private static final String DEFAULT_NAMESPACE = "namespace";
    private static final String DEFAULT_DOCUMENT_TYPE = "langchain4j";
    private static final boolean DEFAULT_AVOID_DUPS = true;
    private static final String FIELD_NAME_TEXT_SEGMENT = "text_segment";
    private static final String FIELD_NAME_VECTOR = "vector";
    private static final String FIELD_NAME_DOCUMENT_ID = "documentid";
    private static final String DEFAULT_RANK_PROFILE = "cosine_similarity";
    private static final int DEFAULT_TARGET_HITS = 10;
    private final String url;
    private final Path keyPath;
    private final Path certPath;
    private final Duration timeout;
    private final String namespace;
    private final String documentType;
    private final String rankProfile;
    private final int targetHits;
    private final boolean avoidDups;
    private VespaQueryApi queryApi;

    /* loaded from: input_file:dev/langchain4j/store/embedding/vespa/VespaEmbeddingStore$VespaEmbeddingStoreBuilder.class */
    public static class VespaEmbeddingStoreBuilder {
        private String url;
        private String keyPath;
        private String certPath;
        private Duration timeout;
        private String namespace;
        private String documentType;
        private String rankProfile;
        private Integer targetHits;
        private Boolean avoidDups;

        VespaEmbeddingStoreBuilder() {
        }

        public VespaEmbeddingStoreBuilder url(String str) {
            this.url = str;
            return this;
        }

        public VespaEmbeddingStoreBuilder keyPath(String str) {
            this.keyPath = str;
            return this;
        }

        public VespaEmbeddingStoreBuilder certPath(String str) {
            this.certPath = str;
            return this;
        }

        public VespaEmbeddingStoreBuilder timeout(Duration duration) {
            this.timeout = duration;
            return this;
        }

        public VespaEmbeddingStoreBuilder namespace(String str) {
            this.namespace = str;
            return this;
        }

        public VespaEmbeddingStoreBuilder documentType(String str) {
            this.documentType = str;
            return this;
        }

        public VespaEmbeddingStoreBuilder rankProfile(String str) {
            this.rankProfile = str;
            return this;
        }

        public VespaEmbeddingStoreBuilder targetHits(Integer num) {
            this.targetHits = num;
            return this;
        }

        public VespaEmbeddingStoreBuilder avoidDups(Boolean bool) {
            this.avoidDups = bool;
            return this;
        }

        public VespaEmbeddingStore build() {
            return new VespaEmbeddingStore(this.url, this.keyPath, this.certPath, this.timeout, this.namespace, this.documentType, this.rankProfile, this.targetHits, this.avoidDups);
        }

        public String toString() {
            return "VespaEmbeddingStore.VespaEmbeddingStoreBuilder(url=" + this.url + ", keyPath=" + this.keyPath + ", certPath=" + this.certPath + ", timeout=" + String.valueOf(this.timeout) + ", namespace=" + this.namespace + ", documentType=" + this.documentType + ", rankProfile=" + this.rankProfile + ", targetHits=" + this.targetHits + ", avoidDups=" + this.avoidDups + ")";
        }
    }

    public VespaEmbeddingStore(String str, String str2, String str3, Duration duration, String str4, String str5, String str6, Integer num, Boolean bool) {
        this.url = str;
        this.keyPath = Paths.get(str2, new String[0]);
        this.certPath = Paths.get(str3, new String[0]);
        this.timeout = duration != null ? duration : DEFAULT_TIMEOUT;
        this.namespace = str4 != null ? str4 : DEFAULT_NAMESPACE;
        this.documentType = str5 != null ? str5 : DEFAULT_DOCUMENT_TYPE;
        this.rankProfile = str6 != null ? str6 : DEFAULT_RANK_PROFILE;
        this.targetHits = num != null ? num.intValue() : DEFAULT_TARGET_HITS;
        this.avoidDups = bool != null ? bool.booleanValue() : true;
    }

    public String add(Embedding embedding) {
        return add(null, embedding, null);
    }

    public void add(String str, Embedding embedding) {
        add(str, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        return add(null, embedding, textSegment);
    }

    public List<String> addAll(List<Embedding> list) {
        return addAll(list, null);
    }

    public List<String> addAll(List<Embedding> list, List<TextSegment> list2) {
        if (list2 != null && list.size() != list2.size()) {
            throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
        }
        final ArrayList arrayList = new ArrayList();
        try {
            JsonFeeder buildJsonFeeder = buildJsonFeeder();
            try {
                ArrayList arrayList2 = new ArrayList();
                for (int i = 0; i < list.size(); i += DEFAULT_AVOID_DUPS) {
                    arrayList2.add(buildRecord(list.get(i), list2 != null ? list2.get(i) : null));
                }
                buildJsonFeeder.feedMany(Json.toInputStream(arrayList2, List.class), new JsonFeeder.ResultCallback() { // from class: dev.langchain4j.store.embedding.vespa.VespaEmbeddingStore.1
                    public void onNextResult(Result result, FeedException feedException) {
                        if (feedException != null) {
                            throw new RuntimeException(feedException.getMessage());
                        }
                        if (Result.Type.success.equals(result.type())) {
                            arrayList.add(result.documentId().toString());
                        }
                    }

                    public void onError(FeedException feedException) {
                        throw new RuntimeException(feedException.getMessage());
                    }
                });
                if (buildJsonFeeder != null) {
                    buildJsonFeeder.close();
                }
                return arrayList;
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int i, double d) {
        try {
            Response execute = getQueryApi().search(Q.select(FIELD_NAME_DOCUMENT_ID, new String[]{FIELD_NAME_TEXT_SEGMENT, FIELD_NAME_VECTOR}).from(this.documentType).where(buildNearestNeighbor()).fix().hits(i).ranking(this.rankProfile).param("input.query(q)", Json.toJson(embedding.vectorAsList())).param("input.query(threshold)", String.valueOf(d)).build()).execute();
            if (execute.isSuccessful()) {
                return (List) ((QueryResponse) execute.body()).getRoot().getChildren().stream().map(VespaEmbeddingStore::toEmbeddingMatch).collect(Collectors.toList());
            }
            throw new RuntimeException("Request failed");
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private String add(String str, Embedding embedding, TextSegment textSegment) {
        AtomicReference atomicReference = new AtomicReference();
        try {
            JsonFeeder buildJsonFeeder = buildJsonFeeder();
            try {
                buildJsonFeeder.feedSingle(Json.toJson(buildRecord(str, embedding, textSegment))).whenComplete((result, th) -> {
                    if (th != null) {
                        throw new RuntimeException(th);
                    }
                    if (Result.Type.success.equals(result.type())) {
                        atomicReference.set(result.documentId().toString());
                    }
                });
                if (buildJsonFeeder != null) {
                    buildJsonFeeder.close();
                }
                return (String) atomicReference.get();
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private JsonFeeder buildJsonFeeder() {
        return JsonFeeder.builder(FeedClientBuilder.create(URI.create(this.url)).setCertificate(this.certPath, this.keyPath).build()).withTimeout(this.timeout).build();
    }

    private VespaQueryApi getQueryApi() {
        if (this.queryApi == null) {
            this.queryApi = VespaQueryClient.createInstance(this.url, this.certPath, this.keyPath);
        }
        return this.queryApi;
    }

    private static EmbeddingMatch<TextSegment> toEmbeddingMatch(Record record) {
        return new EmbeddingMatch<>(Double.valueOf(record.getRelevance()), record.getFields().getDocumentId(), Embedding.from(record.getFields().getVector().getValues()), TextSegment.from(record.getFields().getTextSegment()));
    }

    private Record buildRecord(String str, Embedding embedding, TextSegment textSegment) {
        return new Record(DocumentId.of(this.namespace, this.documentType, str != null ? str : (!this.avoidDups || textSegment == null) ? Utils.randomUUID() : Utils.generateUUIDFrom(textSegment.text())).toString(), textSegment != null ? textSegment.text() : null, embedding.vectorAsList());
    }

    private Record buildRecord(Embedding embedding, TextSegment textSegment) {
        return buildRecord(null, embedding, textSegment);
    }

    private NearestNeighbor buildNearestNeighbor() throws NoSuchMethodException, IllegalAccessException, InvocationTargetException {
        NearestNeighbor nearestNeighbor = Q.nearestNeighbor(FIELD_NAME_VECTOR, "q");
        Method declaredMethod = NearestNeighbor.class.getDeclaredMethod("annotate", Annotation.class);
        declaredMethod.setAccessible(true);
        declaredMethod.invoke(nearestNeighbor, A.a("targetHits", Integer.valueOf(this.targetHits)));
        return nearestNeighbor;
    }

    public static VespaEmbeddingStoreBuilder builder() {
        return new VespaEmbeddingStoreBuilder();
    }
}
