Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/transform/common/access_ptr_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*!
* \file access_ptr_utils.h
* \brief Shared utilities for tl.access_ptr lowering helpers.
*/
#ifndef TVM_TL_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_
#define TVM_TL_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_

#include <tvm/tirx/expr.h>

namespace tvm {
namespace tl {

namespace detail {

template <typename VisitExprFn>
tirx::BufferLoad VisitAccessPtrBase(const tvm::PrimExpr &expr,
VisitExprFn &&visit_expr) {
const auto *base_load_node = expr.as<tirx::BufferLoadNode>();
ICHECK(base_load_node) << "tl.access_ptr base must be BufferLoad, but got "
<< expr;
tirx::BufferLoad base_load =
tvm::ffi::GetRef<tirx::BufferLoad>(base_load_node);

tvm::ffi::Array<tvm::PrimExpr> indices;
bool changed = false;
for (const tvm::PrimExpr &index : base_load->indices) {
tvm::PrimExpr new_index = visit_expr(index);
changed = changed || !new_index.same_as(index);
indices.push_back(new_index);
}

tvm::ffi::Optional<tvm::PrimExpr> predicate = base_load->predicate;
if (predicate.defined()) {
tvm::PrimExpr new_predicate = visit_expr(predicate.value());
changed = changed || !new_predicate.same_as(predicate.value());
predicate = new_predicate;
}

if (!changed) {
return base_load;
}
return tirx::BufferLoad(base_load->buffer, indices, predicate,
base_load->span);
}

} // namespace detail

} // namespace tl
} // namespace tvm

#endif // TVM_TL_TRANSFORM_COMMON_ACCESS_PTR_UTILS_H_
177 changes: 157 additions & 20 deletions src/transform/legalize_safe_memory_access.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
#include <tvm/tirx/stmt_functor.h>
#include <tvm/tirx/transform.h>

#include <optional>
#include <utility>

#include "../op/builtin.h"
#include "../op/parallel.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "common/access_ptr_utils.h"
#include "loop_partition.h"
#include "loop_vectorize.h"

Expand All @@ -28,6 +30,16 @@ using namespace tirx;
using namespace ffi;
using arith::IRMutatorWithAnalyzer;

int GetConstAccessMask(const PrimExpr &expr) {
const auto *imm = expr.as<IntImmNode>();
ICHECK(imm) << "access_ptr rw_mask must be an integer constant, got " << expr;
return static_cast<int>(imm->value);
}

bool AccessMaskMayUse(const PrimExpr &expr, int required_mask) {
return (GetConstAccessMask(expr) & required_mask) != 0;
}

// SafeMemChecker for a BufferLoad/BufferStore node:
// 1. Identify BufferLoad and BufferStore nodes.
// 2. For each index, compare against the buffer's shape.
Expand Down Expand Up @@ -59,6 +71,36 @@ struct SafeMemChecker : public StmtExprVisitor {
}
}

void VisitExpr_(const CallNode *op) final {
if (!op->op.same_as(tl::access_ptr())) {
StmtExprVisitor::VisitExpr_(op);
return;
}

ICHECK_EQ(op->args.size(), 3U)
<< "tl.access_ptr expects 3 arguments, but got " << op->args;
const auto *base_load = op->args[0].as<BufferLoadNode>();
ICHECK(base_load) << "tl.access_ptr base must be BufferLoad, but got "
<< op->args[0];

int rw_mask = GetConstAccessMask(op->args[2]);
CheckBufferIndices(base_load->buffer, base_load->indices,
/*is_load=*/(rw_mask & kAccessRead) != 0,
!disableOOBWarning &&
!IsGlobalBuffer(base_load->buffer));

if (recursively_collect_conds_) {
for (const PrimExpr &index : base_load->indices) {
VisitExpr(index);
}
if (base_load->predicate.defined()) {
VisitExpr(base_load->predicate.value());
}
VisitExpr(op->args[1]);
VisitExpr(op->args[2]);
}
}

void VisitStmt_(const BufferStoreNode *op) final {
// Check if the buffer is in global scope
CheckBufferIndices(op->buffer, op->indices, /*is_load=*/false,
Expand Down Expand Up @@ -234,10 +276,63 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer {
}

private:
struct AccessPtrInfo {
BufferLoad base_load;
PrimExpr rw_mask;
};

// Constructor initializing the base class with the analyzer
SafeMemorysRewriter(arith::Analyzer *analyzer)
: arith::IRMutatorWithAnalyzer(analyzer) {}
// Constructor initializing the base class with the analyzer

PrimExpr VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tl::access_ptr())) {
return VisitAccessPtrCall(op);
}

