Skip to content

Commit 9958ebe

Browse files
committed
[REFACTOR][TIRX] Add IntImm common scalar ctor and streamline MakeConst
Common bool, int32, and int64 scalar constants show up throughout TIRX and related lowering code, and named constructors make these call sites easier to read than repeated DataType spelling. This PR establishes the scalar-constant construction policy and renames make_const to MakeConst to match the public helper naming style. - Prefer IntImm::Bool, IntImm::Int32, and IntImm::Int64 for common known scalar bool, int32, and int64 constants. - Prefer direct IntImm or FloatImm construction when dtype is known to be scalar integer or floating point. This makes the compiled code more compact and efficient. Keep MakeConst for generic overload cases where dtype can be integer, floating point, or vector-valued and the caller needs its scalar/vector dispatch. - Phase out make_zero in favor of explicit scalar constructors, or ConstHandle(0) for null handles.
1 parent 649388e commit 9958ebe

183 files changed

Lines changed: 1219 additions & 1269 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

include/tvm/ir/expr.h

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,33 @@ class IntImm : public PrimExpr {
515515
*/
516516
TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span());
517517

518+
/*!
519+
* \brief Construct a scalar boolean constant.
520+
* \param value The boolean value.
521+
* \param span The location of this object in the source code.
522+
*/
523+
static IntImm Bool(bool value, Span span = Span()) {
524+
return IntImm(DataType::Bool(), value, span);
525+
}
526+
527+
/*!
528+
* \brief Construct a scalar int32 constant.
529+
* \param value The integer value.
530+
* \param span The location of this object in the source code.
531+
*/
532+
static IntImm Int32(int64_t value, Span span = Span()) {
533+
return IntImm(DataType::Int(32), value, span);
534+
}
535+
536+
/*!
537+
* \brief Construct a scalar int64 constant.
538+
* \param value The integer value.
539+
* \param span The location of this object in the source code.
540+
*/
541+
static IntImm Int64(int64_t value, Span span = Span()) {
542+
return IntImm(DataType::Int(64), value, span);
543+
}
544+
518545
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IntImm, PrimExpr, IntImmNode);
519546
TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode);
520547
};
@@ -636,7 +663,7 @@ struct TypeTraits<FloatImm> : public ObjectRefWithFallbackTraitsBase<FloatImm, d
636663

637664
// define automatic conversion from bool, int64_t, double to PrimExpr
638665
TVM_FFI_INLINE PrimExpr TypeTraits<PrimExpr>::ConvertFallbackValue(StrictBool value) {
639-
return IntImm(DataType::Bool(), value, Span());
666+
return IntImm::Bool(value);
640667
}
641668

642669
TVM_FFI_INLINE PrimExpr TypeTraits<PrimExpr>::ConvertFallbackValue(int64_t value) {

include/tvm/ir/node_functor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ namespace tvm {
4747
* return prefix + "IntImm"
4848
* });
4949
*
50-
* Expr x = make_const(1);
50+
* Expr x = MakeConst(1);
5151
* Expr y = x + x;
5252
* // dispatch to IntImm, outputs "MyIntImm"
5353
* LOG(INFO) << tostr(x, "My");

