Skip to content

Commit cd4ac4e

Browse files
committed
style: apply pyink + isort formatting to new test files
1 parent 7cdef7c commit cd4ac4e

File tree

3 files changed

+575
-554
lines changed

3 files changed

+575
-554
lines changed

tests/unittests/agents/test_agent_error_callbacks.py

Lines changed: 141 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -41,219 +41,226 @@
4141

4242
from .. import testing_utils
4343

44-
4544
# ---------------------------------------------------------------------------
4645
# Concrete agent implementations
4746
# ---------------------------------------------------------------------------
4847

48+
4949
class _SuccessAgent(BaseAgent):
50-
def __init__(self):
51-
super().__init__(name="success_agent")
5250

53-
@override
54-
async def _run_async_impl(
55-
self, ctx: InvocationContext
56-
) -> AsyncGenerator[Event, None]:
57-
yield Event(
58-
invocation_id=ctx.invocation_id,
59-
author=self.name,
60-
content=types.Content(parts=[types.Part.from_text(text="done")]),
61-
)
51+
def __init__(self):
52+
super().__init__(name="success_agent")
53+
54+
@override
55+
async def _run_async_impl(
56+
self, ctx: InvocationContext
57+
) -> AsyncGenerator[Event, None]:
58+
yield Event(
59+
invocation_id=ctx.invocation_id,
60+
author=self.name,
61+
content=types.Content(parts=[types.Part.from_text(text="done")]),
62+
)
6263

6364

6465
class _FailingAgent(BaseAgent):
65-
BOOM: ClassVar[RuntimeError] = RuntimeError("agent impl exploded")
66+
BOOM: ClassVar[RuntimeError] = RuntimeError("agent impl exploded")
6667

67-
def __init__(self):
68-
super().__init__(name="failing_agent")
68+
def __init__(self):
69+
super().__init__(name="failing_agent")
6970

70-
@override
71-
async def _run_async_impl(
72-
self, ctx: InvocationContext
73-
) -> AsyncGenerator[Event, None]:
74-
raise _FailingAgent.BOOM
75-
yield # pragma: no cover
71+
@override
72+
async def _run_async_impl(
73+
self, ctx: InvocationContext
74+
) -> AsyncGenerator[Event, None]:
75+
raise _FailingAgent.BOOM
76+
yield # pragma: no cover
7677

7778

7879
class _FailingLiveAgent(BaseAgent):
79-
BOOM: ClassVar[RuntimeError] = RuntimeError("live agent impl exploded")
80+
BOOM: ClassVar[RuntimeError] = RuntimeError("live agent impl exploded")
8081

81-
def __init__(self):
82-
super().__init__(name="failing_live_agent")
82+
def __init__(self):
83+
super().__init__(name="failing_live_agent")
8384

84-
@override
85-
async def _run_async_impl(
86-
self, ctx: InvocationContext
87-
) -> AsyncGenerator[Event, None]:
88-
yield # pragma: no cover
85+
@override
86+
async def _run_async_impl(
87+
self, ctx: InvocationContext
88+
) -> AsyncGenerator[Event, None]:
89+
yield # pragma: no cover
8990

90-
@override
91-
async def _run_live_impl(
92-
self, ctx: InvocationContext
93-
) -> AsyncGenerator[Event, None]:
94-
raise _FailingLiveAgent.BOOM
95-
yield # pragma: no cover
91+
@override
92+
async def _run_live_impl(
93+
self, ctx: InvocationContext
94+
) -> AsyncGenerator[Event, None]:
95+
raise _FailingLiveAgent.BOOM
96+
yield # pragma: no cover
9697

9798

9899
# ---------------------------------------------------------------------------
99100
# Tracking plugin
100101
# ---------------------------------------------------------------------------
101102

103+
102104
class TrackingPlugin(BasePlugin):
103-
__test__ = False
105+
__test__ = False
104106

105-
def __init__(self, name: str = "tracker"):
106-
super().__init__(name)
107-
self.after_agent_called = False
108-
self.agent_error_calls: list[dict] = []
107+
def __init__(self, name: str = "tracker"):
108+
super().__init__(name)
109+
self.after_agent_called = False
110+
self.agent_error_calls: list[dict] = []
109111

110-
async def after_agent_callback(self, *, agent, callback_context, **kwargs):
111-
self.after_agent_called = True
112+
async def after_agent_callback(self, *, agent, callback_context, **kwargs):
113+
self.after_agent_called = True
112114

113-
async def on_agent_error_callback(
114-
self, *, agent, callback_context, error, **kwargs
115-
) -> None:
116-
self.agent_error_calls.append(
117-
{"agent": agent, "callback_context": callback_context, "error": error}
118-
)
115+
async def on_agent_error_callback(
116+
self, *, agent, callback_context, error, **kwargs
117+
) -> None:
118+
self.agent_error_calls.append(
119+
{"agent": agent, "callback_context": callback_context, "error": error}
120+
)
119121

120122

121123
# ---------------------------------------------------------------------------
122124
# Helper to drive run_async
123125
# ---------------------------------------------------------------------------
124126

