Skip to content

Commit 5ade61c

Browse files
committed
Fix rewrite pattern
Skip pattern rewrite if removed variable is used by the new operation Signed-off-by: Jay Gu <jagu@nvidia.com>
1 parent 799a8ab commit 5ade61c

3 files changed

Lines changed: 31 additions & 0 deletions

File tree

changelog.d/fix-rewrite-pattern.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Fixed a bug when pattern match attempted to remove a value that is used by the new operation

src/cuda/tile/_passes/rewrite_patterns.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ def rewrite_patterns(root_block: Block):
150150
# External use -- can't rewrite
151151
continue
152152

153+
new_inputs = set(v.name for op in r.to_add for v in op.all_inputs())
154+
if deleted_results & new_inputs:
155+
# New operations use deleted results -- can't rewrite
156+
continue
157+
153158
# For now, we insert the new operations at the location of the last matched op.
154159
# This is not always correct for maintaining topological sorting, in case if matches
155160
# have multiple outputs. However, currently we only care about rewriting subgraphs

test/test_fma.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,28 @@ def add_mul_kernel(x, y, z, output,
5656
ct.store(output, index=(bidx, 0), tile=output_tile)
5757

5858

59+
@ct.kernel
60+
def mul_add_same_operand_kernel(x, output,
61+
TILE: ct.Constant[int],
62+
DIM: ct.Constant[int]):
63+
bidx = ct.bid(0)
64+
tx = ct.load(x, index=(bidx, 0), shape=(TILE, DIM))
65+
tmp = tx * tx
66+
output_tile = tmp + tmp
67+
ct.store(output, index=(bidx, 0), tile=output_tile)
68+
69+
70+
def test_fma_skip_when_new_op_uses_deleted_var():
71+
shape = (128, 32)
72+
x = make_tensor(shape, dtype=torch.float32, device='cuda')
73+
output = make_tensor(shape, dtype=torch.float32, device='cuda')
74+
TILE = 32
75+
grid = (ceil(shape[0] / TILE), 1, 1)
76+
ct.launch(torch.cuda.current_stream(), grid, mul_add_same_operand_kernel,
77+
(x, output, TILE, shape[1]))
78+
assert_close(output, 2 * x * x, atol=1e-3, rtol=1e-3)
79+
80+
5981
@pytest.mark.use_mlir
6082
@pytest.mark.parametrize(
6183
"kernel, kernel_ref",

0 commit comments

Comments
 (0)