2828 TokenLimitMiddleware ,
2929)
3030from splunklib .ai .messages import AIMessage , AgentResponse
31- from splunklib .ai .middleware import AgentMiddleware , AgentRequest , AgentState , ModelRequest , ModelResponse
31+ from splunklib .ai .middleware import (
32+ AgentMiddleware ,
33+ AgentRequest ,
34+ AgentState ,
35+ ModelRequest ,
36+ ModelResponse ,
37+ )
3238from splunklib .ai .model import OpenAIModel
3339from splunklib .client import Service
3440
@@ -103,7 +109,11 @@ def test_user_timeout_limit_suppresses_default(self) -> None:
103109
104110 def test_all_user_limits_suppress_all_defaults (self ) -> None :
105111 agent = _make_agent (
106- middleware = [TokenLimitMiddleware (50_000 ), StepLimitMiddleware (10 ), TimeoutLimitMiddleware (30.0 )]
112+ middleware = [
113+ TokenLimitMiddleware (50_000 ),
114+ StepLimitMiddleware (10 ),
115+ TimeoutLimitMiddleware (30.0 ),
116+ ]
107117 )
108118 mw = list (agent .middleware or [])
109119 assert len ([m for m in mw if isinstance (m , TokenLimitMiddleware )]) == 1
@@ -124,23 +134,34 @@ async def test_deadline_reset_on_each_invoke(self) -> None:
124134 mw = TimeoutLimitMiddleware (60.0 )
125135 request = _make_agent_request ()
126136
127- await mw .agent_middleware (request , _noop_agent_handler )
128- first_deadline = mw ._deadline # pyright: ignore[reportPrivateUsage]
137+ first_deadline : float | None = None
138+ second_deadline : float | None = None
139+
140+ async def _first_agent_handler (_request : AgentRequest ) -> AgentResponse [None ]:
141+ nonlocal first_deadline
142+ first_deadline = mw ._deadline_per_thread_id ["foo" ] # pyright: ignore[reportPrivateUsage]
143+ return AgentResponse (messages = [], structured_output = None )
144+
145+ async def _second_agent_handler (_request : AgentRequest ) -> AgentResponse [None ]:
146+ nonlocal second_deadline
147+ second_deadline = mw ._deadline_per_thread_id ["foo" ] # pyright: ignore[reportPrivateUsage]
148+ return AgentResponse (messages = [], structured_output = None )
149+
150+ await mw .agent_middleware (request , _first_agent_handler )
129151
130- await mw .agent_middleware (request , _noop_agent_handler )
131- second_deadline = mw ._deadline # pyright: ignore[reportPrivateUsage]
152+ await mw .agent_middleware (request , _second_agent_handler )
132153
133154 assert first_deadline is not None
134155 assert second_deadline is not None
135156 assert second_deadline >= first_deadline
136157
137158 async def test_deadline_is_none_before_first_invoke (self ) -> None :
138159 mw = TimeoutLimitMiddleware (60.0 )
139- assert mw ._deadline is None # pyright: ignore[reportPrivateUsage]
160+ assert mw ._deadline_per_thread_id . get ( "foo" ) is None # pyright: ignore[reportPrivateUsage]
140161
141162 async def test_timeout_fires_when_deadline_exceeded (self ) -> None :
142163 mw = TimeoutLimitMiddleware (60.0 )
143- mw ._deadline = monotonic () - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past
164+ mw ._deadline_per_thread_id [ "foo" ] = monotonic () - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past
144165
145166 state = AgentState (messages = [], total_steps = 0 , token_count = 0 , thread_id = "foo" )
146167 request = ModelRequest (system_message = "" , state = state )
@@ -153,17 +174,29 @@ class TestTokenLimitMiddleware(unittest.IsolatedAsyncioTestCase):
153174 async def test_raises_when_token_count_in_request_exceeds_limit (self ) -> None :
154175 mw = TokenLimitMiddleware (200 )
155176
156- await mw .model_middleware (_make_model_request (token_count = 100 ), _noop_model_handler )
157- await mw .model_middleware (_make_model_request (token_count = 199 ), _noop_model_handler )
177+ await mw .model_middleware (
178+ _make_model_request (token_count = 100 ), _noop_model_handler
179+ )
180+ await mw .model_middleware (
181+ _make_model_request (token_count = 199 ), _noop_model_handler
182+ )
158183 with self .assertRaises (TokenLimitExceededException ):
159- await mw .model_middleware (_make_model_request (token_count = 200 ), _noop_model_handler )
184+ await mw .model_middleware (
185+ _make_model_request (token_count = 200 ), _noop_model_handler
186+ )
160187
161188
162189class TestStepLimitMiddleware (unittest .IsolatedAsyncioTestCase ):
163190 async def test_raises_when_steps_in_request_reach_limit (self ) -> None :
164191 mw = StepLimitMiddleware (3 )
165192
166- await mw .model_middleware (_make_model_request (total_steps = 1 ), _noop_model_handler )
167- await mw .model_middleware (_make_model_request (total_steps = 2 ), _noop_model_handler )
193+ await mw .model_middleware (
194+ _make_model_request (total_steps = 1 ), _noop_model_handler
195+ )
196+ await mw .model_middleware (
197+ _make_model_request (total_steps = 2 ), _noop_model_handler
198+ )
168199 with self .assertRaises (StepsLimitExceededException ):
169- await mw .model_middleware (_make_model_request (total_steps = 3 ), _noop_model_handler )
200+ await mw .model_middleware (
201+ _make_model_request (total_steps = 3 ), _noop_model_handler
202+ )
0 commit comments