Skip to content

Commit 3d4c711

Browse files
committed
Use thread_ids in TimeoutLimitMiddleware
1 parent 4f91f65 commit 3d4c711

2 files changed

Lines changed: 28 additions & 14 deletions

File tree

splunklib/ai/hooks.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,30 +182,34 @@ class TimeoutLimitMiddleware(AgentMiddleware):
182182
"""
183183

184184
_seconds: float
185-
_deadline: float | None
185+
_deadline_per_thread_id: dict[str, float]
186186

187187
def __init__(self, seconds: float) -> None:
188188
self._seconds = seconds
189-
self._deadline = None
189+
self._deadline_per_thread_id = {}
190190

191191
@override
192192
async def agent_middleware(
193193
self,
194194
request: AgentRequest,
195195
handler: AgentMiddlewareHandler,
196196
) -> AgentResponse[Any | None]:
197-
# WARN: this might not work with agents handling
198-
# different threads at the same time.
199-
self._deadline = monotonic() + self._seconds
200-
return await handler(request)
197+
try:
198+
# Agent loop starting.
199+
self._deadline_per_thread_id[request.thread_id] = (
200+
monotonic() + self._seconds
201+
)
202+
return await handler(request)
203+
finally:
204+
del self._deadline_per_thread_id[request.thread_id] # don't leak memory
201205

202206
@override
203207
async def model_middleware(
204208
self,
205209
request: ModelRequest,
206210
handler: ModelMiddlewareHandler,
207211
) -> ModelResponse:
208-
if self._deadline is not None and monotonic() >= self._deadline:
212+
if monotonic() >= self._deadline_per_thread_id[request.state.thread_id]:
209213
raise TimeoutExceededException(timeout_seconds=self._seconds)
210214
return await handler(request)
211215

tests/unit/ai/test_default_limits.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,23 +124,33 @@ async def test_deadline_reset_on_each_invoke(self) -> None:
124124
mw = TimeoutLimitMiddleware(60.0)
125125
request = _make_agent_request()
126126

127-
await mw.agent_middleware(request, _noop_agent_handler)
128-
first_deadline = mw._deadline # pyright: ignore[reportPrivateUsage]
127+
first_deadline: float | None = None
128+
second_deadline: float | None = None
129129

130-
await mw.agent_middleware(request, _noop_agent_handler)
131-
second_deadline = mw._deadline # pyright: ignore[reportPrivateUsage]
130+
async def _first_agent_handler(_request: AgentRequest) -> AgentResponse[None]:
131+
nonlocal first_deadline
132+
first_deadline = mw._deadline_per_thread_id["foo"] # pyright: ignore[reportPrivateUsage]
133+
return AgentResponse(messages=[], structured_output=None)
134+
135+
async def _second_agent_handler(_request: AgentRequest) -> AgentResponse[None]:
136+
nonlocal second_deadline
137+
second_deadline = mw._deadline_per_thread_id["foo"] # pyright: ignore[reportPrivateUsage]
138+
return AgentResponse(messages=[], structured_output=None)
139+
140+
await mw.agent_middleware(request, _first_agent_handler)
141+
await mw.agent_middleware(request, _second_agent_handler)
132142

133143
assert first_deadline is not None
134-
assert second_deadline is not None
144+
assert second_deadline is not None # pyright: ignore[reportUnreachable]
135145
assert second_deadline >= first_deadline
136146

137147
async def test_deadline_is_none_before_first_invoke(self) -> None:
138148
mw = TimeoutLimitMiddleware(60.0)
139-
assert mw._deadline is None # pyright: ignore[reportPrivateUsage]
149+
assert mw._deadline_per_thread_id.get("foo") is None # pyright: ignore[reportPrivateUsage]
140150

141151
async def test_timeout_fires_when_deadline_exceeded(self) -> None:
142152
mw = TimeoutLimitMiddleware(60.0)
143-
mw._deadline = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past
153+
mw._deadline_per_thread_id["foo"] = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past
144154

145155
state = AgentState(messages=[], total_steps=0, token_count=0, thread_id="foo")
146156
request = ModelRequest(system_message="", state=state)

0 commit comments

Comments
 (0)