@@ -353,16 +353,21 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
353353 * the input function signature in the analyzer.
354354 * \param func The function to be analyzed.
355355 * \param ana The analyzer which contains the TIR var upper bounds.
356+ * \param dom_map The domain map of the TIR variables.
356357 */
357- void SetTIRVarUpperBound (Function func, arith::Analyzer* ana) {
358+ void SetTIRVarUpperBound (Function func, arith::Analyzer* ana,
359+ Map<tir::Var, arith::IntSet>* dom_map) {
358360 // Use the attribute-annotated TIR var upper bounds as the TIR var values for
359361 // memory planning.
360362 // NOTE: we only apply the annotated upper bounds to the TIR variables that
361363 // appear in the **function signature**.
362364 Map<ObjectRef, ObjectRef> var_upper_bound_attr_raw =
363365 func->GetAttr <Map<ObjectRef, ObjectRef>>(" tir_var_upper_bound" )
364366 .value_or (Map<ObjectRef, ObjectRef>());
367+ Array<ObjectRef> non_negative_var_attr_raw =
368+ func->GetAttr <Array<ObjectRef>>(" tir_non_negative_var" ).value_or (Array<ObjectRef>());
365369 std::unordered_map<String, IntImm> var_upper_bound_attr;
370+ std::unordered_set<String> non_negative_var_attr;
366371 // We manually check the value type to ensure the values are all positive IntImm.
367372 for (auto it : var_upper_bound_attr_raw) {
368373 const auto * key = it.first .as <StringObj>();
@@ -378,13 +383,23 @@ void SetTIRVarUpperBound(Function func, arith::Analyzer* ana) {
378383 << value->value << " is got." ;
379384 var_upper_bound_attr[GetRef<String>(key)] = GetRef<IntImm>(value);
380385 }
386+ for (ObjectRef var_name : non_negative_var_attr_raw) {
387+ const auto * key = var_name.as <StringObj>();
388+ CHECK (key != nullptr ) << " The element of attr `tir_non_negative_var` should be string. However "
389+ << key->GetTypeKey () << " is got." ;
390+ non_negative_var_attr.insert (GetRef<String>(key));
391+ }
381392 Array<tir::Var> var_in_signature = TIRVarsInStructInfo (GetStructInfo (func));
382393 for (const tir::Var& tir_var : var_in_signature) {
383394 auto it = var_upper_bound_attr.find (tir_var->name_hint );
384395 if (it != var_upper_bound_attr.end ()) {
385- ana->Bind (tir_var,
386- tvm::Range::FromMinExtent (tvm::IntImm (DataType::Int (64 ), 0 ),
387- tvm::IntImm (DataType::Int (64 ), (*it).second ->value + 1 )));
396+ tvm::Range range =
397+ tvm::Range::FromMinExtent (tvm::IntImm (DataType::Int (64 ), 0 ),
398+ tvm::IntImm (DataType::Int (64 ), (*it).second ->value + 1 ));
399+ ana->Bind (tir_var, range);
400+ dom_map->Set (tir_var, arith::IntSet::FromRange (range));
401+ } else if (non_negative_var_attr.count (tir_var->name_hint )) {
402+ ana->MarkGlobalNonNegValue (tir_var);
388403 }
389404 }
390405}
@@ -398,14 +413,20 @@ void SetTIRVarUpperBound(Function func, arith::Analyzer* ana) {
398413 * \return The upper-bounded shape. When a dimension's upper bound
399414 * cannot be determined, we keep the dimension unchanged.
400415 */
401- Array<PrimExpr> GetUpperBoundShape (Array<PrimExpr> shape, arith::Analyzer* ana) {
416+ Array<PrimExpr> GetUpperBoundShape (Array<PrimExpr> shape, arith::Analyzer* ana,
417+ const Map<tir::Var, arith::IntSet>& dom_map) {
402418 // Use the upper bounds of TIR vars as their values.
403419 Array<PrimExpr> upper_bounded_shape;
404420 upper_bounded_shape.reserve (shape.size ());
405421 for (const PrimExpr& dim_len : shape) {
406422 int64_t max_bound = ana->const_int_bound (dim_len)->max_value ;
407423 if (max_bound == std::numeric_limits<int64_t >::max ()) {
408- upper_bounded_shape.push_back (dim_len);
424+ arith::IntSet int_set = ana->int_set (dim_len, dom_map);
425+ if (int_set.HasUpperBound ()) {
426+ upper_bounded_shape.push_back (int_set.max ());
427+ } else {
428+ upper_bounded_shape.push_back (dim_len);
429+ }
409430 } else {
410431 upper_bounded_shape.push_back (tvm::IntImm (DataType::Int (64 ), max_bound));
411432 }
@@ -462,7 +483,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
462483
463484 void VisitExpr_ (const FunctionNode* func) final {
464485 // Set the upper bound of TIR variables in the analyzer.
465- SetTIRVarUpperBound (GetRef<Function>(func), analyzer_);
486+ SetTIRVarUpperBound (GetRef<Function>(func), analyzer_, &dom_map_ );
466487 // Recurse into the function to get its tokens.
467488 Tokens body_tokens = GetTokens (func->body );
468489 // Discard the tokens used by the function return value, as they are external referenced.
@@ -565,7 +586,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
565586
566587 // Use the upper bounds of TIR vars as their values. The upper bound shape can still be dynamic
567588 // if the upper bounds of some variables are not provided.
568- Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape (shape->values , analyzer_);
589+ Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape (shape->values , analyzer_, dom_map_ );
569590
570591 // Create and set token.
571592 StringImm storage_scope = Downcast<StringImm>(call->args [3 ]);
@@ -641,6 +662,8 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
641662 const IRModule& ctx_mod_;
642663 /* ! \brief The arithmetic analyzer. */
643664 arith::Analyzer* analyzer_;
665+ /* ! \brief The domain map of dynamic TIR variables for analysis. */
666+ Map<tir::Var, arith::IntSet> dom_map_;
644667 /* ! \brief The mapping from each token to the binding block where it is created. */
645668 std::unordered_map<const StorageTokenNode*, const BindingBlockNode*> token2block_;
646669 /* ! \brief The mapping from each token to the Exprs that are using this token. */
@@ -816,7 +839,7 @@ class StorageAllocationRewriter : public ExprMutator {
816839 plan_dynamic_output_ = static_cast <bool >(
817840 func_->GetAttr <IntImm>(plan_dyn_attr_).value_or (IntImm (DataType::Int (32 ), 0 ))->value );
818841 if (plan_dynamic_output_) {
819- SetTIRVarUpperBound (GetRef<Function>(func_), &ana_);
842+ SetTIRVarUpperBound (GetRef<Function>(func_), &ana_, &dom_map_ );
820843 }
821844 token2storage_var_.clear ();
822845 Function func = Downcast<Function>(this ->VisitExpr_ (func_));
@@ -879,7 +902,7 @@ class StorageAllocationRewriter : public ExprMutator {
879902 ICHECK_NOTNULL (sinfo);
880903 const auto * shape = sinfo->shape .as <ShapeExprNode>();
881904 ICHECK_NOTNULL (shape);
882- Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape (shape->values , &ana_);
905+ Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape (shape->values , &ana_, dom_map_ );
883906 if (!IsStaticShape (shape->values )) {
884907 ICHECK (!sinfo->IsUnknownDtype ());
885908 ICHECK_EQ (sinfo->dtype , Downcast<DataTypeImm>(call->args [1 ])->value );
@@ -906,6 +929,8 @@ class StorageAllocationRewriter : public ExprMutator {
906929
907930 /* ! \brief The arithmetic analyzer. */
908931 arith::Analyzer ana_;
932+ /* ! \brief The domain map of dynamic TIR variables for analysis. */
933+ Map<tir::Var, arith::IntSet> dom_map_;
909934 /* ! \brief A boolean indicating whether to plan dynamic-shape function output tensors. */
910935 bool plan_dynamic_output_;
911936 /* !
0 commit comments