Skip to content

Commit 2eeaeef

Browse files
committed
Add StructuredOutputRetryLimitMiddleware and default retry limit
1 parent 29bea8d commit 2eeaeef

3 files changed

Lines changed: 179 additions & 2 deletions

File tree

splunklib/ai/base_agent.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
from splunklib.ai.conversation_store import ConversationStore
2424
from splunklib.ai.hooks import (
2525
DEFAULT_STEP_LIMIT,
26+
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT,
2627
DEFAULT_TIMEOUT_SECONDS,
2728
DEFAULT_TOKEN_LIMIT,
2829
StepLimitMiddleware,
30+
StructuredOutputRetryLimitMiddleware,
2931
TimeoutLimitMiddleware,
3032
TokenLimitMiddleware,
3133
)
@@ -87,8 +89,24 @@ def __init__(
8789
TimeoutLimitMiddleware(DEFAULT_TIMEOUT_SECONDS),
8890
]
8991
# Append predefined middlewares by default if not provided already.
90-
default_middleware = [m for m in predefined if type(m) not in user_middleware_types]
91-
self._middleware = (*user_middleware, *default_middleware)
92+
default_middleware = [
93+
m for m in predefined if type(m) not in user_middleware_types
94+
]
95+
96+
predefined_before: list[AgentMiddleware] = [
97+
StructuredOutputRetryLimitMiddleware(DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT),
98+
]
99+
100+
default_before_middleware = [
101+
m for m in predefined_before if type(m) not in user_middleware_types
102+
]
103+
104+
self._middleware = (
105+
*default_before_middleware,
106+
*user_middleware,
107+
*default_middleware,
108+
)
109+
92110
self._trace_id = secrets.token_hex(16) # 32 Hex characters
93111
self._conversation_store = conversation_store
94112
self._thread_id = thread_id

splunklib/ai/hooks.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,18 @@
1212
ModelRequest,
1313
ModelResponse,
1414
)
15+
from splunklib.ai.structured_output import StructuredOutputGenerationException
1516

1617
DEFAULT_TIMEOUT_SECONDS: float = 600.0
1718
DEFAULT_STEP_LIMIT: int = 100
1819
DEFAULT_TOKEN_LIMIT: int = 200_000
20+
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT: int = 3
21+
22+
23+
# TODO: should we include the messages in the exception? We have them
24+
# in AgentState.
25+
26+
# TODO: what if we pass a AiMessage with tool calls to invoke.
1927

2028

2129
class AgentStopException(Exception):
@@ -25,6 +33,7 @@ class AgentStopException(Exception):
2533
class TokenLimitExceededException(AgentStopException):
2634
"""Raised by `Agent.invoke`, when token limit exceeds"""
2735

36+
# TODO: should be an int
2837
def __init__(self, token_limit: float) -> None:
2938
super().__init__(f"Token limit of {token_limit} exceeded.")
3039

@@ -43,6 +52,13 @@ def __init__(self, timeout_seconds: float) -> None:
4352
super().__init__(f"Timed out after {timeout_seconds} seconds.")
4453

4554