125-
async def _collect_events(agent: BaseAgent, plugins: list[BasePlugin]) -> list[Event]:
126-
inv_ctx = await testing_utils.create_invocation_context(
127-
agent=agent, plugins=plugins
128-
)
129-
events = []
130-
async for event in agent.run_async(inv_ctx):
131-
events.append(event)
132-
return events
127+
128+
async def _collect_events(
129+
agent: BaseAgent, plugins: list[BasePlugin]
130+
) -> list[Event]:
131+
inv_ctx = await testing_utils.create_invocation_context(
132+
agent=agent, plugins=plugins
133+
)
134+
events = []
135+
async for event in agent.run_async(inv_ctx):
136+
events.append(event)
137+
return events
133138

134139

135140
async def _collect_live_events(
136141
agent: BaseAgent, plugins: list[BasePlugin]
137142
) -> list[Event]:
138-
inv_ctx = await testing_utils.create_invocation_context(
139-
agent=agent, plugins=plugins
140-
)
141-
events = []
142-
async for event in agent.run_live(inv_ctx):
143-
events.append(event)
144-
return events
143+
inv_ctx = await testing_utils.create_invocation_context(
144+
agent=agent, plugins=plugins
145+
)
146+
events = []
147+
async for event in agent.run_live(inv_ctx):
148+
events.append(event)
149+
return events
145150

146151

147152
# ---------------------------------------------------------------------------
148153
# Tests — run_async path
149154
# ---------------------------------------------------------------------------
150155

156+
151157
class TestAgentOnAgentErrorCallbackAsync:
152158

153-
@pytest.mark.asyncio
154-
async def test_on_agent_error_callback_called_when_impl_raises(self):
155-
tracker = TrackingPlugin()
156-
with pytest.raises(RuntimeError, match="agent impl exploded"):
157-
await _collect_events(_FailingAgent(), [tracker])
159+
@pytest.mark.asyncio
160+
async def test_on_agent_error_callback_called_when_impl_raises(self):
161+
tracker = TrackingPlugin()
162+
with pytest.raises(RuntimeError, match="agent impl exploded"):
163+
await _collect_events(_FailingAgent(), [tracker])
158164

159-
assert len(tracker.agent_error_calls) == 1
160-
assert tracker.agent_error_calls[0]["error"] is _FailingAgent.BOOM
165+
assert len(tracker.agent_error_calls) == 1
166+
assert tracker.agent_error_calls[0]["error"] is _FailingAgent.BOOM
161167

162-
@pytest.mark.asyncio
163-
async def test_on_agent_error_callback_receives_correct_agent(self):
164-
tracker = TrackingPlugin()
165-
agent = _FailingAgent()
168+
@pytest.mark.asyncio
169+
async def test_on_agent_error_callback_receives_correct_agent(self):
170+
tracker = TrackingPlugin()
171+
agent = _FailingAgent()
166172

167-
with pytest.raises(RuntimeError):
168-
await _collect_events(agent, [tracker])
173+
with pytest.raises(RuntimeError):
174+
await _collect_events(agent, [tracker])
169175

170-
assert tracker.agent_error_calls[0]["agent"] is agent
176+
assert tracker.agent_error_calls[0]["agent"] is agent
171177

172-
@pytest.mark.asyncio
173-
async def test_on_agent_error_callback_receives_callback_context(self):
174-
tracker = TrackingPlugin()
178+
@pytest.mark.asyncio
179+
async def test_on_agent_error_callback_receives_callback_context(self):
180+
tracker = TrackingPlugin()
175181

176-
with pytest.raises(RuntimeError):
177-
await _collect_events(_FailingAgent(), [tracker])
182+
with pytest.raises(RuntimeError):
183+
await _collect_events(_FailingAgent(), [tracker])
178184

179-
cb_ctx = tracker.agent_error_calls[0]["callback_context"]
180-
assert isinstance(cb_ctx, CallbackContext)
185+
cb_ctx = tracker.agent_error_calls[0]["callback_context"]
186+
assert isinstance(cb_ctx, CallbackContext)
181187

182-
@pytest.mark.asyncio
183-
async def test_original_exception_reraised_after_notification(self):
184-
tracker = TrackingPlugin()
188+
@pytest.mark.asyncio
189+
async def test_original_exception_reraised_after_notification(self):
190+
tracker = TrackingPlugin()
185191

186-
with pytest.raises(RuntimeError) as exc_info:
187-
await _collect_events(_FailingAgent(), [tracker])
192+
with pytest.raises(RuntimeError) as exc_info:
193+
await _collect_events(_FailingAgent(), [tracker])
188194

189-
assert exc_info.value is _FailingAgent.BOOM
195+
assert exc_info.value is _FailingAgent.BOOM
190196

191-
@pytest.mark.asyncio
192-
async def test_after_agent_callback_not_called_on_error(self):
193-
tracker = TrackingPlugin()
197+
@pytest.mark.asyncio
198+
async def test_after_agent_callback_not_called_on_error(self):
199+
tracker = TrackingPlugin()
194200

