diff --git a/src/transform/common/access_ptr_utils.h b/src/transform/common/access_ptr_utils.h new file mode 100644 index 0000000000..ce356faf0a --- /dev/null +++ b/src/transform/common/access_ptr_utils.h @@ -0,0 +1,51 @@ +/*! + * \file access_ptr_utils.h + * \brief Shared utilities for tl.access_ptr lowering helpers. + */ +#ifndef TVM_TL_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ +#define TVM_TL_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ + +#include + +namespace tvm { +namespace tl { + +namespace detail { + +template +tirx::BufferLoad VisitAccessPtrBase(const tvm::PrimExpr &expr, + VisitExprFn &&visit_expr) { + const auto *base_load_node = expr.as(); + ICHECK(base_load_node) << "tl.access_ptr base must be BufferLoad, but got " + << expr; + tirx::BufferLoad base_load = + tvm::ffi::GetRef(base_load_node); + + tvm::ffi::Array indices; + bool changed = false; + for (const tvm::PrimExpr &index : base_load->indices) { + tvm::PrimExpr new_index = visit_expr(index); + changed = changed || !new_index.same_as(index); + indices.push_back(new_index); + } + + tvm::ffi::Optional predicate = base_load->predicate; + if (predicate.defined()) { + tvm::PrimExpr new_predicate = visit_expr(predicate.value()); + changed = changed || !new_predicate.same_as(predicate.value()); + predicate = new_predicate; + } + + if (!changed) { + return base_load; + } + return tirx::BufferLoad(base_load->buffer, indices, predicate, + base_load->span); +} + +} // namespace detail + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_ diff --git a/src/transform/legalize_safe_memory_access.cc b/src/transform/legalize_safe_memory_access.cc index 0e67ab7021..90de1be19f 100644 --- a/src/transform/legalize_safe_memory_access.cc +++ b/src/transform/legalize_safe_memory_access.cc @@ -13,11 +13,13 @@ #include #include +#include #include #include "../op/builtin.h" #include "../op/parallel.h" #include "arith/ir_mutator_with_analyzer.h" +#include "common/access_ptr_utils.h" #include "loop_partition.h" #include "loop_vectorize.h" @@ -28,6 +30,16 @@ using namespace tirx; using namespace ffi; using arith::IRMutatorWithAnalyzer; +int GetConstAccessMask(const PrimExpr &expr) { + const auto *imm = expr.as(); + ICHECK(imm) << "access_ptr rw_mask must be an integer constant, got " << expr; + return static_cast(imm->value); +} + +bool AccessMaskMayUse(const PrimExpr &expr, int required_mask) { + return (GetConstAccessMask(expr) & required_mask) != 0; +} + // SafeMemChecker for a BufferLoad/BufferStore node: // 1. Identify BufferLoad and BufferStore nodes. // 2. For each index, compare against the buffer's shape. @@ -59,6 +71,36 @@ struct SafeMemChecker : public StmtExprVisitor { } } + void VisitExpr_(const CallNode *op) final { + if (!op->op.same_as(tl::access_ptr())) { + StmtExprVisitor::VisitExpr_(op); + return; + } + + ICHECK_EQ(op->args.size(), 3U) + << "tl.access_ptr expects 3 arguments, but got " << op->args; + const auto *base_load = op->args[0].as(); + ICHECK(base_load) << "tl.access_ptr base must be BufferLoad, but got " + << op->args[0]; + + int rw_mask = GetConstAccessMask(op->args[2]); + CheckBufferIndices(base_load->buffer, base_load->indices, + /*is_load=*/(rw_mask & kAccessRead) != 0, + !disableOOBWarning && + !IsGlobalBuffer(base_load->buffer)); + + if (recursively_collect_conds_) { + for (const PrimExpr &index : base_load->indices) { + VisitExpr(index); + } + if (base_load->predicate.defined()) { + VisitExpr(base_load->predicate.value()); + } + VisitExpr(op->args[1]); + VisitExpr(op->args[2]); + } + } + void VisitStmt_(const BufferStoreNode *op) final { // Check if the buffer is in global scope CheckBufferIndices(op->buffer, op->indices, /*is_load=*/false, @@ -234,10 +276,63 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { } private: + struct AccessPtrInfo { + BufferLoad base_load; + PrimExpr rw_mask; + }; + // Constructor initializing the base class with the analyzer SafeMemorysRewriter(arith::Analyzer *analyzer) : arith::IRMutatorWithAnalyzer(analyzer) {} - // Constructor initializing the base class with the analyzer + + PrimExpr VisitExpr_(const CallNode *op) final { + if (op->op.same_as(tl::access_ptr())) { + return VisitAccessPtrCall(op); + } + + PrimExpr expr = IRMutatorWithAnalyzer::VisitExpr_(op); + const auto *call_node = expr.as(); + if (!call_node || !call_node->op.as()) { + return expr; + } + Call call = Downcast(expr); + Op call_op = Downcast(call->op); + if (!IsAtomicOp(call_op) || call.dtype().is_handle()) { + return call; + } + + Array conditions = CollectCallAccessPtrConditions(call); + if (conditions.empty()) { + return call; + } + + std::optional fallback_ptr = + TryGetAccessPtrInfo(call->args[0]); + if (!fallback_ptr.has_value()) { + return call; + } + + PrimExpr safe_value = GetSafeValue(fallback_ptr->base_load->buffer); + if (safe_value.dtype() != call.dtype()) { + safe_value = Cast(call.dtype(), safe_value); + } + safe_value = analyzer_->Simplify(safe_value); + return if_then_else(CombineConditions(conditions), call, safe_value); + } + + PrimExpr VisitAccessPtrCall(const CallNode *op) { + ICHECK_EQ(op->args.size(), 3U) + << "tl.access_ptr expects 3 args: (BufferLoad, extent, rw_mask)"; + auto visit_expr = [this](const PrimExpr &expr) { + return this->VisitExpr(expr); + }; + Array args{ + detail::VisitAccessPtrBase(op->args[0], visit_expr), + VisitExpr(op->args[1]), + VisitExpr(op->args[2]), + }; + return Call(op->dtype, op->op, args, op->annotations, op->span); + } PrimExpr VisitExpr_(const BufferLoadNode *op) final { auto load = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); @@ -325,10 +420,11 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { static constexpr int kCPAsyncDstPtrArg = 0; static constexpr int kCPAsyncSrcPtrArg = 1; - BufferLoad GetBaseLoadFromAccessPtrExpr(const PrimExpr &expr) { + std::optional TryGetAccessPtrInfo(const PrimExpr &expr) { const auto *ptr_call = expr.as(); - ICHECK(ptr_call) << "cp.async expects access_ptr arguments, but got " - << expr; + if (!ptr_call) { + return std::nullopt; + } if (ptr_call->op.same_as(tl::access_ptr())) { ICHECK_EQ(ptr_call->args.size(), 3U) @@ -336,12 +432,13 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { const auto *base_load = ptr_call->args[0].as(); ICHECK(base_load) << "tl.access_ptr base must be BufferLoad, but got " << ptr_call->args[0]; - return Downcast(ptr_call->args[0]); + return AccessPtrInfo{Downcast(ptr_call->args[0]), + ptr_call->args[2]}; } - ICHECK(ptr_call->op.same_as(builtin::tvm_access_ptr())) - << "cp.async expects tl.access_ptr or tvm_access_ptr, but got " - << ptr_call->op; + if (!ptr_call->op.same_as(builtin::tvm_access_ptr())) { + return std::nullopt; + } ICHECK_EQ(ptr_call->args.size(), 5U) << "tvm_access_ptr expects 5 arguments, but got " << ptr_call->args; const auto *var = ptr_call->args[1].as(); @@ -352,7 +449,43 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { << "Buffer data var " << buffer_data << " is not registered in buffer_data_to_buffer_."; Buffer flat = buffer_data_to_buffer_[buffer_data].GetFlattenedBuffer(); - return BufferLoad(flat, Array{ptr_call->args[2]}); + return AccessPtrInfo{BufferLoad(flat, Array{ptr_call->args[2]}), + ptr_call->args[4]}; + } + + AccessPtrInfo GetRequiredAccessPtrInfo(const PrimExpr &expr, + const char *context) { + std::optional info = TryGetAccessPtrInfo(expr); + ICHECK(info.has_value()) + << context << " expects tl.access_ptr or tvm_access_ptr, got " << expr; + return info.value(); + } + + Array CollectAccessPtrConditions(const PrimExpr &expr, + int required_mask) { + Array conditions; + std::optional info = TryGetAccessPtrInfo(expr); + if (!info.has_value() || !AccessMaskMayUse(info->rw_mask, required_mask)) { + return conditions; + } + + SafeMemChecker checker(analyzer_, /*recursively_collect_conds=*/false); + bool is_load = (GetConstAccessMask(info->rw_mask) & kAccessRead) != 0; + checker.CheckBufferIndices(info->base_load->buffer, + info->base_load->indices, is_load, + /*throw_warning=*/false); + return checker.GetConditions(); + } + + Array CollectCallAccessPtrConditions(const Call &call) { + Array conditions; + for (const PrimExpr &arg : call->args) { + for (const PrimExpr &cond : + CollectAccessPtrConditions(arg, kAccessReadWrite)) { + conditions.push_back(cond); + } + } + return conditions; } bool NeedsEvaluateBoundaryCheck(const Call &call) { @@ -364,11 +497,15 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { ICHECK_GE(call->args.size(), 3U) << "cp.async expects at least 3 arguments, but got " << call->args; Array conditions; - BufferLoad src_base_load = - GetBaseLoadFromAccessPtrExpr(call->args[kCPAsyncSrcPtrArg]); + AccessPtrInfo src_info = + GetRequiredAccessPtrInfo(call->args[kCPAsyncSrcPtrArg], "cp.async"); + if (!AccessMaskMayUse(src_info.rw_mask, kAccessRead)) { + return conditions; + } SafeMemChecker checker(analyzer_, /*recursively_collect_conds=*/false); - checker.CheckBufferIndices(src_base_load->buffer, src_base_load->indices, + checker.CheckBufferIndices(src_info.base_load->buffer, + src_info.base_load->indices, /*is_load=*/true, /*throw_warning=*/false); return checker.GetConditions(); } @@ -376,9 +513,9 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { Buffer GetCPAsyncSourceBuffer(const Call &call) { ICHECK_GE(call->args.size(), 3U) << "cp.async expects at least 3 arguments, but got " << call->args; - BufferLoad src_base_load = - GetBaseLoadFromAccessPtrExpr(call->args[kCPAsyncSrcPtrArg]); - return src_base_load->buffer; + AccessPtrInfo src_info = + GetRequiredAccessPtrInfo(call->args[kCPAsyncSrcPtrArg], "cp.async"); + return src_info.base_load->buffer; } PrimExpr CombineConditions(const Array &conditions) { @@ -405,15 +542,15 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { ICHECK_GE(call->args.size(), 3U) << "cp.async expects at least 3 arguments, but got " << call->args; - BufferLoad dst_base_load = - GetBaseLoadFromAccessPtrExpr(call->args[kCPAsyncDstPtrArg]); + AccessPtrInfo dst_info = + GetRequiredAccessPtrInfo(call->args[kCPAsyncDstPtrArg], "cp.async"); Buffer src_buffer = GetCPAsyncSourceBuffer(call); PrimExpr combined = CombineConditions(conditions); Optional existing_predicate = GetCPAsyncPredicate(call); PrimExpr safe_value = GetSafeValue(src_buffer); - DataType dst_dtype = dst_base_load->buffer->dtype; + DataType dst_dtype = dst_info.base_load->buffer->dtype; if (safe_value.dtype() != dst_dtype) { safe_value = Cast(dst_dtype, safe_value); } @@ -434,8 +571,8 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { Call(call->dtype, call->op, new_args, call->annotations, call->span)); } - Stmt else_case = - BufferStore(dst_base_load->buffer, safe_value, dst_base_load->indices); + Stmt else_case = BufferStore(dst_info.base_load->buffer, safe_value, + dst_info.base_load->indices); return IfThenElse(combined, evaluate, else_case); } diff --git a/src/transform/lower_access_ptr.cc b/src/transform/lower_access_ptr.cc index 44485b9551..bbe4373e3e 100644 --- a/src/transform/lower_access_ptr.cc +++ b/src/transform/lower_access_ptr.cc @@ -12,6 +12,7 @@ #include #include "../op/builtin.h" +#include "common/access_ptr_utils.h" namespace tvm { namespace tl { @@ -79,20 +80,22 @@ PrimExpr LinearOffsetFromLoad(const BufferLoad &load) { class AccessPtrLowerer : public StmtExprMutator { public: PrimExpr VisitExpr_(const CallNode *op) final { - Call call = Downcast(StmtExprMutator::VisitExpr_(op)); - if (!call->op.same_as(tl::access_ptr())) { - return std::move(call); + if (!op->op.same_as(tl::access_ptr())) { + return StmtExprMutator::VisitExpr_(op); } - ICHECK_EQ(call->args.size(), 3U) + ICHECK_EQ(op->args.size(), 3U) << "tl.access_ptr expects 3 args: (BufferLoad, extent, rw_mask)"; - BufferLoad base_load = Downcast(call->args[0]); + auto visit_expr = [this](const PrimExpr &expr) { + return this->VisitExpr(expr); + }; + BufferLoad base_load = detail::VisitAccessPtrBase(op->args[0], visit_expr); Buffer buffer = base_load->buffer; ICHECK(buffer.defined()); - PrimExpr extent = call->args[1]; - PrimExpr rw_mask = call->args[2]; + PrimExpr extent = VisitExpr(op->args[1]); + PrimExpr rw_mask = VisitExpr(op->args[2]); PrimExpr ptype = tirx::TypeAnnotation(buffer->dtype); PrimExpr data = buffer->data; diff --git a/testing/python/issue/test_tilelang_issue_2123.py b/testing/python/issue/test_tilelang_issue_2123.py new file mode 100644 index 0000000000..265612b92e --- /dev/null +++ b/testing/python/issue/test_tilelang_issue_2123.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import tilelang +import tilelang.testing +import tilelang.language as T +from tilelang import tvm +from tvm import tirx +from tvm.tirx import op +from tilelang.engine.phase import LowerAndLegalize +from tilelang.transform import LowerAccessPtr + + +def issue_2123_atomic_load_repro(num_tiles, threads=32): + @T.prim_func + def kernel(status: T.Tensor((num_tiles,), T.int32), out: T.Tensor((1,), T.int32)): + with T.Kernel(num_tiles, threads=threads) as tile: + look = T.alloc_var(T.int32) + state = T.alloc_var(T.int32) + done = T.alloc_var(T.bool) + tx = T.get_thread_binding() + if tx == 0: + look = tile - 1 + done = look < 0 + state = 0 + while not done: + state = T.atomic_load(status[look], memory_order="acquire") + if state != 0: + done = True + else: + look -= 1 + done = look < 0 + if tile == num_tiles - 1: + out[0] = state + + return kernel + + +def _has_op_call(func, op_name): + found = False + target_op = op.Op.get(op_name) + + def _visit(node): + nonlocal found + if isinstance(node, tirx.Call) and node.op.same_as(target_op): + found = True + + tirx.stmt_functor.post_order_visit(func.body, _visit) + return found + + +def _assert_access_ptr_lowered(mod): + assert _has_op_call(mod["main"], "tirx.tvm_access_ptr") + assert not _has_op_call(mod["main"], "tl.access_ptr") + + +def test_issue_2123_atomic_load_lower_access_ptr_direct(): + func = issue_2123_atomic_load_repro(4).with_attr("global_symbol", "main") + mod = tvm.IRModule.from_expr(func) + + lowered = LowerAccessPtr()(mod) + + _assert_access_ptr_lowered(lowered) + + +def test_issue_2123_atomic_load_lower_access_ptr_pipeline(): + target = tvm.target.Target("cuda", host="llvm") + func = issue_2123_atomic_load_repro(4).with_attr("global_symbol", "main") + mod = tvm.IRModule.from_expr(func) + + lowered = LowerAndLegalize(mod, target) + + _assert_access_ptr_lowered(lowered) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py index f2f97701a8..169d4bcc55 100644 --- a/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py +++ b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py @@ -5,11 +5,12 @@ from tvm.tirx.stmt_functor import ir_transform, post_order_visit -def _strip_block_reads_writes(stmt): - """Strip reads and writes from all blocks, replacing them with empty lists.""" +def _strip_block_reads_writes(stmt, strip_annotations: bool = False): + """Strip non-behavioral block metadata before structural comparison.""" def _postorder(op): if isinstance(op, tvm.tirx.SBlock): + annotations = {} if strip_annotations else op.annotations return tvm.tirx.SBlock( op.iter_vars, [], @@ -19,7 +20,7 @@ def _postorder(op): op.init, op.alloc_buffers, op.match_buffers, - op.annotations, + annotations, ) return ir_transform(stmt, None, _postorder) @@ -40,6 +41,19 @@ def _visit(node): return calls +def _is_call_to(expr, op_name): + return isinstance(expr, tvm.tirx.Call) and isinstance(expr.op, tvm.ir.Op) and str(expr.op.name) == op_name + + +def _is_int_zero(expr): + return isinstance(expr, tvm.tirx.IntImm) and int(expr.value) == 0 + + +def _assert_tl_access_ptr_bases_are_buffer_loads(stmt): + for call in _collect_call_nodes(stmt, "tl.access_ptr"): + assert isinstance(call.args[0], tvm.tirx.BufferLoad) + + def _count_if_then_else(stmt): count = 0 @@ -52,6 +66,18 @@ def _visit(node): return count +def _assert_legalize_matches_expected(before, expected, strip_annotations: bool = False): + mod = tvm.IRModule({before.attrs["global_symbol"]: before}) + transformed = tl.transform.LegalizeSafeMemoryAccess()(mod) + + _assert_tl_access_ptr_bases_are_buffer_loads(before.body) + _assert_tl_access_ptr_bases_are_buffer_loads(transformed["main"].body) + tvm.ir.assert_structural_equal( + _strip_block_reads_writes(transformed["main"].body, strip_annotations), + _strip_block_reads_writes(expected.body, strip_annotations), + ) + + def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2): dtype = T.float32 @@ -174,18 +200,18 @@ def assert_oob_store_legalize(M: int = 64, N: int = 64): ) -def cp_async_access_ptr_legalize(N: int = 16, offset: int = 10): +def cp_async_access_ptr_legalize(): dtype = T.float16 @T.prim_func def main( - A: T.Tensor((N,), dtype=dtype), + A: T.Tensor((16,), dtype=dtype), ): - A_shared = T.alloc_buffer((N,), dtype=dtype, scope="shared") + A_shared = T.alloc_buffer((16,), dtype=dtype, scope="shared") for i in T.serial(4): T.ptx_cp_async( T.access_ptr(A_shared[i * 4], "w", 4), - T.access_ptr(A[i * 4 + offset], "r", 4), + T.access_ptr(A[i * 4 + 8], "r", 4), 4, ) T.ptx_commit_group() @@ -193,15 +219,15 @@ def main( @T.prim_func def expected( - A: T.Tensor((N,), dtype=dtype), + A: T.Tensor((16,), dtype=dtype), ): - A_shared = T.alloc_buffer((N,), dtype=dtype, scope="shared") + A_shared = T.alloc_buffer((16,), dtype=dtype, scope="shared") for i in T.serial(4): T.ptx_cp_async( T.access_ptr(A_shared[i * 4], "w", 4), - T.access_ptr(A[i * 4 + offset], "r", 4), + T.access_ptr(A[i * 4 + 8], "r", 4), 4, - i * 4 + offset < N, + i < 2, ) T.ptx_commit_group() T.ptx_wait_group(0) @@ -209,32 +235,33 @@ def expected( return main, expected -def assert_cp_async_access_ptr_legalize(N: int = 16): - func, _ = cp_async_access_ptr_legalize(N) +def assert_cp_async_access_ptr_legalize(): + func, expected = cp_async_access_ptr_legalize() mod = tvm.IRModule({func.attrs["global_symbol"]: func}) transformed = tl.transform.LegalizeSafeMemoryAccess()(mod) body = transformed["main"].body cp_async_calls = _collect_call_nodes(body, {"tirx.ptx_cp_async", "tl.ptx_cp_async"}) assert len(cp_async_calls) > 0 assert all(len(call.args) == 4 for call in cp_async_calls) + _assert_legalize_matches_expected(func, expected) -def cp_async_access_ptr_nonzero_safe_value_legalize(N: int = 16, offset: int = 10, pad_value: int = 3): +def cp_async_access_ptr_nonzero_safe_value_legalize(): dtype = T.float16 @T.prim_func def main( - A: T.Tensor((N,), dtype=dtype), + A: T.Tensor((16,), dtype=dtype), ): with T.sblock("root"): T.reads() T.writes() - T.sblock_attr({"safe_value_map": {A.data: T.float16(pad_value)}}) - A_shared = T.sblock_alloc_buffer((N,), dtype=dtype, scope="shared") + T.sblock_attr({"safe_value_map": {A.data: T.float16(3)}}) + A_shared = T.sblock_alloc_buffer((16,), dtype=dtype, scope="shared") for i in T.serial(4): T.ptx_cp_async( T.access_ptr(A_shared[i * 4], "w", 4), - T.access_ptr(A[i * 4 + offset], "r", 4), + T.access_ptr(A[i * 4 + 8], "r", 4), 4, ) T.ptx_commit_group() @@ -242,28 +269,166 @@ def main( @T.prim_func def expected( - A: T.Tensor((N,), dtype=dtype), + A: T.Tensor((16,), dtype=dtype), ): with T.sblock("root"): T.reads() T.writes() - T.sblock_attr({"safe_value_map": {A.data: T.float16(pad_value)}}) - A_shared = T.sblock_alloc_buffer((N,), dtype=dtype, scope="shared") + T.sblock_attr({"safe_value_map": {A.data: T.float16(3)}}) + A_shared = T.sblock_alloc_buffer((16,), dtype=dtype, scope="shared") for i in T.serial(4): - if i * 4 + offset < N: + if i < 2: T.ptx_cp_async( T.access_ptr(A_shared[i * 4], "w", 4), - T.access_ptr(A[i * 4 + offset], "r", 4), + T.access_ptr(A[i * 4 + 8], "r", 4), 4, ) + else: + A_shared[i * 4] = T.float16(3) T.ptx_commit_group() T.ptx_wait_group(0) return main, expected -def assert_cp_async_access_ptr_nonzero_safe_value_legalize(N: int = 16): - func, _ = cp_async_access_ptr_nonzero_safe_value_legalize(N) +def atomic_load_access_ptr_legalize(): + dtype = T.int32 + + @T.prim_func + def main( + A: T.Tensor((16,), dtype=dtype), + out: T.Tensor((4,), dtype=dtype), + ): + for i in T.serial(4): + out[i] = T.atomic_load(A[i * 4 + 10], memory_order="acquire") + + @T.prim_func + def expected( + A: T.Tensor((16,), dtype=dtype), + out: T.Tensor((4,), dtype=dtype), + ): + for i in T.serial(4): + out[i] = T.if_then_else( + i < 2, + T.atomic_load(A[i * 4 + 10], memory_order="acquire"), + T.int32(0), + ) + + return main, expected + + +def atomic_add_return_access_ptr_legalize(): + dtype = T.int32 + + @T.prim_func + def main( + A: T.Tensor((16,), dtype=dtype), + out: T.Tensor((4,), dtype=dtype), + ): + for i in T.serial(4): + out[i] = T.atomic_add(A[i * 4 + 10], T.int32(1), return_prev=True) + + @T.prim_func + def expected( + A: T.Tensor((16,), dtype=dtype), + out: T.Tensor((4,), dtype=dtype), + ): + for i in T.serial(4): + out[i] = T.if_then_else( + i < 2, + T.atomic_add(A[i * 4 + 10], T.int32(1), return_prev=True), + T.int32(0), + ) + + return main, expected + + +def atomic_store_access_ptr_legalize(): + dtype = T.int32 + + @T.prim_func + def main( + A: T.Tensor((16,), dtype=dtype), + ): + for i in T.serial(4): + T.atomic_store(A[i * 4 + 10], T.int32(1), memory_order="release") + + @T.prim_func + def expected( + A: T.Tensor((16,), dtype=dtype), + ): + for i in T.serial(4): + if i * 4 + 10 < 16: + T.atomic_store(A[i * 4 + 10], T.int32(1), memory_order="release") + + return main, expected + + +def call_extern_access_ptr_mask_legalize(access_type: str): + dtype = T.int32 + + @T.prim_func + def main( + A: T.Tensor((16,), dtype=dtype), + ): + for i in T.serial(4): + T.call_extern( + "handle", + "use_ptr", + T.access_ptr(A[i * 4 + 10], access_type, 1), + ) + + @T.prim_func + def expected( + A: T.Tensor((16,), dtype=dtype), + ): + for i in T.serial(4): + if i * 4 + 10 < 16: + T.call_extern( + "handle", + "use_ptr", + T.access_ptr(A[i * 4 + 10], access_type, 1), + ) + + return main, expected + + +def call_extern_multiple_access_ptrs_legalize(): + dtype = T.int32 + + @T.prim_func + def main( + A: T.Tensor((16,), dtype=dtype), + B: T.Tensor((12,), dtype=dtype), + ): + for i in T.serial(4): + T.call_extern( + "handle", + "use_two_ptrs", + T.access_ptr(A[i * 4 + 10], "r", 1), + T.access_ptr(B[i * 4 + 6], "w", 1), + ) + + @T.prim_func + def expected( + A: T.Tensor((16,), dtype=dtype), + B: T.Tensor((12,), dtype=dtype), + ): + for i in T.serial(4): + if i * 4 + 6 < 12: # noqa: SIM102 + if i * 4 + 10 < 16: + T.call_extern( + "handle", + "use_two_ptrs", + T.access_ptr(A[i * 4 + 10], "r", 1), + T.access_ptr(B[i * 4 + 6], "w", 1), + ) + + return main, expected + + +def assert_cp_async_access_ptr_nonzero_safe_value_legalize(): + func, expected = cp_async_access_ptr_nonzero_safe_value_legalize() mod = tvm.IRModule({func.attrs["global_symbol"]: func}) transformed = tl.transform.LegalizeSafeMemoryAccess()(mod) body = transformed["main"].body @@ -271,6 +436,48 @@ def assert_cp_async_access_ptr_nonzero_safe_value_legalize(N: int = 16): assert len(cp_async_calls) > 0 assert all(len(call.args) == 3 for call in cp_async_calls) assert _count_if_then_else(body) > 0 + _assert_legalize_matches_expected(func, expected, strip_annotations=True) + + +def assert_atomic_load_access_ptr_legalize(): + func, expected = atomic_load_access_ptr_legalize() + mod = tvm.IRModule({func.attrs["global_symbol"]: func}) + transformed = tl.transform.LegalizeSafeMemoryAccess()(mod) + body = transformed["main"].body + + _assert_tl_access_ptr_bases_are_buffer_loads(body) + if_then_else_calls = _collect_call_nodes(body, "tirx.if_then_else") + assert any( + len(call.args) == 3 and _is_call_to(call.args[1], "tl.atomic_load_elem_op") and _is_int_zero(call.args[2]) + for call in if_then_else_calls + ) + _assert_legalize_matches_expected(func, expected) + + +def assert_atomic_add_return_access_ptr_legalize(): + func, expected = atomic_add_return_access_ptr_legalize() + _assert_legalize_matches_expected(func, expected) + + +def assert_atomic_store_access_ptr_legalize(): + func, expected = atomic_store_access_ptr_legalize() + _assert_legalize_matches_expected(func, expected) + + +def assert_call_extern_access_ptr_mask_legalize(access_type: str): + func, expected = call_extern_access_ptr_mask_legalize(access_type) + mod = tvm.IRModule({func.attrs["global_symbol"]: func}) + transformed = tl.transform.LegalizeSafeMemoryAccess()(mod) + body = transformed["main"].body + + _assert_tl_access_ptr_bases_are_buffer_loads(body) + assert _count_if_then_else(body) > 0 + _assert_legalize_matches_expected(func, expected) + + +def assert_call_extern_multiple_access_ptrs_legalize(): + func, expected = call_extern_multiple_access_ptrs_legalize() + _assert_legalize_matches_expected(func, expected) def test_vectorize_access(): @@ -286,11 +493,39 @@ def test_oob_store(): def test_cp_async_access_ptr_oob(): - assert_cp_async_access_ptr_legalize(16) + assert_cp_async_access_ptr_legalize() def test_cp_async_access_ptr_nonzero_safe_value_oob(): - assert_cp_async_access_ptr_nonzero_safe_value_legalize(16) + assert_cp_async_access_ptr_nonzero_safe_value_legalize() + + +def test_atomic_load_access_ptr_oob(): + assert_atomic_load_access_ptr_legalize() + + +def test_atomic_add_return_access_ptr_oob(): + assert_atomic_add_return_access_ptr_legalize() + + +def test_atomic_store_access_ptr_oob(): + assert_atomic_store_access_ptr_legalize() + + +def test_call_extern_access_ptr_read_mask_oob(): + assert_call_extern_access_ptr_mask_legalize("r") + + +def test_call_extern_access_ptr_write_mask_oob(): + assert_call_extern_access_ptr_mask_legalize("w") + + +def test_call_extern_access_ptr_readwrite_mask_oob(): + assert_call_extern_access_ptr_mask_legalize("rw") + + +def test_call_extern_multiple_access_ptrs_oob(): + assert_call_extern_multiple_access_ptrs_legalize() if __name__ == "__main__":