Skip to content

Commit 2d9e842

Browse files
committed
local commit
1 parent 0d0cecc commit 2d9e842

4 files changed

Lines changed: 406 additions & 2 deletions

File tree

packages/sdk/server-ai/src/ldai/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
from ldai.chat import Chat
66
from ldai.client import LDAIClient
7+
from ldai.agent_graph import AgentGraph
78
from ldai.judge import Judge
89
from ldai.models import ( # Deprecated aliases for backward compatibility
910
AIAgentConfig, AIAgentConfigDefault, AIAgentConfigRequest, AIAgents,
1011
AICompletionConfig, AICompletionConfigDefault, AIConfig, AIJudgeConfig,
1112
AIJudgeConfigDefault, JudgeConfiguration, LDAIAgent, LDAIAgentConfig,
12-
LDAIAgentDefaults, LDMessage, ModelConfig, ProviderConfig)
13+
LDAIAgentDefaults, LDMessage, ModelConfig, ProviderConfig, AIAgentGraph, AIAgentGraphEdge)
1314
from ldai.providers.types import EvalScore, JudgeResponse
1415

1516
__all__ = [
@@ -18,12 +19,15 @@
1819
'AIAgentConfigDefault',
1920
'AIAgentConfigRequest',
2021
'AIAgents',
22+
'AIAgentGraph',
23+
'AIAgentGraphEdge',
2124
'AICompletionConfig',
2225
'AICompletionConfigDefault',
2326
'AIJudgeConfig',
2427
'AIJudgeConfigDefault',
2528
'Chat',
2629
'EvalScore',
30+
'AgentGraph',
2731
'Judge',
2832
'JudgeConfiguration',
2933
'JudgeResponse',
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
"""Graph implementation for managing AI agent graphs."""
2+
3+
from typing import Any, Callable, Dict, List, Optional, Set
4+
from ldai.models import AIAgentGraph, AIAgentConfig, AIAgentGraphEdge
5+
from ldclient import Context
6+
7+
8+
class AgentGraphNode:
9+
"""
10+
Node in an agent graph.
11+
"""
12+
13+
default_false = AIAgentConfig(key="", enabled=False)
14+
15+
def __init__(
16+
self,
17+
key: str,
18+
config: AIAgentConfig,
19+
children: List[AIAgentGraphEdge],
20+
parent_graph: "AgentGraph",
21+
):
22+
self._key = key
23+
self._config = config
24+
self._children = children
25+
self._parent_graph = parent_graph
26+
27+
def get_key(self) -> str:
28+
"""Get the key of the node."""
29+
return self._key
30+
31+
def get_config(self) -> AIAgentConfig:
32+
"""Get the config of the node."""
33+
return self._config
34+
35+
def get_edges(self) -> List[AIAgentGraphEdge]:
36+
"""Get the edges of the node."""
37+
return self._children
38+
39+
def get_child_nodes(self) -> List["AgentGraphNode"]:
40+
"""Get the child nodes of the node as AgentGraphNode objects."""
41+
return [
42+
self._parent_graph.get_node(edge.targetConfig) for edge in self._children
43+
]
44+
45+
def is_terminal(self) -> bool:
46+
"""Check if the node is a terminal node."""
47+
return len(self._children) == 0
48+
49+
def get_parent_nodes(self) -> List["AgentGraphNode"]:
50+
"""Get the parent nodes of the node as AgentGraphNode objects."""
51+
return [
52+
self._parent_graph.get_node(edge.sourceConfig)
53+
for edge in self._parent_graph._get_parent_edges(self._key)
54+
]
55+
56+
def traverse(
57+
self, fn: Callable[["AgentGraphNode", Dict[str, Any]], None], execution_context: Dict[str, Any] = {}, visited: Optional[Set[str]] = None
58+
) -> None:
59+
"""Traverse the graph downwardly from this node, calling fn on each node."""
60+
if visited is None:
61+
visited = set()
62+
63+
# Avoid cycles by tracking visited nodes
64+
if self._key in visited:
65+
return
66+
67+
visited.add(self._key)
68+
fn(self, execution_context)
69+
70+
for child in self._children:
71+
node = self._parent_graph.get_node(child.targetConfig)
72+
if node is not None:
73+
node.traverse(fn, execution_context, visited)
74+
75+
def reverse_traverse(
76+
self,
77+
fn: Callable[["AgentGraphNode", Dict[str, Any]], None],
78+
execution_context: Dict[str, Any] = {},
79+
visited: Optional[Set[str]] = None,
80+
) -> None:
81+
"""Reverse traverse the graph upwardly from this node, calling fn on each node."""
82+
if visited is None:
83+
visited = set()
84+
85+
# Avoid cycles by tracking visited nodes
86+
if self._key in visited:
87+
return
88+
89+
visited.add(self._key)
90+
fn(self, execution_context)
91+
92+
for parent in self._parent_graph._get_parent_edges(self._key):
93+
node = self._parent_graph.get_node(parent.sourceConfig)
94+
if node is not None:
95+
node.reverse_traverse(fn, execution_context, visited)
96+
97+
98+
class AgentGraph:
99+
"""
100+
Graph implementation for managing AI agent graphs.
101+
"""
102+
103+
default_false = AIAgentConfig(key="", enabled=False)
104+
105+
def __init__(
106+
self,
107+
agent_graph: AIAgentGraph,
108+
context: Context,
109+
get_agent: Callable[[str, Context, dict], AIAgentConfig],
110+
):
111+
self._agent_graph = agent_graph
112+
self._context = context
113+
self._get_agent = get_agent
114+
self._nodes = self._build_nodes()
115+
116+
def _build_nodes(self) -> Dict[str, AgentGraphNode]:
117+
"""Build the nodes of the graph into AgentGraphNode objects."""
118+
nodes = {
119+
self._agent_graph.rootConfigKey: AgentGraphNode(
120+
self._agent_graph.rootConfigKey,
121+
self._get_agent(
122+
self._agent_graph.rootConfigKey, self._context, self.default_false
123+
),
124+
self._get_child_edges(self._agent_graph.rootConfigKey),
125+
self,
126+
),
127+
}
128+
129+
for edge in self._agent_graph.edges:
130+
nodes[edge.targetConfig] = AgentGraphNode(
131+
edge.targetConfig,
132+
self._get_agent(edge.targetConfig, self._context, self.default_false),
133+
self._get_child_edges(edge.targetConfig),
134+
self,
135+
)
136+
137+
return nodes
138+
139+
def _get_child_edges(self, config_key: str) -> List[AIAgentGraphEdge]:
140+
"""Get the child edges of the given config."""
141+
return [
142+
edge
143+
for edge in self._agent_graph.edges
144+
if edge.sourceConfig == config_key
145+
]
146+
147+
def _get_parent_edges(self, config_key: str) -> List[AIAgentGraphEdge]:
148+
"""Get the parent edges of the given config."""
149+
return [
150+
edge
151+
for edge in self._agent_graph.edges
152+
if edge.targetConfig == config_key
153+
]
154+
155+
def _collect_nodes(
156+
self,
157+
node: AgentGraphNode,
158+
node_depths: Dict[str, int],
159+
nodes_by_depth: Dict[int, List[AgentGraphNode]],
160+
visited: Set[str],
161+
) -> None:
162+
"""Collect all reachable nodes from the given node and group them by depth."""
163+
node_key = node.get_key()
164+
if node_key in visited:
165+
return
166+
visited.add(node_key)
167+
168+
node_depth = node_depths.get(node_key, 0)
169+
if node_depth not in nodes_by_depth:
170+
nodes_by_depth[node_depth] = []
171+
nodes_by_depth[node_depth].append(node)
172+
173+
for child in node.get_child_nodes():
174+
self._collect_nodes(child, node_depths, nodes_by_depth, visited)
175+
176+
def terminal_nodes(self) -> List[AgentGraphNode]:
177+
"""Get the terminal nodes of the graph, meaning any nodes without children."""
178+
return [
179+
node for node in self._nodes.values() if len(node.get_child_nodes()) == 0
180+
]
181+
182+
def root(self) -> AgentGraphNode | None:
183+
"""Get the root node of the graph."""
184+
config = self._get_agent(
185+
self._agent_graph.rootConfigKey, self._context, self.default_false
186+
)
187+
188+
if config.enabled is False:
189+
return None
190+
191+
children = [
192+
edge
193+
for edge in self._agent_graph.edges
194+
if edge.sourceConfig == self._agent_graph.rootConfigKey
195+
]
196+
197+
node = AgentGraphNode(self._agent_graph.rootConfigKey, config, children, self)
198+
199+
return node
200+
201+
def traverse(self, fn: Callable[["AgentGraphNode", Dict[str, Any]], None], execution_context: Dict[str, Any] = {}) -> None:
202+
"""Traverse from the root down to terminal nodes, visiting nodes in order of depth.
203+
Nodes with the longest paths from the root (deepest nodes) will always be visited last."""
204+
root_node = self.root()
205+
if root_node is None:
206+
return
207+
208+
node_depths: Dict[str, int] = {root_node.get_key(): 0}
209+
current_level: List[AgentGraphNode] = [root_node]
210+
depth = 0
211+
max_depth_limit = 10 # Infinite loop protection limit
212+
213+
while current_level and depth < max_depth_limit:
214+
next_level: List[AgentGraphNode] = []
215+
depth += 1
216+
217+
for node in current_level:
218+
for child in node.get_child_nodes():
219+
child_key = child.get_key()
220+
# Defer this child to the next level if it's at a longer path
221+
if child_key not in node_depths or (
222+
depth > node_depths[child_key] and depth < max_depth_limit
223+
):
224+
node_depths[child_key] = depth
225+
next_level.append(child)
226+
227+
current_level = next_level
228+
229+
# Group all nodes by depth
230+
nodes_by_depth: Dict[int, List[AgentGraphNode]] = {}
231+
visited: Set[str] = set()
232+
233+
self._collect_nodes(root_node, node_depths, nodes_by_depth, visited)
234+
# Execute the lambda at this level for the nodes at this depth
235+
for depth_level in sorted(nodes_by_depth.keys()):
236+
for node in nodes_by_depth[depth_level]:
237+
execution_context[node.get_key()] = fn(node, execution_context)
238+
239+
return execution_context[self._agent_graph.rootConfigKey]
240+
241+
def reverse_traverse(self, fn: Callable[["AgentGraphNode", Dict[str, Any]], Any], execution_context: Dict[str, Any] = {}) -> None:
242+
"""Traverse from terminal nodes up to the root, visiting nodes level by level.
243+
The root node will always be visited last, even if multiple paths converge at it."""
244+
terminal_nodes = self.terminal_nodes()
245+
if not terminal_nodes:
246+
return
247+
248+
visited: Set[str] = set()
249+
current_level: List[AgentGraphNode] = terminal_nodes
250+
root_key = self._agent_graph.rootConfigKey
251+
root_node_seen = False
252+
253+
while current_level:
254+
next_level: List[AgentGraphNode] = []
255+
256+
for node in current_level:
257+
node_key = node.get_key()
258+
if node_key in visited:
259+
continue
260+
261+
visited.add(node_key)
262+
# Skip the root node if we reach a terminus, it will be visited last
263+
if node_key == root_key:
264+
root_node_seen = True
265+
continue
266+
267+
execution_context[node_key] = fn(node, execution_context)
268+
269+
for parent in node.get_parent_nodes():
270+
parent_key = parent.get_key()
271+
if parent_key not in visited:
272+
next_level.append(parent)
273+
274+
current_level = next_level
275+
276+
# If we saw the root node, append it at the end as it'll always be the last node in a
277+
# reverse traversal (this should always happen, non-contiguous graphs are invalid)
278+
if root_node_seen:
279+
root_node = self.root()
280+
if root_node is not None:
281+
execution_context[root_node.get_key()] = fn(root_node, execution_context)
282+
283+
return execution_context[self._agent_graph.rootConfigKey]
284+
285+
def get_node(self, key: str) -> AgentGraphNode | None:
286+
"""Get a node by its key."""
287+
return self._nodes.get(key)

0 commit comments

Comments
 (0)