Skip to content

Commit dc2a32b

Browse files
jsonbaileyclaude
andcommitted
fix: ManagedAgentGraph takes config as first required parameter
Aligns ManagedAgentGraph constructor with the ManagedModel and ManagedAgent pattern: graph (AgentGraphDefinition) is now the first required parameter, followed by the runner. Removes the Optional guard and tightens the type annotation accordingly. Updates client.py call site and all test fixtures. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 7f0642e commit dc2a32b

3 files changed

Lines changed: 27 additions & 22 deletions

File tree

packages/sdk/server-ai/src/ldai/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,7 @@ async def create_agent_graph(
799799
if not runner:
800800
return None
801801

802-
return ManagedAgentGraph(runner, graph=graph)
802+
return ManagedAgentGraph(graph, runner)
803803

804804
def agents(
805805
self,

packages/sdk/server-ai/src/ldai/managed_agent_graph.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""ManagedAgentGraph — LaunchDarkly managed wrapper for agent graph execution."""
22

3-
from typing import Any, Optional
3+
from typing import Any
44

5+
from ldai.agent_graph import AgentGraphDefinition
56
from ldai.providers import AgentGraphRunner
67
from ldai.providers.types import (
78
AgentGraphRunnerResult,
@@ -23,18 +24,18 @@ class ManagedAgentGraph:
2324

2425
def __init__(
2526
self,
27+
graph: AgentGraphDefinition,
2628
runner: AgentGraphRunner,
27-
graph: Optional[Any] = None,
2829
):
2930
"""
3031
Initialize ManagedAgentGraph.
3132
32-
:param runner: The AgentGraphRunner to delegate execution to
33-
:param graph: Optional AgentGraphDefinition used to drive graph-level and
33+
:param graph: The AgentGraphDefinition used to drive graph-level and
3434
per-node tracking from the runner result metrics.
35+
:param runner: The AgentGraphRunner to delegate execution to
3536
"""
36-
self._runner = runner
3737
self._graph = graph
38+
self._runner = runner
3839

3940
async def run(self, input: Any) -> ManagedGraphResult:
4041
"""
@@ -56,10 +57,9 @@ async def run(self, input: Any) -> ManagedGraphResult:
5657

5758
summary = self._build_summary_from_runner_result(result)
5859

59-
if self._graph is not None:
60-
graph_tracker = self._graph.create_tracker()
61-
self._flush_graph_tracking(result, graph_tracker)
62-
self._flush_node_tracking(result)
60+
graph_tracker = self._graph.create_tracker()
61+
self._flush_graph_tracking(result, graph_tracker)
62+
self._flush_node_tracking(result)
6363

6464
return ManagedGraphResult(
6565
content=result.content,
@@ -106,9 +106,6 @@ def _flush_node_tracking(self, result: AgentGraphRunnerResult) -> None:
106106
config tracker via the graph definition and fires token, duration,
107107
tool call, and success/error events.
108108
"""
109-
if self._graph is None:
110-
return
111-
112109
for node_key, node_ldai_metrics in result.metrics.node_metrics.items():
113110
node = self._graph.get_node(node_key)
114111
if node is None:

packages/sdk/server-ai/tests/test_managed_agent_graph.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ async def run(self, input) -> AgentGraphRunnerResult:
6262
async def test_managed_agent_graph_run_delegates_to_runner():
6363
"""Runner result content is surfaced correctly."""
6464
runner = StubAgentGraphRunner("hello world")
65-
managed = ManagedAgentGraph(runner)
65+
mock_graph = MagicMock()
66+
mock_graph.create_tracker = MagicMock(return_value=MagicMock())
67+
mock_graph.get_node = MagicMock(return_value=None)
68+
managed = ManagedAgentGraph(mock_graph, runner)
6669
result = await managed.run("test input")
6770
assert isinstance(result, ManagedGraphResult)
6871
assert result.content == "hello world"
@@ -71,7 +74,8 @@ async def test_managed_agent_graph_run_delegates_to_runner():
7174

7275
def test_managed_agent_graph_get_runner():
7376
runner = StubAgentGraphRunner()
74-
managed = ManagedAgentGraph(runner)
77+
mock_graph = MagicMock()
78+
managed = ManagedAgentGraph(mock_graph, runner)
7579
assert managed.get_agent_graph_runner() is runner
7680

7781

@@ -84,7 +88,7 @@ async def test_managed_agent_graph_run_surfaces_graph_metrics():
8488
mock_graph.create_tracker = MagicMock(return_value=mock_tracker)
8589
mock_graph.get_node = MagicMock(return_value=None) # no nodes for this test
8690

87-
managed = ManagedAgentGraph(runner, graph=mock_graph)
91+
managed = ManagedAgentGraph(mock_graph, runner)
8892
result = await managed.run("test input")
8993

9094
assert isinstance(result, ManagedGraphResult)
@@ -105,7 +109,7 @@ async def test_managed_agent_graph_drives_graph_level_tracking():
105109
mock_graph.create_tracker = MagicMock(return_value=mock_tracker)
106110
mock_graph.get_node = MagicMock(return_value=None)
107111

108-
managed = ManagedAgentGraph(runner, graph=mock_graph)
112+
managed = ManagedAgentGraph(mock_graph, runner)
109113
await managed.run("test input")
110114

111115
mock_tracker.track_path.assert_called_once_with(["root", "specialist"])
@@ -135,7 +139,7 @@ def get_node(key):
135139

136140
mock_graph.get_node = get_node
137141

138-
managed = ManagedAgentGraph(runner, graph=mock_graph)
142+
managed = ManagedAgentGraph(mock_graph, runner)
139143
await managed.run("test input")
140144

141145
# root node tracking
@@ -150,10 +154,14 @@ def get_node(key):
150154

151155

152156
@pytest.mark.asyncio
153-
async def test_managed_agent_graph_no_graph_skips_tracking():
154-
"""Without a graph reference, no tracking is called but run succeeds."""
157+
async def test_managed_agent_graph_run_succeeds_with_graph():
158+
"""Run succeeds and returns correct content when graph is provided."""
155159
runner = StubRunnerWithMetrics()
156-
managed = ManagedAgentGraph(runner, graph=None)
160+
mock_graph = MagicMock()
161+
mock_tracker = MagicMock()
162+
mock_graph.create_tracker = MagicMock(return_value=mock_tracker)
163+
mock_graph.get_node = MagicMock(return_value=None)
164+
managed = ManagedAgentGraph(mock_graph, runner)
157165
result = await managed.run("test input")
158166
assert result.content == "new shape output"
159167
assert result.metrics.success is True
@@ -176,7 +184,7 @@ async def run(self, input) -> AgentGraphRunnerResult:
176184
mock_graph.create_tracker = MagicMock(return_value=mock_tracker)
177185
mock_graph.get_node = MagicMock(return_value=None)
178186

179-
managed = ManagedAgentGraph(FailingRunner(), graph=mock_graph)
187+
managed = ManagedAgentGraph(mock_graph, FailingRunner())
180188
result = await managed.run("test input")
181189

182190
assert result.metrics.success is False

0 commit comments

Comments
 (0)