@@ -107,7 +107,6 @@ def token_order_pass(root_block: Block, alias_result: AliasResult):
107107 root_tok = _make_token_var (root_block .ctx , root_block .loc )
108108 token_map = defaultdict (lambda : root_tok )
109109 _to_token_order_in_block (root_block , context , token_map )
110- # Ensures Operation.parent_block is correctly set
111110 root_block [:0 ] = (MakeToken (result_vars = (root_tok ,), loc = root_block .loc ),)
112111
113112
@@ -164,7 +163,8 @@ def _to_token_order_in_block(block: Block,
164163 last_store_key = _last_store_key (alias_set )
165164
166165 input_tok , maybe_input_tok_join_op = _get_input_token (last_store_key , op ,
167- token_map , None )
166+ token_map , None ,
167+ block .ctx )
168168 if maybe_input_tok_join_op :
169169 operations .append (maybe_input_tok_join_op )
170170
@@ -185,7 +185,8 @@ def _to_token_order_in_block(block: Block,
185185 if (
186186 isinstance (op , TileStore )
187187 and (parallel_store_ops := _try_loop_parallel_store (op , context .alias_result ,
188- token_map , innermost_loop_info ))
188+ token_map , innermost_loop_info ,
189+ block .ctx ))
189190 ):
190191 operations .extend (parallel_store_ops )
191192 continue
@@ -195,7 +196,7 @@ def _to_token_order_in_block(block: Block,
195196 last_store_key = _last_store_key (alias_set )
196197
197198 input_tok , maybe_input_tok_join_op = _get_input_token (last_op_key , op , token_map ,
198- None )
199+ None , block . ctx )
199200 if maybe_input_tok_join_op :
200201 operations .append (maybe_input_tok_join_op )
201202
@@ -211,7 +212,7 @@ def _to_token_order_in_block(block: Block,
211212 last_store_key = _last_store_key (alias_set )
212213
213214 input_tok , maybe_input_tok_join_op = _get_input_token (last_op_key , op , token_map ,
214- op .memory_order )
215+ op .memory_order , block . ctx )
215216 if maybe_input_tok_join_op :
216217 operations .append (maybe_input_tok_join_op )
217218
@@ -389,13 +390,14 @@ def should_join(other_key):
389390def _get_input_token (token_key : TokenKey ,
390391 op : Operation ,
391392 token_map : Dict [TokenKey , Var ],
392- memory_order : MemoryOrder | None ) -> Tuple [Var , Operation | None ]:
393+ memory_order : MemoryOrder | None ,
394+ ctx : IRContext ) -> Tuple [Var , Operation | None ]:
393395 tokens_to_join = _collect_join_tokens (token_key , token_map , memory_order )
394396
395397 if len (tokens_to_join ) == 1 :
396398 return tokens_to_join [0 ], None
397399
398- ret_tok = _make_token_var (op . parent_block . ctx , op .loc )
400+ ret_tok = _make_token_var (ctx , op .loc )
399401 ret_op = JoinTokens (tokens = tuple (tokens_to_join ), result_vars = (ret_tok ,), loc = op .loc )
400402 return ret_tok , ret_op
401403
@@ -498,7 +500,8 @@ def _try_loop_parallel_store(
498500 store_op : TileStore ,
499501 alias_result : AliasResult ,
500502 token_map : Dict [TokenKey , Var ],
501- innermost_loop_info : Optional [InnermostLoopInfo ]
503+ innermost_loop_info : Optional [InnermostLoopInfo ],
504+ ctx : IRContext ,
502505) -> Optional [Tuple [Operation , ...] | Operation ]:
503506
504507 if (not innermost_loop_info or
@@ -515,7 +518,7 @@ def _try_loop_parallel_store(
515518
516519 if (ACQUIRE_TOKEN_KEY in token_map and
517520 before_loop_last_op_tok is not token_map [ACQUIRE_TOKEN_KEY ]):
518- input_tok = _make_token_var (store_op . parent_block . ctx , store_op .loc )
521+ input_tok = _make_token_var (ctx , store_op .loc )
519522 maybe_input_tok_join_op = JoinTokens (
520523 tokens = (before_loop_last_op_tok , token_map [ACQUIRE_TOKEN_KEY ]),
521524 result_vars = (input_tok ,), loc = store_op .loc )
@@ -528,7 +531,7 @@ def _try_loop_parallel_store(
528531
529532 # Eagerly join with loop_last_op_tok
530533 loop_last_op_tok = token_map [last_op_key ]
531- new_last_op_tok = _make_token_var (store_op . parent_block . ctx , store_op .loc )
534+ new_last_op_tok = _make_token_var (ctx , store_op .loc )
532535 join_op = JoinTokens (tokens = (loop_last_op_tok , result_tok ),
533536 result_vars = (new_last_op_tok ,), loc = store_op .loc )
534537
0 commit comments