55from enum import Enum
66
77import accelforge .frontend .arch as arch
8+ from accelforge .frontend .arch ._flattened_arch import FlattenedArch
89from accelforge .util ._frozenset import oset
910from accelforge .frontend .mapping import (
1011 MappingNode ,
@@ -39,15 +40,14 @@ def insert_temporal_loops(
3940 ranks_with_tile_pattern : set ,
4041 workload : Workload ,
4142 _can_lower_outermost_memory : bool ,
42- flattened_arch : list [ arch . Leaf ] ,
43+ flattened_arch : FlattenedArch ,
4344 max_fused_loops : int ,
4445 fanouts : dict [str , int ],
4546 fusable_tensors : set [TensorName ],
4647 intermediate_tensors : set [TensorName ],
4748 let_non_intermediate_tensors_respawn_in_backing_storage : bool ,
4849 explore_loop_orders : bool ,
4950):
50- arch_node_names = [n .name for n in flattened_arch ]
5151 # First establish insertion points. Insertion points are:
5252 # - Below the last instance of the first memory
5353 # - Between any two TensorHolder nodes
@@ -88,7 +88,7 @@ def insert_temporal_loops(
8888 for s in split_mapping :
8989 # Within each split mapping group, sort by arch levels.
9090 # This can help create places to put spatial loops
91- s .sort (key = lambda tensor_holder : arch_node_names .index (tensor_holder .component ))
91+ s .sort (key = lambda tensor_holder : flattened_arch .index (tensor_holder .component ))
9292
9393 if sum (map (len , split_mapping )) != len (mapping ):
9494 raise RuntimeError ("BUG: number of storage nodes post-split != original" )
@@ -333,11 +333,10 @@ def _get_next_storages(i: int, toll_allowed: bool = False) -> list[TensorHolder]
333333def insert_spatial_loops (
334334 mapping : list [MappingNode ],
335335 einsum : Einsum ,
336- flattened_arch : list [ arch . Memory ] ,
336+ flattened_arch : FlattenedArch ,
337337 intermediate_tensors : set [TensorName ],
338338):
339339 nodes_with_fanout = [n for n in flattened_arch if n .get_fanout () > 1 ]
340- arch_node_names = [n .name for n in flattened_arch ]
341340 tensor2fully_relevant_rank_vars = einsum .tensor2directly_indexing_rank_variables
342341 simple_rank_variables = einsum ._simple_rank_variables
343342
@@ -346,7 +345,7 @@ def insert_spatial_loops(
346345 # above the fanout in the arch, and below any temporal loops in the
347346 # same block.
348347 insertion_point = _idx_below_lowest_tensor_holder_with_component_above_fanout (
349- node , mapping , arch_node_names
348+ node , mapping , flattened_arch
350349 )
351350 while insertion_point < len (mapping ) and isinstance (
352351 mapping [insertion_point ], Temporal
@@ -386,16 +385,16 @@ def _tensors_seen_above_point(idx, mapping):
386385
387386
388387def _idx_below_lowest_tensor_holder_with_component_above_fanout (
389- fanout_node , mapping , arch_node_names
388+ fanout_node , mapping , flattened_arch : FlattenedArch
390389):
391390 """Return the index right after the lowest TensorHolder whose component
392391 is above the fanout in the arch. If none found, returns len(mapping)."""
393- fanout_arch_idx = arch_node_names .index (fanout_node .name )
392+ fanout_arch_idx = flattened_arch .index (fanout_node .name )
394393 result = 0
395394 for i in range (len (mapping )):
396395 if not isinstance (mapping [i ], TensorHolder ):
397396 continue
398- if arch_node_names .index (mapping [i ].component ) < fanout_arch_idx :
397+ if flattened_arch .index (mapping [i ].component ) < fanout_arch_idx :
399398 result = i + 1
400399 return result
401400
0 commit comments