11"""Graph implementation for managing AI agent graphs."""
22
3- from typing import Any , Callable , Dict , List , Optional , Set
4- from ldai .models import AIAgentGraph , AIAgentConfig , AIAgentGraphEdge
3+ from typing import Any , Callable , Dict , List , Set
4+ from ldai .models import AIAgentGraphConfig , AIAgentConfig , Edge
55from ldclient import Context
66
7-
7+ DEFAULT_FALSE = AIAgentConfig ( key = "" , enabled = False )
88class AgentGraphNode :
99 """
1010 Node in an agent graph.
1111 """
1212
13- default_false = AIAgentConfig (key = "" , enabled = False )
14-
1513 def __init__ (
1614 self ,
1715 key : str ,
1816 config : AIAgentConfig ,
19- children : List [AIAgentGraphEdge ],
20- parent_graph : "AgentGraph" ,
17+ children : List [Edge ],
2118 ):
2219 self ._key = key
2320 self ._config = config
2421 self ._children = children
25- self ._parent_graph = parent_graph
2622
2723 def get_key (self ) -> str :
2824 """Get the key of the node."""
@@ -32,124 +28,85 @@ def get_config(self) -> AIAgentConfig:
3228 """Get the config of the node."""
3329 return self ._config
3430
35- def get_edges (self ) -> List [AIAgentGraphEdge ]:
36- """Get the edges of the node."""
37- return self ._children
38-
39- def get_child_nodes (self ) -> List ["AgentGraphNode" ]:
40- """Get the child nodes of the node as AgentGraphNode objects."""
41- return [
42- self ._parent_graph .get_node (edge .targetConfig ) for edge in self ._children
43- ]
44-
4531 def is_terminal (self ) -> bool :
4632 """Check if the node is a terminal node."""
47- return len (self ._children ) == 0
48-
49- def get_parent_nodes (self ) -> List ["AgentGraphNode" ]:
50- """Get the parent nodes of the node as AgentGraphNode objects."""
51- return [
52- self ._parent_graph .get_node (edge .sourceConfig )
53- for edge in self ._parent_graph ._get_parent_edges (self ._key )
54- ]
55-
56- def traverse (
57- self , fn : Callable [["AgentGraphNode" , Dict [str , Any ]], None ], execution_context : Dict [str , Any ] = {}, visited : Optional [Set [str ]] = None
58- ) -> None :
59- """Traverse the graph downwardly from this node, calling fn on each node."""
60- if visited is None :
61- visited = set ()
62-
63- # Avoid cycles by tracking visited nodes
64- if self ._key in visited :
65- return
66-
67- visited .add (self ._key )
68- fn (self , execution_context )
69-
70- for child in self ._children :
71- node = self ._parent_graph .get_node (child .targetConfig )
72- if node is not None :
73- node .traverse (fn , execution_context , visited )
74-
75- def reverse_traverse (
76- self ,
77- fn : Callable [["AgentGraphNode" , Dict [str , Any ]], None ],
78- execution_context : Dict [str , Any ] = {},
79- visited : Optional [Set [str ]] = None ,
80- ) -> None :
81- """Reverse traverse the graph upwardly from this node, calling fn on each node."""
82- if visited is None :
83- visited = set ()
84-
85- # Avoid cycles by tracking visited nodes
86- if self ._key in visited :
87- return
88-
89- visited .add (self ._key )
90- fn (self , execution_context )
91-
92- for parent in self ._parent_graph ._get_parent_edges (self ._key ):
93- node = self ._parent_graph .get_node (parent .sourceConfig )
94- if node is not None :
95- node .reverse_traverse (fn , execution_context , visited )
33+ return len (self ._children ) == 0
9634
35+ def get_edges (self ) -> List [Edge ]:
36+ """Get the edges of the node."""
37+ return self ._children
9738
98- class AgentGraph :
39+ class AgentGraphDefinition :
9940 """
10041 Graph implementation for managing AI agent graphs.
10142 """
102-
103- default_false = AIAgentConfig (key = "" , enabled = False )
104-
10543 def __init__ (
10644 self ,
107- agent_graph : AIAgentGraph ,
45+ agent_graph : AIAgentGraphConfig ,
46+ nodes : Dict [str , AgentGraphNode ],
10847 context : Context ,
109- get_agent : Callable [[str , Context , dict ], AIAgentConfig ],
11048 ):
11149 self ._agent_graph = agent_graph
11250 self ._context = context
113- self ._get_agent = get_agent
114- self ._nodes = self ._build_nodes ()
51+ self ._nodes = nodes
11552
116- def _build_nodes (self ) -> Dict [str , AgentGraphNode ]:
53+ @staticmethod
54+ def build_nodes (
55+ agent_graph : AIAgentGraphConfig ,
56+ graph_nodes : Dict [str , AIAgentConfig ],
57+ ) -> Dict [str , "AgentGraphNode" ]:
11758 """Build the nodes of the graph into AgentGraphNode objects."""
11859 nodes = {
119- self ._agent_graph .rootConfigKey : AgentGraphNode (
120- self ._agent_graph .rootConfigKey ,
121- self ._get_agent (
122- self ._agent_graph .rootConfigKey , self ._context , self .default_false
123- ),
124- self ._get_child_edges (self ._agent_graph .rootConfigKey ),
125- self ,
60+ agent_graph .root_config_key : AgentGraphNode (
61+ agent_graph .root_config_key ,
62+ graph_nodes [agent_graph .root_config_key ],
63+ [
64+ edge
65+ for edge in agent_graph .edges
66+ if edge .source_config == agent_graph .root_config_key
67+ ],
12668 ),
12769 }
12870
129- for edge in self ._agent_graph .edges :
130- nodes [edge .targetConfig ] = AgentGraphNode (
131- edge .targetConfig ,
132- self ._get_agent (edge .targetConfig , self ._context , self .default_false ),
133- self ._get_child_edges (edge .targetConfig ),
134- self ,
71+ for edge in agent_graph .edges :
72+ nodes [edge .target_config ] = AgentGraphNode (
73+ edge .target_config ,
74+ graph_nodes [edge .target_config ],
75+ [
76+ e
77+ for e in agent_graph .edges
78+ if e .source_config == edge .target_config
79+ ],
13580 )
13681
13782 return nodes
13883
139- def _get_child_edges (self , config_key : str ) -> List [AIAgentGraphEdge ]:
84+ def get_node (self , key : str ) -> AgentGraphNode | None :
85+ """Get a node by its key."""
86+ return self ._nodes .get (key )
87+
88+ def _get_child_edges (self , config_key : str ) -> List [Edge ]:
14089 """Get the child edges of the given config."""
14190 return [
14291 edge
14392 for edge in self ._agent_graph .edges
144- if edge .sourceConfig == config_key
93+ if edge .source_config == config_key
14594 ]
14695
147- def _get_parent_edges (self , config_key : str ) -> List [AIAgentGraphEdge ]:
148- """Get the parent edges of the given config ."""
96+ def get_child_nodes (self , node_key : str ) -> List [AgentGraphNode ]:
97+ """Get the child nodes of the given node key as AgentGraphNode objects ."""
14998 return [
150- edge
99+ self .get_node (edge .target_config )
100+ for edge in self ._agent_graph .edges
101+ if edge .source_config == node_key and self .get_node (edge .target_config ) is not None
102+ ]
103+
104+ def get_parent_nodes (self , node_key : str ) -> List [AgentGraphNode ]:
105+ """Get the parent nodes of the given node key as AgentGraphNode objects."""
106+ return [
107+ self .get_node (edge .source_config )
151108 for edge in self ._agent_graph .edges
152- if edge .targetConfig == config_key
109+ if edge .target_config == node_key and self . get_node ( edge . source_config ) is not None
153110 ]
154111
155112 def _collect_nodes (
@@ -170,33 +127,18 @@ def _collect_nodes(
170127 nodes_by_depth [node_depth ] = []
171128 nodes_by_depth [node_depth ].append (node )
172129
173- for child in node .get_child_nodes ():
130+ for child in self .get_child_nodes (node_key ):
174131 self ._collect_nodes (child , node_depths , nodes_by_depth , visited )
175132
176133 def terminal_nodes (self ) -> List [AgentGraphNode ]:
177134 """Get the terminal nodes of the graph, meaning any nodes without children."""
178135 return [
179- node for node in self ._nodes .values () if len (node .get_child_nodes ()) == 0
136+ node for node in self ._nodes .values () if len (self .get_child_nodes (node . get_key () )) == 0
180137 ]
181138
182139 def root (self ) -> AgentGraphNode | None :
183140 """Get the root node of the graph."""
184- config = self ._get_agent (
185- self ._agent_graph .rootConfigKey , self ._context , self .default_false
186- )
187-
188- if config .enabled is False :
189- return None
190-
191- children = [
192- edge
193- for edge in self ._agent_graph .edges
194- if edge .sourceConfig == self ._agent_graph .rootConfigKey
195- ]
196-
197- node = AgentGraphNode (self ._agent_graph .rootConfigKey , config , children , self )
198-
199- return node
141+ return self ._nodes [self ._agent_graph .root_config_key ]
200142
201143 def traverse (self , fn : Callable [["AgentGraphNode" , Dict [str , Any ]], None ], execution_context : Dict [str , Any ] = {}) -> None :
202144 """Traverse from the root down to terminal nodes, visiting nodes in order of depth.
@@ -215,7 +157,8 @@ def traverse(self, fn: Callable[["AgentGraphNode", Dict[str, Any]], None], execu
215157 depth += 1
216158
217159 for node in current_level :
218- for child in node .get_child_nodes ():
160+ node_key = node .get_key ()
161+ for child in self .get_child_nodes (node_key ):
219162 child_key = child .get_key ()
220163 # Defer this child to the next level if it's at a longer path
221164 if child_key not in node_depths or (
@@ -236,7 +179,7 @@ def traverse(self, fn: Callable[["AgentGraphNode", Dict[str, Any]], None], execu
236179 for node in nodes_by_depth [depth_level ]:
237180 execution_context [node .get_key ()] = fn (node , execution_context )
238181
239- return execution_context [self ._agent_graph .rootConfigKey ]
182+ return execution_context [self ._agent_graph .root_config_key ]
240183
241184 def reverse_traverse (self , fn : Callable [["AgentGraphNode" , Dict [str , Any ]], Any ], execution_context : Dict [str , Any ] = {}) -> None :
242185 """Traverse from terminal nodes up to the root, visiting nodes level by level.
@@ -247,7 +190,7 @@ def reverse_traverse(self, fn: Callable[["AgentGraphNode", Dict[str, Any]], Any]
247190
248191 visited : Set [str ] = set ()
249192 current_level : List [AgentGraphNode ] = terminal_nodes
250- root_key = self ._agent_graph .rootConfigKey
193+ root_key = self ._agent_graph .root_config_key
251194 root_node_seen = False
252195
253196 while current_level :
@@ -266,7 +209,7 @@ def reverse_traverse(self, fn: Callable[["AgentGraphNode", Dict[str, Any]], Any]
266209
267210 execution_context [node_key ] = fn (node , execution_context )
268211
269- for parent in node .get_parent_nodes ():
212+ for parent in self .get_parent_nodes (node_key ):
270213 parent_key = parent .get_key ()
271214 if parent_key not in visited :
272215 next_level .append (parent )
@@ -280,8 +223,6 @@ def reverse_traverse(self, fn: Callable[["AgentGraphNode", Dict[str, Any]], Any]
280223 if root_node is not None :
281224 execution_context [root_node .get_key ()] = fn (root_node , execution_context )
282225
283- return execution_context [self ._agent_graph .rootConfigKey ]
226+ return execution_context [self ._agent_graph .root_config_key ]
227+
284228
285- def get_node (self , key : str ) -> AgentGraphNode | None :
286- """Get a node by its key."""
287- return self ._nodes .get (key )
0 commit comments