package dev.langchain4j.guardrail;

import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.guardrail.config.InputGuardrailsConfig;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.rag.AugmentationResult;
import dev.langchain4j.test.guardrail.GuardrailAssertions;
import java.util.Map;
import java.util.Objects;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.aggregator.AggregateWith;
import org.junit.jupiter.params.aggregator.ArgumentsAccessor;
import org.junit.jupiter.params.aggregator.ArgumentsAggregationException;
import org.junit.jupiter.params.aggregator.ArgumentsAggregator;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito;

/* loaded from: input_file:dev/langchain4j/guardrail/InputGuardrailExecutorTests.class */
class InputGuardrailExecutorTests {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dev/langchain4j/guardrail/InputGuardrailExecutorTests$FailureInputGuardrail.class */
    public static class FailureInputGuardrail<G extends FailureInputGuardrail> implements InputGuardrail {
        protected final String failureMessage;
        private boolean shouldBeExecuted;

        private FailureInputGuardrail(int i) {
            this("failure " + i);
        }

        private FailureInputGuardrail(String str) {
            this.shouldBeExecuted = true;
            this.failureMessage = str;
        }

        G shouldNotBeExecuted() {
            this.shouldBeExecuted = false;
            return this;
        }

        public InputGuardrailResult validate(UserMessage userMessage) {
            return failure(this.failureMessage);
        }
    }

    /* loaded from: input_file:dev/langchain4j/guardrail/InputGuardrailExecutorTests$FatalInputGuardrail.class */
    private static class FatalInputGuardrail extends FailureInputGuardrail<FatalInputGuardrail> {
        private FatalInputGuardrail(int i) {
            super(i);
        }

        @Override // dev.langchain4j.guardrail.InputGuardrailExecutorTests.FailureInputGuardrail
        public InputGuardrailResult validate(UserMessage userMessage) {
            return fatal(this.failureMessage);
        }
    }

    /* loaded from: input_file:dev/langchain4j/guardrail/InputGuardrailExecutorTests$InputGuardrailAggregator.class */
    static class InputGuardrailAggregator implements ArgumentsAggregator {
        InputGuardrailAggregator() {
        }

