|
11 | 11 | from unittest.mock import AsyncMock, MagicMock, patch |
12 | 12 |
|
13 | 13 | from ldai.agent_graph import AgentGraphDefinition |
14 | | -from ldai.models import AIAgentGraphConfig, AIAgentConfig, ModelConfig, ProviderConfig |
| 14 | +from ldai.models import AIAgentGraphConfig, AIAgentConfig, Edge, ModelConfig, ProviderConfig |
15 | 15 | from ldai.tracker import AIGraphTracker, LDAIConfigTracker |
16 | 16 | from ldai_openai.openai_agent_graph_runner import OpenAIAgentGraphRunner |
17 | 17 |
|
@@ -153,6 +153,74 @@ def _make_agents_modules(run_result: MagicMock) -> dict: |
153 | 153 | } |
154 | 154 |
|
155 | 155 |
|
| 156 | +def _make_two_node_graph(mock_ld_client: MagicMock) -> AgentGraphDefinition: |
| 157 | + """Build a two-node AgentGraphDefinition (root-agent → child-agent).""" |
| 158 | + context = MagicMock() |
| 159 | + |
| 160 | + root_tracker = LDAIConfigTracker( |
| 161 | + ld_client=mock_ld_client, |
| 162 | + variation_key='test-variation', |
| 163 | + config_key='root-agent', |
| 164 | + version=1, |
| 165 | + model_name='gpt-4', |
| 166 | + provider_name='openai', |
| 167 | + context=context, |
| 168 | + ) |
| 169 | + child_tracker = LDAIConfigTracker( |
| 170 | + ld_client=mock_ld_client, |
| 171 | + variation_key='test-variation', |
| 172 | + config_key='child-agent', |
| 173 | + version=1, |
| 174 | + model_name='gpt-4', |
| 175 | + provider_name='openai', |
| 176 | + context=context, |
| 177 | + ) |
| 178 | + graph_tracker = AIGraphTracker( |
| 179 | + ld_client=mock_ld_client, |
| 180 | + variation_key='test-variation', |
| 181 | + graph_key='two-node-graph', |
| 182 | + version=1, |
| 183 | + context=context, |
| 184 | + ) |
| 185 | + |
| 186 | + root_config = AIAgentConfig( |
| 187 | + key='root-agent', |
| 188 | + enabled=True, |
| 189 | + model=ModelConfig(name='gpt-4', parameters={}), |
| 190 | + provider=ProviderConfig(name='openai'), |
| 191 | + instructions='You are root.', |
| 192 | + tracker=root_tracker, |
| 193 | + ) |
| 194 | + child_config = AIAgentConfig( |
| 195 | + key='child-agent', |
| 196 | + enabled=True, |
| 197 | + model=ModelConfig(name='gpt-4', parameters={}), |
| 198 | + provider=ProviderConfig(name='openai'), |
| 199 | + instructions='You are child.', |
| 200 | + tracker=child_tracker, |
| 201 | + ) |
| 202 | + |
| 203 | + edge = Edge(key='root-to-child', source_config='root-agent', target_config='child-agent') |
| 204 | + graph_config = AIAgentGraphConfig( |
| 205 | + key='two-node-graph', |
| 206 | + root_config_key='root-agent', |
| 207 | + edges=[edge], |
| 208 | + enabled=True, |
| 209 | + ) |
| 210 | + |
| 211 | + nodes = AgentGraphDefinition.build_nodes(graph_config, { |
| 212 | + 'root-agent': root_config, |
| 213 | + 'child-agent': child_config, |
| 214 | + }) |
| 215 | + return AgentGraphDefinition( |
| 216 | + agent_graph=graph_config, |
| 217 | + nodes=nodes, |
| 218 | + context=context, |
| 219 | + enabled=True, |
| 220 | + tracker=graph_tracker, |
| 221 | + ) |
| 222 | + |
| 223 | + |
156 | 224 | def _events(mock_ld_client: MagicMock) -> dict: |
157 | 225 | """Return dict of event_name -> list of (data, value) from all track() calls.""" |
158 | 226 | result = defaultdict(list) |
@@ -303,3 +371,89 @@ async def test_tracks_failure_and_latency_on_runner_error(): |
303 | 371 | assert '$ld:ai:graph:invocation_failure' in ev |
304 | 372 | assert '$ld:ai:graph:latency' in ev |
305 | 373 | assert '$ld:ai:graph:invocation_success' not in ev |
| 374 | + |
| 375 | + |
| 376 | +@pytest.mark.asyncio |
| 377 | +async def test_multi_node_tracks_per_node_tokens_and_handoff(): |
| 378 | + """Each node emits its own token events; handoff event fires between them.""" |
| 379 | + mock_ld_client = MagicMock() |
| 380 | + graph = _make_two_node_graph(mock_ld_client) |
| 381 | + |
| 382 | + root_entry = MagicMock() |
| 383 | + root_entry.total_tokens = 15 |
| 384 | + root_entry.input_tokens = 10 |
| 385 | + root_entry.output_tokens = 5 |
| 386 | + |
| 387 | + child_entry = MagicMock() |
| 388 | + child_entry.total_tokens = 9 |
| 389 | + child_entry.input_tokens = 6 |
| 390 | + child_entry.output_tokens = 3 |
| 391 | + |
| 392 | + run_result = MagicMock() |
| 393 | + run_result.final_output = 'child answer' |
| 394 | + run_result.new_items = [] |
| 395 | + run_result.usage = None |
| 396 | + run_result.context_wrapper.usage.total_tokens = 24 |
| 397 | + run_result.context_wrapper.usage.input_tokens = 16 |
| 398 | + run_result.context_wrapper.usage.output_tokens = 8 |
| 399 | + run_result.context_wrapper.usage.request_usage_entries = [root_entry, child_entry] |
| 400 | + |
| 401 | + on_handoff_callbacks = [] |
| 402 | + |
| 403 | + def capture_handoff(**kwargs): |
| 404 | + cb = kwargs.get('on_handoff') |
| 405 | + if cb: |
| 406 | + on_handoff_callbacks.append(cb) |
| 407 | + return MagicMock() |
| 408 | + |
| 409 | + async def mock_run(agent, input_str, **kwargs): |
| 410 | + # Simulate the root→child handoff before returning |
| 411 | + if on_handoff_callbacks: |
| 412 | + run_ctx = MagicMock() |
| 413 | + run_ctx.usage.request_usage_entries = [root_entry] |
| 414 | + on_handoff_callbacks[0](run_ctx) |
| 415 | + return run_result |
| 416 | + |
| 417 | + mock_runner_cls = MagicMock() |
| 418 | + mock_runner_cls.run = mock_run |
| 419 | + |
| 420 | + mock_agents = MagicMock() |
| 421 | + mock_agents.Runner = mock_runner_cls |
| 422 | + mock_agents.Agent = MagicMock(return_value=MagicMock()) |
| 423 | + mock_agents.Handoff = MagicMock() |
| 424 | + mock_agents.Tool = MagicMock() |
| 425 | + mock_agents.function_tool = lambda fn: MagicMock() |
| 426 | + mock_agents.handoff = capture_handoff |
| 427 | + |
| 428 | + mock_ext = MagicMock() |
| 429 | + mock_ext.RECOMMENDED_PROMPT_PREFIX = '[PREFIX]' |
| 430 | + |
| 431 | + with patch.dict('sys.modules', { |
| 432 | + 'agents': mock_agents, |
| 433 | + 'agents.extensions': MagicMock(), |
| 434 | + 'agents.extensions.handoff_prompt': mock_ext, |
| 435 | + 'agents.tool_context': MagicMock(), |
| 436 | + }): |
| 437 | + runner = OpenAIAgentGraphRunner(graph, {}) |
| 438 | + result = await runner.run('hello') |
| 439 | + |
| 440 | + assert result.metrics.success is True |
| 441 | + |
| 442 | + ev = _events(mock_ld_client) |
| 443 | + |
| 444 | + # Per-node token events identified by configKey |
| 445 | + root_tokens = [(d, v) for d, v in ev.get('$ld:ai:tokens:total', []) if d.get('configKey') == 'root-agent'] |
| 446 | + child_tokens = [(d, v) for d, v in ev.get('$ld:ai:tokens:total', []) if d.get('configKey') == 'child-agent'] |
| 447 | + assert root_tokens[0][1] == 15 |
| 448 | + assert child_tokens[0][1] == 9 |
| 449 | + |
| 450 | + # Execution path includes both node keys |
| 451 | + path_data = ev['$ld:ai:graph:path'][0][0] |
| 452 | + assert 'root-agent' in path_data['path'] |
| 453 | + assert 'child-agent' in path_data['path'] |
| 454 | + |
| 455 | + # Handoff event fires with correct source and target |
| 456 | + handoff_events = ev.get('$ld:ai:graph:handoff_success', []) |
| 457 | + assert len(handoff_events) == 1 |
| 458 | + assert handoff_events[0][0]['sourceKey'] == 'root-agent' |
| 459 | + assert handoff_events[0][0]['targetKey'] == 'child-agent' |
0 commit comments