@@ -111,10 +111,6 @@ def test_all_user_limits_suppress_all_defaults(self) -> None:
111111 assert len ([m for m in mw if isinstance (m , TimeoutLimitMiddleware )]) == 1
112112
113113
114- async def _noop_agent_handler (_request : AgentRequest ) -> AgentResponse [None ]:
115- return AgentResponse (messages = [], structured_output = None )
116-
117-
118114async def _noop_model_handler (_request : ModelRequest ) -> ModelResponse :
119115 return ModelResponse (message = AIMessage (content = "" , calls = []))
120116
@@ -124,23 +120,33 @@ async def test_deadline_reset_on_each_invoke(self) -> None:
124120 mw = TimeoutLimitMiddleware (60.0 )
125121 request = _make_agent_request ()
126122
127- await mw .agent_middleware (request , _noop_agent_handler )
128- first_deadline = mw ._deadline # pyright: ignore[reportPrivateUsage]
123+ first_deadline : float | None = None
124+ second_deadline : float | None = None
125+
126+ async def _first_agent_handler (_request : AgentRequest ) -> AgentResponse [None ]:
127+ nonlocal first_deadline
128+ first_deadline = mw ._deadline_per_thread_id ["foo" ] # pyright: ignore[reportPrivateUsage]
129+ return AgentResponse (messages = [], structured_output = None )
130+
131+ async def _second_agent_handler (_request : AgentRequest ) -> AgentResponse [None ]:
132+ nonlocal second_deadline
133+ second_deadline = mw ._deadline_per_thread_id ["foo" ] # pyright: ignore[reportPrivateUsage]
134+ return AgentResponse (messages = [], structured_output = None )
129135
130- await mw .agent_middleware (request , _noop_agent_handler )
131- second_deadline = mw ._deadline # pyright: ignore[reportPrivateUsage]
136+ await mw .agent_middleware (request , _first_agent_handler )
137+ await mw .agent_middleware ( request , _second_agent_handler )
132138
133139 assert first_deadline is not None
134- assert second_deadline is not None
140+ assert second_deadline is not None # pyright: ignore[reportUnreachable]
135141 assert second_deadline >= first_deadline
136142
137143 async def test_deadline_is_none_before_first_invoke (self ) -> None :
138144 mw = TimeoutLimitMiddleware (60.0 )
139- assert mw ._deadline is None # pyright: ignore[reportPrivateUsage]
145+ assert mw ._deadline_per_thread_id . get ( "foo" ) is None # pyright: ignore[reportPrivateUsage]
140146
141147 async def test_timeout_fires_when_deadline_exceeded (self ) -> None :
142148 mw = TimeoutLimitMiddleware (60.0 )
143- mw ._deadline = monotonic () - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past
149+ mw ._deadline_per_thread_id [ "foo" ] = monotonic () - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past
144150
145151 state = AgentState (messages = [], total_steps = 0 , token_count = 0 , thread_id = "foo" )
146152 request = ModelRequest (system_message = "" , state = state )
0 commit comments