Skip to content

Commit 4f91f65

Browse files
committed
Add StructuredOutputRetryLimitMiddleware and default retry limit
1 parent 29bea8d commit 4f91f65

5 files changed

Lines changed: 224 additions & 14 deletions

File tree

splunklib/ai/README.md

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,8 @@ triggers the retry logic described above. A custom `model_middleware` can interc
613613
to observe, log, or override the retry behavior. A custom `model_middleware` can also raise
614614
the `StructuredOutputGenerationException` manually to reject structured output and force a re-generation.
615615

616+
The number maximal of re-tries is limited per agent loop invocation see [Default limit middlewares][#default-limit-middlewares].
617+
616618
### Subagents with structured output/input
617619

618620
In addition to output schemas, subagents can define input schemas. These schemas both constrain
@@ -926,7 +928,7 @@ async with Agent(
926928
) as agent: ...
927929
```
928930

929-
### Default limit middlewares
931+
## Default limit middlewares
930932

931933
Every `Agent` automatically applies sane default limits to prevent runaway execution
932934
or excessive token usage. Default limit middlewares are appended after any user-supplied
@@ -939,15 +941,17 @@ chain - place it last if you want the same behavior.
939941
| `TokenLimitMiddleware` | 200 000 tokens | token count of messages passed to the model |
940942
| `StepLimitMiddleware` | 100 steps | steps taken |
941943
| `TimeoutLimitMiddleware` | 600 seconds (10 minutes) | per `invoke` call |
944+
| `StructuredOutputRetryLimitMiddleware` | 3 retries | per `invoke` call |
942945

943946
`TokenLimitMiddleware` and `StepLimitMiddleware` check the values from the messages passed to the
944-
model on each call. `TimeoutLimitMiddleware` resets its deadline on each `invoke`, so every call
945-
gets a fresh time budget.
947+
model on each call. `TimeoutLimitMiddleware` and `StructuredOutputRetryLimitMiddlewa` resets its
948+
deadline/limit on each `invoke`, so effectively these limit only the agent loop.
946949

947950
When a limit is exceeded, the agent raises the corresponding exception:
948-
`TokenLimitExceededException`, `StepsLimitExceededException`, or `TimeoutExceededException`.
951+
`TokenLimitExceededException`, `StepsLimitExceededException`, or `TimeoutExceededException`,
952+
`StructuredOutputRetryLimitExceededException`.
949953

950-
#### Overriding defaults
954+
### Overriding defaults
951955

952956
To override a specific limit, pass your own instance of the corresponding middleware
953957
class. The default for that limit is suppressed automatically - the other defaults
@@ -970,13 +974,18 @@ To override all defaults, pass all three:
970974
async with Agent(
971975
...,
972976
middleware=[
977+
StructuredOutputRetryLimitMiddleware(0), # no-retries.
973978
TokenLimitMiddleware(50_000),
974979
StepLimitMiddleware(10),
975980
TimeoutLimitMiddleware(30.0),
976981
],
977982
) as agent: ...
978983
```
979984

985+
**Note**: When overriding limit middlewares, order matters. Place `StructuredOutputRetryLimitMiddleware`
986+
first and `TokenLimitMiddleware`, `StepLimitMiddleware`, and `TimeoutLimitMiddleware` last,
987+
otherwise the limits may not behave as expected.
988+
980989
There is no explicit opt-out - the intent is that agents should always have some guardrails.
981990

982991
## Logger

splunklib/ai/base_agent.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
from splunklib.ai.conversation_store import ConversationStore
2424
from splunklib.ai.hooks import (
2525
DEFAULT_STEP_LIMIT,
26+
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT,
2627
DEFAULT_TIMEOUT_SECONDS,
2728
DEFAULT_TOKEN_LIMIT,
2829
StepLimitMiddleware,
30+
StructuredOutputRetryLimitMiddleware,
2931
TimeoutLimitMiddleware,
3032
TokenLimitMiddleware,
3133
)
@@ -79,16 +81,24 @@ def __init__(
7981
self._output_schema = output_schema
8082
user_middleware = tuple(middleware) if middleware else ()
8183
user_middleware_types = {type(m) for m in user_middleware}
84+
8285
# NOTE: we're creating separate instances per agent - TimeoutLimitMiddleware is stateful
8386
# and sharing one would cause agents to overwrite each other's deadline.
84-
predefined: list[AgentMiddleware] = [
87+
predefined_before: list[AgentMiddleware] = [
88+
StructuredOutputRetryLimitMiddleware(DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT),
89+
]
90+
predefined_after: list[AgentMiddleware] = [
8591
TokenLimitMiddleware(DEFAULT_TOKEN_LIMIT),
8692
StepLimitMiddleware(DEFAULT_STEP_LIMIT),
8793
TimeoutLimitMiddleware(DEFAULT_TIMEOUT_SECONDS),
8894
]
89-
# Append predefined middlewares by default if not provided already.
90-
default_middleware = [m for m in predefined if type(m) not in user_middleware_types]
91-
self._middleware = (*user_middleware, *default_middleware)
95+
96+
self._middleware = (
97+
*{m for m in predefined_before if type(m) not in user_middleware_types},
98+
*user_middleware,
99+
*{m for m in predefined_after if type(m) not in user_middleware_types},
100+
)
101+
92102
self._trace_id = secrets.token_hex(16) # 32 Hex characters
93103
self._conversation_store = conversation_store
94104
self._thread_id = thread_id

splunklib/ai/engines/langchain.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -882,11 +882,6 @@ async def llm_handler(req: ModelRequest) -> ModelResponse:
882882
except StructuredOutputGenerationException as e:
883883
# Structured output generation failed, retry.
884884

885-
# TODO: we should provide a mechanism to limit the amount of retries
886-
# thath happen sequentially (say 3), otherwise raise a different exception.
887-
# For now this can be done with the use of model middleware that counts
888-
# the amount of StructuredOutputGenerationException that were raised.
889-
890885
ai_msg = _map_message_to_langchain(e.message)
891886
assert isinstance(ai_msg, LC_AIMessage)
892887

splunklib/ai/hooks.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
ModelRequest,
1313
ModelResponse,
1414
)
15+
from splunklib.ai.structured_output import StructuredOutputGenerationException
1516

1617
DEFAULT_TIMEOUT_SECONDS: float = 600.0
1718
DEFAULT_STEP_LIMIT: int = 100
1819
DEFAULT_TOKEN_LIMIT: int = 200_000
20+
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT: int = 3
1921

2022

2123
class AgentStopException(Exception):
@@ -43,6 +45,13 @@ def __init__(self, timeout_seconds: float) -> None:
4345
super().__init__(f"Timed out after {timeout_seconds} seconds.")
4446

4547

48+
class StructuredOutputRetryLimitExceededException(AgentStopException):
49+
"""Raised by `Agent.invoke`, when structured output retry limit exceeds"""
50+
51+
def __init__(self, retry_count: int) -> None:
52+
super().__init__(f"Structured output retry limit of {retry_count} exceeded")
53+
54+
4655
def before_model(
4756
func: Callable[[ModelRequest], None | Awaitable[None]],
4857
) -> AgentMiddleware:
@@ -199,3 +208,43 @@ async def model_middleware(
199208
if self._deadline is not None and monotonic() >= self._deadline:
200209
raise TimeoutExceededException(timeout_seconds=self._seconds)
201210
return await handler(request)
211+
212+
213+
class StructuredOutputRetryLimitMiddleware(AgentMiddleware):
214+
"""Stops agent execution when the agent exceeds structured output
215+
retry limit during a single agent loop invocation.
216+
"""
217+
218+
_limit: int
219+
_retries_per_thread_id: dict[str, int]
220+
221+
def __init__(self, limit: int) -> None:
222+
self._limit = limit
223+
self._retries_per_thread_id = {}
224+
225+
@override
226+
async def agent_middleware(
227+
self,
228+
request: AgentRequest,
229+
handler: AgentMiddlewareHandler,
230+
) -> AgentResponse[Any | None]:
231+
try:
232+
# Agent loop starting.
233+
self._retries_per_thread_id[request.thread_id] = 0
234+
return await handler(request)
235+
finally:
236+
del self._retries_per_thread_id[request.thread_id] # don't leak memory
237+
238+
@override
239+
async def model_middleware(
240+
self,
241+
request: ModelRequest,
242+
handler: ModelMiddlewareHandler,
243+
) -> ModelResponse:
244+
try:
245+
return await handler(request)
246+
except StructuredOutputGenerationException:
247+
self._retries_per_thread_id[request.state.thread_id] += 1
248+
if self._retries_per_thread_id[request.state.thread_id] > self._limit:
249+
raise StructuredOutputRetryLimitExceededException(self._limit)
250+
raise # re-raise, to retry structured output generation

tests/integration/ai/test_structured_output.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
from pydantic.dataclasses import dataclass
2222

2323
from splunklib.ai import Agent
24+
from splunklib.ai.hooks import (
25+
StructuredOutputRetryLimitExceededException,
26+
StructuredOutputRetryLimitMiddleware,
27+
)
2428
from splunklib.ai.messages import (
2529
AgentResponse,
2630
AIMessage,
@@ -930,5 +934,148 @@ async def _model_middleware(
930934
assert len(result.messages) == 3
931935
assert result.structured_output.name == "MIKE"
932936

937+
@pytest.mark.asyncio
938+
@ai_snapshot_test()
939+
async def test_default_retry_limit(self) -> None:
940+
pytest.importorskip("langchain_openai")
941+
942+
class Person(BaseModel):
943+
name: str = Field(description="The person's full name", min_length=1)
944+
945+
model_call_count = 0
946+
947+
@model_middleware
948+
async def _model_middleware(
949+
_request: ModelRequest,
950+
_handler: ModelMiddlewareHandler,
951+
) -> ModelResponse:
952+
nonlocal model_call_count
953+
model_call_count += 1
954+
955+
raise StructuredOutputGenerationException(
956+
message=AIMessage(content="", calls=[]),
957+
error=StructuredOutputValidationError(
958+
validation_error="Invalid output"
959+
),
960+
)
961+
962+
async with Agent(
963+
model=(await self.model()),
964+
system_prompt="Respond with structured data",
965+
output_schema=Person,
966+
service=self.service,
967+
middleware=[_model_middleware],
968+
) as agent:
969+
with pytest.raises(
970+
StructuredOutputRetryLimitExceededException,
971+
match="Structured output retry limit of 3 exceeded",
972+
):
973+
await agent.invoke(
974+
[HumanMessage(content="My name is Mike, what is my name?")]
975+
)
976+
977+
assert model_call_count == 4
978+
979+
@pytest.mark.asyncio
980+
@ai_snapshot_test()
981+
async def test_custom_retry_limit_retry(self) -> None:
982+
pytest.importorskip("langchain_openai")
983+
984+
class Person(BaseModel):
985+
name: str = Field(description="The person's full name", min_length=1)
986+
987+
limits = [0, 1, 20]
988+
for limit in limits:
989+
with self.subTest(limit):
990+
model_call_count = 0
991+
992+
@model_middleware
993+
async def _model_middleware(
994+
_request: ModelRequest,
995+
_handler: ModelMiddlewareHandler,
996+
) -> ModelResponse:
997+
nonlocal model_call_count
998+
model_call_count += 1
999+
1000+
raise StructuredOutputGenerationException(
1001+
message=AIMessage(content="", calls=[]),
1002+
error=StructuredOutputValidationError(
1003+
validation_error="Invalid output"
1004+
),
1005+
)
1006+
1007+
async with Agent(
1008+
model=(await self.model()),
1009+
system_prompt="Respond with structured data",
1010+
output_schema=Person,
1011+
service=self.service,
1012+
middleware=[
1013+
StructuredOutputRetryLimitMiddleware(limit),
1014+
_model_middleware,
1015+
],
1016+
) as agent:
1017+
with pytest.raises(
1018+
StructuredOutputRetryLimitExceededException,
1019+
match=f"Structured output retry limit of {limit} exceeded",
1020+
):
1021+
await agent.invoke(
1022+
[HumanMessage(content="My name is Mike, what is my name?")]
1023+
)
1024+
1025+
# We expect limit + 1, since first LLM call is not a retry.
1026+
assert model_call_count == limit + 1
1027+
1028+
@pytest.mark.asyncio
1029+
@ai_snapshot_test()
1030+
async def test_retry_limit_is_per_agent_loop(self) -> None:
1031+
pytest.importorskip("langchain_openai")
1032+
1033+
class Person(BaseModel):
1034+
name: str = Field(description="The person's full name", min_length=1)
1035+
1036+
after_first_call = False
1037+
1038+
@model_middleware
1039+
async def _model_middleware(
1040+
_request: ModelRequest,
1041+
_handler: ModelMiddlewareHandler,
1042+
) -> ModelResponse:
1043+
if after_first_call:
1044+
return ModelResponse(
1045+
message=AIMessage(content="", calls=[]),
1046+
structured_output=Person(name="Mike"),
1047+
)
1048+
else:
1049+
raise StructuredOutputGenerationException(
1050+
message=AIMessage(content="", calls=[]),
1051+
error=StructuredOutputValidationError(
1052+
validation_error="Invalid output"
1053+
),
1054+
)
1055+
1056+
async with Agent(
1057+
model=(await self.model()),
1058+
system_prompt="Respond with structured data",
1059+
output_schema=Person,
1060+
service=self.service,
1061+
middleware=[
1062+
_model_middleware,
1063+
],
1064+
) as agent:
1065+
with pytest.raises(
1066+
StructuredOutputRetryLimitExceededException,
1067+
match="Structured output retry limit of 3 exceeded",
1068+
):
1069+
await agent.invoke(
1070+
[HumanMessage(content="My name is Mike, what is my name?")]
1071+
)
1072+
1073+
after_first_call = True
1074+
1075+
# Since structured output retry limit is per agent loop, this should not fail.
1076+
await agent.invoke(
1077+
[HumanMessage(content="My name is Mike, what is my name?")]
1078+
)
1079+
9331080
# TODO: test what happens if model/agent middleware removes the structured_output.
9341081
# do we detect that? We should and raise in invoke, that output was removed.

0 commit comments

Comments
 (0)