Skip to content

Commit 889bd6c

Browse files
committed
[REL-11697] Simplify types; add enabled directly to objects; cursor feedback
1 parent c543ce6 commit 889bd6c

5 files changed

Lines changed: 85 additions & 52 deletions

File tree

packages/sdk/server-ai/src/ldai/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from ldclient import log
44

5-
from ldai.agent_graph import AgentGraphDefinition, AIAgentGraphResponse
5+
from ldai.agent_graph import AgentGraphDefinition
66
from ldai.chat import Chat
77
from ldai.client import LDAIClient
88
from ldai.judge import Judge
@@ -21,7 +21,6 @@
2121
'AIAgentConfigRequest',
2222
'AIAgents',
2323
'AIAgentGraphConfig',
24-
'AIAgentGraphResponse',
2524
'Edge',
2625
'AICompletionConfig',
2726
'AICompletionConfigDefault',

packages/sdk/server-ai/src/ldai/agent_graph/__init__.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,23 @@ class AgentGraphDefinition:
4646
"""
4747
Graph implementation for managing AI agent graphs.
4848
"""
49+
enabled: bool
4950

5051
def __init__(
5152
self,
5253
agent_graph: AIAgentGraphConfig,
5354
nodes: Dict[str, AgentGraphNode],
5455
context: Context,
56+
enabled: bool,
5557
):
5658
self._agent_graph = agent_graph
5759
self._context = context
5860
self._nodes = nodes
61+
self.enabled = enabled
62+
63+
def is_enabled(self) -> bool:
64+
"""Check if the graph is enabled."""
65+
return self.enabled
5966

6067
@staticmethod
6168
def build_nodes(
@@ -143,17 +150,20 @@ def terminal_nodes(self) -> List[AgentGraphNode]:
143150
if len(self.get_child_nodes(node.get_key())) == 0
144151
]
145152

146-
def root(self) -> Optional[AgentGraphNode]:
153+
def root(self) -> AgentGraphNode:
147154
"""Get the root node of the graph."""
148-
return self._nodes[self._agent_graph.root_config_key]
155+
return self._nodes.get(self._agent_graph.root_config_key)
149156

