1313from cuda .tile ._exception import Loc , TileInternalError
1414from cuda .tile ._ir .ir import Block , IRContext , Var , Operation
1515from cuda .tile ._ir .ops import (
16- Assign , Break , Continue , EndBranch , IfElse ,
16+ Break , Continue , EndBranch , IfElse ,
1717 JoinTokens , LoadMemoryOperation , Loop , MakeToken ,
1818 MemoryOperation , StoreMemoryOperation , TileAtomicCAS , TileAtomicCASTokenOrdered ,
1919 TileAtomicRMW , TileAtomicRMWTokenOrdered , LoadPointer , LoadPointerTokenOrdered ,
@@ -106,7 +106,6 @@ class VarInfo:
106106@dataclass (frozen = True )
107107class TokenOrderContext :
108108 alias_result : AliasResult
109- var_info : VarInfo
110109 block_memory_effects : Dict [Block , MemoryEffects ]
111110
112111
@@ -123,8 +122,7 @@ class TokenOrderContext:
123122def token_order_pass (root_block : Block , alias_result : AliasResult ):
124123 block_memory_effects = {}
125124 _get_block_memory_effects (root_block , alias_result , block_memory_effects )
126- var_info = _get_var_info (root_block )
127- context = TokenOrderContext (alias_result , var_info , block_memory_effects )
125+ context = TokenOrderContext (alias_result , block_memory_effects )
128126
129127 root_tok = _make_token_var (root_block .ctx , root_block .loc )
130128 token_map = defaultdict (lambda : root_tok )
@@ -169,21 +167,6 @@ def get_memory_effects(cur_op):
169167 block_memory_effects [block ] = blk_mem_effects
170168
171169
172- # TODO: Assign ops should be gone at this point. Need to verify this and remove this logic.
173- def _get_var_info (root_block : Block ) -> VarInfo :
174- root_var = dict ()
175-
176- def traverse (block : Block ):
177- for op in block .operations :
178- if isinstance (op , Assign ):
179- root_var [op .result_var .name ] = root_var .get (op .value .name , op .value .name )
180- for block in op .nested_blocks :
181- traverse (block )
182-
183- traverse (root_block )
184- return VarInfo (root_var )
185-
186-
187170def _to_token_order_in_block (block : Block ,
188171 context : TokenOrderContext ,
189172 token_map : Dict [TokenKey , Var ],
@@ -523,7 +506,7 @@ def _get_parallel_stores(
523506 tile_store_candidates .add (mem_ops [0 ])
524507
525508 # Filter in stores that have non-overlapping indices
526- res = _filter_by_store_index (loop_op , tile_store_candidates , context . var_info )
509+ res = _filter_by_store_index (loop_op , tile_store_candidates )
527510 return res
528511
529512
@@ -539,13 +522,11 @@ def _get_nested_mem_effects(
539522
540523
541524def _filter_by_store_index (loop_op : Loop ,
542- tile_store_candidates : Set [Operation ],
543- var_info : VarInfo ) -> Set [Operation ]:
525+ tile_store_candidates : Set [Operation ]) -> Set [Operation ]:
544526
545527 def is_idx_injective (idx_var : Var ) -> bool :
546- root_idx_var = var_info .root_var .get (idx_var .name , idx_var .name )
547528 # TODO: allow more complex injective check: j = i * 2 + 3
548- return loop_op .is_for_loop and root_idx_var == loop_op .induction_var .name
529+ return loop_op .is_for_loop and idx_var . name == loop_op .induction_var .name
549530
550531 return set (store_op for store_op in tile_store_candidates
551532 if _get_input_var (store_op ).get_type ().elements_disjoint
0 commit comments