Skip to content

Commit cbdb570

Browse files
committed
Use thread_ids in TimeoutLimitMiddleware
1 parent dfecb28 commit cbdb570

2 files changed

Lines changed: 28 additions & 18 deletions

File tree

splunklib/ai/hooks.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,30 +182,34 @@ class TimeoutLimitMiddleware(AgentMiddleware):
182182
"""
183183

184184
_seconds: float
185-
_deadline: float | None
185+
_deadline_per_thread_id: dict[str, float]
186186

187187
def __init__(self, seconds: float) -> None:
188188
self._seconds = seconds
189-
self._deadline = None
189+
self._deadline_per_thread_id = {}
190190

191191
@override
192192
async def agent_middleware(
193193
self,
194194
request: AgentRequest,
195195
handler: AgentMiddlewareHandler,
196196
) -> AgentResponse[Any | None]:
197-
# WARN: this might not work with agents handling
198-
# different threads at the same time.
199-
self._deadline = monotonic() + self._seconds
200-
return await handler(request)
197+
try:
198+
# Agent loop starting.
199+
self._deadline_per_thread_id[request.thread_id] = (
200+
monotonic() + self._seconds
201+
)
202+
return await handler(request)
203+
finally:
204+
del self._deadline_per_thread_id[request.thread_id] # don't leak memory
201205

202206
@override
203207
async def model_middleware(
204208
self,
205209
request: ModelRequest,
206210
handler: ModelMiddlewareHandler,
207211
) -> ModelResponse:
208-
if self._deadline is not None and monotonic() >= self._deadline:
212+
if monotonic() >= self._deadline_per_thread_id[request.state.thread_id]:
209213
raise TimeoutExceededException(timeout_seconds=self._seconds)
210214
return await handler(request)
211215

tests/unit/ai/test_default_limits.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,6 @@ def test_all_user_limits_suppress_all_defaults(self) -> None:
111111
assert len([m for m in mw if isinstance(m, TimeoutLimitMiddleware)]) == 1
112112

113113

114-
async def _noop_agent_handler(_request: AgentRequest) -> AgentResponse[None]:
115-
return AgentResponse(messages=[], structured_output=None)
116-
117-
118114
async def _noop_model_handler(_request: ModelRequest) -> ModelResponse:
119115
return ModelResponse(message=AIMessage(content="", calls=[]))
120116

@@ -124,23 +120,33 @@ async def test_deadline_reset_on_each_invoke(self) -> None:
124120
mw = TimeoutLimitMiddleware(60.0)
125121
request = _make_agent_request()
126122

127-
await mw.agent_middleware(request, _noop_agent_handler)
128-
first_deadline = mw._deadline # pyright: ignore[reportPrivateUsage]
123+
first_deadline: float | None = None
124+
second_deadline: float | None = None
125+
126+
async def _first_agent_handler(_request: AgentRequest) -> AgentResponse[None]:
127+
nonlocal first_deadline
128+
first_deadline = mw._deadline_per_thread_id["foo"] # pyright: ignore[reportPrivateUsage]
129+
return AgentResponse(messages=[], structured_output=None)
130+
131+
async def _second_agent_handler(_request: AgentRequest) -> AgentResponse[None]:
132+
nonlocal second_deadline
133+
second_deadline = mw._deadline_per_thread_id["foo"] # pyright: ignore[reportPrivateUsage]
134+
return AgentResponse(messages=[], structured_output=None)
129135

130-
await mw.agent_middleware(request, _noop_agent_handler)
131-
second_deadline = mw._deadline # pyright: ignore[reportPrivateUsage]
136+
await mw.agent_middleware(request, _first_agent_handler)
137+
await mw.agent_middleware(request, _second_agent_handler)
132138

133139
assert first_deadline is not None
134-
assert second_deadline is not None
140+
assert second_deadline is not None # pyright: ignore[reportUnreachable]
135141
assert second_deadline >= first_deadline
136142

137143
async def test_deadline_is_none_before_first_invoke(self) -> None:
138144
mw = TimeoutLimitMiddleware(60.0)
139-
assert mw._deadline is None # pyright: ignore[reportPrivateUsage]
145+
assert mw._deadline_per_thread_id.get("foo") is None # pyright: ignore[reportPrivateUsage]
140146

141147
async def test_timeout_fires_when_deadline_exceeded(self) -> None:
142148
mw = TimeoutLimitMiddleware(60.0)
143-
mw._deadline = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past
149+
mw._deadline_per_thread_id["foo"] = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past
144150

145151
state = AgentState(messages=[], total_steps=0, token_count=0, thread_id="foo")
146152
request = ModelRequest(system_message="", state=state)

0 commit comments

Comments
 (0)