package dev.langchain4j.rag.query.router;

import dev.langchain4j.model.chat.mock.ChatModelMock;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.rag.query.router.LanguageModelQueryRouter;
import java.util.HashMap;
import java.util.LinkedHashMap;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

@ExtendWith({MockitoExtension.class})
/* loaded from: input_file:dev/langchain4j/rag/query/router/LanguageModelQueryRouterTest.class */
class LanguageModelQueryRouterTest {

    @Mock
    ContentRetriever catArticlesRetriever;

    @Mock
    ContentRetriever dogArticlesRetriever;

    LanguageModelQueryRouterTest() {
    }

    @Test
    void should_route_to_single_retriever() {
        Query from = Query.from("Do Labradors shed?");
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(this.catArticlesRetriever, "articles about cats");
        linkedHashMap.put(this.dogArticlesRetriever, "articles about dogs");
        ChatModelMock thatAlwaysResponds = ChatModelMock.thatAlwaysResponds("2");
        Assertions.assertThat(new LanguageModelQueryRouter(thatAlwaysResponds, linkedHashMap).route(from)).containsExactly(new ContentRetriever[]{this.dogArticlesRetriever});
        Assertions.assertThat(thatAlwaysResponds.userMessageText()).isEqualTo("Based on the user query, determine the most suitable data source(s) to retrieve relevant information from the following options:\n1: articles about cats\n2: articles about dogs\nIt is very important that your answer consists of either a single number or multiple numbers separated by commas and nothing else!\nUser query: Do Labradors shed?");
    }

    @Test
    void should_route_to_single_retriever_builder() {
        Query from = Query.from("Do Labradors shed?");
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(this.catArticlesRetriever, "articles about cats");
        linkedHashMap.put(this.dogArticlesRetriever, "articles about dogs");
        ChatModelMock thatAlwaysResponds = ChatModelMock.thatAlwaysResponds("2");
        Assertions.assertThat(LanguageModelQueryRouter.builder().chatLanguageModel(thatAlwaysResponds).retrieverToDescription(linkedHashMap).build().route(from)).containsExactly(new ContentRetriever[]{this.dogArticlesRetriever});
        Assertions.assertThat(thatAlwaysResponds.userMessageText()).isEqualTo("Based on the user query, determine the most suitable data source(s) to retrieve relevant information from the following options:\n1: articles about cats\n2: articles about dogs\nIt is very important that your answer consists of either a single number or multiple numbers separated by commas and nothing else!\nUser query: Do Labradors shed?");
    }

    @Test
    void should_route_to_multiple_retrievers() {
        Query from = Query.from("Which animal is the fluffiest?");
        HashMap hashMap = new HashMap();
        hashMap.put(this.catArticlesRetriever, "articles about cats");
        hashMap.put(this.dogArticlesRetriever, "articles about dogs");
        Assertions.assertThat(new LanguageModelQueryRouter(ChatModelMock.thatAlwaysResponds("1, 2"), hashMap).route(from)).containsExactlyInAnyOrder(new ContentRetriever[]{this.catArticlesRetriever, this.dogArticlesRetriever});
    }

    @Test
    void should_route_to_multiple_retrievers_with_custom_prompt_template() {
        PromptTemplate from = PromptTemplate.from("Which source should I use to get answer for '{{query}}'? Options: {{options}}'");
        Query from2 = Query.from("Which animal is the fluffiest?");
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(this.catArticlesRetriever, "articles about cats");
        linkedHashMap.put(this.dogArticlesRetriever, "articles about dogs");
        ChatModelMock thatAlwaysResponds = ChatModelMock.thatAlwaysResponds("1, 2");
        Assertions.assertThat(new LanguageModelQueryRouter(thatAlwaysResponds, linkedHashMap, from, LanguageModelQueryRouter.FallbackStrategy.FAIL).route(from2)).containsExactlyInAnyOrder(new ContentRetriever[]{this.catArticlesRetriever, this.dogArticlesRetriever});
        Assertions.assertThat(thatAlwaysResponds.userMessageText()).isEqualTo("Which source should I use to get answer for 'Which animal is the fluffiest?'? Options: 1: articles about cats\n2: articles about dogs'");
    }

    @Test
    void should_not_route_by_default_when_LLM_returns_invalid_response() {
        Query from = Query.from("Hey what's up?");
        ChatModelMock thatAlwaysResponds = ChatModelMock.thatAlwaysResponds("Sorry, I don't know");
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(this.catArticlesRetriever, "articles about cats");
        linkedHashMap.put(this.dogArticlesRetriever, "articles about dogs");
        Assertions.assertThat(new LanguageModelQueryRouter(thatAlwaysResponds, linkedHashMap).route(from)).isEmpty();
    }

