Skip to content

Commit 6f36041

Browse files
fix: update test methods to use new workflow run interface and improve context handling
1 parent 418000d commit 6f36041

4 files changed

Lines changed: 87 additions & 66 deletions

File tree

src/processor/src/tests/unit/libs/agent_framework/test_agent_builder.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,17 @@ def test_build_passes_all_state_to_chat_agent(self):
157157
.build()
158158
)
159159
assert agent is mock_chat.return_value
160+
args = mock_chat.call_args.args
160161
kwargs = mock_chat.call_args.kwargs
161-
assert kwargs["chat_client"] is chat_client
162+
assert args[0] is chat_client
162163
assert kwargs["instructions"] == "inst"
163164
assert kwargs["id"] == "id1"
164165
assert kwargs["name"] == "name1"
165166
assert kwargs["description"] == "desc1"
166-
assert kwargs["temperature"] == 0.3
167-
assert kwargs["max_tokens"] == 100
168-
assert kwargs["tool_choice"] == "auto"
167+
opts = kwargs["default_options"]
168+
assert opts["temperature"] == 0.3
169+
assert opts["max_tokens"] == 100
170+
assert opts["tool_choice"] == "auto"
169171
assert kwargs["extra"] == 42
170172

171173

@@ -180,11 +182,13 @@ def test_create_agent_invokes_chat_agent(self):
180182
temperature=0.4,
181183
)
182184
assert agent is mock_chat.return_value
185+
args = mock_chat.call_args.args
183186
kwargs = mock_chat.call_args.kwargs
184-
assert kwargs["chat_client"] is chat_client
187+
assert args[0] is chat_client
185188
assert kwargs["instructions"] == "i"
186189
assert kwargs["name"] == "n"
187-
assert kwargs["temperature"] == 0.4
190+
opts = kwargs["default_options"]
191+
assert opts["temperature"] == 0.4
188192

189193
def test_create_agent_by_agentinfo_uses_helper_and_creates_client(self):
190194
# Build a fake AgentInfo with the minimum surface used by the method
@@ -215,12 +219,14 @@ def test_create_agent_by_agentinfo_uses_helper_and_creates_client(self):
215219
assert agent is mock_chat.return_value
216220
helper.settings.get_service_config.assert_called_once_with("default")
217221
helper.create_client.assert_called_once()
222+
args = mock_chat.call_args.args
218223
ck = mock_chat.call_args.kwargs
219-
assert ck["chat_client"] == "client-instance"
224+
assert args[0] == "client-instance"
220225
assert ck["instructions"] == "instr"
221226
assert ck["name"] == "A"
222227
assert ck["description"] == "D"
223-
assert ck["temperature"] == 0.2
228+
opts = ck["default_options"]
229+
assert opts["temperature"] == 0.2
224230

225231
def test_create_agent_by_agentinfo_falls_back_to_system_prompt(self):
226232
helper = MagicMock()

src/processor/src/tests/unit/libs/agent_framework/test_groupchat_orchestrator_internals.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -743,16 +743,14 @@ def test_build_groupchat_invokes_builder(self):
743743
})
744744
with patch("libs.agent_framework.groupchat_orchestrator.GroupChatBuilder") as MockBuilder:
745745
built = MagicMock()
746-
built.set_manager.return_value = built
747-
built.participants.return_value = built
748746
built.build.return_value = "wf"
749747
MockBuilder.return_value = built
750748
wf = _run(orch._build_groupchat())
751749
assert wf == "wf"
752750
# ResultGenerator excluded from participants
753-
kwargs = built.participants.call_args.args[0]
754-
assert "arch" in kwargs
755-
assert "rg" not in kwargs
751+
kwargs = MockBuilder.call_args.kwargs
752+
assert "arch" in kwargs["participants"]
753+
assert "rg" not in kwargs["participants"]
756754

757755

758756
# -----------------------------------------------------------------------------

src/processor/src/tests/unit/libs/agent_framework/test_shared_memory_context_provider.py

Lines changed: 68 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,36 @@ def _make_provider(store=None):
6161
), store
6262

6363

64+
def _make_context(messages=None, response_text=None):
65+
"""Create a mock SessionContext for before_run/after_run calls."""
66+
ctx = MagicMock()
67+
ctx.get_messages = MagicMock(return_value=messages or [])
68+
ctx.extend_instructions = MagicMock()
69+
if response_text is not None:
70+
ctx.response = MagicMock()
71+
ctx.response.text = response_text
72+
else:
73+
ctx.response = None
74+
return ctx
75+
76+
77+
async def _call_before_run(provider, messages):
78+
"""Helper to call before_run and return the instructions that were injected."""
79+
ctx = _make_context(messages=messages)
80+
await provider.before_run(agent=MagicMock(), session=MagicMock(), context=ctx, state={})
81+
if ctx.extend_instructions.called:
82+
return ctx.extend_instructions.call_args[0][1] # second positional arg = instructions
83+
return None
84+
85+
86+
async def _call_after_run(provider, response_text):
87+
"""Helper to call after_run with a response."""
88+
ctx = _make_context(response_text=response_text)
89+
await provider.after_run(agent=MagicMock(), session=MagicMock(), context=ctx, state={})
90+
91+
6492
# ---------------------------------------------------------------------------
65-
# invoking() — Pre-LLM memory injection
93+
# before_run() — Pre-LLM memory injection
6694
# ---------------------------------------------------------------------------
6795

