Skip to content

Commit 3ff3b8b

Browse files
committed
[REL-11697] Update PoC per spec and implement tests
1 parent 2d9e842 commit 3ff3b8b

5 files changed

Lines changed: 564 additions & 205 deletions

File tree

packages/sdk/server-ai/src/ldai/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
from ldai.chat import Chat
66
from ldai.client import LDAIClient
7-
from ldai.agent_graph import AgentGraph
7+
from ldai.agent_graph import AgentGraphDefinition
88
from ldai.judge import Judge
99
from ldai.models import ( # Deprecated aliases for backward compatibility
1010
AIAgentConfig, AIAgentConfigDefault, AIAgentConfigRequest, AIAgents,
1111
AICompletionConfig, AICompletionConfigDefault, AIConfig, AIJudgeConfig,
1212
AIJudgeConfigDefault, JudgeConfiguration, LDAIAgent, LDAIAgentConfig,
13-
LDAIAgentDefaults, LDMessage, ModelConfig, ProviderConfig, AIAgentGraph, AIAgentGraphEdge)
13+
LDAIAgentDefaults, LDMessage, ModelConfig, ProviderConfig, AIAgentGraphConfig, Edge)
1414
from ldai.providers.types import EvalScore, JudgeResponse
1515

1616
__all__ = [
@@ -19,15 +19,15 @@
1919
'AIAgentConfigDefault',
2020
'AIAgentConfigRequest',
2121
'AIAgents',
22-
'AIAgentGraph',
23-
'AIAgentGraphEdge',
22+
'AIAgentGraphConfig',
23+
'Edge',
2424
'AICompletionConfig',
2525
'AICompletionConfigDefault',
2626
'AIJudgeConfig',
2727
'AIJudgeConfigDefault',
2828
'Chat',
2929
'EvalScore',
30-
'AgentGraph',
30+
'AgentGraphDefinition',
3131
'Judge',
3232
'JudgeConfiguration',
3333
'JudgeResponse',
Lines changed: 62 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,24 @@
11
"""Graph implementation for managing AI agent graphs."""
22

3-
from typing import Any, Callable, Dict, List, Optional, Set
4-
from ldai.models import AIAgentGraph, AIAgentConfig, AIAgentGraphEdge
3+
from typing import Any, Callable, Dict, List, Set
4+
from ldai.models import AIAgentGraphConfig, AIAgentConfig, Edge
55
from ldclient import Context
66

7-
7+
DEFAULT_FALSE = AIAgentConfig(key="", enabled=False)
88
class AgentGraphNode:
99
"""
1010
Node in an agent graph.
1111
"""
1212

13-
default_false = AIAgentConfig(key="", enabled=False)
14-
1513
def __init__(
1614
self,
1715
key: str,
1816
config: AIAgentConfig,
19-
children: List[AIAgentGraphEdge],
20-
parent_graph: "AgentGraph",
17+
children: List[Edge],
2118
):
2219
self._key = key
2320
self._config = config
2421
self._children = children
25-
self._parent_graph = parent_graph
2622

2723
def get_key(self) -> str:
2824
"""Get the key of the node."""
@@ -32,124 +28,85 @@ def get_config(self) -> AIAgentConfig:
3228
"""Get the config of the node."""
3329
return self._config
3430

35-
def get_edges(self) -> List[AIAgentGraphEdge]:
36-
"""Get the edges of the node."""
37-
return self._children
38-
39-
def get_child_nodes(self) -> List["AgentGraphNode"]:
40-
"""Get the child nodes of the node as AgentGraphNode objects."""
41-
return [
42-
self._parent_graph.get_node(edge.targetConfig) for edge in self._children
43-
]
44-
4531
def is_terminal(self) -> bool:
4632
"""Check if the node is a terminal node."""
47-
return len(self._children) == 0
48-
49-
def get_parent_nodes(self) -> List["AgentGraphNode"]:
50-
"""Get the parent nodes of the node as AgentGraphNode objects."""
51-
return [
52-
self._parent_graph.get_node(edge.sourceConfig)
53-
for edge in self._parent_graph._get_parent_edges(self._key)
54-
]
55-
56-
def traverse(
57-
self, fn: Callable[["AgentGraphNode", Dict[str, Any]], None], execution_context: Dict[str, Any] = {}, visited: Optional[Set[str]] = None
58-
) -> None:
59-
"""Traverse the graph downwardly from this node, calling fn on each node."""
60-
if visited is None:
61-
visited = set()
62-
63-
# Avoid cycles by tracking visited nodes
64-
if self._key in visited:
65-
return
66-
67-
visited.add(self._key)
68-
fn(self, execution_context)
69-
70-
for child in self._children:
71-
node = self._parent_graph.get_node(child.targetConfig)
72-
if node is not None:
73-
node.traverse(fn, execution_context, visited)
74-
75-
def reverse_traverse(
76-
self,
77-
fn: Callable[["AgentGraphNode", Dict[str, Any]], None],
78-
execution_context: Dict[str, Any] = {},
79-
visited: Optional[Set[str]] = None,
80-
) -> None:
81-
"""Reverse traverse the graph upwardly from this node, calling fn on each node."""
82-
if visited is None:
83-
visited = set()
84-
85-
# Avoid cycles by tracking visited nodes
86-
if self._key in visited:
87-
return
88-
89-
visited.add(self._key)
90-
fn(self, execution_context)
91-
92-
for parent in self._parent_graph._get_parent_edges(self._key):
93-
node = self._parent_graph.get_node(parent.sourceConfig)
94-
if node is not None:
95-
node.reverse_traverse(fn, execution_context, visited)
33+
return len(self._children) == 0
9634

35+
def get_edges(self) -> List[Edge]:
36+
"""Get the edges of the node."""
37+
return self._children
9738

98-
class AgentGraph:
39+
class AgentGraphDefinition:
9940
"""
10041
Graph implementation for managing AI agent graphs.
10142
"""
102-
103-
default_false = AIAgentConfig(key="", enabled=False)
104-
10543
def __init__(
10644
self,
107-
agent_graph: AIAgentGraph,
45+
agent_graph: AIAgentGraphConfig,
46+
nodes: Dict[str, AgentGraphNode],
10847
context: Context,
109-
get_agent: Callable[[str, Context, dict], AIAgentConfig],
11048
):
11149
self._agent_graph = agent_graph
11250
self._context = context
113-
self._get_agent = get_agent
114-
self._nodes = self._build_nodes()
51+
self._nodes = nodes
11552

116-
def _build_nodes(self) -> Dict[str, AgentGraphNode]:
53+
@staticmethod
54+
def build_nodes(
55+
agent_graph: AIAgentGraphConfig,
56+
graph_nodes: Dict[str, AIAgentConfig],
57+
) -> Dict[str, "AgentGraphNode"]:
11758
"""Build the nodes of the graph into AgentGraphNode objects."""
11859
nodes = {
119-
self._agent_graph.rootConfigKey: AgentGraphNode(
120-
self._agent_graph.rootConfigKey,
121-
self._get_agent(
122-
self._agent_graph.rootConfigKey, self._context, self.default_false
123-
),
124-
self._get_child_edges(self._agent_graph.rootConfigKey),
125-
self,
60+
agent_graph.root_config_key: AgentGraphNode(
61+
agent_graph.root_config_key,
62+
graph_nodes[agent_graph.root_config_key],
63+
[
64+
edge
65+
for edge in agent_graph.edges
66+
if edge.source_config == agent_graph.root_config_key
67+
],
12668
),
12769
}
12870

129-
for edge in self._agent_graph.edges:
130-
nodes[edge.targetConfig] = AgentGraphNode(
131-
edge.targetConfig,
132-
self._get_agent(edge.targetConfig, self._context, self.default_false),
133-
self._get_child_edges(edge.targetConfig),
134-
self,
71+
for edge in agent_graph.edges:
72+
nodes[edge.target_config] = AgentGraphNode(
73+
edge.target_config,
74+
graph_nodes[edge.target_config],
75+
[
76+
e
77+
for e in agent_graph.edges
78+
if e.source_config == edge.target_config
79+
],
13580
)
13681

13782
return nodes
13883

139-
def _get_child_edges(self, config_key: str) -> List[AIAgentGraphEdge]:
84+
def get_node(self, key: str) -> AgentGraphNode | None:
85+
"""Get a node by its key."""
86+
return self._nodes.get(key)
87+
88+
def _get_child_edges(self, config_key: str) -> List[Edge]:
14089
"""Get the child edges of the given config."""
14190
return [
14291
edge
14392
for edge in self._agent_graph.edges
144-
if edge.sourceConfig == config_key
93+
if edge.source_config == config_key
14594
]
14695

147-
def _get_parent_edges(self, config_key: str) -> List[AIAgentGraphEdge]:
148-
"""Get the parent edges of the given config."""
96+
def get_child_nodes(self, node_key: str) -> List[AgentGraphNode]:
97+
"""Get the child nodes of the given node key as AgentGraphNode objects."""
14998
return [
150-
edge
99+
self.get_node(edge.target_config)
100+
for edge in self._agent_graph.edges
101+
if edge.source_config == node_key and self.get_node(edge.target_config) is not None
102+
]
103+
104+
def get_parent_nodes(self, node_key: str) -> List[AgentGraphNode]:
105+
"""Get the parent nodes of the given node key as AgentGraphNode objects."""
106+
return [
107+
self.get_node(edge.source_config)
151108
for edge in self._agent_graph.edges
152-
if edge.targetConfig == config_key
109+
if edge.target_config == node_key and self.get_node(edge.source_config) is not None
153110
]
154111

155112
def _collect_nodes(
@@ -170,33 +127,18 @@ def _collect_nodes(
170127
nodes_by_depth[node_depth] = []
171128
nodes_by_depth[node_depth].append(node)
172129

173-
for child in node.get_child_nodes():
130+
for child in self.get_child_nodes(node_key):
174131
self._collect_nodes(child, node_depths, nodes_by_depth, visited)
175132

176133
def terminal_nodes(self) -> List[AgentGraphNode]:
177134
"""Get the terminal nodes of the graph, meaning any nodes without children."""
178135
return [
179-
node for node in self._nodes.values() if len(node.get_child_nodes()) == 0
136+
node for node in self._nodes.values() if len(self.get_child_nodes(node.get_key())) == 0
180137
]
181138

182139
def root(self) -> AgentGraphNode | None:
183140
"""Get the root node of the graph."""
184-
config = self._get_agent(
185-
self._agent_graph.rootConfigKey, self._context, self.default_false
186-
)
187-
188-
if config.enabled is False:
189-
return None
190-
191-
children = [
192-
edge
193-
for edge in self._agent_graph.edges
194-
if edge.sourceConfig == self._agent_graph.rootConfigKey
195-
]
196-
197-
node = AgentGraphNode(self._agent_graph.rootConfigKey, config, children, self)
198-
199-
return node
141+
return self._nodes[self._agent_graph.root_config_key]
200142

201143
def traverse(self, fn: Callable[["AgentGraphNode", Dict[str, Any]], None], execution_context: Dict[str, Any] = {}) -> None:
202144
"""Traverse from the root down to terminal nodes, visiting nodes in order of depth.
@@ -215,7 +157,8 @@ def traverse(self, fn: Callable[["AgentGraphNode", Dict[str, Any]], None], execu
215157
depth += 1
216158

217159
for node in current_level:
218-
for child in node.get_child_nodes():
160+
node_key = node.get_key()
161+
for child in self.get_child_nodes(node_key):
219162
child_key = child.get_key()
220163
# Defer this child to the next level if it's at a longer path
221164
if child_key not in node_depths or (
@@ -236,7 +179,7 @@ def traverse(self, fn: Callable[["AgentGraphNode", Dict[str, Any]], None], execu
236179
for node in nodes_by_depth[depth_level]:
237180
execution_context[node.get_key()] = fn(node, execution_context)
238181

239-
return execution_context[self._agent_graph.rootConfigKey]
182+
return execution_context[self._agent_graph.root_config_key]
240183

241184
def reverse_traverse(self, fn: Callable[["AgentGraphNode", Dict[str, Any]], Any], execution_context: Dict[str, Any] = {}) -> None:
242185
"""Traverse from terminal nodes up to the root, visiting nodes level by level.
@@ -247,7 +190,7 @@ def reverse_traverse(self, fn: Callable[["AgentGraphNode", Dict[str, Any]], Any]
247190

248191
visited: Set[str] = set()
249192
current_level: List[AgentGraphNode] = terminal_nodes
250-
root_key = self._agent_graph.rootConfigKey
193+
root_key = self._agent_graph.root_config_key
251194
root_node_seen = False
252195

253196
while current_level:
@@ -266,7 +209,7 @@ def reverse_traverse(self, fn: Callable[["AgentGraphNode", Dict[str, Any]], Any]
266209

267210
execution_context[node_key] = fn(node, execution_context)
268211

269-
for parent in node.get_parent_nodes():
212+
for parent in self.get_parent_nodes(node_key):
270213
parent_key = parent.get_key()
271214
if parent_key not in visited:
272215
next_level.append(parent)
@@ -280,8 +223,6 @@ def reverse_traverse(self, fn: Callable[["AgentGraphNode", Dict[str, Any]], Any]
280223
if root_node is not None:
281224
execution_context[root_node.get_key()] = fn(root_node, execution_context)
282225

283-
return execution_context[self._agent_graph.rootConfigKey]
226+
return execution_context[self._agent_graph.root_config_key]
227+
284228

285-
def get_node(self, key: str) -> AgentGraphNode | None:
286-
"""Get a node by its key."""
287-
return self._nodes.get(key)

0 commit comments

Comments
 (0)