package dev.langchain4j.model.chat.common;

import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
/* loaded from: input_file:dev/langchain4j/model/chat/common/AbstractStreamingChatModelIT.class */
public abstract class AbstractStreamingChatModelIT extends AbstractBaseChatModelIT<StreamingChatModel> {
    public abstract StreamingChatModel createModelWith(ChatModelListener chatModelListener);

    @Test
    void should_propagate_user_exceptions_thrown_from_onPartialResponse() throws Exception {
        final AtomicInteger atomicInteger = new AtomicInteger(0);
        final CompletableFuture completableFuture = new CompletableFuture();
        final ArrayList arrayList = new ArrayList();
        final CompletableFuture completableFuture2 = new CompletableFuture();
        final RuntimeException runtimeException = new RuntimeException("something wrong happened in user code");
        StreamingChatResponseHandler streamingChatResponseHandler = new StreamingChatResponseHandler() { // from class: dev.langchain4j.model.chat.common.AbstractStreamingChatModelIT.1
            public void onPartialResponse(String str) {
                atomicInteger.incrementAndGet();
                throw runtimeException;
            }

            public void onCompleteResponse(ChatResponse chatResponse) {
                completableFuture.complete(chatResponse);
            }

            public void onError(Throwable th) {
                arrayList.add(th);
                completableFuture2.complete(null);
            }
        };
        ChatModelListener chatModelListener = (ChatModelListener) Mockito.mock(ChatModelListener.class);
        StreamingChatModel createModelWith = createModelWith(chatModelListener);
        if (createModelWith == null) {
            return;
        }
        createModelWith.chat("What is the capital of Germany?", streamingChatResponseHandler);
        Assertions.assertThat(((ChatResponse) completableFuture.get(30L, TimeUnit.SECONDS)).aiMessage().text()).containsIgnoringCase("Berlin");
        Assertions.assertThat(atomicInteger.get()).isGreaterThan(1);
        completableFuture2.get(30L, TimeUnit.SECONDS);
        Assertions.assertThat(arrayList).hasSize(atomicInteger.get());
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            Assertions.assertThat((Throwable) it.next()).isEqualTo(runtimeException);
        }
        ((ChatModelListener) Mockito.verify(chatModelListener)).onRequest((ChatModelRequestContext) ArgumentMatchers.any());
        ((ChatModelListener) Mockito.verify(chatModelListener, Mockito.times(atomicInteger.get()))).onError((ChatModelErrorContext) ArgumentMatchers.any());
        ((ChatModelListener) Mockito.verify(chatModelListener)).onResponse((ChatModelResponseContext) ArgumentMatchers.any());
        Mockito.verifyNoMoreInteractions(new Object[]{chatModelListener});
    }

    @Test
    void should_propagate_user_exceptions_thrown_from_onCompleteResponse() throws Exception {
        final CompletableFuture completableFuture = new CompletableFuture();
        final ArrayList arrayList = new ArrayList();
        final CompletableFuture completableFuture2 = new CompletableFuture();
        final RuntimeException runtimeException = new RuntimeException("something wrong happened in user code");
        StreamingChatResponseHandler streamingChatResponseHandler = new StreamingChatResponseHandler() { // from class: dev.langchain4j.model.chat.common.AbstractStreamingChatModelIT.2
            public void onPartialResponse(String str) {
            }

            public void onCompleteResponse(ChatResponse chatResponse) {
                completableFuture.complete(chatResponse);
                throw runtimeException;
            }

            public void onError(Throwable th) {
                arrayList.add(th);
                completableFuture2.complete(null);
            }
        };
        ChatModelListener chatModelListener = (ChatModelListener) Mockito.mock(ChatModelListener.class);
        StreamingChatModel createModelWith = createModelWith(chatModelListener);
        if (createModelWith == null) {
            return;
        }
        createModelWith.chat("What is the capital of Germany?", streamingChatResponseHandler);
        Assertions.assertThat(((ChatResponse) completableFuture.get(30L, TimeUnit.SECONDS)).aiMessage().text()).containsIgnoringCase("Berlin");
        completableFuture2.get(30L, TimeUnit.SECONDS);
        Assertions.assertThat(arrayList).hasSize(1);
        Assertions.assertThat((Throwable) arrayList.get(0)).isEqualTo(runtimeException);
        ((ChatModelListener) Mockito.verify(chatModelListener)).onRequest((ChatModelRequestContext) ArgumentMatchers.any());
        ((ChatModelListener) Mockito.verify(chatModelListener)).onError((ChatModelErrorContext) ArgumentMatchers.any());
        ((ChatModelListener) Mockito.verify(chatModelListener)).onResponse((ChatModelResponseContext) ArgumentMatchers.any());
        Mockito.verifyNoMoreInteractions(new Object[]{chatModelListener});
    }

    @Test
    void should_ignore_user_exceptions_thrown_from_onError() throws Exception {
        final CompletableFuture completableFuture = new CompletableFuture();
        final ArrayList arrayList = new ArrayList();
        final CompletableFuture completableFuture2 = new CompletableFuture();
        final RuntimeException runtimeException = new RuntimeException("something wrong happened in user code");
        StreamingChatResponseHandler streamingChatResponseHandler = new StreamingChatResponseHandler() { // from class: dev.langchain4j.model.chat.common.AbstractStreamingChatModelIT.3
            public void onPartialResponse(String str) {
            }

            public void onCompleteResponse(ChatResponse chatResponse) {
                completableFuture.complete(chatResponse);
                throw runtimeException;
            }

            public void onError(Throwable th) {
                arrayList.add(th);
                completableFuture2.complete(null);
                throw new RuntimeException("something unexpected happened, but it should be ignored");
            }
        };
        ChatModelListener chatModelListener = (ChatModelListener) Mockito.mock(ChatModelListener.class);
        StreamingChatModel createModelWith = createModelWith(chatModelListener);
        if (createModelWith == null) {
            return;
        }
        createModelWith.chat("What is the capital of Germany?", streamingChatResponseHandler);
        Assertions.assertThat(((ChatResponse) completableFuture.get(30L, TimeUnit.SECONDS)).aiMessage().text()).containsIgnoringCase("Berlin");
        completableFuture2.get(30L, TimeUnit.SECONDS);
        Assertions.assertThat(arrayList).hasSize(1);
        Assertions.assertThat((Throwable) arrayList.get(0)).isEqualTo(runtimeException);
        ((ChatModelListener) Mockito.verify(chatModelListener)).onRequest((ChatModelRequestContext) ArgumentMatchers.any());
        ((ChatModelListener) Mockito.verify(chatModelListener)).onError((ChatModelErrorContext) ArgumentMatchers.any());
        ((ChatModelListener) Mockito.verify(chatModelListener)).onResponse((ChatModelResponseContext) ArgumentMatchers.any());
        Mockito.verifyNoMoreInteractions(new Object[]{chatModelListener});
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dev.langchain4j.model.chat.common.AbstractBaseChatModelIT
    public ChatResponseAndStreamingMetadata chat(StreamingChatModel streamingChatModel, ChatRequest chatRequest) {
        final CompletableFuture completableFuture = new CompletableFuture();
        final StringBuffer stringBuffer = new StringBuffer();
        final AtomicInteger atomicInteger = new AtomicInteger();
        final AtomicInteger atomicInteger2 = new AtomicInteger();
        final CopyOnWriteArraySet copyOnWriteArraySet = new CopyOnWriteArraySet();
        streamingChatModel.chat(chatRequest, new StreamingChatResponseHandler() { // from class: dev.langchain4j.model.chat.common.AbstractStreamingChatModelIT.4
            public void onPartialResponse(String str) {
                stringBuffer.append(str);
                atomicInteger.incrementAndGet();
                copyOnWriteArraySet.add(Thread.currentThread());
            }

            public void onCompleteResponse(ChatResponse chatResponse) {
                completableFuture.complete(chatResponse);
                atomicInteger2.incrementAndGet();
                copyOnWriteArraySet.add(Thread.currentThread());
            }

            public void onError(Throwable th) {
                completableFuture.completeExceptionally(th);
                copyOnWriteArraySet.add(Thread.currentThread());
            }
        });
        try {
            ChatResponse chatResponse = (ChatResponse) completableFuture.get(120L, TimeUnit.SECONDS);
            String stringBuffer2 = stringBuffer.toString();
            return new ChatResponseAndStreamingMetadata(chatResponse, new StreamingMetadata(stringBuffer2.isEmpty() ? null : stringBuffer2, atomicInteger.get(), atomicInteger2.get(), copyOnWriteArraySet));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
