From 36abe145bd61b5bd3e00fed12c9884b0074e28e3 Mon Sep 17 00:00:00 2001 From: VitalyR Date: Wed, 6 May 2026 18:28:07 +0800 Subject: [PATCH 1/7] [Bugfix] Fix issue 2123 access_ptr lowering Add the issue 2123 regression test and keep the access_ptr base BufferLoad intact through safe-memory legalization and lowering. Closes #2123 --- src/transform/legalize_safe_memory_access.cc | 42 ++++++++++++- src/transform/lower_access_ptr.cc | 40 +++++++++--- .../python/issue/test_tilelang_issue_2123.py | 63 +++++++++++++++++++ 3 files changed, 137 insertions(+), 8 deletions(-) create mode 100644 testing/python/issue/test_tilelang_issue_2123.py diff --git a/src/transform/legalize_safe_memory_access.cc b/src/transform/legalize_safe_memory_access.cc index de85e0ac78..722d1160b8 100644 --- a/src/transform/legalize_safe_memory_access.cc +++ b/src/transform/legalize_safe_memory_access.cc @@ -233,7 +233,47 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { // Constructor initializing the base class with the analyzer SafeMemorysRewriter(arith::Analyzer *analyzer) : arith::IRMutatorWithAnalyzer(analyzer) {} - // Constructor initializing the base class with the analyzer + + BufferLoad VisitAccessPtrBase(const PrimExpr &expr) { + const auto *base_load_node = expr.as(); + ICHECK(base_load_node) + << "tl.access_ptr base must be BufferLoad, but got " << expr; + BufferLoad base_load = tvm::ffi::GetRef(base_load_node); + + Array indices; + bool changed = false; + for (const PrimExpr &index : base_load->indices) { + PrimExpr new_index = VisitExpr(index); + changed = changed || !new_index.same_as(index); + indices.push_back(new_index); + } + Optional predicate = base_load->predicate; + if (predicate.defined()) { + PrimExpr new_predicate = VisitExpr(predicate.value()); + changed = changed || !new_predicate.same_as(predicate.value()); + predicate = new_predicate; + } + + if (!changed) { + return base_load; + } + return BufferLoad(base_load->buffer, indices, predicate, base_load->span); + } + + PrimExpr VisitExpr_(const CallNode *op) final { + if (!op->op.same_as(tl::access_ptr())) { + return IRMutatorWithAnalyzer::VisitExpr_(op); + } + + ICHECK_EQ(op->args.size(), 3U) + << "tl.access_ptr expects 3 args: (BufferLoad, extent, rw_mask)"; + Array args{ + VisitAccessPtrBase(op->args[0]), + VisitExpr(op->args[1]), + VisitExpr(op->args[2]), + }; + return Call(op->dtype, op->op, args, op->annotations, op->span); + } PrimExpr VisitExpr_(const BufferLoadNode *op) final { auto load = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); diff --git a/src/transform/lower_access_ptr.cc b/src/transform/lower_access_ptr.cc index 96f566c33a..053714c5f8 100644 --- a/src/transform/lower_access_ptr.cc +++ b/src/transform/lower_access_ptr.cc @@ -77,20 +77,19 @@ PrimExpr LinearOffsetFromLoad(const BufferLoad &load) { class AccessPtrLowerer : public StmtExprMutator { public: PrimExpr VisitExpr_(const CallNode *op) final { - Call call = Downcast(StmtExprMutator::VisitExpr_(op)); - if (!call->op.same_as(tl::access_ptr())) { - return std::move(call); + if (!op->op.same_as(tl::access_ptr())) { + return StmtExprMutator::VisitExpr_(op); } - ICHECK_EQ(call->args.size(), 3U) + ICHECK_EQ(op->args.size(), 3U) << "tl.access_ptr expects 3 args: (BufferLoad, extent, rw_mask)"; - BufferLoad base_load = Downcast(call->args[0]); + BufferLoad base_load = VisitAccessPtrBase(op->args[0]); Buffer buffer = base_load->buffer; ICHECK(buffer.defined()); - PrimExpr extent = call->args[1]; - PrimExpr rw_mask = call->args[2]; + PrimExpr extent = VisitExpr(op->args[1]); + PrimExpr rw_mask = VisitExpr(op->args[2]); PrimExpr ptype = tir::TypeAnnotation(buffer->dtype); PrimExpr data = buffer->data; @@ -99,6 +98,33 @@ class AccessPtrLowerer : public StmtExprMutator { Array args{ptype, data, offset, extent, rw_mask}; return Call(DataType::Handle(), builtin::tvm_access_ptr(), args); } + +private: + BufferLoad VisitAccessPtrBase(const PrimExpr &expr) { + const auto *base_load_node = expr.as(); + ICHECK(base_load_node) + << "tl.access_ptr base must be BufferLoad, but got " << expr; + BufferLoad base_load = tvm::ffi::GetRef(base_load_node); + + Array indices; + bool changed = false; + for (const PrimExpr &index : base_load->indices) { + PrimExpr new_index = VisitExpr(index); + changed = changed || !new_index.same_as(index); + indices.push_back(new_index); + } + Optional predicate = base_load->predicate; + if (predicate.defined()) { + PrimExpr new_predicate = VisitExpr(predicate.value()); + changed = changed || !new_predicate.same_as(predicate.value()); + predicate = new_predicate; + } + + if (!changed) { + return base_load; + } + return BufferLoad(base_load->buffer, indices, predicate, base_load->span); + } }; PrimFunc LowerAccessPtrPrimFunc(PrimFunc f) { diff --git a/testing/python/issue/test_tilelang_issue_2123.py b/testing/python/issue/test_tilelang_issue_2123.py new file mode 100644 index 0000000000..5bd45df4d7 --- /dev/null +++ b/testing/python/issue/test_tilelang_issue_2123.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import tilelang +import tilelang.testing +import tilelang.language as T +from tilelang import tvm +from tilelang.engine.phase import LowerAndLegalize + + +def issue_2123_atomic_load_repro(num_tiles, threads=32): + @T.prim_func + def kernel(status: T.Tensor((num_tiles,), T.int32), out: T.Tensor((1,), T.int32)): + with T.Kernel(num_tiles, threads=threads) as tile: + look = T.alloc_var(T.int32) + state = T.alloc_var(T.int32) + done = T.alloc_var(T.bool) + tx = T.get_thread_binding() + if tx == 0: + look = tile - 1 + done = look < 0 + state = 0 + while not done: + state = T.atomic_load(status[look], memory_order="acquire") + if state != 0: + done = True + else: + look -= 1 + done = look < 0 + if tile == num_tiles - 1: + out[0] = state + + return kernel + + +def _has_op_call(func, op_name): + found = False + + def _visit(node): + nonlocal found + if ( + isinstance(node, tvm.tir.Call) + and isinstance(node.op, tvm.ir.Op) + and node.op.name == op_name + ): + found = True + + tvm.tir.stmt_functor.post_order_visit(func.body, _visit) + return found + + +def test_issue_2123_atomic_load_lower_access_ptr(): + target = tvm.target.Target("cuda", host="llvm") + func = issue_2123_atomic_load_repro(4).with_attr("global_symbol", "main") + mod = tvm.IRModule.from_expr(func) + + lowered = LowerAndLegalize(mod, target) + + assert _has_op_call(lowered["main"], "tir.tvm_access_ptr") + assert not _has_op_call(lowered["main"], "tl.access_ptr") + + +if __name__ == "__main__": + tilelang.testing.main() From 7204f781ab3801935b2e53f52d9559975684c8ad Mon Sep 17 00:00:00 2001 From: VitalyR Date: Thu, 7 May 2026 03:29:31 +0800 Subject: [PATCH 2/7] Deduplicate access_ptr base load handling Move shared tl.access_ptr BufferLoad base visitation into a common transform helper so LowerAccessPtr and safe-memory legalization stay aligned. Extend the issue 2123 regression with a direct LowerAccessPtr pass check while keeping the LowerAndLegalize pipeline coverage. --- src/transform/common/access_ptr_utils.h | 45 +++++++++++++++++++ src/transform/legalize_safe_memory_access.cc | 32 +++---------- src/transform/lower_access_ptr.cc | 33 +++----------- .../python/issue/test_tilelang_issue_2123.py | 26 +++++++---- 4 files changed, 73 insertions(+), 63 deletions(-) create mode 100644 src/transform/common/access_ptr_utils.h diff --git a/src/transform/common/access_ptr_utils.h b/src/transform/common/access_ptr_utils.h new file mode 100644 index 0000000000..1f3b04af42 --- /dev/null +++ b/src/transform/common/access_ptr_utils.h @@ -0,0 +1,45 @@ +#ifndef TILELANG_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ +#define TILELANG_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ + +#include +#include + +namespace tvm { +namespace tl { +using namespace tir; + +namespace detail { + +template +BufferLoad VisitAccessPtrBase(const PrimExpr &expr, VisitExprFn &&visit_expr) { + const auto *base_load_node = expr.as(); + ICHECK(base_load_node) + << "tl.access_ptr base must be BufferLoad, but got " << expr; + BufferLoad base_load = tvm::ffi::GetRef(base_load_node); + + Array indices; + bool changed = false; + for (const PrimExpr &index : base_load->indices) { + PrimExpr new_index = visit_expr(index); + changed = changed || !new_index.same_as(index); + indices.push_back(new_index); + } + + Optional predicate = base_load->predicate; + if (predicate.defined()) { + PrimExpr new_predicate = visit_expr(predicate.value()); + changed = changed || !new_predicate.same_as(predicate.value()); + predicate = new_predicate; + } + + if (!changed) { + return base_load; + } + return BufferLoad(base_load->buffer, indices, predicate, base_load->span); +} + +} // namespace detail +} // namespace tl +} // namespace tvm + +#endif // TILELANG_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ diff --git a/src/transform/legalize_safe_memory_access.cc b/src/transform/legalize_safe_memory_access.cc index 722d1160b8..3302820e0f 100644 --- a/src/transform/legalize_safe_memory_access.cc +++ b/src/transform/legalize_safe_memory_access.cc @@ -15,6 +15,7 @@ #include "../op/builtin.h" #include "../op/parallel.h" #include "arith/ir_mutator_with_analyzer.h" +#include "common/access_ptr_utils.h" #include "loop_partition.h" #include "loop_vectorize.h" @@ -234,32 +235,6 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { SafeMemorysRewriter(arith::Analyzer *analyzer) : arith::IRMutatorWithAnalyzer(analyzer) {} - BufferLoad VisitAccessPtrBase(const PrimExpr &expr) { - const auto *base_load_node = expr.as(); - ICHECK(base_load_node) - << "tl.access_ptr base must be BufferLoad, but got " << expr; - BufferLoad base_load = tvm::ffi::GetRef(base_load_node); - - Array indices; - bool changed = false; - for (const PrimExpr &index : base_load->indices) { - PrimExpr new_index = VisitExpr(index); - changed = changed || !new_index.same_as(index); - indices.push_back(new_index); - } - Optional predicate = base_load->predicate; - if (predicate.defined()) { - PrimExpr new_predicate = VisitExpr(predicate.value()); - changed = changed || !new_predicate.same_as(predicate.value()); - predicate = new_predicate; - } - - if (!changed) { - return base_load; - } - return BufferLoad(base_load->buffer, indices, predicate, base_load->span); - } - PrimExpr VisitExpr_(const CallNode *op) final { if (!op->op.same_as(tl::access_ptr())) { return IRMutatorWithAnalyzer::VisitExpr_(op); @@ -267,8 +242,11 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { ICHECK_EQ(op->args.size(), 3U) << "tl.access_ptr expects 3 args: (BufferLoad, extent, rw_mask)"; + auto visit_expr = [this](const PrimExpr &expr) { + return this->VisitExpr(expr); + }; Array args{ - VisitAccessPtrBase(op->args[0]), + detail::VisitAccessPtrBase(op->args[0], visit_expr), VisitExpr(op->args[1]), VisitExpr(op->args[2]), }; diff --git a/src/transform/lower_access_ptr.cc b/src/transform/lower_access_ptr.cc index 053714c5f8..bb80cbd197 100644 --- a/src/transform/lower_access_ptr.cc +++ b/src/transform/lower_access_ptr.cc @@ -11,6 +11,7 @@ #include #include "../op/builtin.h" +#include "common/access_ptr_utils.h" namespace tvm { namespace tl { @@ -84,7 +85,10 @@ class AccessPtrLowerer : public StmtExprMutator { ICHECK_EQ(op->args.size(), 3U) << "tl.access_ptr expects 3 args: (BufferLoad, extent, rw_mask)"; - BufferLoad base_load = VisitAccessPtrBase(op->args[0]); + auto visit_expr = [this](const PrimExpr &expr) { + return this->VisitExpr(expr); + }; + BufferLoad base_load = detail::VisitAccessPtrBase(op->args[0], visit_expr); Buffer buffer = base_load->buffer; ICHECK(buffer.defined()); @@ -98,33 +102,6 @@ class AccessPtrLowerer : public StmtExprMutator { Array args{ptype, data, offset, extent, rw_mask}; return Call(DataType::Handle(), builtin::tvm_access_ptr(), args); } - -private: - BufferLoad VisitAccessPtrBase(const PrimExpr &expr) { - const auto *base_load_node = expr.as(); - ICHECK(base_load_node) - << "tl.access_ptr base must be BufferLoad, but got " << expr; - BufferLoad base_load = tvm::ffi::GetRef(base_load_node); - - Array indices; - bool changed = false; - for (const PrimExpr &index : base_load->indices) { - PrimExpr new_index = VisitExpr(index); - changed = changed || !new_index.same_as(index); - indices.push_back(new_index); - } - Optional predicate = base_load->predicate; - if (predicate.defined()) { - PrimExpr new_predicate = VisitExpr(predicate.value()); - changed = changed || !new_predicate.same_as(predicate.value()); - predicate = new_predicate; - } - - if (!changed) { - return base_load; - } - return BufferLoad(base_load->buffer, indices, predicate, base_load->span); - } }; PrimFunc LowerAccessPtrPrimFunc(PrimFunc f) { diff --git a/testing/python/issue/test_tilelang_issue_2123.py b/testing/python/issue/test_tilelang_issue_2123.py index 5bd45df4d7..a589d8307b 100644 --- a/testing/python/issue/test_tilelang_issue_2123.py +++ b/testing/python/issue/test_tilelang_issue_2123.py @@ -5,6 +5,7 @@ import tilelang.language as T from tilelang import tvm from tilelang.engine.phase import LowerAndLegalize +from tilelang.transform import LowerAccessPtr def issue_2123_atomic_load_repro(num_tiles, threads=32): @@ -37,26 +38,35 @@ def _has_op_call(func, op_name): def _visit(node): nonlocal found - if ( - isinstance(node, tvm.tir.Call) - and isinstance(node.op, tvm.ir.Op) - and node.op.name == op_name - ): + if isinstance(node, tvm.tir.Call) and isinstance(node.op, tvm.ir.Op) and node.op.name == op_name: found = True tvm.tir.stmt_functor.post_order_visit(func.body, _visit) return found -def test_issue_2123_atomic_load_lower_access_ptr(): +def _assert_access_ptr_lowered(mod): + assert _has_op_call(mod["main"], "tir.tvm_access_ptr") + assert not _has_op_call(mod["main"], "tl.access_ptr") + + +def test_issue_2123_atomic_load_lower_access_ptr_direct(): + func = issue_2123_atomic_load_repro(4).with_attr("global_symbol", "main") + mod = tvm.IRModule.from_expr(func) + + lowered = LowerAccessPtr()(mod) + + _assert_access_ptr_lowered(lowered) + + +def test_issue_2123_atomic_load_lower_access_ptr_pipeline(): target = tvm.target.Target("cuda", host="llvm") func = issue_2123_atomic_load_repro(4).with_attr("global_symbol", "main") mod = tvm.IRModule.from_expr(func) lowered = LowerAndLegalize(mod, target) - assert _has_op_call(lowered["main"], "tir.tvm_access_ptr") - assert not _has_op_call(lowered["main"], "tl.access_ptr") + _assert_access_ptr_lowered(lowered) if __name__ == "__main__": From 14e93088737d2f999b93c66bd580f4cf62a0b2ab Mon Sep 17 00:00:00 2001 From: VitalyR Date: Thu, 7 May 2026 03:33:25 +0800 Subject: [PATCH 3/7] Format access_ptr helper header --- src/transform/common/access_ptr_utils.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transform/common/access_ptr_utils.h b/src/transform/common/access_ptr_utils.h index 1f3b04af42..4bdf8f04ab 100644 --- a/src/transform/common/access_ptr_utils.h +++ b/src/transform/common/access_ptr_utils.h @@ -13,8 +13,8 @@ namespace detail { template BufferLoad VisitAccessPtrBase(const PrimExpr &expr, VisitExprFn &&visit_expr) { const auto *base_load_node = expr.as(); - ICHECK(base_load_node) - << "tl.access_ptr base must be BufferLoad, but got " << expr; + ICHECK(base_load_node) << "tl.access_ptr base must be BufferLoad, but got " + << expr; BufferLoad base_load = tvm::ffi::GetRef(base_load_node); Array indices; @@ -38,8 +38,8 @@ BufferLoad VisitAccessPtrBase(const PrimExpr &expr, VisitExprFn &&visit_expr) { return BufferLoad(base_load->buffer, indices, predicate, base_load->span); } -} // namespace detail -} // namespace tl -} // namespace tvm +} // namespace detail +} // namespace tl +} // namespace tvm -#endif // TILELANG_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ +#endif // TILELANG_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ From 31a7fff095e50bf77cf8f02517131a71b567fc43 Mon Sep 17 00:00:00 2001 From: VitalyR Date: Thu, 7 May 2026 03:50:04 +0800 Subject: [PATCH 4/7] Narrow access_ptr helper scope --- src/transform/common/access_ptr_utils.h | 45 -------------------- src/transform/legalize_safe_memory_access.cc | 32 +++++++++++++- src/transform/lower_access_ptr.cc | 33 +++++++++++--- 3 files changed, 58 insertions(+), 52 deletions(-) delete mode 100644 src/transform/common/access_ptr_utils.h diff --git a/src/transform/common/access_ptr_utils.h b/src/transform/common/access_ptr_utils.h deleted file mode 100644 index 4bdf8f04ab..0000000000 --- a/src/transform/common/access_ptr_utils.h +++ /dev/null @@ -1,45 +0,0 @@ -#ifndef TILELANG_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ -#define TILELANG_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ - -#include -#include - -namespace tvm { -namespace tl { -using namespace tir; - -namespace detail { - -template -BufferLoad VisitAccessPtrBase(const PrimExpr &expr, VisitExprFn &&visit_expr) { - const auto *base_load_node = expr.as(); - ICHECK(base_load_node) << "tl.access_ptr base must be BufferLoad, but got " - << expr; - BufferLoad base_load = tvm::ffi::GetRef(base_load_node); - - Array indices; - bool changed = false; - for (const PrimExpr &index : base_load->indices) { - PrimExpr new_index = visit_expr(index); - changed = changed || !new_index.same_as(index); - indices.push_back(new_index); - } - - Optional predicate = base_load->predicate; - if (predicate.defined()) { - PrimExpr new_predicate = visit_expr(predicate.value()); - changed = changed || !new_predicate.same_as(predicate.value()); - predicate = new_predicate; - } - - if (!changed) { - return base_load; - } - return BufferLoad(base_load->buffer, indices, predicate, base_load->span); -} - -} // namespace detail -} // namespace tl -} // namespace tvm - -#endif // TILELANG_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ diff --git a/src/transform/legalize_safe_memory_access.cc b/src/transform/legalize_safe_memory_access.cc index 3302820e0f..64d4dc6c10 100644 --- a/src/transform/legalize_safe_memory_access.cc +++ b/src/transform/legalize_safe_memory_access.cc @@ -15,7 +15,6 @@ #include "../op/builtin.h" #include "../op/parallel.h" #include "arith/ir_mutator_with_analyzer.h" -#include "common/access_ptr_utils.h" #include "loop_partition.h" #include "loop_vectorize.h" @@ -235,6 +234,35 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { SafeMemorysRewriter(arith::Analyzer *analyzer) : arith::IRMutatorWithAnalyzer(analyzer) {} + template + BufferLoad VisitAccessPtrBase(const PrimExpr &expr, + VisitExprFn &&visit_expr) { + const auto *base_load_node = expr.as(); + ICHECK(base_load_node) << "tl.access_ptr base must be BufferLoad, but got " + << expr; + BufferLoad base_load = tvm::ffi::GetRef(base_load_node); + + Array indices; + bool changed = false; + for (const PrimExpr &index : base_load->indices) { + PrimExpr new_index = visit_expr(index); + changed = changed || !new_index.same_as(index); + indices.push_back(new_index); + } + + Optional predicate = base_load->predicate; + if (predicate.defined()) { + PrimExpr new_predicate = visit_expr(predicate.value()); + changed = changed || !new_predicate.same_as(predicate.value()); + predicate = new_predicate; + } + + if (!changed) { + return base_load; + } + return BufferLoad(base_load->buffer, indices, predicate, base_load->span); + } + PrimExpr VisitExpr_(const CallNode *op) final { if (!op->op.same_as(tl::access_ptr())) { return IRMutatorWithAnalyzer::VisitExpr_(op); @@ -246,7 +274,7 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { return this->VisitExpr(expr); }; Array args{ - detail::VisitAccessPtrBase(op->args[0], visit_expr), + VisitAccessPtrBase(op->args[0], visit_expr), VisitExpr(op->args[1]), VisitExpr(op->args[2]), }; diff --git a/src/transform/lower_access_ptr.cc b/src/transform/lower_access_ptr.cc index bb80cbd197..f6ad89ffdf 100644 --- a/src/transform/lower_access_ptr.cc +++ b/src/transform/lower_access_ptr.cc @@ -11,7 +11,6 @@ #include #include "../op/builtin.h" -#include "common/access_ptr_utils.h" namespace tvm { namespace tl { @@ -85,10 +84,7 @@ class AccessPtrLowerer : public StmtExprMutator { ICHECK_EQ(op->args.size(), 3U) << "tl.access_ptr expects 3 args: (BufferLoad, extent, rw_mask)"; - auto visit_expr = [this](const PrimExpr &expr) { - return this->VisitExpr(expr); - }; - BufferLoad base_load = detail::VisitAccessPtrBase(op->args[0], visit_expr); + BufferLoad base_load = VisitAccessPtrBase(op->args[0]); Buffer buffer = base_load->buffer; ICHECK(buffer.defined()); @@ -102,6 +98,33 @@ class AccessPtrLowerer : public StmtExprMutator { Array args{ptype, data, offset, extent, rw_mask}; return Call(DataType::Handle(), builtin::tvm_access_ptr(), args); } + +private: + BufferLoad VisitAccessPtrBase(const PrimExpr &expr) { + const auto *base_load_node = expr.as(); + ICHECK(base_load_node) << "tl.access_ptr base must be BufferLoad, but got " + << expr; + BufferLoad base_load = tvm::ffi::GetRef(base_load_node); + + Array indices; + bool changed = false; + for (const PrimExpr &index : base_load->indices) { + PrimExpr new_index = VisitExpr(index); + changed = changed || !new_index.same_as(index); + indices.push_back(new_index); + } + Optional predicate = base_load->predicate; + if (predicate.defined()) { + PrimExpr new_predicate = VisitExpr(predicate.value()); + changed = changed || !new_predicate.same_as(predicate.value()); + predicate = new_predicate; + } + + if (!changed) { + return base_load; + } + return BufferLoad(base_load->buffer, indices, predicate, base_load->span); + } }; PrimFunc LowerAccessPtrPrimFunc(PrimFunc f) { From 0b4fdcc1d47d9988ad917193d99eb924dc119bce Mon Sep 17 00:00:00 2001 From: VitalyR Date: Thu, 7 May 2026 04:31:44 +0800 Subject: [PATCH 5/7] Share access_ptr base helper --- src/transform/common/access_ptr_utils.h | 49 ++++++++++++++++++++ src/transform/legalize_safe_memory_access.cc | 32 +------------ src/transform/lower_access_ptr.cc | 33 ++----------- 3 files changed, 56 insertions(+), 58 deletions(-) create mode 100644 src/transform/common/access_ptr_utils.h diff --git a/src/transform/common/access_ptr_utils.h b/src/transform/common/access_ptr_utils.h new file mode 100644 index 0000000000..f21d6de625 --- /dev/null +++ b/src/transform/common/access_ptr_utils.h @@ -0,0 +1,49 @@ +/*! + * \file access_ptr_utils.h + * \brief Shared utilities for tl.access_ptr lowering helpers. + */ +#ifndef TVM_TL_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ +#define TVM_TL_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ + +#include +#include + +namespace tvm { +namespace tl { + +namespace detail { + +template +BufferLoad VisitAccessPtrBase(const PrimExpr &expr, VisitExprFn &&visit_expr) { + const auto *base_load_node = expr.as(); + ICHECK(base_load_node) << "tl.access_ptr base must be BufferLoad, but got " + << expr; + BufferLoad base_load = tvm::ffi::GetRef(base_load_node); + + Array indices; + bool changed = false; + for (const PrimExpr &index : base_load->indices) { + PrimExpr new_index = visit_expr(index); + changed = changed || !new_index.same_as(index); + indices.push_back(new_index); + } + + Optional predicate = base_load->predicate; + if (predicate.defined()) { + PrimExpr new_predicate = visit_expr(predicate.value()); + changed = changed || !new_predicate.same_as(predicate.value()); + predicate = new_predicate; + } + + if (!changed) { + return base_load; + } + return BufferLoad(base_load->buffer, indices, predicate, base_load->span); +} + +} // namespace detail + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ diff --git a/src/transform/legalize_safe_memory_access.cc b/src/transform/legalize_safe_memory_access.cc index 64d4dc6c10..3302820e0f 100644 --- a/src/transform/legalize_safe_memory_access.cc +++ b/src/transform/legalize_safe_memory_access.cc @@ -15,6 +15,7 @@ #include "../op/builtin.h" #include "../op/parallel.h" #include "arith/ir_mutator_with_analyzer.h" +#include "common/access_ptr_utils.h" #include "loop_partition.h" #include "loop_vectorize.h" @@ -234,35 +235,6 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { SafeMemorysRewriter(arith::Analyzer *analyzer) : arith::IRMutatorWithAnalyzer(analyzer) {} - template - BufferLoad VisitAccessPtrBase(const PrimExpr &expr, - VisitExprFn &&visit_expr) { - const auto *base_load_node = expr.as(); - ICHECK(base_load_node) << "tl.access_ptr base must be BufferLoad, but got " - << expr; - BufferLoad base_load = tvm::ffi::GetRef(base_load_node); - - Array indices; - bool changed = false; - for (const PrimExpr &index : base_load->indices) { - PrimExpr new_index = visit_expr(index); - changed = changed || !new_index.same_as(index); - indices.push_back(new_index); - } - - Optional predicate = base_load->predicate; - if (predicate.defined()) { - PrimExpr new_predicate = visit_expr(predicate.value()); - changed = changed || !new_predicate.same_as(predicate.value()); - predicate = new_predicate; - } - - if (!changed) { - return base_load; - } - return BufferLoad(base_load->buffer, indices, predicate, base_load->span); - } - PrimExpr VisitExpr_(const CallNode *op) final { if (!op->op.same_as(tl::access_ptr())) { return IRMutatorWithAnalyzer::VisitExpr_(op); @@ -274,7 +246,7 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { return this->VisitExpr(expr); }; Array args{ - VisitAccessPtrBase(op->args[0], visit_expr), + detail::VisitAccessPtrBase(op->args[0], visit_expr), VisitExpr(op->args[1]), VisitExpr(op->args[2]), }; diff --git a/src/transform/lower_access_ptr.cc b/src/transform/lower_access_ptr.cc index f6ad89ffdf..bb80cbd197 100644 --- a/src/transform/lower_access_ptr.cc +++ b/src/transform/lower_access_ptr.cc @@ -11,6 +11,7 @@ #include #include "../op/builtin.h" +#include "common/access_ptr_utils.h" namespace tvm { namespace tl { @@ -84,7 +85,10 @@ class AccessPtrLowerer : public StmtExprMutator { ICHECK_EQ(op->args.size(), 3U) << "tl.access_ptr expects 3 args: (BufferLoad, extent, rw_mask)"; - BufferLoad base_load = VisitAccessPtrBase(op->args[0]); + auto visit_expr = [this](const PrimExpr &expr) { + return this->VisitExpr(expr); + }; + BufferLoad base_load = detail::VisitAccessPtrBase(op->args[0], visit_expr); Buffer buffer = base_load->buffer; ICHECK(buffer.defined()); @@ -98,33 +102,6 @@ class AccessPtrLowerer : public StmtExprMutator { Array args{ptype, data, offset, extent, rw_mask}; return Call(DataType::Handle(), builtin::tvm_access_ptr(), args); } - -private: - BufferLoad VisitAccessPtrBase(const PrimExpr &expr) { - const auto *base_load_node = expr.as(); - ICHECK(base_load_node) << "tl.access_ptr base must be BufferLoad, but got " - << expr; - BufferLoad base_load = tvm::ffi::GetRef(base_load_node); - - Array indices; - bool changed = false; - for (const PrimExpr &index : base_load->indices) { - PrimExpr new_index = VisitExpr(index); - changed = changed || !new_index.same_as(index); - indices.push_back(new_index); - } - Optional predicate = base_load->predicate; - if (predicate.defined()) { - PrimExpr new_predicate = VisitExpr(predicate.value()); - changed = changed || !new_predicate.same_as(predicate.value()); - predicate = new_predicate; - } - - if (!changed) { - return base_load; - } - return BufferLoad(base_load->buffer, indices, predicate, base_load->span); - } }; PrimFunc LowerAccessPtrPrimFunc(PrimFunc f) { From 11a1660ee3fb31edd8ec4b42b026ecbbd026fb55 Mon Sep 17 00:00:00 2001 From: VitalyR Date: Thu, 7 May 2026 04:48:15 +0800 Subject: [PATCH 6/7] Address access_ptr helper review comments --- src/transform/common/access_ptr_utils.h | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/transform/common/access_ptr_utils.h b/src/transform/common/access_ptr_utils.h index f21d6de625..9ecc40d9f6 100644 --- a/src/transform/common/access_ptr_utils.h +++ b/src/transform/common/access_ptr_utils.h @@ -5,7 +5,6 @@ #ifndef TVM_TL_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ #define TVM_TL_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ -#include #include namespace tvm { @@ -14,23 +13,24 @@ namespace tl { namespace detail { template -BufferLoad VisitAccessPtrBase(const PrimExpr &expr, VisitExprFn &&visit_expr) { - const auto *base_load_node = expr.as(); +tir::BufferLoad VisitAccessPtrBase(const tvm::PrimExpr &expr, + VisitExprFn &&visit_expr) { + const auto *base_load_node = expr.as(); ICHECK(base_load_node) << "tl.access_ptr base must be BufferLoad, but got " << expr; - BufferLoad base_load = tvm::ffi::GetRef(base_load_node); + tir::BufferLoad base_load = tvm::ffi::GetRef(base_load_node); - Array indices; + tvm::Array indices; bool changed = false; - for (const PrimExpr &index : base_load->indices) { - PrimExpr new_index = visit_expr(index); + for (const tvm::PrimExpr &index : base_load->indices) { + tvm::PrimExpr new_index = visit_expr(index); changed = changed || !new_index.same_as(index); indices.push_back(new_index); } - Optional predicate = base_load->predicate; + tvm::Optional predicate = base_load->predicate; if (predicate.defined()) { - PrimExpr new_predicate = visit_expr(predicate.value()); + tvm::PrimExpr new_predicate = visit_expr(predicate.value()); changed = changed || !new_predicate.same_as(predicate.value()); predicate = new_predicate; } @@ -38,7 +38,8 @@ BufferLoad VisitAccessPtrBase(const PrimExpr &expr, VisitExprFn &&visit_expr) { if (!changed) { return base_load; } - return BufferLoad(base_load->buffer, indices, predicate, base_load->span); + return tir::BufferLoad(base_load->buffer, indices, predicate, + base_load->span); } } // namespace detail From 866f17a9bbc0b2601dd0b7961befdd44be383c81 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 20 May 2026 18:53:19 +0800 Subject: [PATCH 7/7] [Transform] Preserve access_ptr buffer loads in safe memory legalization --- src/transform/common/access_ptr_utils.h | 19 +- src/transform/legalize_safe_memory_access.cc | 161 ++++++++-- .../python/issue/test_tilelang_issue_2123.py | 9 +- ...g_transform_legalize_safe_memory_access.py | 289 ++++++++++++++++-- 4 files changed, 418 insertions(+), 60 deletions(-) diff --git a/src/transform/common/access_ptr_utils.h b/src/transform/common/access_ptr_utils.h index 9ecc40d9f6..ce356faf0a 100644 --- a/src/transform/common/access_ptr_utils.h +++ b/src/transform/common/access_ptr_utils.h @@ -5,7 +5,7 @@ #ifndef TVM_TL_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ #define TVM_TL_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ -#include +#include namespace tvm { namespace tl { @@ -13,14 +13,15 @@ namespace tl { namespace detail { template -tir::BufferLoad VisitAccessPtrBase(const tvm::PrimExpr &expr, - VisitExprFn &&visit_expr) { - const auto *base_load_node = expr.as(); +tirx::BufferLoad VisitAccessPtrBase(const tvm::PrimExpr &expr, + VisitExprFn &&visit_expr) { + const auto *base_load_node = expr.as(); ICHECK(base_load_node) << "tl.access_ptr base must be BufferLoad, but got " << expr; - tir::BufferLoad base_load = tvm::ffi::GetRef(base_load_node); + tirx::BufferLoad base_load = + tvm::ffi::GetRef(base_load_node); - tvm::Array indices; + tvm::ffi::Array indices; bool changed = false; for (const tvm::PrimExpr &index : base_load->indices) { tvm::PrimExpr new_index = visit_expr(index); @@ -28,7 +29,7 @@ tir::BufferLoad VisitAccessPtrBase(const tvm::PrimExpr &expr, indices.push_back(new_index); } - tvm::Optional predicate = base_load->predicate; + tvm::ffi::Optional predicate = base_load->predicate; if (predicate.defined()) { tvm::PrimExpr new_predicate = visit_expr(predicate.value()); changed = changed || !new_predicate.same_as(predicate.value()); @@ -38,8 +39,8 @@ tir::BufferLoad VisitAccessPtrBase(const tvm::PrimExpr &expr, if (!changed) { return base_load; } - return tir::BufferLoad(base_load->buffer, indices, predicate, - base_load->span); + return tirx::BufferLoad(base_load->buffer, indices, predicate, + base_load->span); } } // namespace detail diff --git a/src/transform/legalize_safe_memory_access.cc b/src/transform/legalize_safe_memory_access.cc index 86b9c028f0..90de1be19f 100644 --- a/src/transform/legalize_safe_memory_access.cc +++ b/src/transform/legalize_safe_memory_access.cc @@ -13,6 +13,7 @@ #include #include +#include #include #include "../op/builtin.h" @@ -29,6 +30,16 @@ using namespace tirx; using namespace ffi; using arith::IRMutatorWithAnalyzer; +int GetConstAccessMask(const PrimExpr &expr) { + const auto *imm = expr.as(); + ICHECK(imm) << "access_ptr rw_mask must be an integer constant, got " << expr; + return static_cast(imm->value); +} + +bool AccessMaskMayUse(const PrimExpr &expr, int required_mask) { + return (GetConstAccessMask(expr) & required_mask) != 0; +} + // SafeMemChecker for a BufferLoad/BufferStore node: // 1. Identify BufferLoad and BufferStore nodes. // 2. For each index, compare against the buffer's shape. @@ -60,6 +71,36 @@ struct SafeMemChecker : public StmtExprVisitor { } } + void VisitExpr_(const CallNode *op) final { + if (!op->op.same_as(tl::access_ptr())) { + StmtExprVisitor::VisitExpr_(op); + return; + } + + ICHECK_EQ(op->args.size(), 3U) + << "tl.access_ptr expects 3 arguments, but got " << op->args; + const auto *base_load = op->args[0].as(); + ICHECK(base_load) << "tl.access_ptr base must be BufferLoad, but got " + << op->args[0]; + + int rw_mask = GetConstAccessMask(op->args[2]); + CheckBufferIndices(base_load->buffer, base_load->indices, + /*is_load=*/(rw_mask & kAccessRead) != 0, + !disableOOBWarning && + !IsGlobalBuffer(base_load->buffer)); + + if (recursively_collect_conds_) { + for (const PrimExpr &index : base_load->indices) { + VisitExpr(index); + } + if (base_load->predicate.defined()) { + VisitExpr(base_load->predicate.value()); + } + VisitExpr(op->args[1]); + VisitExpr(op->args[2]); + } + } + void VisitStmt_(const BufferStoreNode *op) final { // Check if the buffer is in global scope CheckBufferIndices(op->buffer, op->indices, /*is_load=*/false, @@ -235,15 +276,51 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { } private: + struct AccessPtrInfo { + BufferLoad base_load; + PrimExpr rw_mask; + }; + // Constructor initializing the base class with the analyzer SafeMemorysRewriter(arith::Analyzer *analyzer) : arith::IRMutatorWithAnalyzer(analyzer) {} PrimExpr VisitExpr_(const CallNode *op) final { - if (!op->op.same_as(tl::access_ptr())) { - return IRMutatorWithAnalyzer::VisitExpr_(op); + if (op->op.same_as(tl::access_ptr())) { + return VisitAccessPtrCall(op); + } + + PrimExpr expr = IRMutatorWithAnalyzer::VisitExpr_(op); + const auto *call_node = expr.as(); + if (!call_node || !call_node->op.as()) { + return expr; + } + Call call = Downcast(expr); + Op call_op = Downcast(call->op); + if (!IsAtomicOp(call_op) || call.dtype().is_handle()) { + return call; + } + + Array conditions = CollectCallAccessPtrConditions(call); + if (conditions.empty()) { + return call; + } + + std::optional fallback_ptr = + TryGetAccessPtrInfo(call->args[0]); + if (!fallback_ptr.has_value()) { + return call; } + PrimExpr safe_value = GetSafeValue(fallback_ptr->base_load->buffer); + if (safe_value.dtype() != call.dtype()) { + safe_value = Cast(call.dtype(), safe_value); + } + safe_value = analyzer_->Simplify(safe_value); + return if_then_else(CombineConditions(conditions), call, safe_value); + } + + PrimExpr VisitAccessPtrCall(const CallNode *op) { ICHECK_EQ(op->args.size(), 3U) << "tl.access_ptr expects 3 args: (BufferLoad, extent, rw_mask)"; auto visit_expr = [this](const PrimExpr &expr) { @@ -343,10 +420,11 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { static constexpr int kCPAsyncDstPtrArg = 0; static constexpr int kCPAsyncSrcPtrArg = 1; - BufferLoad GetBaseLoadFromAccessPtrExpr(const PrimExpr &expr) { + std::optional TryGetAccessPtrInfo(const PrimExpr &expr) { const auto *ptr_call = expr.as(); - ICHECK(ptr_call) << "cp.async expects access_ptr arguments, but got " - << expr; + if (!ptr_call) { + return std::nullopt; + } if (ptr_call->op.same_as(tl::access_ptr())) { ICHECK_EQ(ptr_call->args.size(), 3U) @@ -354,12 +432,13 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { const auto *base_load = ptr_call->args[0].as(); ICHECK(base_load) << "tl.access_ptr base must be BufferLoad, but got " << ptr_call->args[0]; - return Downcast(ptr_call->args[0]); + return AccessPtrInfo{Downcast(ptr_call->args[0]), + ptr_call->args[2]}; } - ICHECK(ptr_call->op.same_as(builtin::tvm_access_ptr())) - << "cp.async expects tl.access_ptr or tvm_access_ptr, but got " - << ptr_call->op; + if (!ptr_call->op.same_as(builtin::tvm_access_ptr())) { + return std::nullopt; + } ICHECK_EQ(ptr_call->args.size(), 5U) << "tvm_access_ptr expects 5 arguments, but got " << ptr_call->args; const auto *var = ptr_call->args[1].as(); @@ -370,7 +449,43 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { << "Buffer data var " << buffer_data << " is not registered in buffer_data_to_buffer_."; Buffer flat = buffer_data_to_buffer_[buffer_data].GetFlattenedBuffer(); - return BufferLoad(flat, Array{ptr_call->args[2]}); + return AccessPtrInfo{BufferLoad(flat, Array{ptr_call->args[2]}), + ptr_call->args[4]}; + } + + AccessPtrInfo GetRequiredAccessPtrInfo(const PrimExpr &expr, + const char *context) { + std::optional info = TryGetAccessPtrInfo(expr); + ICHECK(info.has_value()) + << context << " expects tl.access_ptr or tvm_access_ptr, got " << expr; + return info.value(); + } + + Array CollectAccessPtrConditions(const PrimExpr &expr, + int required_mask) { + Array conditions; + std::optional info = TryGetAccessPtrInfo(expr); + if (!info.has_value() || !AccessMaskMayUse(info->rw_mask, required_mask)) { + return conditions; + } + + SafeMemChecker checker(analyzer_, /*recursively_collect_conds=*/false); + bool is_load = (GetConstAccessMask(info->rw_mask) & kAccessRead) != 0; + checker.CheckBufferIndices(info->base_load->buffer, + info->base_load->indices, is_load, + /*throw_warning=*/false); + return checker.GetConditions(); + } + + Array CollectCallAccessPtrConditions(const Call &call) { + Array conditions; + for (const PrimExpr &arg : call->args) { + for (const PrimExpr &cond : + CollectAccessPtrConditions(arg, kAccessReadWrite)) { + conditions.push_back(cond); + } + } + return conditions; } bool NeedsEvaluateBoundaryCheck(const Call &call) { @@ -382,11 +497,15 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { ICHECK_GE(call->args.size(), 3U) << "cp.async expects at least 3 arguments, but got " << call->args; Array conditions; - BufferLoad src_base_load = - GetBaseLoadFromAccessPtrExpr(call->args[kCPAsyncSrcPtrArg]); + AccessPtrInfo src_info = + GetRequiredAccessPtrInfo(call->args[kCPAsyncSrcPtrArg], "cp.async"); + if (!AccessMaskMayUse(src_info.rw_mask, kAccessRead)) { + return conditions; + } SafeMemChecker checker(analyzer_, /*recursively_collect_conds=*/false); - checker.CheckBufferIndices(src_base_load->buffer, src_base_load->indices, + checker.CheckBufferIndices(src_info.base_load->buffer, + src_info.base_load->indices, /*is_load=*/true, /*throw_warning=*/false); return checker.GetConditions(); } @@ -394,9 +513,9 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { Buffer GetCPAsyncSourceBuffer(const Call &call) { ICHECK_GE(call->args.size(), 3U) << "cp.async expects at least 3 arguments, but got " << call->args; - BufferLoad src_base_load = - GetBaseLoadFromAccessPtrExpr(call->args[kCPAsyncSrcPtrArg]); - return src_base_load->buffer; + AccessPtrInfo src_info = + GetRequiredAccessPtrInfo(call->args[kCPAsyncSrcPtrArg], "cp.async"); + return src_info.base_load->buffer; } PrimExpr CombineConditions(const Array &conditions) { @@ -423,15 +542,15 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { ICHECK_GE(call->args.size(), 3U) << "cp.async expects at least 3 arguments, but got " << call->args; - BufferLoad dst_base_load = - GetBaseLoadFromAccessPtrExpr(call->args[kCPAsyncDstPtrArg]); + AccessPtrInfo dst_info = + GetRequiredAccessPtrInfo(call->args[kCPAsyncDstPtrArg], "cp.async"); Buffer src_buffer = GetCPAsyncSourceBuffer(call); PrimExpr combined = CombineConditions(conditions); Optional existing_predicate = GetCPAsyncPredicate(call); PrimExpr safe_value = GetSafeValue(src_buffer); - DataType dst_dtype = dst_base_load->buffer->dtype; + DataType dst_dtype = dst_info.base_load->buffer->dtype; if (safe_value.dtype() != dst_dtype) { safe_value = Cast(dst_dtype, safe_value); } @@ -452,8 +571,8 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { Call(call->dtype, call->op, new_args, call->annotations, call->span)); } - Stmt else_case = - BufferStore(dst_base_load->buffer, safe_value, dst_base_load->indices); + Stmt else_case = BufferStore(dst_info.base_load->buffer, safe_value, + dst_info.base_load->indices); return IfThenElse(combined, evaluate, else_case); } diff --git a/testing/python/issue/test_tilelang_issue_2123.py b/testing/python/issue/test_tilelang_issue_2123.py index a589d8307b..265612b92e 100644 --- a/testing/python/issue/test_tilelang_issue_2123.py +++ b/testing/python/issue/test_tilelang_issue_2123.py @@ -4,6 +4,8 @@ import tilelang.testing import tilelang.language as T from tilelang import tvm +from tvm import tirx +from tvm.tirx import op from tilelang.engine.phase import LowerAndLegalize from tilelang.transform import LowerAccessPtr @@ -35,18 +37,19 @@ def kernel(status: T.Tensor((num_tiles,), T.int32), out: T.Tensor((1,), T.int32) def _has_op_call(func, op_name): found = False + target_op = op.Op.get(op_name) def _visit(node): nonlocal found - if isinstance(node, tvm.tir.Call) and isinstance(node.op, tvm.ir.Op) and node.op.name == op_name: + if isinstance(node, tirx.Call) and node.op.same_as(target_op): found = True - tvm.tir.stmt_functor.post_order_visit(func.body, _visit) + tirx.stmt_functor.post_order_visit(func.body, _visit) return found def _assert_access_ptr_lowered(mod): - assert _has_op_call(mod["main"], "tir.tvm_access_ptr") + assert _has_op_call(mod["main"], "tirx.tvm_access_ptr") assert not _has_op_call(mod["main"], "tl.access_ptr") diff --git a/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py index f2f97701a8..169d4bcc55 100644 --- a/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py +++ b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py @@ -5,11 +5,12 @@ from tvm.tirx.stmt_functor import ir_transform, post_order_visit -def _strip_block_reads_writes(stmt): - """Strip reads and writes from all blocks, replacing them with empty lists.""" +def _strip_block_reads_writes(stmt, strip_annotations: bool = False): + """Strip non-behavioral block metadata before structural comparison.""" def _postorder(op): if isinstance(op, tvm.tirx.SBlock): + annotations = {} if strip_annotations else op.annotations return tvm.tirx.SBlock( op.iter_vars, [], @@ -19,7 +20,7 @@ def _postorder(op): op.init, op.alloc_buffers, op.match_buffers, - op.annotations, + annotations, ) return ir_transform(stmt, None, _postorder) @@ -40,6 +41,19 @@ def _visit(node): return calls +def _is_call_to(expr, op_name): + return isinstance(expr, tvm.tirx.Call) and isinstance(expr.op, tvm.ir.Op) and str(expr.op.name) == op_name + + +def _is_int_zero(expr): + return isinstance(expr, tvm.tirx.IntImm) and int(expr.value) == 0 + + +def _assert_tl_access_ptr_bases_are_buffer_loads(stmt): + for call in _collect_call_nodes(stmt, "tl.access_ptr"): + assert isinstance(call.args[0], tvm.tirx.BufferLoad) + + def _count_if_then_else(stmt): count = 0 @@ -52,6 +66,18 @@ def _visit(node): return count +def _assert_legalize_matches_expected(before, expected, strip_annotations: bool = False): + mod = tvm.IRModule({before.attrs["global_symbol"]: before}) + transformed = tl.transform.LegalizeSafeMemoryAccess()(mod) + + _assert_tl_access_ptr_bases_are_buffer_loads(before.body) + _assert_tl_access_ptr_bases_are_buffer_loads(transformed["main"].body) + tvm.ir.assert_structural_equal( + _strip_block_reads_writes(transformed["main"].body, strip_annotations), + _strip_block_reads_writes(expected.body, strip_annotations), + ) + + def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2): dtype = T.float32 @@ -174,18 +200,18 @@ def assert_oob_store_legalize(M: int = 64, N: int = 64): ) -def cp_async_access_ptr_legalize(N: int = 16, offset: int = 10): +def cp_async_access_ptr_legalize(): dtype = T.float16 @T.prim_func def main( - A: T.Tensor((N,), dtype=dtype), + A: T.Tensor((16,), dtype=dtype), ): - A_shared = T.alloc_buffer((N,), dtype=dtype, scope="shared") + A_shared = T.alloc_buffer((16,), dtype=dtype, scope="shared") for i in T.serial(4): T.ptx_cp_async( T.access_ptr(A_shared[i * 4], "w", 4), - T.access_ptr(A[i * 4 + offset], "r", 4), + T.access_ptr(A[i * 4 + 8], "r", 4), 4, ) T.ptx_commit_group() @@ -193,15 +219,15 @@ def main( @T.prim_func def expected( - A: T.Tensor((N,), dtype=dtype), + A: T.Tensor((16,), dtype=dtype), ): - A_shared = T.alloc_buffer((N,), dtype=dtype, scope="shared") + A_shared = T.alloc_buffer((16,), dtype=dtype, scope="shared") for i in T.serial(4): T.ptx_cp_async( T.access_ptr(A_shared[i * 4], "w", 4), - T.access_ptr(A[i * 4 + offset], "r", 4), + T.access_ptr(A[i * 4 + 8], "r", 4), 4, - i * 4 + offset < N, + i < 2, ) T.ptx_commit_group() T.ptx_wait_group(0) @@ -209,32 +235,33 @@ def expected( return main, expected -def assert_cp_async_access_ptr_legalize(N: int = 16): - func, _ = cp_async_access_ptr_legalize(N) +def assert_cp_async_access_ptr_legalize(): + func, expected = cp_async_access_ptr_legalize() mod = tvm.IRModule({func.attrs["global_symbol"]: func}) transformed = tl.transform.LegalizeSafeMemoryAccess()(mod) body = transformed["main"].body cp_async_calls = _collect_call_nodes(body, {"tirx.ptx_cp_async", "tl.ptx_cp_async"}) assert len(cp_async_calls) > 0 assert all(len(call.args) == 4 for call in cp_async_calls) + _assert_legalize_matches_expected(func, expected) -def cp_async_access_ptr_nonzero_safe_value_legalize(N: int = 16, offset: int = 10, pad_value: int = 3): +def cp_async_access_ptr_nonzero_safe_value_legalize(): dtype = T.float16 @T.prim_func def main( - A: T.Tensor((N,), dtype=dtype), + A: T.Tensor((16,), dtype=dtype), ): with T.sblock("root"): T.reads() T.writes() - T.sblock_attr({"safe_value_map": {A.data: T.float16(pad_value)}}) - A_shared = T.sblock_alloc_buffer((N,), dtype=dtype, scope="shared") + T.sblock_attr({"safe_value_map": {A.data: T.float16(3)}}) + A_shared = T.sblock_alloc_buffer((16,), dtype=dtype, scope="shared") for i in T.serial(4): T.ptx_cp_async( T.access_ptr(A_shared[i * 4], "w", 4), - T.access_ptr(A[i * 4 + offset], "r", 4), + T.access_ptr(A[i * 4 + 8], "r", 4), 4, ) T.ptx_commit_group() @@ -242,28 +269,166 @@ def main( @T.prim_func def expected( - A: T.Tensor((N,), dtype=dtype), + A: T.Tensor((16,), dtype=dtype), ): with T.sblock("root"): T.reads() T.writes() - T.sblock_attr({"safe_value_map": {A.data: T.float16(pad_value)}}) - A_shared = T.sblock_alloc_buffer((N,), dtype=dtype, scope="shared") + T.sblock_attr({"safe_value_map": {A.data: T.float16(3)}}) + A_shared = T.sblock_alloc_buffer((16,), dtype=dtype, scope="shared") for i in T.serial(4): - if i * 4 + offset < N: + if i < 2: T.ptx_cp_async( T.access_ptr(A_shared[i * 4], "w", 4), - T.access_ptr(A[i * 4 + offset], "r", 4), + T.access_ptr(A[i * 4 + 8], "r", 4), 4, ) + else: + A_shared[i * 4] = T.float16(3) T.ptx_commit_group() T.ptx_wait_group(0) return main, expected -def assert_cp_async_access_ptr_nonzero_safe_value_legalize(N: int = 16): - func, _ = cp_async_access_ptr_nonzero_safe_value_legalize(N) +def atomic_load_access_ptr_legalize(): + dtype = T.int32 + + @T.prim_func + def main( + A: T.Tensor((16,), dtype=dtype), + out: T.Tensor((4,), dtype=dtype), + ): + for i in T.serial(4): + out[i] = T.atomic_load(A[i * 4 + 10], memory_order="acquire") + + @T.prim_func + def expected( + A: T.Tensor((16,), dtype=dtype), + out: T.Tensor((4,), dtype=dtype), + ): + for i in T.serial(4): + out[i] = T.if_then_else( + i < 2, + T.atomic_load(A[i * 4 + 10], memory_order="acquire"), + T.int32(0), + ) + + return main, expected + + +def atomic_add_return_access_ptr_legalize(): + dtype = T.int32 + + @T.prim_func + def main( + A: T.Tensor((16,), dtype=dtype), + out: T.Tensor((4,), dtype=dtype), + ): + for i in T.serial(4): + out[i] = T.atomic_add(A[i * 4 + 10], T.int32(1), return_prev=True) + + @T.prim_func + def expected( + A: T.Tensor((16,), dtype=dtype), + out: T.Tensor((4,), dtype=dtype), + ): + for i in T.serial(4): + out[i] = T.if_then_else( + i < 2, + T.atomic_add(A[i * 4 + 10], T.int32(1), return_prev=True), + T.int32(0), + ) + + return main, expected + + +def atomic_store_access_ptr_legalize(): + dtype = T.int32 + + @T.prim_func + def main( + A: T.Tensor((16,), dtype=dtype), + ): + for i in T.serial(4): + T.atomic_store(A[i * 4 + 10], T.int32(1), memory_order="release") + + @T.prim_func + def expected( + A: T.Tensor((16,), dtype=dtype), + ): + for i in T.serial(4): + if i * 4 + 10 < 16: + T.atomic_store(A[i * 4 + 10], T.int32(1), memory_order="release") + + return main, expected + + +def call_extern_access_ptr_mask_legalize(access_type: str): + dtype = T.int32 + + @T.prim_func + def main( + A: T.Tensor((16,), dtype=dtype), + ): + for i in T.serial(4): + T.call_extern( + "handle", + "use_ptr", + T.access_ptr(A[i * 4 + 10], access_type, 1), + ) + + @T.prim_func + def expected( + A: T.Tensor((16,), dtype=dtype), + ): + for i in T.serial(4): + if i * 4 + 10 < 16: + T.call_extern( + "handle", + "use_ptr", + T.access_ptr(A[i * 4 + 10], access_type, 1), + ) + + return main, expected + + +def call_extern_multiple_access_ptrs_legalize(): + dtype = T.int32 + + @T.prim_func + def main( + A: T.Tensor((16,), dtype=dtype), + B: T.Tensor((12,), dtype=dtype), + ): + for i in T.serial(4): + T.call_extern( + "handle", + "use_two_ptrs", + T.access_ptr(A[i * 4 + 10], "r", 1), + T.access_ptr(B[i * 4 + 6], "w", 1), + ) + + @T.prim_func + def expected( + A: T.Tensor((16,), dtype=dtype), + B: T.Tensor((12,), dtype=dtype), + ): + for i in T.serial(4): + if i * 4 + 6 < 12: # noqa: SIM102 + if i * 4 + 10 < 16: + T.call_extern( + "handle", + "use_two_ptrs", + T.access_ptr(A[i * 4 + 10], "r", 1), + T.access_ptr(B[i * 4 + 6], "w", 1), + ) + + return main, expected + + +def assert_cp_async_access_ptr_nonzero_safe_value_legalize(): + func, expected = cp_async_access_ptr_nonzero_safe_value_legalize() mod = tvm.IRModule({func.attrs["global_symbol"]: func}) transformed = tl.transform.LegalizeSafeMemoryAccess()(mod) body = transformed["main"].body @@ -271,6 +436,48 @@ def assert_cp_async_access_ptr_nonzero_safe_value_legalize(N: int = 16): assert len(cp_async_calls) > 0 assert all(len(call.args) == 3 for call in cp_async_calls) assert _count_if_then_else(body) > 0 + _assert_legalize_matches_expected(func, expected, strip_annotations=True) + + +def assert_atomic_load_access_ptr_legalize(): + func, expected = atomic_load_access_ptr_legalize() + mod = tvm.IRModule({func.attrs["global_symbol"]: func}) + transformed = tl.transform.LegalizeSafeMemoryAccess()(mod) + body = transformed["main"].body + + _assert_tl_access_ptr_bases_are_buffer_loads(body) + if_then_else_calls = _collect_call_nodes(body, "tirx.if_then_else") + assert any( + len(call.args) == 3 and _is_call_to(call.args[1], "tl.atomic_load_elem_op") and _is_int_zero(call.args[2]) + for call in if_then_else_calls + ) + _assert_legalize_matches_expected(func, expected) + + +def assert_atomic_add_return_access_ptr_legalize(): + func, expected = atomic_add_return_access_ptr_legalize() + _assert_legalize_matches_expected(func, expected) + + +def assert_atomic_store_access_ptr_legalize(): + func, expected = atomic_store_access_ptr_legalize() + _assert_legalize_matches_expected(func, expected) + + +def assert_call_extern_access_ptr_mask_legalize(access_type: str): + func, expected = call_extern_access_ptr_mask_legalize(access_type) + mod = tvm.IRModule({func.attrs["global_symbol"]: func}) + transformed = tl.transform.LegalizeSafeMemoryAccess()(mod) + body = transformed["main"].body + + _assert_tl_access_ptr_bases_are_buffer_loads(body) + assert _count_if_then_else(body) > 0 + _assert_legalize_matches_expected(func, expected) + + +def assert_call_extern_multiple_access_ptrs_legalize(): + func, expected = call_extern_multiple_access_ptrs_legalize() + _assert_legalize_matches_expected(func, expected) def test_vectorize_access(): @@ -286,11 +493,39 @@ def test_oob_store(): def test_cp_async_access_ptr_oob(): - assert_cp_async_access_ptr_legalize(16) + assert_cp_async_access_ptr_legalize() def test_cp_async_access_ptr_nonzero_safe_value_oob(): - assert_cp_async_access_ptr_nonzero_safe_value_legalize(16) + assert_cp_async_access_ptr_nonzero_safe_value_legalize() + + +def test_atomic_load_access_ptr_oob(): + assert_atomic_load_access_ptr_legalize() + + +def test_atomic_add_return_access_ptr_oob(): + assert_atomic_add_return_access_ptr_legalize() + + +def test_atomic_store_access_ptr_oob(): + assert_atomic_store_access_ptr_legalize() + + +def test_call_extern_access_ptr_read_mask_oob(): + assert_call_extern_access_ptr_mask_legalize("r") + + +def test_call_extern_access_ptr_write_mask_oob(): + assert_call_extern_access_ptr_mask_legalize("w") + + +def test_call_extern_access_ptr_readwrite_mask_oob(): + assert_call_extern_access_ptr_mask_legalize("rw") + + +def test_call_extern_multiple_access_ptrs_oob(): + assert_call_extern_multiple_access_ptrs_legalize() if __name__ == "__main__":