    @Test
    void should_not_route_by_default_when_LLM_call_fails() {
        Query from = Query.from("Hey what's up?");
        ChatModelMock thatAlwaysThrowsException = ChatModelMock.thatAlwaysThrowsException();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(this.catArticlesRetriever, "articles about cats");
        linkedHashMap.put(this.dogArticlesRetriever, "articles about dogs");
        Assertions.assertThat(new LanguageModelQueryRouter(thatAlwaysThrowsException, linkedHashMap).route(from)).isEmpty();
    }

    @Test
    void should_route_to_all_retrievers_when_LLM_returns_invalid_response() {
        Query from = Query.from("Hey what's up?");
        ChatModelMock thatAlwaysResponds = ChatModelMock.thatAlwaysResponds("Sorry, I don't know");
        LanguageModelQueryRouter.FallbackStrategy fallbackStrategy = LanguageModelQueryRouter.FallbackStrategy.ROUTE_TO_ALL;
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(this.catArticlesRetriever, "articles about cats");
        linkedHashMap.put(this.dogArticlesRetriever, "articles about dogs");
        Assertions.assertThat(LanguageModelQueryRouter.builder().chatLanguageModel(thatAlwaysResponds).retrieverToDescription(linkedHashMap).fallbackStrategy(fallbackStrategy).build().route(from)).containsExactlyInAnyOrder(new ContentRetriever[]{this.catArticlesRetriever, this.dogArticlesRetriever});
    }

    @Test
    void should_route_to_all_retrievers_when_LLM_call_fails() {
        Query from = Query.from("Hey what's up?");
        ChatModelMock thatAlwaysThrowsException = ChatModelMock.thatAlwaysThrowsException();
        LanguageModelQueryRouter.FallbackStrategy fallbackStrategy = LanguageModelQueryRouter.FallbackStrategy.ROUTE_TO_ALL;
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(this.catArticlesRetriever, "articles about cats");
        linkedHashMap.put(this.dogArticlesRetriever, "articles about dogs");
        Assertions.assertThat(LanguageModelQueryRouter.builder().chatLanguageModel(thatAlwaysThrowsException).retrieverToDescription(linkedHashMap).fallbackStrategy(fallbackStrategy).build().route(from)).containsExactlyInAnyOrder(new ContentRetriever[]{this.catArticlesRetriever, this.dogArticlesRetriever});
    }

    @Test
    void should_fail_when_LLM_returns_invalid_response() {
        Query from = Query.from("Hey what's up?");
        ChatModelMock thatAlwaysResponds = ChatModelMock.thatAlwaysResponds("Sorry, I don't know");
        LanguageModelQueryRouter.FallbackStrategy fallbackStrategy = LanguageModelQueryRouter.FallbackStrategy.FAIL;
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(this.catArticlesRetriever, "articles about cats");
        linkedHashMap.put(this.dogArticlesRetriever, "articles about dogs");
        LanguageModelQueryRouter build = LanguageModelQueryRouter.builder().chatLanguageModel(thatAlwaysResponds).retrieverToDescription(linkedHashMap).fallbackStrategy(fallbackStrategy).build();
        Assertions.assertThatThrownBy(() -> {
            build.route(from);
        }).hasRootCauseExactlyInstanceOf(NumberFormatException.class);
    }

    @Test
    void should_fail_when_LLM_call_fails() {
        Query from = Query.from("Hey what's up?");
        ChatModelMock thatAlwaysThrowsExceptionWithMessage = ChatModelMock.thatAlwaysThrowsExceptionWithMessage("Something went wrong");
        LanguageModelQueryRouter.FallbackStrategy fallbackStrategy = LanguageModelQueryRouter.FallbackStrategy.FAIL;
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(this.catArticlesRetriever, "articles about cats");
        linkedHashMap.put(this.dogArticlesRetriever, "articles about dogs");
        LanguageModelQueryRouter build = LanguageModelQueryRouter.builder().chatLanguageModel(thatAlwaysThrowsExceptionWithMessage).retrieverToDescription(linkedHashMap).fallbackStrategy(fallbackStrategy).build();
        Assertions.assertThatThrownBy(() -> {
            build.route(from);
        }).isExactlyInstanceOf(RuntimeException.class).hasMessageContaining("Something went wrong");
    }
}
