Skip to content

Commit f38b3f6

Browse files
committed
Use thread_ids in TimeoutLimitMiddleware
1 parent 11db930 commit f38b3f6

2 files changed

Lines changed: 22 additions & 12 deletions

File tree

splunklib/ai/hooks.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,30 +182,34 @@ class TimeoutLimitMiddleware(AgentMiddleware):
182182
"""
183183

184184
_seconds: float
185-
_deadline: float | None
185+
_deadline_per_thread_id: dict[str, Any]
186186

187187
def __init__(self, seconds: float) -> None:
188188
self._seconds = seconds
189-
self._deadline = None
189+
self._deadline_per_thread_id = {}
190190

191191
@override
192192
async def agent_middleware(
193193
self,
194194
request: AgentRequest,
195195
handler: AgentMiddlewareHandler,
196196
) -> AgentResponse[Any | None]:
197-
# WARN: this might not work with agents handling
198-
# different threads at the same time.
199-
self._deadline = monotonic() + self._seconds
200-
return await handler(request)
197+
try:
198+
# Agent loop starting.
199+
self._deadline_per_thread_id[request.thread_id] = (
200+
monotonic() + self._seconds
201+
)
202+
return await handler(request)
203+
finally:
204+
del self._deadline_per_thread_id[request.thread_id] # don't leak memory
201205

202206
@override
203207
async def model_middleware(
204208
self,
205209
request: ModelRequest,
206210
handler: ModelMiddlewareHandler,
207211
) -> ModelResponse:
208-
if self._deadline is not None and monotonic() >= self._deadline:
212+
if monotonic() >= self._deadline_per_thread_id[request.state.thread_id]:
209213
raise TimeoutExceededException(timeout_seconds=self._seconds)
210214
return await handler(request)
211215

tests/unit/ai/test_default_limits.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@
2828
TokenLimitMiddleware,
2929
)
3030
from splunklib.ai.messages import AIMessage, AgentResponse
31-
from splunklib.ai.middleware import AgentMiddleware, AgentRequest, AgentState, ModelRequest, ModelResponse
31+
from splunklib.ai.middleware import (
32+
AgentMiddleware,
33+
AgentRequest,
34+
AgentState,
35+
ModelRequest,
36+
ModelResponse,
37+
)
3238
from splunklib.ai.model import OpenAIModel
3339
from splunklib.client import Service
3440

@@ -125,22 +131,22 @@ async def test_deadline_reset_on_each_invoke(self) -> None:
125131
request = _make_agent_request()
126132

127133
await mw.agent_middleware(request, _noop_agent_handler)
128-
first_deadline = mw._deadline # pyright: ignore[reportPrivateUsage]
134+
first_deadline = mw._deadline_per_thread_id["foo"] # pyright: ignore[reportPrivateUsage]
129135

130136
await mw.agent_middleware(request, _noop_agent_handler)
131-
second_deadline = mw._deadline # pyright: ignore[reportPrivateUsage]
137+
second_deadline = mw._deadline_per_thread_id["foo"] # pyright: ignore[reportPrivateUsage]
132138

133139
assert first_deadline is not None
134140
assert second_deadline is not None
135141
assert second_deadline >= first_deadline
136142

137143
async def test_deadline_is_none_before_first_invoke(self) -> None:
138144
mw = TimeoutLimitMiddleware(60.0)
139-
assert mw._deadline is None # pyright: ignore[reportPrivateUsage]
145+
assert mw._deadline_per_thread_id["foo"] is None # pyright: ignore[reportPrivateUsage]
140146

141147
async def test_timeout_fires_when_deadline_exceeded(self) -> None:
142148
mw = TimeoutLimitMiddleware(60.0)
143-
mw._deadline = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past
149+
mw._deadline_per_thread_id["foo"] = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past
144150

145151
state = AgentState(messages=[], total_steps=0, token_count=0, thread_id="foo")
146152
request = ModelRequest(system_message="", state=state)

0 commit comments

Comments
 (0)