@@ -352,7 +352,7 @@ def batch_intervals(
352352 )
353353 for snapshot , intervals in merged_intervals .items ()
354354 }
355- snapshot_batches = {}
355+ snapshot_batches : t . Dict [ Snapshot , Intervals ] = {}
356356 all_unready_intervals : t .Dict [str , set [Interval ]] = {}
357357 for snapshot_id in dag :
358358 if snapshot_id not in snapshot_intervals :
@@ -364,13 +364,22 @@ def batch_intervals(
364364
365365 adapter = self .snapshot_evaluator .get_adapter (snapshot .model_gateway )
366366
367+ parent_intervals : Intervals = []
368+ for parent_id in snapshot .parents :
369+ parent_snapshot , _ = snapshot_intervals .get (parent_id , (None , []))
370+ if not parent_snapshot or parent_snapshot .is_external :
371+ continue
372+
373+ parent_intervals .extend (snapshot_batches [parent_snapshot ])
374+
367375 context = ExecutionContext (
368376 adapter ,
369377 self .snapshots_by_name ,
370378 deployability_index ,
371379 default_dialect = adapter .dialect ,
372380 default_catalog = self .default_catalog ,
373381 is_restatement = is_restatement ,
382+ parent_intervals = parent_intervals ,
374383 )
375384
376385 intervals = self ._check_ready_intervals (
@@ -538,6 +547,10 @@ def run_node(node: SchedulingUnit) -> None:
538547 execution_time = execution_time ,
539548 )
540549 else :
550+ # If batch_index > 0, then the target table must exist since the first batch would have created it
551+ target_table_exists = (
552+ snapshot .snapshot_id not in snapshots_to_create or node .batch_index > 0
553+ )
541554 audit_results = self .evaluate (
542555 snapshot = snapshot ,
543556 environment_naming_info = environment_naming_info ,
@@ -548,7 +561,7 @@ def run_node(node: SchedulingUnit) -> None:
548561 batch_index = node .batch_index ,
549562 allow_destructive_snapshots = allow_destructive_snapshots ,
550563 allow_additive_snapshots = allow_additive_snapshots ,
551- target_table_exists = snapshot . snapshot_id not in snapshots_to_create ,
564+ target_table_exists = target_table_exists ,
552565 selected_models = selected_models ,
553566 )
554567
@@ -646,6 +659,7 @@ def _dag(
646659 }
647660 snapshots_to_create = snapshots_to_create or set ()
648661 original_snapshots_to_create = snapshots_to_create .copy ()
662+ upstream_dependencies_cache : t .Dict [SnapshotId , t .Set [SchedulingUnit ]] = {}
649663
650664 snapshot_dag = snapshot_dag or snapshots_to_dag (batches )
651665 dag = DAG [SchedulingUnit ]()
@@ -657,12 +671,15 @@ def _dag(
657671 snapshot = self .snapshots_by_name [snapshot_id .name ]
658672 intervals = intervals_per_snapshot .get (snapshot .name , [])
659673
660- upstream_dependencies : t .List [SchedulingUnit ] = []
674+ upstream_dependencies : t .Set [SchedulingUnit ] = set ()
661675
662676 for p_sid in snapshot .parents :
663- upstream_dependencies .extend (
677+ upstream_dependencies .update (
664678 self ._find_upstream_dependencies (
665- p_sid , intervals_per_snapshot , original_snapshots_to_create
679+ p_sid ,
680+ intervals_per_snapshot ,
681+ original_snapshots_to_create ,
682+ upstream_dependencies_cache ,
666683 )
667684 )
668685
@@ -713,29 +730,42 @@ def _find_upstream_dependencies(
713730 parent_sid : SnapshotId ,
714731 intervals_per_snapshot : t .Dict [str , Intervals ],
715732 snapshots_to_create : t .Set [SnapshotId ],
716- ) -> t .List [SchedulingUnit ]:
733+ cache : t .Dict [SnapshotId , t .Set [SchedulingUnit ]],
734+ ) -> t .Set [SchedulingUnit ]:
717735 if parent_sid not in self .snapshots :
718- return []
736+ return set ()
737+ if parent_sid in cache :
738+ return cache [parent_sid ]
719739
720740 p_intervals = intervals_per_snapshot .get (parent_sid .name , [])
721741
742+ parent_node : t .Optional [SchedulingUnit ] = None
722743 if p_intervals :
723744 if len (p_intervals ) > 1 :
724- return [DummyNode (snapshot_name = parent_sid .name )]
725- interval = p_intervals [0 ]
726- return [EvaluateNode (snapshot_name = parent_sid .name , interval = interval , batch_index = 0 )]
727- if parent_sid in snapshots_to_create :
728- return [CreateNode (snapshot_name = parent_sid .name )]
745+ parent_node = DummyNode (snapshot_name = parent_sid .name )
746+ else :
747+ interval = p_intervals [0 ]
748+ parent_node = EvaluateNode (
749+ snapshot_name = parent_sid .name , interval = interval , batch_index = 0
750+ )
751+ elif parent_sid in snapshots_to_create :
752+ parent_node = CreateNode (snapshot_name = parent_sid .name )
753+
754+ if parent_node is not None :
755+ cache [parent_sid ] = {parent_node }
756+ return {parent_node }
757+
729758 # This snapshot has no intervals and doesn't need creation which means
730759 # that it can be a transitive dependency
731- transitive_deps : t .List [SchedulingUnit ] = []
760+ transitive_deps : t .Set [SchedulingUnit ] = set ()
732761 parent_snapshot = self .snapshots [parent_sid ]
733762 for grandparent_sid in parent_snapshot .parents :
734- transitive_deps .extend (
763+ transitive_deps .update (
735764 self ._find_upstream_dependencies (
736- grandparent_sid , intervals_per_snapshot , snapshots_to_create
765+ grandparent_sid , intervals_per_snapshot , snapshots_to_create , cache
737766 )
738767 )
768+ cache [parent_sid ] = transitive_deps
739769 return transitive_deps
740770
741771 def _run_or_audit (
0 commit comments