Skip to content

Commit 542d30b

Browse files
committed
Use thread_ids in TimeoutLimitMiddleware
1 parent 11db930 commit 542d30b

2 files changed

Lines changed: 58 additions & 21 deletions

File tree

splunklib/ai/hooks.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,30 +182,34 @@ class TimeoutLimitMiddleware(AgentMiddleware):
182182
"""
183183

184184
_seconds: float
185-
_deadline: float | None
185+
_deadline_per_thread_id: dict[str, float]
186186

187187
def __init__(self, seconds: float) -> None:
188188
self._seconds = seconds
189-
self._deadline = None
189+
self._deadline_per_thread_id = {}
190190

191191
@override
192192
async def agent_middleware(
193193
self,
194194
request: AgentRequest,
195195
handler: AgentMiddlewareHandler,
196196
) -> AgentResponse[Any | None]:
197-
# WARN: this might not work with agents handling
198-
# different threads at the same time.
199-
self._deadline = monotonic() + self._seconds
200-
return await handler(request)
197+
try:
198+
# Agent loop starting.
199+
self._deadline_per_thread_id[request.thread_id] = (
200+
monotonic() + self._seconds
201+
)
202+
return await handler(request)
203+
finally:
204+
del self._deadline_per_thread_id[request.thread_id] # don't leak memory
201205

202206
@override
203207
async def model_middleware(
204208
self,
205209
request: ModelRequest,
206210
handler: ModelMiddlewareHandler,
207211
) -> ModelResponse:
208-
if self._deadline is not None and monotonic() >= self._deadline:
212+
if monotonic() >= self._deadline_per_thread_id[request.state.thread_id]:
209213
raise TimeoutExceededException(timeout_seconds=self._seconds)
210214
return await handler(request)
211215

tests/unit/ai/test_default_limits.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@
2828
TokenLimitMiddleware,
2929
)
3030
from splunklib.ai.messages import AIMessage, AgentResponse
31-
from splunklib.ai.middleware import AgentMiddleware, AgentRequest, AgentState, ModelRequest, ModelResponse
31+
from splunklib.ai.middleware import (
32+
AgentMiddleware,
33+
AgentRequest,
34+
AgentState,
35+
ModelRequest,
36+
ModelResponse,
37+
)
3238
from splunklib.ai.model import OpenAIModel
3339
from splunklib.client import Service
3440

@@ -103,7 +109,11 @@ def test_user_timeout_limit_suppresses_default(self) -> None:
103109

104110
def test_all_user_limits_suppress_all_defaults(self) -> None:
105111
agent = _make_agent(
106-
middleware=[TokenLimitMiddleware(50_000), StepLimitMiddleware(10), TimeoutLimitMiddleware(30.0)]
112+
middleware=[
113+
TokenLimitMiddleware(50_000),
114+
StepLimitMiddleware(10),
115+
TimeoutLimitMiddleware(30.0),
116+
]
107117
)
108118
mw = list(agent.middleware or [])
109119
assert len([m for m in mw if isinstance(m, TokenLimitMiddleware)]) == 1
@@ -124,23 +134,34 @@ async def test_deadline_reset_on_each_invoke(self) -> None:
124134
mw = TimeoutLimitMiddleware(60.0)
125135
request = _make_agent_request()
126136

127-
await mw.agent_middleware(request, _noop_agent_handler)
128-
first_deadline = mw._deadline # pyright: ignore[reportPrivateUsage]
137+
first_deadline: float | None = None
138+
second_deadline: float | None = None
139+
140+
async def _first_agent_handler(_request: AgentRequest) -> AgentResponse[None]:
141+
nonlocal first_deadline
142+
first_deadline = mw._deadline_per_thread_id["foo"] # pyright: ignore[reportPrivateUsage]
143+
return AgentResponse(messages=[], structured_output=None)
144+
145+
async def _second_agent_handler(_request: AgentRequest) -> AgentResponse[None]:
146+
nonlocal second_deadline
147+
second_deadline = mw._deadline_per_thread_id["foo"] # pyright: ignore[reportPrivateUsage]
148+
return AgentResponse(messages=[], structured_output=None)
149+
150+
await mw.agent_middleware(request, _first_agent_handler)
129151

130-
await mw.agent_middleware(request, _noop_agent_handler)
131-
second_deadline = mw._deadline # pyright: ignore[reportPrivateUsage]
152+
await mw.agent_middleware(request, _second_agent_handler)
132153

133154
assert first_deadline is not None
134155
assert second_deadline is not None
135156
assert second_deadline >= first_deadline
136157

137158
async def test_deadline_is_none_before_first_invoke(self) -> None:
138159
mw = TimeoutLimitMiddleware(60.0)
139-
assert mw._deadline is None # pyright: ignore[reportPrivateUsage]
160+
assert mw._deadline_per_thread_id.get("foo") is None # pyright: ignore[reportPrivateUsage]
140161

141162
async def test_timeout_fires_when_deadline_exceeded(self) -> None:
142163
mw = TimeoutLimitMiddleware(60.0)
143-
mw._deadline = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past
164+
mw._deadline_per_thread_id["foo"] = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past
144165

145166
state = AgentState(messages=[], total_steps=0, token_count=0, thread_id="foo")
146167
request = ModelRequest(system_message="", state=state)
@@ -153,17 +174,29 @@ class TestTokenLimitMiddleware(unittest.IsolatedAsyncioTestCase):
153174
async def test_raises_when_token_count_in_request_exceeds_limit(self) -> None:
154175
mw = TokenLimitMiddleware(200)
155176

156-
await mw.model_middleware(_make_model_request(token_count=100), _noop_model_handler)
157-
await mw.model_middleware(_make_model_request(token_count=199), _noop_model_handler)
177+
await mw.model_middleware(
178+
_make_model_request(token_count=100), _noop_model_handler
179+
)
180+
await mw.model_middleware(
181+
_make_model_request(token_count=199), _noop_model_handler
182+
)
158183
with self.assertRaises(TokenLimitExceededException):
159-
await mw.model_middleware(_make_model_request(token_count=200), _noop_model_handler)
184+
await mw.model_middleware(
185+
_make_model_request(token_count=200), _noop_model_handler
186+
)
160187

161188

162189
class TestStepLimitMiddleware(unittest.IsolatedAsyncioTestCase):
163190
async def test_raises_when_steps_in_request_reach_limit(self) -> None:
164191
mw = StepLimitMiddleware(3)
165192

166-
await mw.model_middleware(_make_model_request(total_steps=1), _noop_model_handler)
167-
await mw.model_middleware(_make_model_request(total_steps=2), _noop_model_handler)
193+
await mw.model_middleware(
194+
_make_model_request(total_steps=1), _noop_model_handler
195+
)
196+
await mw.model_middleware(
197+
_make_model_request(total_steps=2), _noop_model_handler
198+
)
168199
with self.assertRaises(StepsLimitExceededException):
169-
await mw.model_middleware(_make_model_request(total_steps=3), _noop_model_handler)
200+
await mw.model_middleware(
201+
_make_model_request(total_steps=3), _noop_model_handler
202+
)

0 commit comments

Comments
 (0)