Skip to content

Commit ca63835

Browse files
committed
remove parent block from IR
Signed-off-by: Boyan Li <boyanl@nvidia.com>
1 parent 58e0ccd commit ca63835

File tree

2 files changed

+13
-32
lines changed

2 files changed

+13
-32
lines changed

src/cuda/tile/_ir/ir.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -679,14 +679,6 @@ def result_var(self) -> Var:
679679
raise ValueError(f"Operation {self.op} has {len(self.result_vars)} results")
680680
return self.result_vars[0]
681681

682-
@property
683-
def parent_block(self) -> Block:
684-
return self._parent_block
685-
686-
@parent_block.setter
687-
def parent_block(self, block: Block):
688-
self._parent_block = block
689-
690682
def generate_bytecode(self, ctx: "BytecodeContext"):
691683
raise NotImplementedError(f"Operation {self.op} must implement generate_bytecode")
692684

@@ -788,12 +780,9 @@ def empty_like_self(self: "Block") -> "Block":
788780

789781
def append(self, op: Operation):
790782
self._operations.append(op)
791-
op.parent_block = self
792783

793784
def extend(self, ops: Sequence[Operation]):
794785
self._operations.extend(ops)
795-
for op in ops:
796-
op.parent_block = self
797786

798787
def __len__(self):
799788
return len(self._operations)
@@ -814,14 +803,10 @@ def __delitem__(self, i):
814803
self._replace(i if isinstance(i, slice) else slice(i, i + 1), ())
815804

816805
def _replace(self, s: slice, new_ops: Sequence[Operation]):
817-
for op in new_ops:
818-
op.parent_block = self
819806
self._operations[s] = new_ops
820807

821808
def detach_all(self):
822809
ret, self._operations = self._operations, []
823-
for op in ret:
824-
op.parent_block = None
825810
return ret
826811

827812
@property
@@ -830,14 +815,7 @@ def operations(self) -> Sequence[Operation]:
830815

831816
@operations.setter
832817
def operations(self, ops: Sequence[Operation]):
833-
# Clear parent links on the old ops
834-
for old in self._operations:
835-
old.parent_block = None
836-
837-
# Replace list and re-link parents
838818
self._operations = list(ops)
839-
for op in self._operations:
840-
op.parent_block = self
841819

842820
def make_temp_var(self, loc: Loc) -> Var:
843821
return self.ctx.make_temp(loc)

src/cuda/tile/_passes/token_order.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def token_order_pass(root_block: Block, alias_result: AliasResult):
107107
root_tok = _make_token_var(root_block.ctx, root_block.loc)
108108
token_map = defaultdict(lambda: root_tok)
109109
_to_token_order_in_block(root_block, context, token_map)
110-
# Ensures Operation.parent_block is correctly set
111110
root_block[:0] = (MakeToken(result_vars=(root_tok,), loc=root_block.loc),)
112111

113112

@@ -164,7 +163,8 @@ def _to_token_order_in_block(block: Block,
164163
last_store_key = _last_store_key(alias_set)
165164

166165
input_tok, maybe_input_tok_join_op = _get_input_token(last_store_key, op,
167-
token_map, None)
166+
token_map, None,
167+
block.ctx)
168168
if maybe_input_tok_join_op:
169169
operations.append(maybe_input_tok_join_op)
170170

@@ -185,7 +185,8 @@ def _to_token_order_in_block(block: Block,
185185
if (
186186
isinstance(op, TileStore)
187187
and (parallel_store_ops := _try_loop_parallel_store(op, context.alias_result,
188-
token_map, innermost_loop_info))
188+
token_map, innermost_loop_info,
189+
block.ctx))
189190
):
190191
operations.extend(parallel_store_ops)
191192
continue
@@ -195,7 +196,7 @@ def _to_token_order_in_block(block: Block,
195196
last_store_key = _last_store_key(alias_set)
196197

197198
input_tok, maybe_input_tok_join_op = _get_input_token(last_op_key, op, token_map,
198-
None)
199+
None, block.ctx)
199200
if maybe_input_tok_join_op:
200201
operations.append(maybe_input_tok_join_op)
201202

@@ -211,7 +212,7 @@ def _to_token_order_in_block(block: Block,
211212
last_store_key = _last_store_key(alias_set)
212213

213214
input_tok, maybe_input_tok_join_op = _get_input_token(last_op_key, op, token_map,
214-
op.memory_order)
215+
op.memory_order, block.ctx)
215216
if maybe_input_tok_join_op:
216217
operations.append(maybe_input_tok_join_op)
217218

@@ -389,13 +390,14 @@ def should_join(other_key):
389390
def _get_input_token(token_key: TokenKey,
390391
op: Operation,
391392
token_map: Dict[TokenKey, Var],
392-
memory_order: MemoryOrder | None) -> Tuple[Var, Operation | None]:
393+
memory_order: MemoryOrder | None,
394+
ctx: IRContext) -> Tuple[Var, Operation | None]:
393395
tokens_to_join = _collect_join_tokens(token_key, token_map, memory_order)
394396

395397
if len(tokens_to_join) == 1:
396398
return tokens_to_join[0], None
397399

398-
ret_tok = _make_token_var(op.parent_block.ctx, op.loc)
400+
ret_tok = _make_token_var(ctx, op.loc)
399401
ret_op = JoinTokens(tokens=tuple(tokens_to_join), result_vars=(ret_tok,), loc=op.loc)
400402
return ret_tok, ret_op
401403

@@ -498,7 +500,8 @@ def _try_loop_parallel_store(
498500
store_op: TileStore,
499501
alias_result: AliasResult,
500502
token_map: Dict[TokenKey, Var],
501-
innermost_loop_info: Optional[InnermostLoopInfo]
503+
innermost_loop_info: Optional[InnermostLoopInfo],
504+
ctx: IRContext,
502505
) -> Optional[Tuple[Operation, ...] | Operation]:
503506

504507
if (not innermost_loop_info or
@@ -515,7 +518,7 @@ def _try_loop_parallel_store(
515518

516519
if (ACQUIRE_TOKEN_KEY in token_map and
517520
before_loop_last_op_tok is not token_map[ACQUIRE_TOKEN_KEY]):
518-
input_tok = _make_token_var(store_op.parent_block.ctx, store_op.loc)
521+
input_tok = _make_token_var(ctx, store_op.loc)
519522
maybe_input_tok_join_op = JoinTokens(
520523
tokens=(before_loop_last_op_tok, token_map[ACQUIRE_TOKEN_KEY]),
521524
result_vars=(input_tok,), loc=store_op.loc)
@@ -528,7 +531,7 @@ def _try_loop_parallel_store(
528531

529532
# Eagerly join with loop_last_op_tok
530533
loop_last_op_tok = token_map[last_op_key]
531-
new_last_op_tok = _make_token_var(store_op.parent_block.ctx, store_op.loc)
534+
new_last_op_tok = _make_token_var(ctx, store_op.loc)
532535
join_op = JoinTokens(tokens=(loop_last_op_tok, result_tok),
533536
result_vars=(new_last_op_tok,), loc=store_op.loc)
534537

0 commit comments

Comments
 (0)