We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 799a8ab commit 5ade61cCopy full SHA for 5ade61c
3 files changed
changelog.d/fix-rewrite-pattern.md
@@ -0,0 +1,4 @@
1
+<!--- SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2
+<!--- SPDX-License-Identifier: Apache-2.0 -->
3
+
4
+- Fixed a bug when pattern match attempted to remove a value that is used by the new operation
src/cuda/tile/_passes/rewrite_patterns.py
@@ -150,6 +150,11 @@ def rewrite_patterns(root_block: Block):
150
# External use -- can't rewrite
151
continue
152
153
+ new_inputs = set(v.name for op in r.to_add for v in op.all_inputs())
154
+ if deleted_results & new_inputs:
155
+ # New operations use deleted results -- can't rewrite
156
+ continue
157
158
# For now, we insert the new operations at the location of the last matched op.
159
# This is not always correct for maintaining topological sorting, in case if matches
160
# have multiple outputs. However, currently we only care about rewriting subgraphs
test/test_fma.py
@@ -56,6 +56,28 @@ def add_mul_kernel(x, y, z, output,
56
ct.store(output, index=(bidx, 0), tile=output_tile)
57
58
59
+@ct.kernel
60
+def mul_add_same_operand_kernel(x, output,
61
+ TILE: ct.Constant[int],
62
+ DIM: ct.Constant[int]):
63
+ bidx = ct.bid(0)
64
+ tx = ct.load(x, index=(bidx, 0), shape=(TILE, DIM))
65
+ tmp = tx * tx
66
+ output_tile = tmp + tmp
67
+ ct.store(output, index=(bidx, 0), tile=output_tile)
68
69
70
+def test_fma_skip_when_new_op_uses_deleted_var():
71
+ shape = (128, 32)
72
+ x = make_tensor(shape, dtype=torch.float32, device='cuda')
73
+ output = make_tensor(shape, dtype=torch.float32, device='cuda')
74
+ TILE = 32
75
+ grid = (ceil(shape[0] / TILE), 1, 1)
76
+ ct.launch(torch.cuda.current_stream(), grid, mul_add_same_operand_kernel,
77
+ (x, output, TILE, shape[1]))
78
+ assert_close(output, 2 * x * x, atol=1e-3, rtol=1e-3)
79
80
81
@pytest.mark.use_mlir
82
@pytest.mark.parametrize(
83
"kernel, kernel_ref",
0 commit comments