11"""Graph implementation for managing AI agent graphs."""
22
3- from typing import Any , Callable , Dict , List , Set
4- from ldai . models import AIAgentGraphConfig , AIAgentConfig , Edge
3+ from typing import Any , Callable , Dict , List , Optional , Set
4+
55from ldclient import Context
66
7+ from ldai .models import AIAgentConfig , AIAgentGraphConfig , Edge
8+
79DEFAULT_FALSE = AIAgentConfig (key = "" , enabled = False )
10+
11+
812class AgentGraphNode :
913 """
1014 Node in an agent graph.
@@ -36,10 +40,12 @@ def get_edges(self) -> List[Edge]:
3640 """Get the edges of the node."""
3741 return self ._children
3842
43+
3944class AgentGraphDefinition :
4045 """
4146 Graph implementation for managing AI agent graphs.
4247 """
48+
4349 def __init__ (
4450 self ,
4551 agent_graph : AIAgentGraphConfig ,
@@ -72,42 +78,40 @@ def build_nodes(
7278 nodes [edge .target_config ] = AgentGraphNode (
7379 edge .target_config ,
7480 graph_nodes [edge .target_config ],
75- [
76- e
77- for e in agent_graph .edges
78- if e .source_config == edge .target_config
79- ],
81+ [e for e in agent_graph .edges if e .source_config == edge .target_config ],
8082 )
8183
8284 return nodes
8385
84- def get_node (self , key : str ) -> AgentGraphNode | None :
86+ def get_node (self , key : str ) -> Optional [ AgentGraphNode ] :
8587 """Get a node by its key."""
8688 return self ._nodes .get (key )
8789
8890 def _get_child_edges (self , config_key : str ) -> List [Edge ]:
8991 """Get the child edges of the given config."""
9092 return [
91- edge
92- for edge in self ._agent_graph .edges
93- if edge .source_config == config_key
93+ edge for edge in self ._agent_graph .edges if edge .source_config == config_key
9494 ]
9595
9696 def get_child_nodes (self , node_key : str ) -> List [AgentGraphNode ]:
9797 """Get the child nodes of the given node key as AgentGraphNode objects."""
98- return [
99- self .get_node (edge .target_config )
100- for edge in self ._agent_graph .edges
101- if edge .source_config == node_key and self .get_node (edge .target_config ) is not None
102- ]
98+ nodes : List [AgentGraphNode ] = []
99+ for edge in self ._agent_graph .edges :
100+ if edge .source_config == node_key :
101+ node = self .get_node (edge .target_config )
102+ if node is not None :
103+ nodes .append (node )
104+ return nodes
103105
104106 def get_parent_nodes (self , node_key : str ) -> List [AgentGraphNode ]:
105107 """Get the parent nodes of the given node key as AgentGraphNode objects."""
106- return [
107- self .get_node (edge .source_config )
108- for edge in self ._agent_graph .edges
109- if edge .target_config == node_key and self .get_node (edge .source_config ) is not None
110- ]
108+ nodes : List [AgentGraphNode ] = []
109+ for edge in self ._agent_graph .edges :
110+ if edge .target_config == node_key :
111+ node = self .get_node (edge .source_config )
112+ if node is not None :
113+ nodes .append (node )
114+ return nodes
111115
112116 def _collect_nodes (
113117 self ,
@@ -133,14 +137,20 @@ def _collect_nodes(
133137 def terminal_nodes (self ) -> List [AgentGraphNode ]:
134138 """Get the terminal nodes of the graph, meaning any nodes without children."""
135139 return [
136- node for node in self ._nodes .values () if len (self .get_child_nodes (node .get_key ())) == 0
140+ node
141+ for node in self ._nodes .values ()
142+ if len (self .get_child_nodes (node .get_key ())) == 0
137143 ]
138144
139- def root (self ) -> AgentGraphNode | None :
145+ def root (self ) -> Optional [ AgentGraphNode ] :
140146 """Get the root node of the graph."""
141147 return self ._nodes [self ._agent_graph .root_config_key ]
142148
143- def traverse (self , fn : Callable [["AgentGraphNode" , Dict [str , Any ]], None ], execution_context : Dict [str , Any ] = {}) -> None :
149+ def traverse (
150+ self ,
151+ fn : Callable [["AgentGraphNode" , Dict [str , Any ]], Any ],
152+ execution_context : Dict [str , Any ] = {},
153+ ) -> None :
144154 """Traverse from the root down to terminal nodes, visiting nodes in order of depth.
145155 Nodes with the longest paths from the root (deepest nodes) will always be visited last."""
146156 root_node = self .root ()
@@ -181,7 +191,11 @@ def traverse(self, fn: Callable[["AgentGraphNode", Dict[str, Any]], None], execu
181191
182192 return execution_context [self ._agent_graph .root_config_key ]
183193
184- def reverse_traverse (self , fn : Callable [["AgentGraphNode" , Dict [str , Any ]], Any ], execution_context : Dict [str , Any ] = {}) -> None :
194+ def reverse_traverse (
195+ self ,
196+ fn : Callable [["AgentGraphNode" , Dict [str , Any ]], Any ],
197+ execution_context : Dict [str , Any ] = {},
198+ ) -> None :
185199 """Traverse from terminal nodes up to the root, visiting nodes level by level.
186200 The root node will always be visited last, even if multiple paths converge at it."""
187201 terminal_nodes = self .terminal_nodes ()
@@ -208,7 +222,7 @@ def reverse_traverse(self, fn: Callable[["AgentGraphNode", Dict[str, Any]], Any]
208222 continue
209223
210224 execution_context [node_key ] = fn (node , execution_context )
211-
225+
212226 for parent in self .get_parent_nodes (node_key ):
213227 parent_key = parent .get_key ()
214228 if parent_key not in visited :
@@ -221,8 +235,8 @@ def reverse_traverse(self, fn: Callable[["AgentGraphNode", Dict[str, Any]], Any]
221235 if root_node_seen :
222236 root_node = self .root ()
223237 if root_node is not None :
224- execution_context [root_node .get_key ()] = fn (root_node , execution_context )
238+ execution_context [root_node .get_key ()] = fn (
239+ root_node , execution_context
240+ )
225241
226242 return execution_context [self ._agent_graph .root_config_key ]
227-
228-
0 commit comments