|
21 | 21 | from pydantic.dataclasses import dataclass |
22 | 22 |
|
23 | 23 | from splunklib.ai import Agent |
| 24 | +from splunklib.ai.hooks import ( |
| 25 | + StructuredOutputRetryLimitExceededException, |
| 26 | + StructuredOutputRetryLimitMiddleware, |
| 27 | +) |
24 | 28 | from splunklib.ai.messages import ( |
25 | 29 | AgentResponse, |
26 | 30 | AIMessage, |
@@ -930,5 +934,148 @@ async def _model_middleware( |
930 | 934 | assert len(result.messages) == 3 |
931 | 935 | assert result.structured_output.name == "MIKE" |
932 | 936 |
|
| 937 | + @pytest.mark.asyncio |
| 938 | + @ai_snapshot_test() |
| 939 | + async def test_default_retry_limit(self) -> None: |
| 940 | + pytest.importorskip("langchain_openai") |
| 941 | + |
| 942 | + class Person(BaseModel): |
| 943 | + name: str = Field(description="The person's full name", min_length=1) |
| 944 | + |
| 945 | + model_call_count = 0 |
| 946 | + |
| 947 | + @model_middleware |
| 948 | + async def _model_middleware( |
| 949 | + _request: ModelRequest, |
| 950 | + _handler: ModelMiddlewareHandler, |
| 951 | + ) -> ModelResponse: |
| 952 | + nonlocal model_call_count |
| 953 | + model_call_count += 1 |
| 954 | + |
| 955 | + raise StructuredOutputGenerationException( |
| 956 | + message=AIMessage(content="", calls=[]), |
| 957 | + error=StructuredOutputValidationError( |
| 958 | + validation_error="Invalid output" |
| 959 | + ), |
| 960 | + ) |
| 961 | + |
| 962 | + async with Agent( |
| 963 | + model=(await self.model()), |
| 964 | + system_prompt="Respond with structured data", |
| 965 | + output_schema=Person, |
| 966 | + service=self.service, |
| 967 | + middleware=[_model_middleware], |
| 968 | + ) as agent: |
| 969 | + with pytest.raises( |
| 970 | + StructuredOutputRetryLimitExceededException, |
| 971 | + match="Structured output retry limit of 3 exceeded", |
| 972 | + ): |
| 973 | + await agent.invoke( |
| 974 | + [HumanMessage(content="My name is Mike, what is my name?")] |
| 975 | + ) |
| 976 | + |
| 977 | + assert model_call_count == 4 |
| 978 | + |
| 979 | + @pytest.mark.asyncio |
| 980 | + @ai_snapshot_test() |
| 981 | + async def test_custom_retry_limit_retry(self) -> None: |
| 982 | + pytest.importorskip("langchain_openai") |
| 983 | + |
| 984 | + class Person(BaseModel): |
| 985 | + name: str = Field(description="The person's full name", min_length=1) |
| 986 | + |
| 987 | + limits = [0, 1, 20] |
| 988 | + for limit in limits: |
| 989 | + with self.subTest(limit): |
| 990 | + model_call_count = 0 |
| 991 | + |
| 992 | + @model_middleware |
| 993 | + async def _model_middleware( |
| 994 | + _request: ModelRequest, |
| 995 | + _handler: ModelMiddlewareHandler, |
| 996 | + ) -> ModelResponse: |
| 997 | + nonlocal model_call_count |
| 998 | + model_call_count += 1 |
| 999 | + |
| 1000 | + raise StructuredOutputGenerationException( |
| 1001 | + message=AIMessage(content="", calls=[]), |
| 1002 | + error=StructuredOutputValidationError( |
| 1003 | + validation_error="Invalid output" |
| 1004 | + ), |
| 1005 | + ) |
| 1006 | + |
| 1007 | + async with Agent( |
| 1008 | + model=(await self.model()), |
| 1009 | + system_prompt="Respond with structured data", |
| 1010 | + output_schema=Person, |
| 1011 | + service=self.service, |
| 1012 | + middleware=[ |
| 1013 | + StructuredOutputRetryLimitMiddleware(limit), |
| 1014 | + _model_middleware, |
| 1015 | + ], |
| 1016 | + ) as agent: |
| 1017 | + with pytest.raises( |
| 1018 | + StructuredOutputRetryLimitExceededException, |
| 1019 | + match=f"Structured output retry limit of {limit} exceeded", |
| 1020 | + ): |
| 1021 | + await agent.invoke( |
| 1022 | + [HumanMessage(content="My name is Mike, what is my name?")] |
| 1023 | + ) |
| 1024 | + |
| 1025 | + # We expect limit + 1, since first LLM call is not a retry. |
| 1026 | + assert model_call_count == limit + 1 |
| 1027 | + |
| 1028 | + @pytest.mark.asyncio |
| 1029 | + @ai_snapshot_test() |
| 1030 | + async def test_retry_limit_is_per_agent_loop(self) -> None: |
| 1031 | + pytest.importorskip("langchain_openai") |
| 1032 | + |
| 1033 | + class Person(BaseModel): |
| 1034 | + name: str = Field(description="The person's full name", min_length=1) |
| 1035 | + |
| 1036 | + after_first_call = False |
| 1037 | + |
| 1038 | + @model_middleware |
| 1039 | + async def _model_middleware( |
| 1040 | + _request: ModelRequest, |
| 1041 | + _handler: ModelMiddlewareHandler, |
| 1042 | + ) -> ModelResponse: |
| 1043 | + if after_first_call: |
| 1044 | + return ModelResponse( |
| 1045 | + message=AIMessage(content="", calls=[]), |
| 1046 | + structured_output=Person(name="Mike"), |
| 1047 | + ) |
| 1048 | + else: |
| 1049 | + raise StructuredOutputGenerationException( |
| 1050 | + message=AIMessage(content="", calls=[]), |
| 1051 | + error=StructuredOutputValidationError( |
| 1052 | + validation_error="Invalid output" |
| 1053 | + ), |
| 1054 | + ) |
| 1055 | + |
| 1056 | + async with Agent( |
| 1057 | + model=(await self.model()), |
| 1058 | + system_prompt="Respond with structured data", |
| 1059 | + output_schema=Person, |
| 1060 | + service=self.service, |
| 1061 | + middleware=[ |
| 1062 | + _model_middleware, |
| 1063 | + ], |
| 1064 | + ) as agent: |
| 1065 | + with pytest.raises( |
| 1066 | + StructuredOutputRetryLimitExceededException, |
| 1067 | + match="Structured output retry limit of 3 exceeded", |
| 1068 | + ): |
| 1069 | + await agent.invoke( |
| 1070 | + [HumanMessage(content="My name is Mike, what is my name?")] |
| 1071 | + ) |
| 1072 | + |
| 1073 | + after_first_call = True |
| 1074 | + |
| 1075 | + # Since structured output retry limit is per agent loop, this should not fail. |
| 1076 | + await agent.invoke( |
| 1077 | + [HumanMessage(content="My name is Mike, what is my name?")] |
| 1078 | + ) |
| 1079 | + |
933 | 1080 | # TODO: test what happens if model/agent middleware removes the structured_output. |
934 | 1081 | # do we detect that? We should and raise in invoke, that output was removed. |
0 commit comments