55+
class StructuredOutputRetryLimitExceededException(AgentStopException):
56+
"""Raised by `Agent.invoke`, when structured output retry limit exceeds"""
57+
58+
def __init__(self, retry_count: int) -> None:
59+
super().__init__(f"Structured output retry limit of {retry_count} exceeded")
60+
61+
4662
def before_model(
4763
func: Callable[[ModelRequest], None | Awaitable[None]],
4864
) -> AgentMiddleware:
@@ -125,6 +141,10 @@ async def agent_middleware(
125141
return _Middleware()
126142

127143

144+
# TODO: we should have a token budget limit.
145+
146+
147+
# TODO: actually we could call this context window limit, right?
128148
class TokenLimitMiddleware(AgentMiddleware):
129149
"""Stops agent execution when the token count of messages passed to the model exceeds the given limit."""
130150

@@ -187,6 +207,7 @@ async def agent_middleware(
187207
) -> AgentResponse[Any | None]:
188208
# WARN: this might not work with agents handling
189209
# different threads at the same time.
210+
# TODO: now we have thread_id, thus we can solve this.
190211
self._deadline = monotonic() + self._seconds
191212
return await handler(request)
192213

@@ -199,3 +220,44 @@ async def model_middleware(
199220
if self._deadline is not None and monotonic() >= self._deadline:
200221
raise TimeoutExceededException(timeout_seconds=self._seconds)
201222
return await handler(request)
223+
224+
225+
class StructuredOutputRetryLimitMiddleware(AgentMiddleware):
226+
"""Stops agent execution when the agent exceeds structured output
227+
retry limit during a single agent loop invocation.
228+
"""
229+
230+
_limit: int
231+
_retries_per_thread_id: dict[str, int]
232+
233+
def __init__(self, limit: int) -> None:
234+
self._limit = limit
235+
self._retries_per_thread_id = {}
236+
237+
@override
238+
async def agent_middleware(
239+
self,
240+
request: AgentRequest,
241+
handler: AgentMiddlewareHandler,
242+
) -> AgentResponse[Any | None]:
243+
try:
244+
# Agent loop starting.
245+
self._retries_per_thread_id[request.thread_id] = 0
246+
return await handler(request)
247+
finally:
248+
del self._retries_per_thread_id[request.thread_id] # don't leak memory
249+
250+
@override
251+
async def model_middleware(
252+
self,
253+
request: ModelRequest,
254+
handler: ModelMiddlewareHandler,
255+
) -> ModelResponse:
256+
try:
257+
return await handler(request)
258+
except StructuredOutputGenerationException:
259+
# TODO: 0 is fince, document that? 0 means no retries, right?
260+
self._retries_per_thread_id[request.state.thread_id] += 1
261+
if self._retries_per_thread_id[request.state.thread_id] > self._limit:
262+
raise StructuredOutputRetryLimitExceededException(self._limit)
263+
raise # re-raise, to retry structured output generation

tests/integration/ai/test_structured_output.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
from pydantic.dataclasses import dataclass
2222

2323
from splunklib.ai import Agent
24+
from splunklib.ai.hooks import (
25+
StructuredOutputRetryLimitExceededException,
26+
StructuredOutputRetryLimitMiddleware,
27+
)
2428
from splunklib.ai.messages import (
2529
AgentResponse,
2630
AIMessage,
@@ -930,5 +934,98 @@ async def _model_middleware(
930934
assert len(result.messages) == 3
931935
assert result.structured_output.name == "MIKE"
932936

937+
@pytest.mark.asyncio
938+
@ai_snapshot_test()
939+
async def test_default_retry_limit(self) -> None:
940+
pytest.importorskip("langchain_openai")
941+
942+
class Person(BaseModel):
943+
name: str = Field(description="The person's full name", min_length=1)
944+
945+
model_call_count = 0
946+
947+
@model_middleware
948+
async def _model_middleware(
949+
_request: ModelRequest,
950+
_handler: ModelMiddlewareHandler,
951+
) -> ModelResponse:
952+
nonlocal model_call_count
953+
model_call_count += 1
954+
955+
raise StructuredOutputGenerationException(
956+
message=AIMessage(content="", calls=[]),
957+
error=StructuredOutputValidationError(
958+
validation_error="Invalid output"
959+
),
960+
)
961+
962+
async with Agent(
963+
model=(await self.model()),
964+
system_prompt="Respond with structured data",
965+
output_schema=Person,
966+
service=self.service,
967+
middleware=[_model_middleware],
968+
) as agent:
969+
with pytest.raises(
970+
StructuredOutputRetryLimitExceededException,
971+
match="Structured output retry limit of 3 exceeded",
972+
):
973+
await agent.invoke(
974+
[HumanMessage(content="My name is Mike, what is my name?")]
975+
)
976+
977+
assert model_call_count == 4
978+
979+
@pytest.mark.asyncio
980+
@ai_snapshot_test()
981+
async def test_custom_retry_limit_retry(self) -> None:
982+
pytest.importorskip("langchain_openai")
983+
984+
class Person(BaseModel):
985+
name: str = Field(description="The person's full name", min_length=1)
986+
987+
limits = [0, 1, 20]
988+
for limit in limits:
989+
with self.subTest(limit):
990+
model_call_count = 0
991+
992+
@model_middleware
993+
async def _model_middleware(
994+
_request: ModelRequest,
995+
_handler: ModelMiddlewareHandler,
996+
) -> ModelResponse:
997+
nonlocal model_call_count
998+
model_call_count += 1
999+
1000+
raise StructuredOutputGenerationException(
1001+
message=AIMessage(content="", calls=[]),
1002+
error=StructuredOutputValidationError(
1003+
validation_error="Invalid output"
1004+
),
1005+
)
1006+
1007+
async with Agent(
1008+
model=(await self.model()),
1009+
system_prompt="Respond with structured data",
1010+
output_schema=Person,
1011+
service=self.service,
1012+
middleware=[
1013+
StructuredOutputRetryLimitMiddleware(limit),
1014+
_model_middleware,
1015+
],
1016+
) as agent:
1017+
with pytest.raises(
1018+
StructuredOutputRetryLimitExceededException,
1019+
match=f"Structured output retry limit of {limit} exceeded",
1020+
):
1021+
await agent.invoke(
1022+
[HumanMessage(content="My name is Mike, what is my name?")]
1023+
)
1024+
1025+
# We expect limit + 1, since first LLM call is not a retry.
1026+
assert model_call_count == limit + 1
1027+
1028+
# TODO: make sure with test that this retry is in a single agent loop invoaction.
1029+
9331030
# TODO: test what happens if model/agent middleware removes the structured_output.
9341031
# do we detect that? We should and raise in invoke, that output was removed.

0 commit comments

Comments
 (0)