6896

@@ -75,11 +103,11 @@ async def _run():
75103
]
76104
messages = [_make_chat_message("How should we handle storage configuration?")]
77105

78-
context = await provider.invoking(messages)
106+
instructions = await _call_before_run(provider, messages)
79107

80-
assert context.instructions is not None
81-
assert "GKE Filestore CSI" in context.instructions
82-
assert "Azure Files for AKS" in context.instructions
108+
assert instructions is not None
109+
assert "GKE Filestore CSI" in instructions
110+
assert "Azure Files for AKS" in instructions
83111
store.search.assert_called_once()
84112

85113
asyncio.run(_run())
@@ -88,9 +116,8 @@ async def _run():
88116
def test_invoking_empty_messages_returns_empty():
89117
async def _run():
90118
provider, _ = _make_provider()
91-
context = await provider.invoking([])
92-
assert context.instructions is None
93-
assert getattr(context, "messages", []) == []
119+
instructions = await _call_before_run(provider, [])
120+
assert instructions is None
94121

95122
asyncio.run(_run())
96123

@@ -101,8 +128,8 @@ async def _run():
101128
store.search.return_value = []
102129
messages = [_make_chat_message("What is the overall migration plan for AKS?")]
103130

104-
context = await provider.invoking(messages)
105-
assert context.instructions is None
131+
instructions = await _call_before_run(provider, messages)
132+
assert instructions is None
106133

107134
asyncio.run(_run())
108135

@@ -113,8 +140,8 @@ async def _run():
113140
store.search.side_effect = Exception("search failed")
114141
messages = [_make_chat_message("What is the networking plan for AKS?")]
115142

116-
context = await provider.invoking(messages)
117-
assert context.instructions is None
143+
instructions = await _call_before_run(provider, messages)
144+
assert instructions is None
118145

119146
asyncio.run(_run())
120147

@@ -125,7 +152,7 @@ async def _run():
125152
long_text = "x" * 5000
126153
messages = [_make_chat_message(long_text)]
127154

128-
await provider.invoking(messages)
155+
await _call_before_run(provider, messages)
129156

130157
query = store.search.call_args.kwargs["query"]
131158
assert len(query) <= 2000
@@ -142,7 +169,7 @@ async def _run():
142169
_make_chat_message("Latest question about storage"),
143170
]
144171

145-
await provider.invoking(messages)
172+
await _call_before_run(provider, messages)
146173

147174
query = store.search.call_args.kwargs["query"]
148175
assert "Latest question about storage" in query
@@ -159,10 +186,10 @@ async def _run():
159186
store.search.return_value = large_memories
160187
messages = [_make_chat_message("What storage configuration should we use for persistent volumes?")]
161188

162-
context = await provider.invoking(messages)
189+
instructions = await _call_before_run(provider, messages)
163190

164-
assert context.instructions is not None
165-
assert len(context.instructions) <= MAX_MEMORY_CONTEXT_CHARS + 200
191+
assert instructions is not None
192+
assert len(instructions) <= MAX_MEMORY_CONTEXT_CHARS + 200
166193

167194
asyncio.run(_run())
168195

@@ -175,10 +202,10 @@ async def _run():
175202
]
176203
messages = [_make_chat_message("What storage class should we choose for the cluster?")]
177204

178-
context = await provider.invoking(messages)
205+
instructions = await _call_before_run(provider, messages)
179206

180-
assert "Chief Architect" in context.instructions
181-
assert "design" in context.instructions
207+
assert "Chief Architect" in instructions
208+
assert "design" in instructions
182209

183210
asyncio.run(_run())
184211

@@ -189,26 +216,25 @@ async def _run():
189216
store.search.return_value = [_make_memory_entry("some memory")]
190217
single = _make_chat_message("What about networking configuration for AKS?")
191218

192-
context = await provider.invoking(single)
219+
instructions = await _call_before_run(provider, [single])
193220

194-
assert context.instructions is not None
221+
assert instructions is not None
195222
store.search.assert_called_once()
196223

197224
asyncio.run(_run())
198225

199226

200227
# ---------------------------------------------------------------------------
201-
# invoked() — Post-LLM memory storage
228+
# after_run() — Post-LLM memory storage
202229
# ---------------------------------------------------------------------------
203230

204231

