Skip to content

Commit 4058eb3

Browse files
committed
Add StructuredOutputRetryLimitMiddleware and default retry limit
1 parent 29bea8d commit 4058eb3

5 files changed

Lines changed: 225 additions & 9 deletions

File tree

splunklib/ai/base_agent.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
from splunklib.ai.conversation_store import ConversationStore
2424
from splunklib.ai.hooks import (
2525
DEFAULT_STEP_LIMIT,
26+
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT,
2627
DEFAULT_TIMEOUT_SECONDS,
2728
DEFAULT_TOKEN_LIMIT,
2829
StepLimitMiddleware,
30+
StructuredOutputRetryLimitMiddleware,
2931
TimeoutLimitMiddleware,
3032
TokenLimitMiddleware,
3133
)
@@ -79,16 +81,24 @@ def __init__(
7981
self._output_schema = output_schema
8082
user_middleware = tuple(middleware) if middleware else ()
8183
user_middleware_types = {type(m) for m in user_middleware}
84+
8285
# NOTE: we're creating separate instances per agent - TimeoutLimitMiddleware is stateful
8386
# and sharing one would cause agents to overwrite each other's deadline.
84-
predefined: list[AgentMiddleware] = [
87+
predefined_before: list[AgentMiddleware] = [
88+
StructuredOutputRetryLimitMiddleware(DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT),
89+
]
90+
predefined_after: list[AgentMiddleware] = [
8591
TokenLimitMiddleware(DEFAULT_TOKEN_LIMIT),
8692
StepLimitMiddleware(DEFAULT_STEP_LIMIT),
8793
TimeoutLimitMiddleware(DEFAULT_TIMEOUT_SECONDS),
8894
]
89-
# Append predefined middlewares by default if not provided already.
90-
default_middleware = [m for m in predefined if type(m) not in user_middleware_types]
91-
self._middleware = (*user_middleware, *default_middleware)
95+
96+
self._middleware = (
97+
*{m for m in predefined_before if type(m) not in user_middleware_types},
98+
*user_middleware,
99+
*{m for m in predefined_after if type(m) not in user_middleware_types},
100+
)
101+
92102
self._trace_id = secrets.token_hex(16) # 32 Hex characters
93103
self._conversation_store = conversation_store
94104
self._thread_id = thread_id

splunklib/ai/engines/langchain.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -882,11 +882,6 @@ async def llm_handler(req: ModelRequest) -> ModelResponse:
882882
except StructuredOutputGenerationException as e:
883883
# Structured output generation failed, retry.
884884

885-
# TODO: we should provide a mechanism to limit the amount of retries
886-
# thath happen sequentially (say 3), otherwise raise a different exception.
887-
# For now this can be done with the use of model middleware that counts
888-
# the amount of StructuredOutputGenerationException that were raised.
889-
890885
ai_msg = _map_message_to_langchain(e.message)
891886
assert isinstance(ai_msg, LC_AIMessage)
892887

splunklib/ai/hooks.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,18 @@
1212
ModelRequest,
1313
ModelResponse,
1414
)
15+
from splunklib.ai.structured_output import StructuredOutputGenerationException
1516

1617
DEFAULT_TIMEOUT_SECONDS: float = 600.0
1718
DEFAULT_STEP_LIMIT: int = 100
1819
DEFAULT_TOKEN_LIMIT: int = 200_000
20+
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT: int = 3
21+
22+
23+
# TODO: should we include the messages in the exception? We have them
24+
# in AgentState.
25+
26+
# TODO: what if we pass a AiMessage with tool calls to invoke.
1927

2028

2129
class AgentStopException(Exception):
@@ -25,6 +33,7 @@ class AgentStopException(Exception):
2533
class TokenLimitExceededException(AgentStopException):
2634
"""Raised by `Agent.invoke`, when token limit exceeds"""
2735

36+
# TODO: should be an int
2837
def __init__(self, token_limit: float) -> None:
2938
super().__init__(f"Token limit of {token_limit} exceeded.")
3039

@@ -43,6 +52,13 @@ def __init__(self, timeout_seconds: float) -> None:
4352
super().__init__(f"Timed out after {timeout_seconds} seconds.")
4453

4554

55+
class StructuredOutputRetryLimitExceededException(AgentStopException):
56+
"""Raised by `Agent.invoke`, when structured output retry limit exceeds"""
57+
58+
def __init__(self, retry_count: int) -> None:
59+
super().__init__(f"Structured output retry limit of {retry_count} exceeded")
60+
61+
4662
def before_model(
4763
func: Callable[[ModelRequest], None | Awaitable[None]],
4864
) -> AgentMiddleware:
@@ -125,6 +141,10 @@ async def agent_middleware(
125141
return _Middleware()
126142

127143

144+
# TODO: we should have a token budget limit.
145+
146+
147+
# TODO: actually we could call this context window limit, right?
128148
class TokenLimitMiddleware(AgentMiddleware):
129149
"""Stops agent execution when the token count of messages passed to the model exceeds the given limit."""
130150

@@ -187,6 +207,7 @@ async def agent_middleware(
187207
) -> AgentResponse[Any | None]:
188208
# WARN: this might not work with agents handling
189209
# different threads at the same time.
210+
# TODO: now we have thread_id, thus we can solve this.
190211
self._deadline = monotonic() + self._seconds
191212
return await handler(request)
192213

@@ -199,3 +220,44 @@ async def model_middleware(
199220
if self._deadline is not None and monotonic() >= self._deadline:
200221
raise TimeoutExceededException(timeout_seconds=self._seconds)
201222
return await handler(request)
223+
224+
225+
class StructuredOutputRetryLimitMiddleware(AgentMiddleware):
226+
"""Stops agent execution when the agent exceeds structured output
227+
retry limit during a single agent loop invocation.
228+
"""
229+
230+
_limit: int
231+
_retries_per_thread_id: dict[str, int]
232+
233+
def __init__(self, limit: int) -> None:
234+
self._limit = limit
235+
self._retries_per_thread_id = {}
236+
237+
@override
238+
async def agent_middleware(
239+
self,
240+
request: AgentRequest,
241+
handler: AgentMiddlewareHandler,
242+
) -> AgentResponse[Any | None]:
243+
try:
244+
# Agent loop starting.
245+
self._retries_per_thread_id[request.thread_id] = 0
246+
return await handler(request)
247+
finally:
248+
del self._retries_per_thread_id[request.thread_id] # don't leak memory
249+
250+
@override
251+
async def model_middleware(
252+
self,
253+
request: ModelRequest,
254+
handler: ModelMiddlewareHandler,
255+
) -> ModelResponse:
256+
try:
257+
return await handler(request)
258+
except StructuredOutputGenerationException:
259+
# TODO: 0 is fince, document that? 0 means no retries, right?
260+
self._retries_per_thread_id[request.state.thread_id] += 1
261+
if self._retries_per_thread_id[request.state.thread_id] > self._limit:
262+
raise StructuredOutputRetryLimitExceededException(self._limit)
263+
raise # re-raise, to retry structured output generation

splunklib/ai/messages.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class OpaqueBlock:
6161
# Type alias for all content block variants.
6262
ContentBlock = TextBlock | OpaqueBlock
6363

64+
# TODO: should we set kw_only = True
65+
6466

6567
@dataclass(frozen=True)
6668
class ToolCall:

tests/integration/ai/test_structured_output.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
from pydantic.dataclasses import dataclass
2222

2323
from splunklib.ai import Agent
24+
from splunklib.ai.hooks import (
25+
StructuredOutputRetryLimitExceededException,
26+
StructuredOutputRetryLimitMiddleware,
27+
)
2428
from splunklib.ai.messages import (
2529
AgentResponse,
2630
AIMessage,
@@ -930,5 +934,148 @@ async def _model_middleware(
930934
assert len(result.messages) == 3
931935
assert result.structured_output.name == "MIKE"
932936

937+
@pytest.mark.asyncio
938+
@ai_snapshot_test()
939+
async def test_default_retry_limit(self) -> None:
940+
pytest.importorskip("langchain_openai")
941+
942+
class Person(BaseModel):
943+
name: str = Field(description="The person's full name", min_length=1)
944+
945+
model_call_count = 0
946+
947+
@model_middleware
948+
async def _model_middleware(
949+
_request: ModelRequest,
950+
_handler: ModelMiddlewareHandler,
951+
) -> ModelResponse:
952+
nonlocal model_call_count
953+
model_call_count += 1
954+
955+
raise StructuredOutputGenerationException(
956+
message=AIMessage(content="", calls=[]),
957+
error=StructuredOutputValidationError(
958+
validation_error="Invalid output"
959+
),
960+
)
961+
962+
async with Agent(
963+
model=(await self.model()),
964+
system_prompt="Respond with structured data",
965+
output_schema=Person,
966+
service=self.service,
967+
middleware=[_model_middleware],
968+
) as agent:
969+
with pytest.raises(
970+
StructuredOutputRetryLimitExceededException,
971+
match="Structured output retry limit of 3 exceeded",
972+
):
973+
await agent.invoke(
974+
[HumanMessage(content="My name is Mike, what is my name?")]
975+
)
976+
977+
assert model_call_count == 4
978+
979+
@pytest.mark.asyncio
980+
@ai_snapshot_test()
981+
async def test_custom_retry_limit_retry(self) -> None:
982+
pytest.importorskip("langchain_openai")
983+
984+
class Person(BaseModel):
985+
name: str = Field(description="The person's full name", min_length=1)
986+
987+
limits = [0, 1, 20]
988+
for limit in limits:
989+
with self.subTest(limit):
990+
model_call_count = 0
991+
992+
@model_middleware
993+
async def _model_middleware(
994+
_request: ModelRequest,
995+
_handler: ModelMiddlewareHandler,
996+
) -> ModelResponse:
997+
nonlocal model_call_count
998+
model_call_count += 1
999+
1000+
raise StructuredOutputGenerationException(
1001+
message=AIMessage(content="", calls=[]),
1002+
error=StructuredOutputValidationError(
1003+
validation_error="Invalid output"
1004+
),
1005+
)
1006+
1007+
async with Agent(
1008+
model=(await self.model()),
1009+
system_prompt="Respond with structured data",
1010+
output_schema=Person,
1011+
service=self.service,
1012+
middleware=[
1013+
StructuredOutputRetryLimitMiddleware(limit),
1014+
_model_middleware,
1015+
],
1016+
) as agent:
1017+
with pytest.raises(
1018+
StructuredOutputRetryLimitExceededException,
1019+
match=f"Structured output retry limit of {limit} exceeded",
1020+
):
1021+
await agent.invoke(
1022+
[HumanMessage(content="My name is Mike, what is my name?")]
1023+
)
1024+
1025+
# We expect limit + 1, since first LLM call is not a retry.
1026+
assert model_call_count == limit + 1
1027+
1028+
@pytest.mark.asyncio
1029+
@ai_snapshot_test()
1030+
async def test_retry_limit_is_per_agent_loop(self) -> None:
1031+
pytest.importorskip("langchain_openai")
1032+
1033+
class Person(BaseModel):
1034+
name: str = Field(description="The person's full name", min_length=1)
1035+
1036+
after_first_call = False
1037+
1038+
@model_middleware
1039+
async def _model_middleware(
1040+
_request: ModelRequest,
1041+
_handler: ModelMiddlewareHandler,
1042+
) -> ModelResponse:
1043+
if after_first_call:
1044+
return ModelResponse(
1045+
message=AIMessage(content="", calls=[]),
1046+
structured_output=Person(name="Mike"),
1047+
)
1048+
else:
1049+
raise StructuredOutputGenerationException(
1050+
message=AIMessage(content="", calls=[]),
1051+
error=StructuredOutputValidationError(
1052+
validation_error="Invalid output"
1053+
),
1054+
)
1055+
1056+
async with Agent(
1057+
model=(await self.model()),
1058+
system_prompt="Respond with structured data",
1059+
output_schema=Person,
1060+
service=self.service,
1061+
middleware=[
1062+
_model_middleware,
1063+
],
1064+
) as agent:
1065+
with pytest.raises(
1066+
StructuredOutputRetryLimitExceededException,
1067+
match="Structured output retry limit of 3 exceeded",
1068+
):
1069+
await agent.invoke(
1070+
[HumanMessage(content="My name is Mike, what is my name?")]
1071+
)
1072+
1073+
after_first_call = True
1074+
1075+
# Since structured output retry limit is per agent loop, this should not fail.
1076+
await agent.invoke(
1077+
[HumanMessage(content="My name is Mike, what is my name?")]
1078+
)
1079+
9331080
# TODO: test what happens if model/agent middleware removes the structured_output.
9341081
# do we detect that? We should and raise in invoke, that output was removed.

0 commit comments

Comments
 (0)