Skip to content

Commit ad58d3e

Browse files
authored
[REFACTOR][TIRX] Add IntImm common scalar ctor and streamline MakeConst (#19797)
## Summary Common bool, int32, and int64 scalar constants show up throughout TIRX and related lowering code. Named constructors make these call sites easier to read than repeated `DataType` spelling, and avoid routing obvious scalar constants through the generic `MakeConst` helper. ## Usage guideline Prefer direct `IntImm` or `FloatImm` construction when dtype is known to be scalar integer or floating point. This makes the compiled code more compact and efficient. Keep `MakeConst` for generic overload cases where dtype can be integer, floating point, or vector-valued and the caller needs its scalar/vector dispatch. This PR establishes the scalar-constant construction policy: - Prefer `IntImm::Bool`, `IntImm::Int32`, and `IntImm::Int64` for common known scalar bool, int32, and int64 constants. - Prefer direct `IntImm` or `FloatImm` construction when dtype is known to be scalar integer or floating point. - Keep `MakeConst` for generic overload cases where dtype can be integer, floating point, or vector-valued and the caller needs its scalar/vector dispatch. - Phase out `make_zero` in favor of explicit scalar constructors, or `ConstHandle(0)` for null handles. ## Changes - Add `IntImm::Bool`, `IntImm::Int32`, and `IntImm::Int64` helpers. - Rename `make_const` to `MakeConst` and document it as the generic/vector construction helper. - Migrate deterministic scalar constant construction to the clearer APIs while keeping generic and vector-aware paths on `MakeConst`. - Remove `make_zero`, `const_true`, `const_false`, and the unused `tirx.const_true` registry entry.
1 parent 35a35b8 commit ad58d3e

183 files changed

Lines changed: 1219 additions & 1269 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

include/tvm/ir/expr.h

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,33 @@ class IntImm : public PrimExpr {
515515
*/
516516
TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span());
517517

518+
/*!
519+
* \brief Construct a scalar boolean constant.
520+
* \param value The boolean value.
521+
* \param span The location of this object in the source code.
522+
*/
523+
static IntImm Bool(bool value, Span span = Span()) {
524+
return IntImm(DataType::Bool(), value, span);
525+
}
526+
527+
/*!
528+
* \brief Construct a scalar int32 constant.
529+
* \param value The integer value.
530+
* \param span The location of this object in the source code.
531+
*/
532+
static IntImm Int32(int64_t value, Span span = Span()) {
533+
return IntImm(DataType::Int(32), value, span);
534+
}
535+
536+
/*!
537+
* \brief Construct a scalar int64 constant.
538+
* \param value The integer value.
539+
* \param span The location of this object in the source code.
540+
*/
541+
static IntImm Int64(int64_t value, Span span = Span()) {
542+
return IntImm(DataType::Int(64), value, span);
543+
}
544+
518545
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IntImm, PrimExpr, IntImmNode);
519546
TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode);
520547
};
@@ -636,7 +663,7 @@ struct TypeTraits<FloatImm> : public ObjectRefWithFallbackTraitsBase<FloatImm, d
636663

637664
// define automatic conversion from bool, int64_t, double to PrimExpr
638665
TVM_FFI_INLINE PrimExpr TypeTraits<PrimExpr>::ConvertFallbackValue(StrictBool value) {
639-
return IntImm(DataType::Bool(), value, Span());
666+
return IntImm::Bool(value);
640667
}
641668

642669
TVM_FFI_INLINE PrimExpr TypeTraits<PrimExpr>::ConvertFallbackValue(int64_t value) {

include/tvm/ir/node_functor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ namespace tvm {
4747
* return prefix + "IntImm"
4848
* });
4949
*
50-
* Expr x = make_const(1);
50+
* Expr x = MakeConst(1);
5151
* Expr y = x + x;
5252
* // dispatch to IntImm, outputs "MyIntImm"
5353
* LOG(INFO) << tostr(x, "My");

