|
41 | 41 |
|
42 | 42 | from .. import testing_utils |
43 | 43 |
|
44 | | - |
45 | 44 | # --------------------------------------------------------------------------- |
46 | 45 | # Concrete agent implementations |
47 | 46 | # --------------------------------------------------------------------------- |
48 | 47 |
|
| 48 | + |
49 | 49 | class _SuccessAgent(BaseAgent): |
50 | | - def __init__(self): |
51 | | - super().__init__(name="success_agent") |
52 | 50 |
|
53 | | - @override |
54 | | - async def _run_async_impl( |
55 | | - self, ctx: InvocationContext |
56 | | - ) -> AsyncGenerator[Event, None]: |
57 | | - yield Event( |
58 | | - invocation_id=ctx.invocation_id, |
59 | | - author=self.name, |
60 | | - content=types.Content(parts=[types.Part.from_text(text="done")]), |
61 | | - ) |
| 51 | + def __init__(self): |
| 52 | + super().__init__(name="success_agent") |
| 53 | + |
| 54 | + @override |
| 55 | + async def _run_async_impl( |
| 56 | + self, ctx: InvocationContext |
| 57 | + ) -> AsyncGenerator[Event, None]: |
| 58 | + yield Event( |
| 59 | + invocation_id=ctx.invocation_id, |
| 60 | + author=self.name, |
| 61 | + content=types.Content(parts=[types.Part.from_text(text="done")]), |
| 62 | + ) |
62 | 63 |
|
63 | 64 |
|
64 | 65 | class _FailingAgent(BaseAgent): |
65 | | - BOOM: ClassVar[RuntimeError] = RuntimeError("agent impl exploded") |
| 66 | + BOOM: ClassVar[RuntimeError] = RuntimeError("agent impl exploded") |
66 | 67 |
|
67 | | - def __init__(self): |
68 | | - super().__init__(name="failing_agent") |
| 68 | + def __init__(self): |
| 69 | + super().__init__(name="failing_agent") |
69 | 70 |
|
70 | | - @override |
71 | | - async def _run_async_impl( |
72 | | - self, ctx: InvocationContext |
73 | | - ) -> AsyncGenerator[Event, None]: |
74 | | - raise _FailingAgent.BOOM |
75 | | - yield # pragma: no cover |
| 71 | + @override |
| 72 | + async def _run_async_impl( |
| 73 | + self, ctx: InvocationContext |
| 74 | + ) -> AsyncGenerator[Event, None]: |
| 75 | + raise _FailingAgent.BOOM |
| 76 | + yield # pragma: no cover |
76 | 77 |
|
77 | 78 |
|
78 | 79 | class _FailingLiveAgent(BaseAgent): |
79 | | - BOOM: ClassVar[RuntimeError] = RuntimeError("live agent impl exploded") |
| 80 | + BOOM: ClassVar[RuntimeError] = RuntimeError("live agent impl exploded") |
80 | 81 |
|
81 | | - def __init__(self): |
82 | | - super().__init__(name="failing_live_agent") |
| 82 | + def __init__(self): |
| 83 | + super().__init__(name="failing_live_agent") |
83 | 84 |
|
84 | | - @override |
85 | | - async def _run_async_impl( |
86 | | - self, ctx: InvocationContext |
87 | | - ) -> AsyncGenerator[Event, None]: |
88 | | - yield # pragma: no cover |
| 85 | + @override |
| 86 | + async def _run_async_impl( |
| 87 | + self, ctx: InvocationContext |
| 88 | + ) -> AsyncGenerator[Event, None]: |
| 89 | + yield # pragma: no cover |
89 | 90 |
|
90 | | - @override |
91 | | - async def _run_live_impl( |
92 | | - self, ctx: InvocationContext |
93 | | - ) -> AsyncGenerator[Event, None]: |
94 | | - raise _FailingLiveAgent.BOOM |
95 | | - yield # pragma: no cover |
| 91 | + @override |
| 92 | + async def _run_live_impl( |
| 93 | + self, ctx: InvocationContext |
| 94 | + ) -> AsyncGenerator[Event, None]: |
| 95 | + raise _FailingLiveAgent.BOOM |
| 96 | + yield # pragma: no cover |
96 | 97 |
|
97 | 98 |
|
98 | 99 | # --------------------------------------------------------------------------- |
99 | 100 | # Tracking plugin |
100 | 101 | # --------------------------------------------------------------------------- |
101 | 102 |
|
| 103 | + |
102 | 104 | class TrackingPlugin(BasePlugin): |
103 | | - __test__ = False |
| 105 | + __test__ = False |
104 | 106 |
|
105 | | - def __init__(self, name: str = "tracker"): |
106 | | - super().__init__(name) |
107 | | - self.after_agent_called = False |
108 | | - self.agent_error_calls: list[dict] = [] |
| 107 | + def __init__(self, name: str = "tracker"): |
| 108 | + super().__init__(name) |
| 109 | + self.after_agent_called = False |
| 110 | + self.agent_error_calls: list[dict] = [] |
109 | 111 |
|
110 | | - async def after_agent_callback(self, *, agent, callback_context, **kwargs): |
111 | | - self.after_agent_called = True |
| 112 | + async def after_agent_callback(self, *, agent, callback_context, **kwargs): |
| 113 | + self.after_agent_called = True |
112 | 114 |
|
113 | | - async def on_agent_error_callback( |
114 | | - self, *, agent, callback_context, error, **kwargs |
115 | | - ) -> None: |
116 | | - self.agent_error_calls.append( |
117 | | - {"agent": agent, "callback_context": callback_context, "error": error} |
118 | | - ) |
| 115 | + async def on_agent_error_callback( |
| 116 | + self, *, agent, callback_context, error, **kwargs |
| 117 | + ) -> None: |
| 118 | + self.agent_error_calls.append( |
| 119 | + {"agent": agent, "callback_context": callback_context, "error": error} |
| 120 | + ) |
119 | 121 |
|
120 | 122 |
|
121 | 123 | # --------------------------------------------------------------------------- |
122 | 124 | # Helper to drive run_async |
123 | 125 | # --------------------------------------------------------------------------- |
124 | 126 |
|
125 | | -async def _collect_events(agent: BaseAgent, plugins: list[BasePlugin]) -> list[Event]: |
126 | | - inv_ctx = await testing_utils.create_invocation_context( |
127 | | - agent=agent, plugins=plugins |
128 | | - ) |
129 | | - events = [] |
130 | | - async for event in agent.run_async(inv_ctx): |
131 | | - events.append(event) |
132 | | - return events |
| 127 | + |
| 128 | +async def _collect_events( |
| 129 | + agent: BaseAgent, plugins: list[BasePlugin] |
| 130 | +) -> list[Event]: |
| 131 | + inv_ctx = await testing_utils.create_invocation_context( |
| 132 | + agent=agent, plugins=plugins |
| 133 | + ) |
| 134 | + events = [] |
| 135 | + async for event in agent.run_async(inv_ctx): |
| 136 | + events.append(event) |
| 137 | + return events |
133 | 138 |
|
134 | 139 |
|
135 | 140 | async def _collect_live_events( |
136 | 141 | agent: BaseAgent, plugins: list[BasePlugin] |
137 | 142 | ) -> list[Event]: |
138 | | - inv_ctx = await testing_utils.create_invocation_context( |
139 | | - agent=agent, plugins=plugins |
140 | | - ) |
141 | | - events = [] |
142 | | - async for event in agent.run_live(inv_ctx): |
143 | | - events.append(event) |
144 | | - return events |
| 143 | + inv_ctx = await testing_utils.create_invocation_context( |
| 144 | + agent=agent, plugins=plugins |
| 145 | + ) |
| 146 | + events = [] |
| 147 | + async for event in agent.run_live(inv_ctx): |
| 148 | + events.append(event) |
| 149 | + return events |
145 | 150 |
|
146 | 151 |
|
147 | 152 | # --------------------------------------------------------------------------- |
148 | 153 | # Tests — run_async path |
149 | 154 | # --------------------------------------------------------------------------- |
150 | 155 |
|
| 156 | + |
151 | 157 | class TestAgentOnAgentErrorCallbackAsync: |
152 | 158 |
|
153 | | - @pytest.mark.asyncio |
154 | | - async def test_on_agent_error_callback_called_when_impl_raises(self): |
155 | | - tracker = TrackingPlugin() |
156 | | - with pytest.raises(RuntimeError, match="agent impl exploded"): |
157 | | - await _collect_events(_FailingAgent(), [tracker]) |
| 159 | + @pytest.mark.asyncio |
| 160 | + async def test_on_agent_error_callback_called_when_impl_raises(self): |
| 161 | + tracker = TrackingPlugin() |
| 162 | + with pytest.raises(RuntimeError, match="agent impl exploded"): |
| 163 | + await _collect_events(_FailingAgent(), [tracker]) |
158 | 164 |
|
159 | | - assert len(tracker.agent_error_calls) == 1 |
160 | | - assert tracker.agent_error_calls[0]["error"] is _FailingAgent.BOOM |
| 165 | + assert len(tracker.agent_error_calls) == 1 |
| 166 | + assert tracker.agent_error_calls[0]["error"] is _FailingAgent.BOOM |
161 | 167 |
|
162 | | - @pytest.mark.asyncio |
163 | | - async def test_on_agent_error_callback_receives_correct_agent(self): |
164 | | - tracker = TrackingPlugin() |
165 | | - agent = _FailingAgent() |
| 168 | + @pytest.mark.asyncio |
| 169 | + async def test_on_agent_error_callback_receives_correct_agent(self): |
| 170 | + tracker = TrackingPlugin() |
| 171 | + agent = _FailingAgent() |
166 | 172 |
|
167 | | - with pytest.raises(RuntimeError): |
168 | | - await _collect_events(agent, [tracker]) |
| 173 | + with pytest.raises(RuntimeError): |
| 174 | + await _collect_events(agent, [tracker]) |
169 | 175 |
|
170 | | - assert tracker.agent_error_calls[0]["agent"] is agent |
| 176 | + assert tracker.agent_error_calls[0]["agent"] is agent |
171 | 177 |
|
172 | | - @pytest.mark.asyncio |
173 | | - async def test_on_agent_error_callback_receives_callback_context(self): |
174 | | - tracker = TrackingPlugin() |
| 178 | + @pytest.mark.asyncio |
| 179 | + async def test_on_agent_error_callback_receives_callback_context(self): |
| 180 | + tracker = TrackingPlugin() |
175 | 181 |
|
176 | | - with pytest.raises(RuntimeError): |
177 | | - await _collect_events(_FailingAgent(), [tracker]) |
| 182 | + with pytest.raises(RuntimeError): |
| 183 | + await _collect_events(_FailingAgent(), [tracker]) |
178 | 184 |
|
179 | | - cb_ctx = tracker.agent_error_calls[0]["callback_context"] |
180 | | - assert isinstance(cb_ctx, CallbackContext) |
| 185 | + cb_ctx = tracker.agent_error_calls[0]["callback_context"] |
| 186 | + assert isinstance(cb_ctx, CallbackContext) |
181 | 187 |
|
182 | | - @pytest.mark.asyncio |
183 | | - async def test_original_exception_reraised_after_notification(self): |
184 | | - tracker = TrackingPlugin() |
| 188 | + @pytest.mark.asyncio |
| 189 | + async def test_original_exception_reraised_after_notification(self): |
| 190 | + tracker = TrackingPlugin() |
185 | 191 |
|
186 | | - with pytest.raises(RuntimeError) as exc_info: |
187 | | - await _collect_events(_FailingAgent(), [tracker]) |
| 192 | + with pytest.raises(RuntimeError) as exc_info: |
| 193 | + await _collect_events(_FailingAgent(), [tracker]) |
188 | 194 |
|
189 | | - assert exc_info.value is _FailingAgent.BOOM |
| 195 | + assert exc_info.value is _FailingAgent.BOOM |
190 | 196 |
|
191 | | - @pytest.mark.asyncio |
192 | | - async def test_after_agent_callback_not_called_on_error(self): |
193 | | - tracker = TrackingPlugin() |
| 197 | + @pytest.mark.asyncio |
| 198 | + async def test_after_agent_callback_not_called_on_error(self): |
| 199 | + tracker = TrackingPlugin() |
194 | 200 |
|
195 | | - with pytest.raises(RuntimeError): |
196 | | - await _collect_events(_FailingAgent(), [tracker]) |
| 201 | + with pytest.raises(RuntimeError): |
| 202 | + await _collect_events(_FailingAgent(), [tracker]) |
197 | 203 |
|
198 | | - assert not tracker.after_agent_called |
| 204 | + assert not tracker.after_agent_called |
199 | 205 |
|
200 | | - @pytest.mark.asyncio |
201 | | - async def test_on_agent_error_callback_not_called_on_success(self): |
202 | | - tracker = TrackingPlugin() |
203 | | - events = await _collect_events(_SuccessAgent(), [tracker]) |
| 206 | + @pytest.mark.asyncio |
| 207 | + async def test_on_agent_error_callback_not_called_on_success(self): |
| 208 | + tracker = TrackingPlugin() |
| 209 | + events = await _collect_events(_SuccessAgent(), [tracker]) |
204 | 210 |
|
205 | | - assert len(events) >= 1 |
206 | | - assert len(tracker.agent_error_calls) == 0 |
| 211 | + assert len(events) >= 1 |
| 212 | + assert len(tracker.agent_error_calls) == 0 |
207 | 213 |
|
208 | | - @pytest.mark.asyncio |
209 | | - async def test_after_agent_callback_still_called_on_success(self): |
210 | | - tracker = TrackingPlugin() |
211 | | - await _collect_events(_SuccessAgent(), [tracker]) |
| 214 | + @pytest.mark.asyncio |
| 215 | + async def test_after_agent_callback_still_called_on_success(self): |
| 216 | + tracker = TrackingPlugin() |
| 217 | + await _collect_events(_SuccessAgent(), [tracker]) |
212 | 218 |
|
213 | | - assert tracker.after_agent_called |
| 219 | + assert tracker.after_agent_called |
214 | 220 |
|
215 | | - @pytest.mark.asyncio |
216 | | - async def test_multiple_plugins_all_notified_on_agent_error(self): |
217 | | - tracker_a = TrackingPlugin("a") |
218 | | - tracker_b = TrackingPlugin("b") |
| 221 | + @pytest.mark.asyncio |
| 222 | + async def test_multiple_plugins_all_notified_on_agent_error(self): |
| 223 | + tracker_a = TrackingPlugin("a") |
| 224 | + tracker_b = TrackingPlugin("b") |
219 | 225 |
|
220 | | - with pytest.raises(RuntimeError): |
221 | | - await _collect_events(_FailingAgent(), [tracker_a, tracker_b]) |
| 226 | + with pytest.raises(RuntimeError): |
| 227 | + await _collect_events(_FailingAgent(), [tracker_a, tracker_b]) |
222 | 228 |
|
223 | | - assert len(tracker_a.agent_error_calls) == 1 |
224 | | - assert len(tracker_b.agent_error_calls) == 1 |
| 229 | + assert len(tracker_a.agent_error_calls) == 1 |
| 230 | + assert len(tracker_b.agent_error_calls) == 1 |
225 | 231 |
|
226 | 232 |
|
227 | 233 | # --------------------------------------------------------------------------- |
228 | 234 | # Tests — run_live path |
229 | 235 | # --------------------------------------------------------------------------- |
230 | 236 |
|
| 237 | + |
231 | 238 | class TestAgentOnAgentErrorCallbackLive: |
232 | 239 |
|
233 | | - @pytest.mark.asyncio |
234 | | - async def test_on_agent_error_callback_called_when_live_impl_raises(self): |
235 | | - tracker = TrackingPlugin() |
| 240 | + @pytest.mark.asyncio |
| 241 | + async def test_on_agent_error_callback_called_when_live_impl_raises(self): |
| 242 | + tracker = TrackingPlugin() |
236 | 243 |
|
237 | | - with pytest.raises(RuntimeError, match="live agent impl exploded"): |
238 | | - await _collect_live_events(_FailingLiveAgent(), [tracker]) |
| 244 | + with pytest.raises(RuntimeError, match="live agent impl exploded"): |
| 245 | + await _collect_live_events(_FailingLiveAgent(), [tracker]) |
239 | 246 |
|
240 | | - assert len(tracker.agent_error_calls) == 1 |
241 | | - assert tracker.agent_error_calls[0]["error"] is _FailingLiveAgent.BOOM |
| 247 | + assert len(tracker.agent_error_calls) == 1 |
| 248 | + assert tracker.agent_error_calls[0]["error"] is _FailingLiveAgent.BOOM |
242 | 249 |
|
243 | | - @pytest.mark.asyncio |
244 | | - async def test_after_agent_callback_not_called_on_live_error(self): |
245 | | - tracker = TrackingPlugin() |
| 250 | + @pytest.mark.asyncio |
| 251 | + async def test_after_agent_callback_not_called_on_live_error(self): |
| 252 | + tracker = TrackingPlugin() |
246 | 253 |
|
247 | | - with pytest.raises(RuntimeError): |
248 | | - await _collect_live_events(_FailingLiveAgent(), [tracker]) |
| 254 | + with pytest.raises(RuntimeError): |
| 255 | + await _collect_live_events(_FailingLiveAgent(), [tracker]) |
249 | 256 |
|
250 | | - assert not tracker.after_agent_called |
| 257 | + assert not tracker.after_agent_called |
251 | 258 |
|
252 | | - @pytest.mark.asyncio |
253 | | - async def test_original_live_exception_reraised_unchanged(self): |
254 | | - tracker = TrackingPlugin() |
| 259 | + @pytest.mark.asyncio |
| 260 | + async def test_original_live_exception_reraised_unchanged(self): |
| 261 | + tracker = TrackingPlugin() |
255 | 262 |
|
256 | | - with pytest.raises(RuntimeError) as exc_info: |
257 | | - await _collect_live_events(_FailingLiveAgent(), [tracker]) |
| 263 | + with pytest.raises(RuntimeError) as exc_info: |
| 264 | + await _collect_live_events(_FailingLiveAgent(), [tracker]) |
258 | 265 |
|
259 | | - assert exc_info.value is _FailingLiveAgent.BOOM |
| 266 | + assert exc_info.value is _FailingLiveAgent.BOOM |
0 commit comments