include/tvm/script/printer/doc.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,15 +277,15 @@ class LiteralDoc : public ExprDoc {
277277
* \param p The object path
278278
*/
279279
static LiteralDoc Int(int64_t v, const ffi::Optional<AccessPath>& p) {
280-
return LiteralDoc(IntImm(DataType::Int(64), v), p);
280+
return LiteralDoc(IntImm::Int64(v), p);
281281
}
282282
/*!
283283
* \brief Create a LiteralDoc to represent boolean.
284284
* \param v The boolean value.
285285
* \param p The object path
286286
*/
287287
static LiteralDoc Boolean(bool v, const ffi::Optional<AccessPath>& p) {
288-
return LiteralDoc(IntImm(DataType::Bool(), v), p);
288+
return LiteralDoc(IntImm::Bool(v), p);
289289
}
290290
/*!
291291
* \brief Create a LiteralDoc to represent float.

include/tvm/tirx/buffer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ class Buffer : public ffi::ObjectRef {
206206
* \param input_extent The extent of ptr.
207207
*/
208208
TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(),
209-
int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0),
209+
int content_lanes = 1, PrimExpr offset = IntImm::Int32(0),
210210
ffi::Optional<PrimExpr> input_extent = std::nullopt) const;
211211
/*!
212212
* \brief Create an Expr that does a vector load at begin index.

include/tvm/tirx/op.h

Lines changed: 48 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ namespace tvm {
5252
// Most common operators can be overloaded by argument type(PrimExpr).
5353
// So we put them under the root namespace.
5454
//
55-
// We put more developer oriented APIs -- make_const and is_const under tirx
55+
// We put more developer oriented APIs -- MakeConst and is_const under tirx
5656
// as they are more specific to the tirx namespace.
5757

5858
/*!
@@ -816,7 +816,14 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) {
816816

817817
/*!
818818
* \brief Make a const value with certain data type.
819-
* \param t The target type.
819+
*
820+
* Prefer direct IntImm or FloatImm construction when dtype is known to be
821+
* scalar integer or floating point. This makes the compiled code more compact
822+
* and efficient. Keep MakeConst for generic overload cases where dtype can be
823+
* integer, floating point, or vector-valued and the caller needs its
824+
* scalar/vector dispatch.
825+
*
826+
* \param dtype The target type.
820827
* \param value The input value
821828
* \return the result expression.
822829
* \tparam ValueType The constant value type
@@ -825,32 +832,14 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) {
825832
template <typename ValueType,
826833
typename = typename std::enable_if<std::is_standard_layout<ValueType>::value &&
827834
std::is_trivial<ValueType>::value>::type>
828-
inline PrimExpr make_const(DataType t, ValueType value, Span span = Span());
829-
/*!
830-
* \brief Make a const zero expr.
831-
* \param t The target type.
832-
* \param span The location of this operation in the source.
833-
* \return the result expression.
834-
*/
835-
inline PrimExpr make_zero(DataType t, Span span = Span());
836-
/*!
837-
* \brief Make a constant true expression.
838-
* \param lanes The number of lanes in the bool
839-
* \param span The location of this operation in the source.
840-
* \return The result expression.
841-
*/
842-
inline PrimExpr const_true(int lanes = 1, Span span = Span()) {
843-
return make_const(DataType::Bool(lanes), 1);
844-
}
835+
inline PrimExpr MakeConst(DataType dtype, ValueType value, Span span = Span());
845836
/*!
846-
* \brief Make a constant false expression.
847-
* \param lanes The number of lanes in the bool
837+
* \brief Make a constant handle value.
838+
* \param value The integer payload to reinterpret as a handle.
848839
* \param span The location of this operation in the source.
849840
* \return The result expression.
850841
*/
851-
inline PrimExpr const_false(int lanes = 1, Span span = Span()) {
852-
return make_const(DataType::Bool(lanes), 0);
853-
}
842+
inline PrimExpr ConstHandle(int64_t value, Span span = Span());
854843
/*!
855844
* \brief Get x as constant int expression.
856845
* \param x The expression
@@ -981,53 +970,52 @@ inline bool is_no_op(const tirx::Stmt& stmt) {
981970
}
982971

983972
template <typename ValueType>
984-
inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) {
985-
if (t.is_int() || t.is_bool()) return IntImm(t, static_cast<int64_t>(value), span);
986-
if (t.is_uint()) {
973+
inline PrimExpr MakeConstScalar(DataType dtype, ValueType value, Span span = Span()) {
974+
if (dtype.is_int() || dtype.is_bool()) return IntImm(dtype, static_cast<int64_t>(value), span);
975+
if (dtype.is_uint()) {
987976
// Use IntImm if it is a small integer
988977
uint64_t uval = static_cast<uint64_t>(value);
989978
if (value < static_cast<ValueType>(0)) {
990979
TVM_FFI_THROW(InternalError) << "cannot make uint from negative value " << value;
991980
} else if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
992-
return IntImm(t, static_cast<int64_t>(value), span);
981+
return IntImm(dtype, static_cast<int64_t>(value), span);
993982
} else {
994983
uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U;
995984
uint64_t low = uval & mask;
996985
uint64_t high = uval >> 32U;
997-
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
986+
return LargeUIntImm(dtype, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
998987
}
999988
}
1000-
if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float6() || t.is_float4())
1001-
return FloatImm(t, static_cast<double>(value), span);
1002-
TVM_FFI_THROW(InternalError) << "cannot make const for type " << t;
989+
if (dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float6() ||
990+
dtype.is_float4()) {
991+
return FloatImm(dtype, static_cast<double>(value), span);
992+
}
993+
TVM_FFI_THROW(InternalError) << "cannot make const for type " << dtype;
1003994
throw;
1004995
}
1005996

1006997
template <>
1007-
inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) {
1008-
return MakeConstScalar(t, static_cast<int>(value), span);
998+
inline PrimExpr MakeConstScalar(DataType dtype, bool value, Span span) {
999+
return MakeConstScalar(dtype, static_cast<int>(value), span);
10091000
}
10101001

10111002
template <typename ValueType, typename>
1012-
inline PrimExpr make_const(DataType t, ValueType value, Span span) {
1013-
if (t.is_scalar()) {
1014-
return MakeConstScalar(t, value, span);
1003+
inline PrimExpr MakeConst(DataType dtype, ValueType value, Span span) {
1004+
if (dtype.is_scalar()) {
1005+
return MakeConstScalar(dtype, value, span);
10151006
} else {
1016-
if (t.is_fixed_length_vector()) {
1017-
return tirx::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span);
1007+
if (dtype.is_fixed_length_vector()) {
1008+
return tirx::Broadcast(MakeConstScalar(dtype.element_of(), value, span), dtype.lanes(), span);
10181009
} else {
1019-
PrimExpr lanes =
1020-
tirx::Mul(tirx::Call(DataType::Int(32), tirx::builtin::vscale(), {}), t.vscale_factor());
1021-
return tirx::Broadcast(MakeConstScalar(t.element_of(), value, span), lanes, span);
1010+
PrimExpr lanes = tirx::Mul(tirx::Call(DataType::Int(32), tirx::builtin::vscale(), {}),
1011+
dtype.vscale_factor());
1012+
return tirx::Broadcast(MakeConstScalar(dtype.element_of(), value, span), lanes, span);
10221013
}
10231014
}
10241015
}
10251016

1026-
inline PrimExpr make_zero(DataType t, Span span) {
1027-
if (t.is_handle()) {
1028-
return reinterpret(t, make_const(DataType::UInt(64), 0, span));
1029-
}
1030-
return make_const(t, 0, span);
1017+
inline PrimExpr ConstHandle(int64_t value, Span span) {
1018+
return reinterpret(DataType::Handle(), IntImm(DataType::UInt(64), value, span));
10311019
}
10321020

10331021
} // namespace tirx
@@ -1043,13 +1031,13 @@ inline PrimExpr make_zero(DataType t, Span span) {
10431031
inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \
10441032
inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \
10451033
inline PrimExpr Name(int a, const PrimExpr& b) { \
1046-
return Name(tirx::make_const(b.dtype(), a), b); \
1034+
return Name(tirx::MakeConst(b.dtype(), a), b); \
10471035
} \
10481036
inline PrimExpr Name(const PrimExpr& a, int b) { \
1049-
return Name(a, tirx::make_const(a.dtype(), b)); \
1037+
return Name(a, tirx::MakeConst(a.dtype(), b)); \
10501038
} \
10511039
inline PrimExpr Name(const PrimExpr& a, double b) { \
1052-
return Name(a, tirx::make_const(DataType::Float(64), b)); \
1040+
return Name(a, FloatImm(DataType::Float(64), b)); \
10531041
}
10541042

10551043
#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name) \
@@ -1060,13 +1048,13 @@ inline PrimExpr make_zero(DataType t, Span span) {
10601048
return Name(PrimExpr(a), b, span); \
10611049
} \
10621050
inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \
1063-
return Name(tirx::make_const(b.dtype(), a), b, span); \
1051+
return Name(tirx::MakeConst(b.dtype(), a), b, span); \
10641052
} \
10651053
inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \
1066-
return Name(a, tirx::make_const(a.dtype(), b), span); \
1054+
return Name(a, tirx::MakeConst(a.dtype(), b), span); \
10671055
} \
10681056
inline PrimExpr Name(const PrimExpr& a, double b, Span span = Span()) { \
1069-
return Name(a, tirx::make_const(DataType::Float(64), b), span); \
1057+
return Name(a, FloatImm(DataType::Float(64), b), span); \
10701058
}
10711059

10721060
#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
@@ -1081,18 +1069,18 @@ inline PrimExpr make_zero(DataType t, Span span) {
10811069
return Name(PrimExpr(a), b, span); \
10821070
}
10831071

1084-
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
1085-
inline PrimExpr Name(const PrimExpr& a, int b) { \
1086-
return Name(a, tirx::make_const(a.dtype(), b)); \
1087-
} \
1088-
inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tirx::make_const(b.dtype(), a), b); }
1072+
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
1073+
inline PrimExpr Name(const PrimExpr& a, int b) { \
1074+
return Name(a, tirx::MakeConst(a.dtype(), b)); \
1075+
} \
1076+
inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tirx::MakeConst(b.dtype(), a), b); }
10891077

10901078
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \
10911079
inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \
1092-
return Name(a, tirx::make_const(a.dtype(), b), span); \
1080+
return Name(a, tirx::MakeConst(a.dtype(), b), span); \
10931081
} \
10941082
inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \
1095-
return Name(tirx::make_const(b.dtype(), a), b, span); \
1083+
return Name(tirx::MakeConst(b.dtype(), a), b, span); \
10961084
}
10971085

10981086
TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+);

include/tvm/topi/detail/broadcast.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ inline tvm::ffi::Array<tvm::PrimExpr> InputIndexFromBroadcast(
130130
// Only inject 0 here if we have not yet reached the dimension of I
131131
// (i.e. this must be a 1)
132132
if (!found && (ovars.size() - i) <= expected_dims) {
133-
ivars.push_back(tvm::tirx::make_zero(ovars[i].dtype()));
133+
ivars.push_back(tvm::IntImm(ovars[i].dtype(), 0));
134134
}
135135
}
136136
TVM_FFI_ICHECK(expected_dims == ivars.size());

include/tvm/topi/detail/extern.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,12 @@ inline PrimExpr pack_buffer(Buffer buf) {
108108
} else {
109109
strides = 0;
110110
}
111-
ffi::Array<PrimExpr> pack_args{
112-
buf->data,
113-
shape,
114-
strides,
115-
make_const(DataType::Int(32), static_cast<int64_t>(buf->shape.size())),
116-
make_const(buf->dtype, 0),
117-
buf->elem_offset};
111+
ffi::Array<PrimExpr> pack_args{buf->data,
112+
shape,
113+
strides,
114+
IntImm::Int32(static_cast<int64_t>(buf->shape.size())),
115+
MakeConst(buf->dtype, 0),
116+
buf->elem_offset};
118117
return tvm::tirx::Call(DataType::Handle(), tvm::tirx::builtin::tvm_stack_make_array(), pack_args);
119118
}
120119

include/tvm/topi/detail/strided_slice.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ inline ffi::Array<PrimExpr> StridedSliceCanonicalizeBegin(const ffi::Array<PrimE
9999
if (ishape[ax]->IsInstance<tvm::IntImmNode>()) {
100100
int64_t dim_i = GetConstInt(ishape[ax]);
101101
int64_t begin_i = CanonicalizeIndex(begin[i], dim_i, strides[i]);
102-
begin_expr.push_back(make_const(dtype, begin_i));
102+
begin_expr.push_back(MakeConst(dtype, begin_i));
103103
} else {
104104
auto idim = ishape[ax];
105-
auto b_expr = make_const(dtype, begin[i]);
105+
auto b_expr = MakeConst(dtype, begin[i]);
106106
PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr;
107107
auto s = strides[i];
108108
if (s < 0) {

0 commit comments

Comments
 (0)