Skip to content

Commit 11b9c6d

Browse files
committed
Add StructuredOutputRetryLimitMiddleware and default retry limit
Regenreated `test_agent_understands_other_agents.json`, since the previous recoding hit that default limit.
1 parent 29bea8d commit 11b9c6d

7 files changed

Lines changed: 340 additions & 671 deletions

File tree

splunklib/ai/README.md

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,8 @@ triggers the retry logic described above. A custom `model_middleware` can interc
613613
to observe, log, or override the retry behavior. A custom `model_middleware` can also raise
614614
the `StructuredOutputGenerationException` manually to reject structured output and force a re-generation.
615615

616+
The number maximal of re-tries is limited per agent loop invocation see [Default limit middlewares][#default-limit-middlewares].
617+
616618
### Subagents with structured output/input
617619

618620
In addition to output schemas, subagents can define input schemas. These schemas both constrain
@@ -926,7 +928,7 @@ async with Agent(
926928
) as agent: ...
927929
```
928930

929-
### Default limit middlewares
931+
## Default limit middlewares
930932

931933
Every `Agent` automatically applies sane default limits to prevent runaway execution
932934
or excessive token usage. Default limit middlewares are appended after any user-supplied
@@ -939,15 +941,17 @@ chain - place it last if you want the same behavior.
939941
| `TokenLimitMiddleware` | 200 000 tokens | token count of messages passed to the model |
940942
| `StepLimitMiddleware` | 100 steps | steps taken |
941943
| `TimeoutLimitMiddleware` | 600 seconds (10 minutes) | per `invoke` call |
944+
| `StructuredOutputRetryLimitMiddleware` | 3 retries | per `invoke` call |
942945

943946
`TokenLimitMiddleware` and `StepLimitMiddleware` check the values from the messages passed to the
944-
model on each call. `TimeoutLimitMiddleware` resets its deadline on each `invoke`, so every call
945-
gets a fresh time budget.
947+
model on each call. `TimeoutLimitMiddleware` and `StructuredOutputRetryLimitMiddlewa` resets its
948+
deadline/limit on each `invoke`, so effectively these limit only the agent loop.
946949

947950
When a limit is exceeded, the agent raises the corresponding exception:
948-
`TokenLimitExceededException`, `StepsLimitExceededException`, or `TimeoutExceededException`.
951+
`TokenLimitExceededException`, `StepsLimitExceededException`, or `TimeoutExceededException`,
952+
`StructuredOutputRetryLimitExceededException`.
949953

950-
#### Overriding defaults
954+
### Overriding defaults
951955

952956
To override a specific limit, pass your own instance of the corresponding middleware
953957
class. The default for that limit is suppressed automatically - the other defaults
@@ -970,13 +974,18 @@ To override all defaults, pass all three:
970974
async with Agent(
971975
...,
972976
middleware=[
977+
StructuredOutputRetryLimitMiddleware(0), # no-retries.
973978
TokenLimitMiddleware(50_000),
974979
StepLimitMiddleware(10),
975980
TimeoutLimitMiddleware(30.0),
976981
],
977982
) as agent: ...
978983
```
979984

985+
**Note**: When overriding limit middlewares, order matters. Place `StructuredOutputRetryLimitMiddleware`
986+
first and `TokenLimitMiddleware`, `StepLimitMiddleware`, and `TimeoutLimitMiddleware` last,
987+
otherwise the limits may not behave as expected.
988+
980989
There is no explicit opt-out - the intent is that agents should always have some guardrails.
981990

982991
## Logger

splunklib/ai/base_agent.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
from splunklib.ai.conversation_store import ConversationStore
2424
from splunklib.ai.hooks import (
2525
DEFAULT_STEP_LIMIT,
26+
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT,
2627
DEFAULT_TIMEOUT_SECONDS,
2728
DEFAULT_TOKEN_LIMIT,
2829
StepLimitMiddleware,
30+
StructuredOutputRetryLimitMiddleware,
2931
TimeoutLimitMiddleware,
3032
TokenLimitMiddleware,
3133
)
@@ -79,16 +81,24 @@ def __init__(
7981
self._output_schema = output_schema
8082
user_middleware = tuple(middleware) if middleware else ()
8183
user_middleware_types = {type(m) for m in user_middleware}
84+
8285
# NOTE: we're creating separate instances per agent - TimeoutLimitMiddleware is stateful
8386
# and sharing one would cause agents to overwrite each other's deadline.
84-
predefined: list[AgentMiddleware] = [
87+
predefined_before: list[AgentMiddleware] = [
88+
StructuredOutputRetryLimitMiddleware(DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT),
89+
]
90+
predefined_after: list[AgentMiddleware] = [
8591
TokenLimitMiddleware(DEFAULT_TOKEN_LIMIT),
8692
StepLimitMiddleware(DEFAULT_STEP_LIMIT),
8793
TimeoutLimitMiddleware(DEFAULT_TIMEOUT_SECONDS),
8894
]
89-
# Append predefined middlewares by default if not provided already.
90-
default_middleware = [m for m in predefined if type(m) not in user_middleware_types]
91-
self._middleware = (*user_middleware, *default_middleware)
95+
96+
self._middleware = (
97+
*{m for m in predefined_before if type(m) not in user_middleware_types},
98+
*user_middleware,
99+
*{m for m in predefined_after if type(m) not in user_middleware_types},
100+
)
101+
92102
self._trace_id = secrets.token_hex(16) # 32 Hex characters
93103
self._conversation_store = conversation_store
94104
self._thread_id = thread_id

splunklib/ai/engines/langchain.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -882,11 +882,6 @@ async def llm_handler(req: ModelRequest) -> ModelResponse:
882882
except StructuredOutputGenerationException as e:
883883
# Structured output generation failed, retry.
884884

885-
# TODO: we should provide a mechanism to limit the amount of retries
886-
# thath happen sequentially (say 3), otherwise raise a different exception.
887-
# For now this can be done with the use of model middleware that counts
888-
# the amount of StructuredOutputGenerationException that were raised.
889-
890885
ai_msg = _map_message_to_langchain(e.message)
891886
assert isinstance(ai_msg, LC_AIMessage)
892887

splunklib/ai/hooks.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
ModelRequest,
1313
ModelResponse,
1414
)
15+
from splunklib.ai.structured_output import StructuredOutputGenerationException
1516

1617
DEFAULT_TIMEOUT_SECONDS: float = 600.0
1718
DEFAULT_STEP_LIMIT: int = 100
1819
DEFAULT_TOKEN_LIMIT: int = 200_000
20+
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT: int = 3
1921

2022

2123
class AgentStopException(Exception):
@@ -43,6 +45,13 @@ def __init__(self, timeout_seconds: float) -> None:
4345
super().__init__(f"Timed out after {timeout_seconds} seconds.")
4446

4547

48+
class StructuredOutputRetryLimitExceededException(AgentStopException):
49+
"""Raised by `Agent.invoke`, when structured output retry limit exceeds"""
50+
51+
def __init__(self, retry_count: int) -> None:
52+
super().__init__(f"Structured output retry limit of {retry_count} exceeded")
53+
54+
4655
def before_model(
4756
func: Callable[[ModelRequest], None | Awaitable[None]],
4857
) -> AgentMiddleware:
@@ -199,3 +208,43 @@ async def model_middleware(
199208
if self._deadline is not None and monotonic() >= self._deadline:
200209
raise TimeoutExceededException(timeout_seconds=self._seconds)
201210
return await handler(request)
211+
212+
213+
class StructuredOutputRetryLimitMiddleware(AgentMiddleware):
214+
"""Stops agent execution when the agent exceeds structured output
215+
retry limit during a single agent loop invocation.
216+
"""
217+
218+
_limit: int
219+
_retries_per_thread_id: dict[str, int]
220+
221+
def __init__(self, limit: int) -> None:
222+
self._limit = limit
223+
self._retries_per_thread_id = {}
224+
225+
@override
226+
async def agent_middleware(
227+
self,
228+
request: AgentRequest,
229+
handler: AgentMiddlewareHandler,
230+
) -> AgentResponse[Any | None]:
231+
try:
232+
# Agent loop starting.
233+
self._retries_per_thread_id[request.thread_id] = 0
234+
return await handler(request)
235+
finally:
236+
del self._retries_per_thread_id[request.thread_id] # don't leak memory
237+
238+
@override
239+
async def model_middleware(
240+
self,
241+
request: ModelRequest,
242+
handler: ModelMiddlewareHandler,
243+
) -> ModelResponse:
244+
try:
245+
return await handler(request)
246+
except StructuredOutputGenerationException:
247+
self._retries_per_thread_id[request.state.thread_id] += 1
248+
if self._retries_per_thread_id[request.state.thread_id] > self._limit:
249+
raise StructuredOutputRetryLimitExceededException(self._limit)
250+
raise # re-raise, to retry structured output generation

tests/ai_testlib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def _json_body_matcher(r1: Any, r2: Any) -> None:
172172
my_vcr = vcr.VCR(
173173
cassette_library_dir=snapshot_dir,
174174
serializer="json-friendly",
175-
record_mode=RecordMode.ONCE,
175+
record_mode=RecordMode.NEW_EPISODES,
176176
match_on=[
177177
"method",
178178
"scheme",
@@ -184,7 +184,7 @@ def _json_body_matcher(r1: Any, r2: Any) -> None:
184184
],
185185
before_record_request=_before_record_request,
186186
before_record_response=_before_record_response,
187-
record_on_exception=False,
187+
# record_on_exception=False,
188188
drop_unused_requests=True,
189189
)
190190
my_vcr.register_serializer("json-friendly", _JSONFriendlySerializer())

0 commit comments

Comments
 (0)