Skip to content

Commit ee3bc90

Browse files
committed
[REL-11697] add protection if graph exhausts max_depth_limit setting
1 parent c1aefd1 commit ee3bc90

2 files changed

Lines changed: 28 additions & 11 deletions

File tree

packages/sdk/server-ai/src/ldai/agent_graph/__init__.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -127,20 +127,22 @@ def _collect_nodes(
127127
node_depths: Dict[str, int],
128128
nodes_by_depth: Dict[int, List[AgentGraphNode]],
129129
visited: Set[str],
130+
max_depth: int,
130131
) -> None:
131132
"""Collect all reachable nodes from the given node and group them by depth."""
132133
node_key = node.get_key()
133134
if node_key in visited:
134135
return
135136
visited.add(node_key)
136137

137-
node_depth = node_depths.get(node_key, 0)
138+
# Use max_depth for nodes not in node_depths to ensure they execute last
139+
node_depth = node_depths.get(node_key, max_depth)
138140
if node_depth not in nodes_by_depth:
139141
nodes_by_depth[node_depth] = []
140142
nodes_by_depth[node_depth].append(node)
141143

142144
for child in self.get_child_nodes(node_key):
143-
self._collect_nodes(child, node_depths, nodes_by_depth, visited)
145+
self._collect_nodes(child, node_depths, nodes_by_depth, visited, max_depth)
144146

145147
def terminal_nodes(self) -> List[AgentGraphNode]:
146148
"""Get the terminal nodes of the graph, meaning any nodes without children."""
@@ -172,29 +174,44 @@ def traverse(
172174
current_level: List[AgentGraphNode] = [root_node]
173175
depth = 0
174176
max_depth_limit = 10 # Infinite loop protection limit
177+
max_depth_encountered = 0
178+
visited: Set[str] = {root_node.get_key()} # Track visited nodes in BFS to prevent cycles
175179

176-
while current_level and depth < max_depth_limit:
180+
# Continue BFS to discover all nodes, but stop recording depths after max_depth_limit
181+
while current_level:
177182
next_level: List[AgentGraphNode] = []
178183
depth += 1
179184

180185
for node in current_level:
181186
node_key = node.get_key()
182187
for child in self.get_child_nodes(node_key):
183188
child_key = child.get_key()
184-
# Defer this child to the next level if it's at a longer path
185-
if child_key not in node_depths or (
186-
depth > node_depths[child_key] and depth < max_depth_limit
187-
):
188-
node_depths[child_key] = depth
189-
next_level.append(child)
189+
if depth <= max_depth_limit:
190+
# Defer this child to the next level if it's at a longer path
191+
if child_key not in node_depths or depth > node_depths[child_key]:
192+
node_depths[child_key] = depth
193+
max_depth_encountered = max(max_depth_encountered, depth)
194+
# Add to next level if not already visited (prevents cycles)
195+
if child_key not in visited:
196+
visited.add(child_key)
197+
next_level.append(child)
198+
else:
199+
max_depth_encountered = max(max_depth_encountered, depth)
200+
if child_key not in visited:
201+
# Push this to the next level to be visited
202+
visited.add(child_key)
203+
next_level.append(child)
190204

191205
current_level = next_level
192206

207+
# Use max_depth_limit + 1 to ensure they execute after all recorded nodes
208+
max_depth = max(max_depth_limit + 1, max_depth_encountered + 1)
209+
193210
# Group all nodes by depth
194211
nodes_by_depth: Dict[int, List[AgentGraphNode]] = {}
195212
visited: Set[str] = set()
196213

197-
self._collect_nodes(root_node, node_depths, nodes_by_depth, visited)
214+
self._collect_nodes(root_node, node_depths, nodes_by_depth, visited, max_depth)
198215
# Execute the lambda at this level for the nodes at this depth
199216
for depth_level in sorted(nodes_by_depth.keys()):
200217
for node in nodes_by_depth[depth_level]:

packages/sdk/server-ai/src/ldai/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def agent_graph(
448448
)
449449

450450
all_agent_keys = [variation["rootConfigKey"]] + [
451-
edge["targetConfig"] for edge in variation.get("edges", [])
451+
edge.get("targetConfig", "") for edge in variation.get("edges", []) if edge.get("targetConfig")
452452
]
453453
agent_configs = {
454454
key: self.agent_config(key, context, AIAgentConfigDefault(enabled=False))

0 commit comments

Comments
 (0)