include/tvm/script/printer/doc.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,15 +277,15 @@ class LiteralDoc : public ExprDoc {
277277
* \param p The object path
278278
*/
279279
static LiteralDoc Int(int64_t v, const ffi::Optional<AccessPath>& p) {
280-
return LiteralDoc(IntImm(DataType::Int(64), v), p);
280+
return LiteralDoc(IntImm::Int64(v), p);
281281
}
282282
/*!
283283
* \brief Create a LiteralDoc to represent boolean.
284284
* \param v The boolean value.
285285
* \param p The object path
286286
*/
287287
static LiteralDoc Boolean(bool v, const ffi::Optional<AccessPath>& p) {
288-
return LiteralDoc(IntImm(DataType::Bool(), v), p);
288+
return LiteralDoc(IntImm::Bool(v), p);
289289
}
290290
/*!
291291
* \brief Create a LiteralDoc to represent float.

include/tvm/tirx/buffer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ class Buffer : public ffi::ObjectRef {
206206
* \param input_extent The extent of ptr.
207207
*/
208208
TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(),
209-
int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0),
209+
int content_lanes = 1, PrimExpr offset = IntImm::Int32(0),
210210
ffi::Optional<PrimExpr> input_extent = std::nullopt) const;
211211
/*!
212212
* \brief Create an Expr that does a vector load at begin index.

include/tvm/tirx/op.h

Lines changed: 48 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ namespace tvm {
5252
// Most common operators can be overloaded by argument type(PrimExpr).
5353
// So we put them under the root namespace.
5454
//
55-
// We put more developer oriented APIs -- make_const and is_const under tirx
55+
// We put more developer oriented APIs -- MakeConst and is_const under tirx
5656
// as they are more specific to the tirx namespace.
5757

5858
/*!
@@ -816,7 +816,14 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) {
816816

817817
/*!
818818
* \brief Make a const value with certain data type.
819-
* \param t The target type.
819+
*
820+
* Prefer direct IntImm or FloatImm construction when dtype is known to be
821+
* scalar integer or floating point. This makes the compiled code more compact
822+
* and efficient. Keep MakeConst for generic overload cases where dtype can be
823+
* integer, floating point, or vector-valued and the caller needs its
824+
* scalar/vector dispatch.
825+
*
826+
* \param dtype The target type.
820827
* \param value The input value
821828
* \return the result expression.
822829
* \tparam ValueType The constant value type
@@ -825,32 +832,14 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) {
825832
template <typename ValueType,
826833
typename = typename std::enable_if<std::is_standard_layout<ValueType>::value &&
827834
std::is_trivial<ValueType>::value>::type>
828-
inline PrimExpr make_const(DataType t, ValueType value, Span span = Span());
829-
/*!
830-
* \brief Make a const zero expr.
831-
* \param t The target type.
832-
* \param span The location of this operation in the source.
833-
* \return the result expression.
834-
*/
835-
inline PrimExpr make_zero(DataType t, Span span = Span());
836-
/*!
837-
* \brief Make a constant true expression.
838-
* \param lanes The number of lanes in the bool
839-
* \param span The location of this operation in the source.
840-
* \return The result expression.
841-
*/
842-
inline PrimExpr const_true(int lanes = 1, Span span = Span()) {
843-
return make_const(DataType::Bool(lanes), 1);
844-
}
835+
inline PrimExpr MakeConst(DataType dtype, ValueType value, Span span = Span());
845836
/*!
846-
* \brief Make a constant false expression.
847-
* \param lanes The number of lanes in the bool
837+
* \brief Make a constant handle value.
838+
* \param value The integer payload to reinterpret as a handle.
848839
* \param span The location of this operation in the source.
849840
* \return The result expression.
850841
*/
851-
inline PrimExpr const_false(int lanes = 1, Span span = Span()) {
852-
return make_const(DataType::Bool(lanes), 0);
853-
}
842+
inline PrimExpr ConstHandle(int64_t value, Span span = Span());
854843
/*!
855844
* \brief Get x as constant int expression.
856845
* \param x The expression
@@ -981,53 +970,52 @@ inline bool is_no_op(const tirx::Stmt& stmt) {
981970
}
982971

983972
template <typename ValueType>
984-
inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) {
985-
if (t.is_int() || t.is_bool()) return IntImm(t, static_cast<int64_t>(value), span);
986-
if (t.is_uint()) {
973+
inline PrimExpr MakeConstScalar(DataType dtype, ValueType value, Span span = Span()) {
974+
if (dtype.is_int() || dtype.is_bool()) return IntImm(dtype, static_cast<int64_t>(value), span);
975+
if (dtype.is_uint()) {
987976
// Use IntImm if it is a small integer
988977
uint64_t uval = static_cast<uint64_t>(value);
989978
if (value < static_cast<ValueType>(0)) {
990979
TVM_FFI_THROW(InternalError) << "cannot make uint from negative value " << value;
991980
} else if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
992-
return IntImm(t, static_cast<int64_t>(value), span);
981+
return IntImm(dtype, static_cast<int64_t>(value), span);
993982
} else {
994983
uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U;
995984
uint64_t low = uval & mask;
996985
uint64_t high = uval >> 32U;
997-
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
986+
return LargeUIntImm(dtype, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
998987
}
999988
}
1000-
if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float6() || t.is_float4())
1001-
return FloatImm(t, static_cast<double>(value), span);
1002-
TVM_FFI_THROW(InternalError) << "cannot make const for type " << t;
989+
if (dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float6() ||
990+
dtype.is_float4()) {
991+
return FloatImm(dtype, static_cast<double>(value), span);
992+
}
993+
TVM_FFI_THROW(InternalError) << "cannot make const for type " << dtype;
1003994
throw;
1004995
}
1005996

1006997
template <>
1007-
inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) {
1008-
return MakeConstScalar(t, static_cast<int>(value), span);
998+
inline PrimExpr MakeConstScalar(DataType dtype, bool value, Span span) {
999+
return MakeConstScalar(dtype, static_cast<int>(value), span);
10091000
}
10101001

10111002
template <typename ValueType, typename>
1012-
inline PrimExpr make_const(DataType t, ValueType value, Span span) {
1013-
if (t.is_scalar()) {
1014-
return MakeConstScalar(t, value, span);
1003+
inline PrimExpr MakeConst(DataType dtype, ValueType value, Span span) {
1004+
if (dtype.is_scalar()) {
1005+
return MakeConstScalar(dtype, value, span);
10151006
} else {
1016-
if (t.is_fixed_length_vector()) {
1017-
return tirx::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span);
1007+
if (dtype.is_fixed_length_vector()) {
1008+
return tirx::Broadcast(MakeConstScalar(dtype.element_of(), value, span), dtype.lanes(), span);
10181009
} else {
1019-
PrimExpr lanes =
1020-
tirx::Mul(tirx::Call(DataType::Int(32), tirx::builtin::vscale(), {}), t.vscale_factor());
1021-
return tirx::Broadcast(MakeConstScalar(t.element_of(), value, span), lanes, span);
1010+
PrimExpr lanes = tirx::Mul(tirx::Call(DataType::Int(32), tirx::builtin::vscale(), {}),
1011+
dtype.vscale_factor());
1012+
return tirx::Broadcast(MakeConstScalar(dtype.element_of(), value, span), lanes, span);
10221013
}
10231014
}
10241015
}
10251016

1026-
inline PrimExpr make_zero(DataType t, Span span) {
1027-
if (t.is_handle()) {
1028-
return reinterpret(t, make_const(DataType::UInt(64), 0, span));
1029-
}
1030-
return make_const(t, 0, span);
1017+
inline PrimExpr ConstHandle(int64_t value, Span span) {
1018+
return reinterpret(DataType::Handle(), IntImm(DataType::UInt(64), value, span));
10311019
}
10321020

10331021
} // namespace tirx
@@ -1043,13 +1031,13 @@ inline PrimExpr make_zero(DataType t, Span span) {
10431031
inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \
10441032
inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \
10451033
inline PrimExpr Name(int a, const PrimExpr& b) { \
1046-
return Name(tirx::make_const(b.dtype(), a), b); \
1034+
return Name(tirx::MakeConst(b.dtype(), a), b); \
10471035
} \
10481036
inline PrimExpr Name(const PrimExpr& a, int b) { \
1049-
return Name(a, tirx::make_const(a.dtype(), b)); \
1037+
return Name(a, tirx::MakeConst(a.dtype(), b)); \
10501038
} \
10511039
inline PrimExpr Name(const PrimExpr& a, double b) { \
1052-
return Name(a, tirx::make_const(DataType::Float(64), b)); \
1040+
return Name(a, FloatImm(DataType::Float(64), b)); \
10531041
}
10541042

10551043
#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name) \
@@ -1060,13 +1048,13 @@ inline PrimExpr make_zero(DataType t, Span span) {
10601048
return Name(PrimExpr(a), b, span); \
10611049
} \
10621050
inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \
1063-
return Name(tirx::make_const(b.dtype(), a), b, span); \
1051+
return Name(tirx::MakeConst(b.dtype(), a), b, span); \
10641052
} \
10651053
inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \
1066-
return Name(a, tirx::make_const(a.dtype(), b), span); \
1054+
return Name(a, tirx::MakeConst(a.dtype(), b), span); \
10671055
} \
10681056
inline PrimExpr Name(const PrimExpr& a, double b, Span span = Span()) { \
1069-
return Name(a, tirx::make_const(DataType::Float(64), b), span); \
1057+
return Name(a, FloatImm(DataType::Float(64), b), span); \
10701058
}
10711059

10721060
#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
@@ -1081,18 +1069,18 @@ inline PrimExpr make_zero(DataType t, Span span) {
10811069
return Name(PrimExpr(a), b, span); \
10821070
}
10831071

1084-
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
1085-
inline PrimExpr Name(const PrimExpr& a, int b) { \
1086-
return Name(a, tirx::make_const(a.dtype(), b)); \
1087-
} \
1088-
inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tirx::make_const(b.dtype(), a), b); }
1072+
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
1073+
inline PrimExpr Name(const PrimExpr& a, int b) { \
1074+
return Name(a, tirx::MakeConst(a.dtype(), b)); \
1075+
} \
1076+
inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tirx::MakeConst(b.dtype(), a), b); }
10891077

10901078
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \
10911079
inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \
1092-
return Name(a, tirx::make_const(a.dtype(), b), span); \
1080+
return Name(a, tirx::MakeConst(a.dtype(), b), span); \
10931081
} \
10941082
inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \
1095-
return Name(tirx::make_const(b.dtype(), a), b, span); \
1083+
return Name(tirx::MakeConst(b.dtype(), a), b, span); \
10961084
}
10971085

10981086
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+);

include/tvm/topi/detail/broadcast.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ inline tvm::ffi::Array<tvm::PrimExpr> InputIndexFromBroadcast(
130130
// Only inject 0 here if we have not yet reached the dimension of I
131131
// (i.e. this must be a 1)
132132
if (!found && (ovars.size() - i) <= expected_dims) {
133-
ivars.push_back(tvm::tirx::make_zero(ovars[i].dtype()));
133+
ivars.push_back(tvm::IntImm(ovars[i].dtype(), 0));
134134
}
135135
}
136136
TVM_FFI_ICHECK(expected_dims == ivars.size());

include/tvm/topi/detail/extern.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,12 @@ inline PrimExpr pack_buffer(Buffer buf) {
108108
} else {
109109
strides = 0;
110110
}
111-
ffi::Array<PrimExpr> pack_args{
112-
buf->data,
113-
shape,
114-
strides,
115-
make_const(DataType::Int(32), static_cast<int64_t>(buf->shape.size())),
116-
make_const(buf->dtype, 0),
117-
buf->elem_offset};
111+
ffi::Array<PrimExpr> pack_args{buf->data,
112+
shape,
113+
strides,
114+
IntImm::Int32(static_cast<int64_t>(buf->shape.size())),
115+
MakeConst(buf->dtype, 0),
116+
buf->elem_offset};
118117
return tvm::tirx::Call(DataType::Handle(), tvm::tirx::builtin::tvm_stack_make_array(), pack_args);
119118
}
120119

include/tvm/topi/detail/strided_slice.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ inline ffi::Array<PrimExpr> StridedSliceCanonicalizeBegin(const ffi::Array<PrimE
9999
if (ishape[ax]->IsInstance<tvm::IntImmNode>()) {
100100
int64_t dim_i = GetConstInt(ishape[ax]);
101101
int64_t begin_i = CanonicalizeIndex(begin[i], dim_i, strides[i]);
102-
begin_expr.push_back(make_const(dtype, begin_i));
102+
begin_expr.push_back(MakeConst(dtype, begin_i));
103103
} else {
104104
auto idim = ishape[ax];
105-
auto b_expr = make_const(dtype, begin[i]);
105+
auto b_expr = MakeConst(dtype, begin[i]);
106106
PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr;
107107
auto s = strides[i];
108108
if (s < 0) {

0 commit comments

Comments
 (0)