PrimExpr expr = IRMutatorWithAnalyzer::VisitExpr_(op);
const auto *call_node = expr.as<CallNode>();
if (!call_node || !call_node->op.as<OpNode>()) {
return expr;
}
Call call = Downcast<Call>(expr);
Op call_op = Downcast<Op>(call->op);
if (!IsAtomicOp(call_op) || call.dtype().is_handle()) {
return call;
}

Array<PrimExpr> conditions = CollectCallAccessPtrConditions(call);
if (conditions.empty()) {
return call;
}

std::optional<AccessPtrInfo> fallback_ptr =
TryGetAccessPtrInfo(call->args[0]);
if (!fallback_ptr.has_value()) {
return call;
}

PrimExpr safe_value = GetSafeValue(fallback_ptr->base_load->buffer);
if (safe_value.dtype() != call.dtype()) {
safe_value = Cast(call.dtype(), safe_value);
}
safe_value = analyzer_->Simplify(safe_value);
return if_then_else(CombineConditions(conditions), call, safe_value);
}

PrimExpr VisitAccessPtrCall(const CallNode *op) {
ICHECK_EQ(op->args.size(), 3U)
<< "tl.access_ptr expects 3 args: (BufferLoad, extent, rw_mask)";
auto visit_expr = [this](const PrimExpr &expr) {
return this->VisitExpr(expr);
};
Array<PrimExpr> args{
detail::VisitAccessPtrBase(op->args[0], visit_expr),
VisitExpr(op->args[1]),
VisitExpr(op->args[2]),
};
return Call(op->dtype, op->op, args, op->annotations, op->span);
}

PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
Expand Down Expand Up @@ -325,23 +420,25 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer {
static constexpr int kCPAsyncDstPtrArg = 0;
static constexpr int kCPAsyncSrcPtrArg = 1;

BufferLoad GetBaseLoadFromAccessPtrExpr(const PrimExpr &expr) {
std::optional<AccessPtrInfo> TryGetAccessPtrInfo(const PrimExpr &expr) {
const auto *ptr_call = expr.as<CallNode>();
ICHECK(ptr_call) << "cp.async expects access_ptr arguments, but got "
<< expr;
if (!ptr_call) {
return std::nullopt;
}

if (ptr_call->op.same_as(tl::access_ptr())) {
ICHECK_EQ(ptr_call->args.size(), 3U)
<< "tl.access_ptr expects 3 arguments, but got " << ptr_call->args;
const auto *base_load = ptr_call->args[0].as<BufferLoadNode>();
ICHECK(base_load) << "tl.access_ptr base must be BufferLoad, but got "
<< ptr_call->args[0];
return Downcast<BufferLoad>(ptr_call->args[0]);
return AccessPtrInfo{Downcast<BufferLoad>(ptr_call->args[0]),
ptr_call->args[2]};
}

ICHECK(ptr_call->op.same_as(builtin::tvm_access_ptr()))
<< "cp.async expects tl.access_ptr or tvm_access_ptr, but got "
<< ptr_call->op;
if (!ptr_call->op.same_as(builtin::tvm_access_ptr())) {
return std::nullopt;
}
ICHECK_EQ(ptr_call->args.size(), 5U)
<< "tvm_access_ptr expects 5 arguments, but got " << ptr_call->args;
const auto *var = ptr_call->args[1].as<VarNode>();
Expand All @@ -352,7 +449,43 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer {
<< "Buffer data var " << buffer_data
<< " is not registered in buffer_data_to_buffer_.";
Buffer flat = buffer_data_to_buffer_[buffer_data].GetFlattenedBuffer();
return BufferLoad(flat, Array<PrimExpr>{ptr_call->args[2]});
return AccessPtrInfo{BufferLoad(flat, Array<PrimExpr>{ptr_call->args[2]}),
ptr_call->args[4]};
}

AccessPtrInfo GetRequiredAccessPtrInfo(const PrimExpr &expr,
const char *context) {
std::optional<AccessPtrInfo> info = TryGetAccessPtrInfo(expr);
ICHECK(info.has_value())
<< context << " expects tl.access_ptr or tvm_access_ptr, got " << expr;
return info.value();
}

Array<PrimExpr> CollectAccessPtrConditions(const PrimExpr &expr,
int required_mask) {
Array<PrimExpr> conditions;
std::optional<AccessPtrInfo> info = TryGetAccessPtrInfo(expr);
if (!info.has_value() || !AccessMaskMayUse(info->rw_mask, required_mask)) {
return conditions;
}

SafeMemChecker checker(analyzer_, /*recursively_collect_conds=*/false);
bool is_load = (GetConstAccessMask(info->rw_mask) & kAccessRead) != 0;
checker.CheckBufferIndices(info->base_load->buffer,
info->base_load->indices, is_load,
/*throw_warning=*/false);
return checker.GetConditions();
}

Array<PrimExpr> CollectCallAccessPtrConditions(const Call &call) {
Array<PrimExpr> conditions;
for (const PrimExpr &arg : call->args) {
for (const PrimExpr &cond :
CollectAccessPtrConditions(arg, kAccessReadWrite)) {
conditions.push_back(cond);
}
}
return conditions;
}

bool NeedsEvaluateBoundaryCheck(const Call &call) {
Expand All @@ -364,21 +497,25 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer {
ICHECK_GE(call->args.size(), 3U)
<< "cp.async expects at least 3 arguments, but got " << call->args;
Array<PrimExpr> conditions;
BufferLoad src_base_load =
GetBaseLoadFromAccessPtrExpr(call->args[kCPAsyncSrcPtrArg]);
AccessPtrInfo src_info =
GetRequiredAccessPtrInfo(call->args[kCPAsyncSrcPtrArg], "cp.async");
if (!AccessMaskMayUse(src_info.rw_mask, kAccessRead)) {
return conditions;
}

SafeMemChecker checker(analyzer_, /*recursively_collect_conds=*/false);
checker.CheckBufferIndices(src_base_load->buffer, src_base_load->indices,
checker.CheckBufferIndices(src_info.base_load->buffer,
src_info.base_load->indices,
/*is_load=*/true, /*throw_warning=*/false);
return checker.GetConditions();
}

Buffer GetCPAsyncSourceBuffer(const Call &call) {
ICHECK_GE(call->args.size(), 3U)
<< "cp.async expects at least 3 arguments, but got " << call->args;
BufferLoad src_base_load =
GetBaseLoadFromAccessPtrExpr(call->args[kCPAsyncSrcPtrArg]);
return src_base_load->buffer;
AccessPtrInfo src_info =
GetRequiredAccessPtrInfo(call->args[kCPAsyncSrcPtrArg], "cp.async");
return src_info.base_load->buffer;
}

PrimExpr CombineConditions(const Array<PrimExpr> &conditions) {
Expand All @@ -405,15 +542,15 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer {

ICHECK_GE(call->args.size(), 3U)
<< "cp.async expects at least 3 arguments, but got " << call->args;
BufferLoad dst_base_load =
GetBaseLoadFromAccessPtrExpr(call->args[kCPAsyncDstPtrArg]);
AccessPtrInfo dst_info =
GetRequiredAccessPtrInfo(call->args[kCPAsyncDstPtrArg], "cp.async");
Buffer src_buffer = GetCPAsyncSourceBuffer(call);

PrimExpr combined = CombineConditions(conditions);
Optional<PrimExpr> existing_predicate = GetCPAsyncPredicate(call);

PrimExpr safe_value = GetSafeValue(src_buffer);
DataType dst_dtype = dst_base_load->buffer->dtype;
DataType dst_dtype = dst_info.base_load->buffer->dtype;
if (safe_value.dtype() != dst_dtype) {
safe_value = Cast(dst_dtype, safe_value);
}
Expand All @@ -434,8 +571,8 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer {
Call(call->dtype, call->op, new_args, call->annotations, call->span));
}

Stmt else_case =
BufferStore(dst_base_load->buffer, safe_value, dst_base_load->indices);
Stmt else_case = BufferStore(dst_info.base_load->buffer, safe_value,
dst_info.base_load->indices);
return IfThenElse(combined, evaluate, else_case);
}

Expand Down
17 changes: 10 additions & 7 deletions src/transform/lower_access_ptr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <tvm/tirx/transform.h>

#include "../op/builtin.h"
#include "common/access_ptr_utils.h"

namespace tvm {
namespace tl {
Expand Down Expand Up @@ -79,20 +80,22 @@ PrimExpr LinearOffsetFromLoad(const BufferLoad &load) {
class AccessPtrLowerer : public StmtExprMutator {
public:
PrimExpr VisitExpr_(const CallNode *op) final {
Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (!call->op.same_as(tl::access_ptr())) {
return std::move(call);
if (!op->op.same_as(tl::access_ptr())) {
return StmtExprMutator::VisitExpr_(op);
}

ICHECK_EQ(call->args.size(), 3U)
ICHECK_EQ(op->args.size(), 3U)
<< "tl.access_ptr expects 3 args: (BufferLoad, extent, rw_mask)";

BufferLoad base_load = Downcast<BufferLoad>(call->args[0]);
auto visit_expr = [this](const PrimExpr &expr) {
return this->VisitExpr(expr);
};
BufferLoad base_load = detail::VisitAccessPtrBase(op->args[0], visit_expr);
Buffer buffer = base_load->buffer;
ICHECK(buffer.defined());

PrimExpr extent = call->args[1];
PrimExpr rw_mask = call->args[2];
PrimExpr extent = VisitExpr(op->args[1]);
PrimExpr rw_mask = VisitExpr(op->args[2]);

PrimExpr ptype = tirx::TypeAnnotation(buffer->dtype);
PrimExpr data = buffer->data;
Expand Down
Loading
Loading