Skip to content

Commit 0a3fe22

Browse files
authored
[Relax] Enhance symbolic expr estimation in memory planning (#16872)
This PR enhances the symbolic expression upper bound estimation in static memory planning. Prior to this PR, we are not able to estimate the upper bound of `a * b` when `a` has an upper bound while `b` does not. This PR enhances the estimation with arith::IntSet. We introduce another TIR attribute `tir_non_negative_var` to indicate the non-negative TIR variables for memory planning use. A new unit test is introduced for this enhancement.
1 parent 3f09e7f commit 0a3fe22

2 files changed

Lines changed: 137 additions & 10 deletions

File tree

src/relax/transform/static_plan_block_memory.cc

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -353,16 +353,21 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
353353
* the input function signature in the analyzer.
354354
* \param func The function to be analyzed.
355355
* \param ana The analyzer which contains the TIR var upper bounds.
356+
* \param dom_map The domain map of the TIR variables.
356357
*/
357-
void SetTIRVarUpperBound(Function func, arith::Analyzer* ana) {
358+
void SetTIRVarUpperBound(Function func, arith::Analyzer* ana,
359+
Map<tir::Var, arith::IntSet>* dom_map) {
358360
// Use the attribute-annotated TIR var upper bounds as the TIR var values for
359361
// memory planning.
360362
// NOTE: we only apply the annotated upper bounds to the TIR variables that
361363
// appear in the **function signature**.
362364
Map<ObjectRef, ObjectRef> var_upper_bound_attr_raw =
363365
func->GetAttr<Map<ObjectRef, ObjectRef>>("tir_var_upper_bound")
364366
.value_or(Map<ObjectRef, ObjectRef>());
367+
Array<ObjectRef> non_negative_var_attr_raw =
368+
func->GetAttr<Array<ObjectRef>>("tir_non_negative_var").value_or(Array<ObjectRef>());
365369
std::unordered_map<String, IntImm> var_upper_bound_attr;
370+
std::unordered_set<String> non_negative_var_attr;
366371
// We manually check the value type to ensure the values are all positive IntImm.
367372
for (auto it : var_upper_bound_attr_raw) {
368373
const auto* key = it.first.as<StringObj>();
@@ -378,13 +383,23 @@ void SetTIRVarUpperBound(Function func, arith::Analyzer* ana) {
378383
<< value->value << " is got.";
379384
var_upper_bound_attr[GetRef<String>(key)] = GetRef<IntImm>(value);
380385
}
386+
for (ObjectRef var_name : non_negative_var_attr_raw) {
387+
const auto* key = var_name.as<StringObj>();
388+
CHECK(key != nullptr) << "The element of attr `tir_non_negative_var` should be string. However "
389+
<< key->GetTypeKey() << " is got.";
390+
non_negative_var_attr.insert(GetRef<String>(key));
391+
}
381392
Array<tir::Var> var_in_signature = TIRVarsInStructInfo(GetStructInfo(func));
382393
for (const tir::Var& tir_var : var_in_signature) {
383394
auto it = var_upper_bound_attr.find(tir_var->name_hint);
384395
if (it != var_upper_bound_attr.end()) {
385-
ana->Bind(tir_var,
386-
tvm::Range::FromMinExtent(tvm::IntImm(DataType::Int(64), 0),
387-
tvm::IntImm(DataType::Int(64), (*it).second->value + 1)));
396+
tvm::Range range =
397+
tvm::Range::FromMinExtent(tvm::IntImm(DataType::Int(64), 0),
398+
tvm::IntImm(DataType::Int(64), (*it).second->value + 1));
399+
ana->Bind(tir_var, range);
400+
dom_map->Set(tir_var, arith::IntSet::FromRange(range));
401+
} else if (non_negative_var_attr.count(tir_var->name_hint)) {
402+
ana->MarkGlobalNonNegValue(tir_var);
388403
}
389404
}
390405
}
@@ -398,14 +413,20 @@ void SetTIRVarUpperBound(Function func, arith::Analyzer* ana) {
398413
* \return The upper-bounded shape. When a dimension's upper bound
399414
* cannot be determined, we keep the dimension unchanged.
400415
*/
401-
Array<PrimExpr> GetUpperBoundShape(Array<PrimExpr> shape, arith::Analyzer* ana) {
416+
Array<PrimExpr> GetUpperBoundShape(Array<PrimExpr> shape, arith::Analyzer* ana,
417+
const Map<tir::Var, arith::IntSet>& dom_map) {
402418
// Use the upper bounds of TIR vars as their values.
403419
Array<PrimExpr> upper_bounded_shape;
404420
upper_bounded_shape.reserve(shape.size());
405421
for (const PrimExpr& dim_len : shape) {
406422
int64_t max_bound = ana->const_int_bound(dim_len)->max_value;
407423
if (max_bound == std::numeric_limits<int64_t>::max()) {
408-
upper_bounded_shape.push_back(dim_len);
424+
arith::IntSet int_set = ana->int_set(dim_len, dom_map);
425+
if (int_set.HasUpperBound()) {
426+
upper_bounded_shape.push_back(int_set.max());
427+
} else {
428+
upper_bounded_shape.push_back(dim_len);
429+
}
409430
} else {
410431
upper_bounded_shape.push_back(tvm::IntImm(DataType::Int(64), max_bound));
411432
}
@@ -462,7 +483,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
462483

463484
void VisitExpr_(const FunctionNode* func) final {
464485
// Set the upper bound of TIR variables in the analyzer.
465-
SetTIRVarUpperBound(GetRef<Function>(func), analyzer_);
486+
SetTIRVarUpperBound(GetRef<Function>(func), analyzer_, &dom_map_);
466487
// Recurse into the function to get its tokens.
467488
Tokens body_tokens = GetTokens(func->body);
468489
// Discard the tokens used by the function return value, as they are external referenced.
@@ -565,7 +586,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
565586

566587
// Use the upper bounds of TIR vars as their values. The upper bound shape can still be dynamic
567588
// if the upper bounds of some variables are not provided.
568-
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, analyzer_);
589+
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, analyzer_, dom_map_);
569590

570591
// Create and set token.
571592
StringImm storage_scope = Downcast<StringImm>(call->args[3]);
@@ -641,6 +662,8 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
641662
const IRModule& ctx_mod_;
642663
/*! \brief The arithmetic analyzer. */
643664
arith::Analyzer* analyzer_;
665+
/*! \brief The domain map of dynamic TIR variables for analysis. */
666+
Map<tir::Var, arith::IntSet> dom_map_;
644667
/*! \brief The mapping from each token to the binding block where it is created. */
645668
std::unordered_map<const StorageTokenNode*, const BindingBlockNode*> token2block_;
646669
/*! \brief The mapping from each token to the Exprs that are using this token. */
@@ -816,7 +839,7 @@ class StorageAllocationRewriter : public ExprMutator {
816839
plan_dynamic_output_ = static_cast<bool>(
817840
func_->GetAttr<IntImm>(plan_dyn_attr_).value_or(IntImm(DataType::Int(32), 0))->value);
818841
if (plan_dynamic_output_) {
819-
SetTIRVarUpperBound(GetRef<Function>(func_), &ana_);
842+
SetTIRVarUpperBound(GetRef<Function>(func_), &ana_, &dom_map_);
820843
}
821844
token2storage_var_.clear();
822845
Function func = Downcast<Function>(this->VisitExpr_(func_));
@@ -879,7 +902,7 @@ class StorageAllocationRewriter : public ExprMutator {
879902
ICHECK_NOTNULL(sinfo);
880903
const auto* shape = sinfo->shape.as<ShapeExprNode>();
881904
ICHECK_NOTNULL(shape);
882-
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_);
905+
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_, dom_map_);
883906
if (!IsStaticShape(shape->values)) {
884907
ICHECK(!sinfo->IsUnknownDtype());
885908
ICHECK_EQ(sinfo->dtype, Downcast<DataTypeImm>(call->args[1])->value);
@@ -906,6 +929,8 @@ class StorageAllocationRewriter : public ExprMutator {
906929

907930
/*! \brief The arithmetic analyzer. */
908931
arith::Analyzer ana_;
932+
/*! \brief The domain map of dynamic TIR variables for analysis. */
933+
Map<tir::Var, arith::IntSet> dom_map_;
909934
/*! \brief A boolean indicating whether to plan dynamic-shape function output tensors. */
910935
bool plan_dynamic_output_;
911936
/*!

tests/python/relax/test_transform_static_plan_block_memory.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,5 +1347,107 @@ def main(x: R.Tensor((2, "n"), dtype="float32")):
13471347
relax.transform.StaticPlanBlockMemory()(Module)
13481348

13491349

1350+
def test_add():
1351+
@I.ir_module
1352+
class Module:
1353+
@T.prim_func(private=True)
1354+
def cumsum(var_A: T.handle, var_A_1: T.handle, var_exclusive_scan_thrust: T.handle):
1355+
T.evaluate(0)
1356+
1357+
@R.function
1358+
def main(
1359+
probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32")
1360+
) -> R.Tensor(("batch_size", "vocab_size"), dtype="float32"):
1361+
batch_size = T.int64()
1362+
vocab_size = T.int64()
1363+
R.func_attr(
1364+
{
1365+
"relax.force_pure": 1,
1366+
"relax.memory_plan_dynamic_func_output": 1,
1367+
"tir_var_upper_bound": {"batch_size": 32},
1368+
"tir_non_negative_var": ["vocab_size"],
1369+
}
1370+
)
1371+
cls = Module
1372+
lv1: R.Tensor(
1373+
(2 * (batch_size * vocab_size * 4) + 4194304,),
1374+
dtype="uint8",
1375+
) = R.builtin.alloc_tensor(
1376+
R.shape([2 * (batch_size * vocab_size * 4) + 4194304]),
1377+
R.dtype("uint8"),
1378+
R.prim_value(0),
1379+
R.str("global"),
1380+
)
1381+
alloc1: R.Tensor((batch_size, vocab_size), dtype="float32") = R.builtin.alloc_tensor(
1382+
R.shape([batch_size, vocab_size]),
1383+
R.dtype("float32"),
1384+
R.prim_value(0),
1385+
R.str("global"),
1386+
)
1387+
cls.cumsum(probs, lv1, alloc1)
1388+
cumsum: R.Tensor((batch_size, vocab_size), dtype="float32") = alloc1
1389+
lv1_1: R.Tensor((batch_size, vocab_size), dtype="int32") = R.call_packed(
1390+
"vm.builtin.reshape",
1391+
cumsum,
1392+
R.shape([batch_size, vocab_size]),
1393+
sinfo_args=(R.Tensor((batch_size, vocab_size), dtype="float"),),
1394+
)
1395+
return lv1_1
1396+
1397+
@I.ir_module
1398+
class Expected:
1399+
@T.prim_func(private=True)
1400+
def cumsum(var_A: T.handle, var_A_1: T.handle, var_exclusive_scan_thrust: T.handle):
1401+
T.evaluate(0)
1402+
1403+
@R.function
1404+
def main(
1405+
probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32")
1406+
) -> R.Tensor(("batch_size", "vocab_size"), dtype="int32"):
1407+
batch_size = T.int64()
1408+
vocab_size = T.int64()
1409+
R.func_attr(
1410+
{
1411+
"relax.force_pure": 1,
1412+
"tir_non_negative_var": ["vocab_size"],
1413+
"tir_var_upper_bound": {"batch_size": 32},
1414+
}
1415+
)
1416+
cls = Expected
1417+
storage: R.Object = R.memory.alloc_storage(
1418+
R.shape([32 * vocab_size * 4 * 2 + 4194304]),
1419+
R.prim_value(0),
1420+
R.str("global"),
1421+
R.dtype("uint8"),
1422+
)
1423+
lv1: R.Tensor(
1424+
(2 * (batch_size * vocab_size * 4) + 4194304,),
1425+
dtype="uint8",
1426+
) = R.memory.alloc_tensor(
1427+
storage,
1428+
R.prim_value(0),
1429+
R.shape([2 * (batch_size * vocab_size * 4) + 4194304]),
1430+
R.dtype("uint8"),
1431+
)
1432+
storage1: R.Object = R.memory.alloc_storage(
1433+
R.shape([128 * vocab_size]), R.prim_value(0), R.str("global"), R.dtype("float32")
1434+
)
1435+
alloc1: R.Tensor((batch_size, vocab_size), dtype="float32") = R.memory.alloc_tensor(
1436+
storage1, R.prim_value(0), R.shape([batch_size, vocab_size]), R.dtype("float32")
1437+
)
1438+
cls.cumsum(probs, lv1, alloc1)
1439+
cumsum: R.Tensor((batch_size, vocab_size), dtype="float32") = alloc1
1440+
lv1_1: R.Tensor((batch_size, vocab_size), dtype="int32") = R.call_packed(
1441+
"vm.builtin.reshape",
1442+
cumsum,
1443+
R.shape([batch_size, vocab_size]),
1444+
sinfo_args=(R.Tensor((batch_size, vocab_size), dtype="float32"),),
1445+
)
1446+
return lv1_1
1447+
1448+
mod = relax.transform.StaticPlanBlockMemory()(Module)
1449+
tvm.ir.assert_structural_equal(mod, Expected)
1450+
1451+
13501452
if __name__ == "__main__":
13511453
tvm.testing.main()

0 commit comments

Comments
 (0)