195-
with pytest.raises(RuntimeError):
196-
await _collect_events(_FailingAgent(), [tracker])
201+
with pytest.raises(RuntimeError):
202+
await _collect_events(_FailingAgent(), [tracker])
197203

198-
assert not tracker.after_agent_called
204+
assert not tracker.after_agent_called
199205

200-
@pytest.mark.asyncio
201-
async def test_on_agent_error_callback_not_called_on_success(self):
202-
tracker = TrackingPlugin()
203-
events = await _collect_events(_SuccessAgent(), [tracker])
206+
@pytest.mark.asyncio
207+
async def test_on_agent_error_callback_not_called_on_success(self):
208+
tracker = TrackingPlugin()
209+
events = await _collect_events(_SuccessAgent(), [tracker])
204210

205-
assert len(events) >= 1
206-
assert len(tracker.agent_error_calls) == 0
211+
assert len(events) >= 1
212+
assert len(tracker.agent_error_calls) == 0
207213

208-
@pytest.mark.asyncio
209-
async def test_after_agent_callback_still_called_on_success(self):
210-
tracker = TrackingPlugin()
211-
await _collect_events(_SuccessAgent(), [tracker])
214+
@pytest.mark.asyncio
215+
async def test_after_agent_callback_still_called_on_success(self):
216+
tracker = TrackingPlugin()
217+
await _collect_events(_SuccessAgent(), [tracker])
212218

213-
assert tracker.after_agent_called
219+
assert tracker.after_agent_called
214220

215-
@pytest.mark.asyncio
216-
async def test_multiple_plugins_all_notified_on_agent_error(self):
217-
tracker_a = TrackingPlugin("a")
218-
tracker_b = TrackingPlugin("b")
221+
@pytest.mark.asyncio
222+
async def test_multiple_plugins_all_notified_on_agent_error(self):
223+
tracker_a = TrackingPlugin("a")
224+
tracker_b = TrackingPlugin("b")
219225

220-
with pytest.raises(RuntimeError):
221-
await _collect_events(_FailingAgent(), [tracker_a, tracker_b])
226+
with pytest.raises(RuntimeError):
227+
await _collect_events(_FailingAgent(), [tracker_a, tracker_b])
222228

223-
assert len(tracker_a.agent_error_calls) == 1
224-
assert len(tracker_b.agent_error_calls) == 1
229+
assert len(tracker_a.agent_error_calls) == 1
230+
assert len(tracker_b.agent_error_calls) == 1
225231

226232

227233
# ---------------------------------------------------------------------------
228234
# Tests — run_live path
229235
# ---------------------------------------------------------------------------
230236

237+
231238
class TestAgentOnAgentErrorCallbackLive:
232239

233-
@pytest.mark.asyncio
234-
async def test_on_agent_error_callback_called_when_live_impl_raises(self):
235-
tracker = TrackingPlugin()
240+
@pytest.mark.asyncio
241+
async def test_on_agent_error_callback_called_when_live_impl_raises(self):
242+
tracker = TrackingPlugin()
236243

237-
with pytest.raises(RuntimeError, match="live agent impl exploded"):
238-
await _collect_live_events(_FailingLiveAgent(), [tracker])
244+
with pytest.raises(RuntimeError, match="live agent impl exploded"):
245+
await _collect_live_events(_FailingLiveAgent(), [tracker])
239246

240-
assert len(tracker.agent_error_calls) == 1
241-
assert tracker.agent_error_calls[0]["error"] is _FailingLiveAgent.BOOM
247+
assert len(tracker.agent_error_calls) == 1
248+
assert tracker.agent_error_calls[0]["error"] is _FailingLiveAgent.BOOM
242249

243-
@pytest.mark.asyncio
244-
async def test_after_agent_callback_not_called_on_live_error(self):
245-
tracker = TrackingPlugin()
250+
@pytest.mark.asyncio
251+
async def test_after_agent_callback_not_called_on_live_error(self):
252+
tracker = TrackingPlugin()
246253

247-
with pytest.raises(RuntimeError):
248-
await _collect_live_events(_FailingLiveAgent(), [tracker])
254+
with pytest.raises(RuntimeError):
255+
await _collect_live_events(_FailingLiveAgent(), [tracker])
249256

250-
assert not tracker.after_agent_called
257+
assert not tracker.after_agent_called
251258

252-
@pytest.mark.asyncio
253-
async def test_original_live_exception_reraised_unchanged(self):
254-
tracker = TrackingPlugin()
259+
@pytest.mark.asyncio
260+
async def test_original_live_exception_reraised_unchanged(self):
261+
tracker = TrackingPlugin()
255262

256-
with pytest.raises(RuntimeError) as exc_info:
257-
await _collect_live_events(_FailingLiveAgent(), [tracker])
263+
with pytest.raises(RuntimeError) as exc_info:
264+
await _collect_live_events(_FailingLiveAgent(), [tracker])
258265

259-
assert exc_info.value is _FailingLiveAgent.BOOM
266+
assert exc_info.value is _FailingLiveAgent.BOOM

0 commit comments

Comments
 (0)