Skip to content

Commit 1d6d8a2

Browse files
committed
[TIRx] Replace vars in buffer strides and elem_offset
Signed-off-by: Guan-Ming (Wesley) Chiu <105915352+guan404ming@users.noreply.github.com>
1 parent 9808108 commit 1d6d8a2

2 files changed

Lines changed: 21 additions & 3 deletions

File tree

python/tvm/tirx/transform/common.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from tvm.tirx.stmt_functor import StmtExprMutator, StmtMutator
3737

3838

39-
# FIXME: this pass does not replace var in the shape/layout of a buffer
4039
class BufferReplacer(StmtExprMutator):
4140
"""
4241
Replace buffer with another buffer.
@@ -63,6 +62,10 @@ def mutate_buffer(self, buffer: Buffer):
6362
self.buffer_attr_var_mutated = False
6463
new_data = self.visit_expr(buffer.data)
6564
new_shape = [self.visit_expr(expr) for expr in buffer.shape]
65+
new_strides = [self.visit_expr(expr) for expr in buffer.strides]
66+
new_elem_offset = (
67+
self.visit_expr(buffer.elem_offset) if buffer.elem_offset is not None else None
68+
)
6669
if isinstance(buffer.layout, TileLayout):
6770
new_shard = []
6871
new_replicate = []
@@ -90,8 +93,8 @@ def mutate_buffer(self, buffer: Buffer):
9093
buffer.dtype,
9194
buffer.name,
9295
new_data,
93-
buffer.strides,
94-
buffer.elem_offset,
96+
new_strides,
97+
new_elem_offset,
9598
buffer.scope(),
9699
buffer.data_alignment,
97100
buffer.offset_factor,

tests/python/tirx/test_op.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,21 @@ def test_buffer_replacer_no_shared_default():
5757
assert len(r2.buffer_map) == 0
5858

5959

60+
def test_buffer_replacer_replaces_strides_and_elem_offset():
61+
"""Vars in buffer strides/elem_offset must be replaced, not passed through."""
62+
from tvm.tirx import BufferStore, Var
63+
from tvm.tirx.transform.common import BufferReplacer
64+
65+
n = Var("n", "int32")
66+
m = Var("m", "int32")
67+
A = decl_buffer((64,), "float32", strides=[n], elem_offset=n)
68+
store = BufferStore(A, 1.0, [0])
69+
70+
new = BufferReplacer(var_map={n: m})(store)
71+
assert new.buffer.strides[0].same_as(m)
72+
assert new.buffer.elem_offset.same_as(m)
73+
74+
6075
def test_gemm_async_partial_scale_factor():
6176
"""Regression test for F7: gemm_async must reject partial scale factors."""
6277
from tvm.tirx.script.builder.tirx import gemm_async

0 commit comments

Comments
 (0)