150157
def traverse(
151158
self,
152159
fn: Callable[["AgentGraphNode", Dict[str, Any]], Any],
153-
execution_context: Dict[str, Any] = {},
154-
) -> None:
160+
execution_context: Dict[str, Any] = None,
161+
) -> Any:
155162
"""Traverse from the root down to terminal nodes, visiting nodes in order of depth.
156163
Nodes with the longest paths from the root (deepest nodes) will always be visited last."""
164+
if execution_context is None:
165+
execution_context = {}
166+
157167
root_node = self.root()
158168
if root_node is None:
159169
return
@@ -195,10 +205,14 @@ def traverse(
195205
def reverse_traverse(
196206
self,
197207
fn: Callable[["AgentGraphNode", Dict[str, Any]], Any],
198-
execution_context: Dict[str, Any] = {},
199-
) -> None:
208+
execution_context: Dict[str, Any] = None,
209+
) -> Any:
210+
200211
"""Traverse from terminal nodes up to the root, visiting nodes level by level.
201212
The root node will always be visited last, even if multiple paths converge at it."""
213+
if execution_context is None:
214+
execution_context = {}
215+
202216
terminal_nodes = self.terminal_nodes()
203217
if not terminal_nodes:
204218
return
@@ -242,15 +256,3 @@ def reverse_traverse(
242256

243257
return execution_context[self._agent_graph.root_config_key]
244258

245-
246-
# ============================================================================
247-
# AI Config Agent Graph Response
248-
# ============================================================================
249-
@dataclass
250-
class AIAgentGraphResponse:
251-
"""
252-
Agent graph response.
253-
"""
254-
255-
enabled: bool
256-
graph: Optional[AgentGraphDefinition] = None

packages/sdk/server-ai/src/ldai/client.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ldclient.client import LDClient
66

77
from ldai import log
8-
from ldai.agent_graph import AgentGraphDefinition, AIAgentGraphResponse
8+
from ldai.agent_graph import AgentGraphDefinition
99
from ldai.chat import Chat
1010
from ldai.judge import Judge
1111
from ldai.models import (AIAgentConfig, AIAgentConfigDefault,
@@ -425,18 +425,30 @@ def agent_graph(
425425
self,
426426
key: str,
427427
context: Context,
428-
) -> AIAgentGraphResponse:
428+
) -> AgentGraphDefinition:
429429
"""`
430430
Retrieve an AI agent graph.
431431
"""
432432
variation = self._client.variation(key, context, {})
433433

434434
if not variation.get("rootConfigKey"):
435435
log.debug(f"Agent graph {key} is disabled, no root config key found")
436-
return AIAgentGraphResponse(enabled=False, graph=None)
436+
return AgentGraphDefinition(
437+
AIAgentGraphConfig(
438+
key=key,
439+
name="",
440+
root_config_key="",
441+
edges=[],
442+
description="",
443+
enabled=False,
444+
),
445+
nodes={},
446+
context=context,
447+
enabled=False,
448+
)
437449

438450
all_agent_keys = [variation["rootConfigKey"]] + [
439-
edge["targetConfig"] for edge in variation["edges"]
451+
edge["targetConfig"] for edge in variation.get("edges", [])
440452
]
441453
agent_configs = {
442454
key: self.agent_config(key, context, AIAgentConfigDefault(enabled=False))
@@ -447,7 +459,19 @@ def agent_graph(
447459
log.debug(
448460
f"Agent graph {key} is disabled, not all agent configs are enabled"
449461
)
450-
return AIAgentGraphResponse(enabled=False, graph=None)
462+
return AgentGraphDefinition(
463+
AIAgentGraphConfig(
464+
key=key,
465+
name="",
466+
root_config_key="",
467+
edges=[],
468+
description="",
469+
enabled=False,
470+
),
471+
nodes={},
472+
context=context,
473+
enabled=False,
474+
)
451475

452476
try:
453477
agent_graph_config = AIAgentGraphConfig(
@@ -467,20 +491,30 @@ def agent_graph(
467491
)
468492
except Exception as e:
469493
log.debug(f"Agent graph {key} is disabled, invalid agent graph config")
470-
return AIAgentGraphResponse(enabled=False, graph=None)
494+
return AgentGraphDefinition(
495+
AIAgentGraphConfig(
496+
key=key,
497+
name="",
498+
root_config_key="",
499+
edges=[],
500+
description="",
501+
enabled=False,
502+
),
503+
nodes={},
504+
context=context,
505+
enabled=False,
506+
)
471507

472508
nodes = AgentGraphDefinition.build_nodes(
473509
agent_graph_config,
474510
agent_configs,
475511
)
476512

477-
return AIAgentGraphResponse(
478-
enabled=True,
479-
graph=AgentGraphDefinition(
480-
agent_graph=agent_graph_config,
481-
nodes=nodes,
482-
context=context,
483-
),
513+
return AgentGraphDefinition(
514+
agent_graph=agent_graph_config,
515+
nodes=nodes,
516+
context=context,
517+
enabled=agent_graph_config.enabled,
484518
)
485519

486520
def agents(

packages/sdk/server-ai/src/ldai/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from dataclasses import dataclass, field
33
from typing import Any, Dict, List, Literal, Optional, Union
44

5-
from ldai.agent_graph import AgentGraphDefinition
65
from ldai.tracker import LDAIConfigTracker
76

87

@@ -370,6 +369,7 @@ class AIAgentGraphConfig:
370369
root_config_key: str
371370
edges: List[Edge]
372371
description: Optional[str] = ""
372+
enabled: bool = True
373373

374374

375375
# ============================================================================

packages/sdk/server-ai/tests/test_agent_graph.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -237,32 +237,32 @@ def ldai_client(client: LDClient) -> LDAIClient:
237237
def test_agent_graph_method(ldai_client: LDAIClient):
238238
graph = ldai_client.agent_graph("test-agent-graph", Context.create("user-key"))
239239

240-
assert graph["enabled"] is True
241-
assert graph["graph"] is not None
242-
assert graph["graph"].root() is not None
243-
assert graph["graph"].root().get_key() == "customer-support-agent"
244-
assert len(graph["graph"].get_child_nodes("customer-support-agent")) == 3
245-
assert len(graph["graph"].get_child_nodes("personalized-agent")) == 0
246-
assert len(graph["graph"].get_child_nodes("multi-context-agent")) == 0
247-
assert len(graph["graph"].get_child_nodes("minimal-agent")) == 0
240+
assert graph.enabled is True
241+
assert graph is not None
242+
assert graph.root() is not None
243+
assert graph.root().get_key() == "customer-support-agent"
244+
assert len(graph.get_child_nodes("customer-support-agent")) == 3
245+
assert len(graph.get_child_nodes("personalized-agent")) == 0
246+
assert len(graph.get_child_nodes("multi-context-agent")) == 0
247+
assert len(graph.get_child_nodes("minimal-agent")) == 0
248248

249249

250250
def test_agent_graph_method_disabled_agent(ldai_client: LDAIClient):
251251
graph = ldai_client.agent_graph(
252252
"test-agent-graph-disabled-agent", Context.create("user-key")
253253
)
254254

255-
assert graph["enabled"] is False
256-
assert graph["graph"] is None
255+
assert graph.enabled is False
256+
assert graph.root() is None
257257

258258

259259
def test_agent_graph_method_no_root_key(ldai_client: LDAIClient):
260260
graph = ldai_client.agent_graph(
261261
"test-agent-graph-no-root-key", Context.create("user-key")
262262
)
263263

264-
assert graph["enabled"] is False
265-
assert graph["graph"] is None
264+
assert graph.enabled is False
265+
assert graph.root() is None
266266

267267

268268
def test_agent_graph_build_nodes(ldai_client: LDAIClient):
@@ -319,9 +319,7 @@ def test_agent_graph_build_nodes(ldai_client: LDAIClient):
319319

320320

321321
def test_agent_graph_get_methods(ldai_client: LDAIClient):
322-
graph = ldai_client.agent_graph("test-agent-graph", Context.create("user-key"))[
323-
"graph"
324-
]
322+
graph = ldai_client.agent_graph("test-agent-graph", Context.create("user-key"))
325323

326324
assert graph.root() is not None
327325
assert graph.root().get_key() == "customer-support-agent"
@@ -359,7 +357,7 @@ def test_agent_graph_get_methods(ldai_client: LDAIClient):
359357
def test_agent_graph_traverse(ldai_client: LDAIClient):
360358
graph = ldai_client.agent_graph(
361359
"test-agent-graph-depth-3", Context.create("user-key")
362-
)["graph"]
360+
)
363361

364362
context = {}
365363
order = []
@@ -387,7 +385,7 @@ def handle_traverse(node, context):
387385
def test_agent_graph_reverse_traverse(ldai_client: LDAIClient):
388386
graph = ldai_client.agent_graph(
389387
"test-agent-graph-depth-3", Context.create("user-key")
390-
)["graph"]
388+
)
391389

392390
context = {}
393391
order = []
@@ -414,7 +412,7 @@ def handle_reverse_traverse(node, context):
414412
def test_agent_graph_handoff(ldai_client: LDAIClient):
415413
graph = ldai_client.agent_graph(
416414
"test-agent-graph-depth-3", Context.create("user-key")
417-
)["graph"]
415+
)
418416

419417
context = {}
420418

0 commit comments

Comments
 (0)