@@ -184,11 +184,11 @@ def _get_selected_branch_fields(self, node: IfInvocation) -> tuple[str, str]:
184184
185185 def _prune_unselected_if_inputs (self , exec_node_id : str , unselected_field : str ) -> None :
186186 for edge in self ._state .execution_graph ._get_input_edges (exec_node_id , unselected_field ):
187- if edge .source .node_id in self ._state .executed :
188- continue
189- if self . _state . indegree [ exec_node_id ] == 0 :
190- raise RuntimeError ( f" indegree underflow for { exec_node_id } when pruning { unselected_field } " )
191- self ._state .indegree [ exec_node_id ] -= 1
187+ if edge .source .node_id not in self ._state .executed :
188+ if self . _state . indegree [ exec_node_id ] == 0 :
189+ raise RuntimeError ( f" indegree underflow for { exec_node_id } when pruning { unselected_field } " )
190+ self . _state . indegree [ exec_node_id ] -= 1
191+ self ._state .execution_graph . delete_edge ( edge )
192192
193193 def _apply_branch_resolution (
194194 self ,
@@ -243,6 +243,10 @@ def is_deferred_by_unresolved_if(self, exec_node_id: str) -> bool:
243243 return False
244244
245245 def mark_exec_node_skipped (self , exec_node_id : str ) -> None :
246+ state = self ._state ._get_prepared_exec_metadata (exec_node_id ).state
247+ if state in ("executed" , "skipped" ):
248+ return
249+
246250 self ._state ._remove_from_ready_queues (exec_node_id )
247251 self ._state ._set_prepared_exec_state (exec_node_id , "skipped" )
248252 self ._state .executed .add (exec_node_id )
@@ -356,7 +360,7 @@ def _initialize_execution_node(self, exec_node_id: str) -> None:
356360 def _get_collect_iteration_mappings (self , parent_node_ids : list [str ]) -> list [tuple [str , str ]]:
357361 all_iteration_mappings : list [tuple [str , str ]] = []
358362 for source_node_id in parent_node_ids :
359- prepared_nodes = self ._state . source_prepared_mapping [ source_node_id ]
363+ prepared_nodes = self ._get_prepared_nodes_for_source ( source_node_id )
360364 all_iteration_mappings .extend ((source_node_id , prepared_id ) for prepared_id in prepared_nodes )
361365 return all_iteration_mappings
362366
@@ -414,7 +418,11 @@ def get_node_iterators(self, node_id: str, it_graph: Optional[nx.DiGraph] = None
414418 return [n for n in nx .ancestors (g , node_id ) if isinstance (self ._state .graph .get_node (n ), IterateInvocation )]
415419
416420 def _get_prepared_nodes_for_source (self , source_node_id : str ) -> set [str ]:
417- return self ._state .source_prepared_mapping [source_node_id ]
421+ return {
422+ exec_node_id
423+ for exec_node_id in self ._state .source_prepared_mapping [source_node_id ]
424+ if self ._state ._get_prepared_exec_metadata (exec_node_id ).state != "skipped"
425+ }
418426
419427 def _get_parent_iterator_exec_nodes (
420428 self , source_node_id : str , graph : nx .DiGraph , prepared_iterator_nodes : list [str ]
@@ -471,10 +479,15 @@ def get_iteration_node(
471479 prepared_iterator_nodes : list [str ],
472480 ) -> Optional [str ]:
473481 prepared_nodes = self ._get_prepared_nodes_for_source (source_node_id )
474- if len (prepared_nodes ) == 1 :
482+ if len (prepared_nodes ) == 1 and not prepared_iterator_nodes :
475483 return next (iter (prepared_nodes ))
476484
477485 parent_iterators = self ._get_parent_iterator_exec_nodes (source_node_id , graph , prepared_iterator_nodes )
486+ if len (prepared_nodes ) == 1 :
487+ prepared_node_id = next (iter (prepared_nodes ))
488+ if self ._matches_parent_iterators (prepared_node_id , parent_iterators , execution_graph ):
489+ return prepared_node_id
490+ return None
478491
479492 direct_iterator_match = self ._get_direct_prepared_iterator_match (
480493 prepared_nodes , prepared_iterator_nodes , parent_iterators , execution_graph
@@ -733,6 +746,12 @@ def _sort_collect_input_edges(self, input_edges: list[Edge], field_name: str) ->
733746 def _get_copied_result_value (self , edge : Edge ) -> Any :
734747 return copydeep (getattr (self ._state .results [edge .source .node_id ], edge .source .field ))
735748
749+ def _try_get_copied_result_value (self , edge : Edge ) -> tuple [bool , Any ]:
750+ source_output = self ._state .results .get (edge .source .node_id )
751+ if source_output is None :
752+ return False , None
753+ return True , copydeep (getattr (source_output , edge .source .field ))
754+
736755 def _build_collect_collection (self , input_edges : list [Edge ]) -> list [Any ]:
737756 item_edges = self ._sort_collect_input_edges (input_edges , ITEM_FIELD )
738757 collection_edges = self ._sort_collect_input_edges (input_edges , COLLECTION_FIELD )
@@ -761,7 +780,20 @@ def _prepare_collect_inputs(self, node: "CollectInvocation", input_edges: list[E
761780 def _prepare_if_inputs (self , node : IfInvocation , input_edges : list [Edge ]) -> None :
762781 selected_field = self ._state ._resolved_if_exec_branches .get (node .id )
763782 allowed_fields = {"condition" , selected_field } if selected_field is not None else {"condition" }
764- self ._set_node_inputs (node , input_edges , allowed_fields )
783+
784+ for edge in input_edges :
785+ if edge .destination .field not in allowed_fields :
786+ continue
787+
788+ found_value , copied_value = self ._try_get_copied_result_value (edge )
789+ if not found_value :
790+ iteration_path = self ._state ._get_iteration_path (node .id )
791+ raise RuntimeError (
792+ "IfInvocation selected input edge points at an exec node with no stored result output: "
793+ f"if_exec_id={ node .id } , source_exec_id={ edge .source .node_id } , iteration_path={ iteration_path } "
794+ )
795+
796+ setattr (node , edge .destination .field , copied_value )
765797
766798 def _prepare_default_inputs (self , node : BaseInvocation , input_edges : list [Edge ]) -> None :
767799 self ._set_node_inputs (node , input_edges )
0 commit comments