        public Object aggregateArguments(ArgumentsAccessor argumentsAccessor, ParameterContext parameterContext) throws ArgumentsAggregationException {
            Stream skip = argumentsAccessor.toList().stream().skip(parameterContext.getIndex());
            Class<InputGuardrail> cls = InputGuardrail.class;
            Objects.requireNonNull(InputGuardrail.class);
            return skip.map(cls::cast).toArray(i -> {
                return new InputGuardrail[i];
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dev/langchain4j/guardrail/InputGuardrailExecutorTests$SuccessInputGuardrail.class */
    public static class SuccessInputGuardrail implements InputGuardrail {
        private boolean shouldBeExecuted;

        SuccessInputGuardrail(boolean z) {
            this.shouldBeExecuted = true;
            this.shouldBeExecuted = z;
        }

        SuccessInputGuardrail() {
            this(true);
        }

        public InputGuardrailResult validate(UserMessage userMessage) {
            return InputGuardrailResult.success();
        }
    }

    InputGuardrailExecutorTests() {
    }

    @MethodSource({"successGuardrails"})
    @ParameterizedTest(name = "{0}")
    void allSuccessfulGuardrails(String str, int i, @AggregateWith(InputGuardrailAggregator.class) InputGuardrail... inputGuardrailArr) {
        InputGuardrail[] inputGuardrailArr2 = (InputGuardrail[]) Stream.of((Object[]) inputGuardrailArr).map((v0) -> {
            return Mockito.spy(v0);
        }).toArray(i2 -> {
            return new InputGuardrail[i2];
        });
        InputGuardrailRequest from = from(UserMessage.from("test"));
        GuardrailAssertions.assertThat(InputGuardrailExecutor.builder().guardrails(inputGuardrailArr2).build().execute(from)).isSuccessful();
        IntStream.range(0, i).mapToObj(i3 -> {
            return (SuccessInputGuardrail) inputGuardrailArr2[i3];
        }).forEach(successInputGuardrail -> {
            GuardrailAssertions.assertThat(successInputGuardrail.shouldBeExecuted).isTrue();
            ((SuccessInputGuardrail) Mockito.verify(successInputGuardrail)).validate(from);
        });
        IntStream.range(i, inputGuardrailArr2.length).mapToObj(i4 -> {
            return (SuccessInputGuardrail) inputGuardrailArr2[i4];
        }).forEach(successInputGuardrail2 -> {
            GuardrailAssertions.assertThat(successInputGuardrail2.shouldBeExecuted).isFalse();
            ((SuccessInputGuardrail) Mockito.verify(successInputGuardrail2, Mockito.never())).validate(from);
        });
    }

    @Test
    void noGuardrails() {
        GuardrailAssertions.assertThat(InputGuardrailExecutor.builder().build().execute(from(UserMessage.from("test")))).isSuccessful();
    }

    @MethodSource({"failedFatalGuardrails"})
    @ParameterizedTest(name = "{0}")
    void failedFatal(String str, int i, int i2, @AggregateWith(InputGuardrailAggregator.class) InputGuardrail... inputGuardrailArr) {
        InputGuardrail[] inputGuardrailArr2 = (InputGuardrail[]) Stream.of((Object[]) inputGuardrailArr).map((v0) -> {
            return Mockito.spy(v0);
        }).toArray(i3 -> {
            return new InputGuardrail[i3];
        });
        InputGuardrailRequest from = from(UserMessage.from("test"));
        InputGuardrailExecutor build = InputGuardrailExecutor.builder().guardrails(inputGuardrailArr2).config(InputGuardrailsConfig.builder().build()).build();
        Assertions.assertThatExceptionOfType(InputGuardrailException.class).isThrownBy(() -> {
            build.execute(from);
        }).withMessageMatching("The guardrail " + getClass().getName() + "\\$.+Guardrail failed with this message: failure \\d");
        IntStream.range(0, i).mapToObj(i4 -> {
            return inputGuardrailArr2[i4];
        }).forEach(inputGuardrail -> {
            GuardrailAssertions.assertThat(inputGuardrail instanceof SuccessInputGuardrail ? ((SuccessInputGuardrail) inputGuardrail).shouldBeExecuted : ((FailureInputGuardrail) inputGuardrail).shouldBeExecuted).isTrue();
            ((InputGuardrail) Mockito.verify(inputGuardrail)).validate(from);
        });
        IntStream.range(i, inputGuardrailArr2.length).mapToObj(i5 -> {
            return inputGuardrailArr2[i5];
        }).forEach(inputGuardrail2 -> {
            GuardrailAssertions.assertThat(inputGuardrail2 instanceof SuccessInputGuardrail ? ((SuccessInputGuardrail) inputGuardrail2).shouldBeExecuted : ((FailureInputGuardrail) inputGuardrail2).shouldBeExecuted).isFalse();
            ((InputGuardrail) Mockito.verify(inputGuardrail2, Mockito.never())).validate(from);
        });
        Stream of = Stream.of((Object[]) inputGuardrailArr2);
        Class<FailureInputGuardrail> cls = FailureInputGuardrail.class;
        Objects.requireNonNull(FailureInputGuardrail.class);
        Stream filter = of.filter((v1) -> {
            return r1.isInstance(v1);
        });
        Class<FailureInputGuardrail> cls2 = FailureInputGuardrail.class;
        Objects.requireNonNull(FailureInputGuardrail.class);
        GuardrailAssertions.assertThat(filter.map((v1) -> {
            return r1.cast(v1);
        }).filter(failureInputGuardrail -> {
            return failureInputGuardrail.shouldBeExecuted;
        }).count()).isEqualTo(i2);
    }

    static Stream<Arguments> successGuardrails() {
        return Stream.of((Object[]) new Arguments[]{Arguments.of(new Object[]{"No guardrails", 0}), Arguments.of(new Object[]{"One successful guardrail", 1, new SuccessInputGuardrail()}), Arguments.of(new Object[]{"Two successful guardrails", 2, new SuccessInputGuardrail(), new SuccessInputGuardrail()}), Arguments.of(new Object[]{"Three successful guardrails", 3, new SuccessInputGuardrail(), new SuccessInputGuardrail(), new SuccessInputGuardrail()})});
    }

    static Stream<Arguments> failedFatalGuardrails() {
        return Stream.of((Object[]) new Arguments[]{Arguments.of(new Object[]{"One successful one fatal guardrail", 2, 1, new SuccessInputGuardrail(), new FatalInputGuardrail(1)}), Arguments.of(new Object[]{"One fatal one successful guardrail", 1, 1, new FatalInputGuardrail(1), new SuccessInputGuardrail(false)}), Arguments.of(new Object[]{"One successful one fatal one successful guardrails", 2, 1, new SuccessInputGuardrail(), new FatalInputGuardrail(1), new SuccessInputGuardrail(false)}), Arguments.of(new Object[]{"One successful one fatal one failed guardrails", 2, 1, new SuccessInputGuardrail(), new FatalInputGuardrail(1), new FailureInputGuardrail(2).shouldNotBeExecuted()}), Arguments.of(new Object[]{"One failure one successful guardrail", 2, 1, new FailureInputGuardrail(1), new SuccessInputGuardrail()}), Arguments.of(new Object[]{"One successful one failure one successful guardrails", 3, 1, new SuccessInputGuardrail(), new FailureInputGuardrail(1), new SuccessInputGuardrail()}), Arguments.of(new Object[]{"One successful one fatal one failure guardrails", 2, 1, new SuccessInputGuardrail(), new FatalInputGuardrail(1), new FailureInputGuardrail(2).shouldNotBeExecuted()}), Arguments.of(new Object[]{"Two failure guardrails", 2, 2, new FailureInputGuardrail(1), new FailureInputGuardrail(2)}), Arguments.of(new Object[]{"One successful one failure one fatal one failure guardrails", 3, 2, new SuccessInputGuardrail(), new FailureInputGuardrail(2), new FatalInputGuardrail(1), new FailureInputGuardrail(3).shouldNotBeExecuted()})});
    }

    public static InputGuardrailRequest from(UserMessage userMessage) {
        return InputGuardrailRequest.builder().userMessage(userMessage).commonParams(GuardrailRequestParams.builder().chatMemory((ChatMemory) null).augmentationResult((AugmentationResult) null).userMessageTemplate("").variables(Map.of()).build()).build();
    }
}