205232
def test_invoked_stores_response():
206233
async def _run():
207234
provider, store = _make_provider()
208-
request = [_make_chat_message("What is the networking plan for AKS?")]
209-
response = [_make_chat_message("We should use Azure CNI for networking configuration in the AKS cluster")]
235+
response_text = "We should use Azure CNI for networking configuration in the AKS cluster"
210236

211-
await provider.invoked(request, response)
237+
await _call_after_run(provider, response_text)
212238
await provider.flush()
213239

214240
store.add.assert_called_once()
@@ -222,10 +248,9 @@ async def _run():
222248
def test_invoked_skips_on_exception():
223249
async def _run():
224250
provider, store = _make_provider()
225-
request = [_make_chat_message("Q")]
226-
response = [_make_chat_message("A" * 100)]
227-
228-
await provider.invoked(request, response, invoke_exception=Exception("fail"))
251+
# after_run with no response simulates exception path
252+
ctx = _make_context(response_text=None)
253+
await provider.after_run(agent=MagicMock(), session=MagicMock(), context=ctx, state={})
229254
store.add.assert_not_called()
230255

231256
asyncio.run(_run())
@@ -234,9 +259,8 @@ async def _run():
234259
def test_invoked_skips_none_response():
235260
async def _run():
236261
provider, store = _make_provider()
237-
request = [_make_chat_message("Q")]
238-
239-
await provider.invoked(request, None)
262+
ctx = _make_context(response_text=None)
263+
await provider.after_run(agent=MagicMock(), session=MagicMock(), context=ctx, state={})
240264
store.add.assert_not_called()
241265

242266
asyncio.run(_run())
@@ -245,10 +269,8 @@ async def _run():
245269
def test_invoked_skips_short_response():
246270
async def _run():
247271
provider, store = _make_provider()
248-
request = [_make_chat_message("Q")]
249-
short = [_make_chat_message("x" * (MIN_CONTENT_LENGTH_TO_STORE - 1))]
250-
251-
await provider.invoked(request, short)
272+
short_text = "x" * (MIN_CONTENT_LENGTH_TO_STORE - 1)
273+
await _call_after_run(provider, short_text)
252274
store.add.assert_not_called()
253275

254276
asyncio.run(_run())
@@ -257,10 +279,8 @@ async def _run():
257279
def test_invoked_stores_long_response():
258280
async def _run():
259281
provider, store = _make_provider()
260-
request = [_make_chat_message("Q")]
261-
long_resp = [_make_chat_message("x" * (MIN_CONTENT_LENGTH_TO_STORE + 1))]
262-
263-
await provider.invoked(request, long_resp)
282+
long_text = "x" * (MIN_CONTENT_LENGTH_TO_STORE + 1)
283+
await _call_after_run(provider, long_text)
264284
await provider.flush()
265285
store.add.assert_called_once()
266286

@@ -270,11 +290,10 @@ async def _run():
270290
def test_invoked_increments_turn_counter():
271291
async def _run():
272292
provider, store = _make_provider()
273-
request = [_make_chat_message("Q")]
274-
response = [_make_chat_message("A" * 100)]
293+
response_text = "A" * 100
275294

276-
await provider.invoked(request, response)
277-
await provider.invoked(request, response)
295+
await _call_after_run(provider, response_text)
296+
await _call_after_run(provider, response_text)
278297
assert provider._turn_counter == 2
279298

280299
asyncio.run(_run())
@@ -284,10 +303,9 @@ def test_invoked_store_failure_does_not_raise():
284303
async def _run():
285304
provider, store = _make_provider()
286305
store.add.side_effect = Exception("store failed")
287-
request = [_make_chat_message("Q")]
288-
response = [_make_chat_message("A" * 100)]
306+
response_text = "A" * 100
289307

290-
await provider.invoked(request, response)
308+
await _call_after_run(provider, response_text)
291309
await provider.flush() # Should not raise
292310

293311
asyncio.run(_run())
@@ -296,10 +314,9 @@ async def _run():
296314
def test_invoked_with_single_message():
297315
async def _run():
298316
provider, store = _make_provider()
299-
request = _make_chat_message("What is the question about networking?")
300-
response = _make_chat_message("We should use Azure CNI Overlay for the networking configuration in AKS")
317+
response_text = "We should use Azure CNI Overlay for the networking configuration in AKS"
301318

302-
await provider.invoked(request, response)
319+
await _call_after_run(provider, response_text)
303320
await provider.flush()
304321
store.add.assert_called_once()
305322

src/processor/src/tests/unit/steps/test_migration_processor_run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ def _make_processor(events: list, memory_store=None) -> MigrationProcessor:
5454

5555
proc._telemetry = telemetry # expose for assertions
5656

57-
async def _stream(_input):
57+
async def _stream(_input, **kwargs):
5858
for ev in events:
5959
yield ev
6060

6161
workflow = MagicMock()
62-
workflow.run_stream = _stream
62+
workflow.run = _stream
6363
proc.workflow = workflow
6464

6565
# Patch _create_memory_store as an AsyncMock returning the provided value.

0 commit comments

Comments
 (0)