Skip to content

Commit 2cc64e5

Browse files
committed
[REFACTOR][TIRX] Add IntImm common scalar constructors
Common bool, int32, and int64 scalar constants show up throughout TIRX and related lowering code, and named constructors make these call sites easier to read than repeated DataType spelling. This PR establishes the scalar-constant construction policy: - Prefer IntImm::Bool, IntImm::Int32, and IntImm::Int64 for common known scalar bool, int32, and int64 constants. - Prefer direct IntImm or FloatImm construction when dtype is known to be scalar integer or floating point. - Keep make_const for generic overload cases where dtype can be integer, floating point, or vector-valued and the caller needs its scalar/vector dispatch. - Phase out make_zero in favor of explicit scalar constructors, or ConstHandle(0) for null handles.
1 parent 949be81 commit 2cc64e5

164 files changed

Lines changed: 949 additions & 998 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

include/tvm/ir/expr.h

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,33 @@ class IntImm : public PrimExpr {
515515
*/
516516
TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span());
517517

518+
/*!
519+
* \brief Construct a scalar boolean constant.
520+
* \param value The boolean value.
521+
* \param span The location of this object in the source code.
522+
*/
523+
static IntImm Bool(bool value, Span span = Span()) {
524+
return IntImm(DataType::Bool(), value, span);
525+
}
526+
527+
/*!
528+
* \brief Construct a scalar int32 constant.
529+
* \param value The integer value.
530+
* \param span The location of this object in the source code.
531+
*/
532+
static IntImm Int32(int64_t value, Span span = Span()) {
533+
return IntImm(DataType::Int(32), value, span);
534+
}
535+
536+
/*!
537+
* \brief Construct a scalar int64 constant.
538+
* \param value The integer value.
539+
* \param span The location of this object in the source code.
540+
*/
541+
static IntImm Int64(int64_t value, Span span = Span()) {
542+
return IntImm(DataType::Int(64), value, span);
543+
}
544+
518545
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IntImm, PrimExpr, IntImmNode);
519546
TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode);
520547
};
@@ -636,7 +663,7 @@ struct TypeTraits<FloatImm> : public ObjectRefWithFallbackTraitsBase<FloatImm, d
636663

637664
// define automatic conversion from bool, int64_t, double to PrimExpr
638665
TVM_FFI_INLINE PrimExpr TypeTraits<PrimExpr>::ConvertFallbackValue(StrictBool value) {
639-
return IntImm(DataType::Bool(), value, Span());
666+
return IntImm::Bool(value);
640667
}
641668

642669
TVM_FFI_INLINE PrimExpr TypeTraits<PrimExpr>::ConvertFallbackValue(int64_t value) {

include/tvm/script/printer/doc.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,15 +277,15 @@ class LiteralDoc : public ExprDoc {
277277
* \param p The object path
278278
*/
279279
static LiteralDoc Int(int64_t v, const ffi::Optional<AccessPath>& p) {
280-
return LiteralDoc(IntImm(DataType::Int(64), v), p);
280+
return LiteralDoc(IntImm::Int64(v), p);
281281
}
282282
/*!
283283
* \brief Create a LiteralDoc to represent boolean.
284284
* \param v The boolean value.
285285
* \param p The object path
286286
*/
287287
static LiteralDoc Boolean(bool v, const ffi::Optional<AccessPath>& p) {
288-
return LiteralDoc(IntImm(DataType::Bool(), v), p);
288+
return LiteralDoc(IntImm::Bool(v), p);
289289
}
290290
/*!
291291
* \brief Create a LiteralDoc to represent float.

include/tvm/tirx/buffer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ class Buffer : public ffi::ObjectRef {
206206
* \param input_extent The extent of ptr.
207207
*/
208208
TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(),
209-
int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0),
209+
int content_lanes = 1, PrimExpr offset = IntImm::Int32(0),
210210
ffi::Optional<PrimExpr> input_extent = std::nullopt) const;
211211
/*!
212212
* \brief Create an Expr that does a vector load at begin index.

include/tvm/tirx/op.h

Lines changed: 35 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,13 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) {
816816

817817
/*!
818818
* \brief Make a const value with certain data type.
819-
* \param t The target type.
819+
*
820+
* This helper is intended for generic overload sites where the dtype may be
821+
* integer, floating point, or vector-valued. When the dtype is statically known
822+
* to be a scalar integer or floating point type, callers should use IntImm or
823+
* FloatImm directly for more compact and efficient code.
824+
*
825+
* \param dtype The target type.
820826
* \param value The input value
821827
* \return the result expression.
822828
* \tparam ValueType The constant value type
@@ -825,32 +831,14 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) {
825831
template <typename ValueType,
826832
typename = typename std::enable_if<std::is_standard_layout<ValueType>::value &&
827833
std::is_trivial<ValueType>::value>::type>
828-
inline PrimExpr make_const(DataType t, ValueType value, Span span = Span());
829-
/*!
830-
* \brief Make a const zero expr.
831-
* \param t The target type.
832-
* \param span The location of this operation in the source.
833-
* \return the result expression.
834-
*/
835-
inline PrimExpr make_zero(DataType t, Span span = Span());
836-
/*!
837-
* \brief Make a constant true expression.
838-
* \param lanes The number of lanes in the bool
839-
* \param span The location of this operation in the source.
840-
* \return The result expression.
841-
*/
842-
inline PrimExpr const_true(int lanes = 1, Span span = Span()) {
843-
return make_const(DataType::Bool(lanes), 1);
844-
}
834+
inline PrimExpr make_const(DataType dtype, ValueType value, Span span = Span());
845835
/*!
846-
* \brief Make a constant false expression.
847-
* \param lanes The number of lanes in the bool
836+
* \brief Make a constant handle value.
837+
* \param value The integer handle value.
848838
* \param span The location of this operation in the source.
849839
* \return The result expression.
850840
*/
851-
inline PrimExpr const_false(int lanes = 1, Span span = Span()) {
852-
return make_const(DataType::Bool(lanes), 0);
853-
}
841+
inline PrimExpr ConstHandle(int32_t value, Span span = Span());
854842
/*!
855843
* \brief Get x as constant int expression.
856844
* \param x The expression
@@ -981,53 +969,52 @@ inline bool is_no_op(const tirx::Stmt& stmt) {
981969
}
982970

983971
template <typename ValueType>
984-
inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) {
985-
if (t.is_int() || t.is_bool()) return IntImm(t, static_cast<int64_t>(value), span);
986-
if (t.is_uint()) {
972+
inline PrimExpr MakeConstScalar(DataType dtype, ValueType value, Span span = Span()) {
973+
if (dtype.is_int() || dtype.is_bool()) return IntImm(dtype, static_cast<int64_t>(value), span);
974+
if (dtype.is_uint()) {
987975
// Use IntImm if it is a small integer
988976
uint64_t uval = static_cast<uint64_t>(value);
989977
if (value < static_cast<ValueType>(0)) {
990978
TVM_FFI_THROW(InternalError) << "cannot make uint from negative value " << value;
991979
} else if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
992-
return IntImm(t, static_cast<int64_t>(value), span);
980+
return IntImm(dtype, static_cast<int64_t>(value), span);
993981
} else {
994982
uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U;
995983
uint64_t low = uval & mask;
996984
uint64_t high = uval >> 32U;
997-
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
985+
return LargeUIntImm(dtype, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
998986
}
999987
}
1000-
if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float6() || t.is_float4())
1001-
return FloatImm(t, static_cast<double>(value), span);
1002-
TVM_FFI_THROW(InternalError) << "cannot make const for type " << t;
988+
if (dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float6() ||
989+
dtype.is_float4()) {
990+
return FloatImm(dtype, static_cast<double>(value), span);
991+
}
992+
TVM_FFI_THROW(InternalError) << "cannot make const for type " << dtype;
1003993
throw;
1004994
}
1005995

1006996
template <>
1007-
inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) {
1008-
return MakeConstScalar(t, static_cast<int>(value), span);
997+
inline PrimExpr MakeConstScalar(DataType dtype, bool value, Span span) {
998+
return MakeConstScalar(dtype, static_cast<int>(value), span);
1009999
}
10101000

10111001
template <typename ValueType, typename>
1012-
inline PrimExpr make_const(DataType t, ValueType value, Span span) {
1013-
if (t.is_scalar()) {
1014-
return MakeConstScalar(t, value, span);
1002+
inline PrimExpr make_const(DataType dtype, ValueType value, Span span) {
1003+
if (dtype.is_scalar()) {
1004+
return MakeConstScalar(dtype, value, span);
10151005
} else {
1016-
if (t.is_fixed_length_vector()) {
1017-
return tirx::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span);
1006+
if (dtype.is_fixed_length_vector()) {
1007+
return tirx::Broadcast(MakeConstScalar(dtype.element_of(), value, span), dtype.lanes(), span);
10181008
} else {
1019-
PrimExpr lanes =
1020-
tirx::Mul(tirx::Call(DataType::Int(32), tirx::builtin::vscale(), {}), t.vscale_factor());
1021-
return tirx::Broadcast(MakeConstScalar(t.element_of(), value, span), lanes, span);
1009+
PrimExpr lanes = tirx::Mul(tirx::Call(DataType::Int(32), tirx::builtin::vscale(), {}),
1010+
dtype.vscale_factor());
1011+
return tirx::Broadcast(MakeConstScalar(dtype.element_of(), value, span), lanes, span);
10221012
}
10231013
}
10241014
}
10251015

1026-
inline PrimExpr make_zero(DataType t, Span span) {
1027-
if (t.is_handle()) {
1028-
return reinterpret(t, make_const(DataType::UInt(64), 0, span));
1029-
}
1030-
return make_const(t, 0, span);
1016+
inline PrimExpr ConstHandle(int32_t value, Span span) {
1017+
return reinterpret(DataType::Handle(), IntImm(DataType::UInt(64), value, span));
10311018
}
10321019

10331020
} // namespace tirx
@@ -1049,7 +1036,7 @@ inline PrimExpr make_zero(DataType t, Span span) {
10491036
return Name(a, tirx::make_const(a.dtype(), b)); \
10501037
} \
10511038
inline PrimExpr Name(const PrimExpr& a, double b) { \
1052-
return Name(a, tirx::make_const(DataType::Float(64), b)); \
1039+
return Name(a, FloatImm(DataType::Float(64), b)); \
10531040
}
10541041

10551042
#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name) \
@@ -1066,7 +1053,7 @@ inline PrimExpr make_zero(DataType t, Span span) {
10661053
return Name(a, tirx::make_const(a.dtype(), b), span); \
10671054
} \
10681055
inline PrimExpr Name(const PrimExpr& a, double b, Span span = Span()) { \
1069-
return Name(a, tirx::make_const(DataType::Float(64), b), span); \
1056+
return Name(a, FloatImm(DataType::Float(64), b), span); \
10701057
}
10711058

10721059
#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \

include/tvm/topi/detail/broadcast.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ inline tvm::ffi::Array<tvm::PrimExpr> InputIndexFromBroadcast(
130130
// Only inject 0 here if we have not yet reached the dimension of I
131131
// (i.e. this must be a 1)
132132
if (!found && (ovars.size() - i) <= expected_dims) {
133-
ivars.push_back(tvm::tirx::make_zero(ovars[i].dtype()));
133+
ivars.push_back(tvm::IntImm(ovars[i].dtype(), 0));
134134
}
135135
}
136136
TVM_FFI_ICHECK(expected_dims == ivars.size());

include/tvm/topi/detail/extern.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,12 @@ inline PrimExpr pack_buffer(Buffer buf) {
108108
} else {
109109
strides = 0;
110110
}
111-
ffi::Array<PrimExpr> pack_args{
112-
buf->data,
113-
shape,
114-
strides,
115-
make_const(DataType::Int(32), static_cast<int64_t>(buf->shape.size())),
116-
make_const(buf->dtype, 0),
117-
buf->elem_offset};
111+
ffi::Array<PrimExpr> pack_args{buf->data,
112+
shape,
113+
strides,
114+
IntImm::Int32(static_cast<int64_t>(buf->shape.size())),
115+
make_const(buf->dtype, 0),
116+
buf->elem_offset};
118117
return tvm::tirx::Call(DataType::Handle(), tvm::tirx::builtin::tvm_stack_make_array(), pack_args);
119118
}
120119

include/tvm/topi/elemwise.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ inline Tensor sign(const Tensor& x, std::string name = "T_sign", std::string tag
209209
return compute(
210210
x->shape,
211211
[&](const ffi::Array<Var>& i) {
212-
PrimExpr zero = make_zero(x->dtype);
212+
PrimExpr zero = make_const(x->dtype, 0);
213213
PrimExpr one = make_const(x->dtype, 1);
214214
PrimExpr minus_one = make_const(x->dtype, -1);
215215
auto s1 = tvm::tirx::Select((x(i) < zero), minus_one, zero);
@@ -392,19 +392,19 @@ inline Tensor full_like(const Tensor& x, const PrimExpr fill_value,
392392
* y = exp(f) = 1 + 2 * P(x**2)/(Q(x**2) - P(x**2))
393393
*/
394394
inline Tensor fast_exp_float32(const Tensor& _x, std::string name, std::string tag) {
395-
auto x_hi = make_const(DataType::Float(32), 88.3762626647950f);
396-
auto x_lo = make_const(DataType::Float(32), -88.3762626647949f);
397-
auto log2e = make_const(DataType::Float(32), 1.44269504088896341f);
398-
auto ln2 = make_const(DataType::Float(32), 0.6931471805599453f);
399-
PrimExpr p[6] = {make_const(DataType::Float(32), 1.9875691500E-4f),
400-
make_const(DataType::Float(32), 1.3981999507E-3f),
401-
make_const(DataType::Float(32), 8.3334519073E-3f),
402-
make_const(DataType::Float(32), 4.1665795894E-2f),
403-
make_const(DataType::Float(32), 1.6666665459E-1f),
404-
make_const(DataType::Float(32), 5.0000001201E-1f)};
405-
auto one = make_const(DataType::Float(32), 1.0f);
406-
auto one_half = make_const(DataType::Float(32), 0.5f);
407-
auto b = make_const(DataType::Float(32), 127.0f);
395+
auto x_hi = FloatImm(DataType::Float(32), 88.3762626647950f);
396+
auto x_lo = FloatImm(DataType::Float(32), -88.3762626647949f);
397+
auto log2e = FloatImm(DataType::Float(32), 1.44269504088896341f);
398+
auto ln2 = FloatImm(DataType::Float(32), 0.6931471805599453f);
399+
PrimExpr p[6] = {FloatImm(DataType::Float(32), 1.9875691500E-4f),
400+
FloatImm(DataType::Float(32), 1.3981999507E-3f),
401+
FloatImm(DataType::Float(32), 8.3334519073E-3f),
402+
FloatImm(DataType::Float(32), 4.1665795894E-2f),
403+
FloatImm(DataType::Float(32), 1.6666665459E-1f),
404+
FloatImm(DataType::Float(32), 5.0000001201E-1f)};
405+
auto one = FloatImm(DataType::Float(32), 1.0f);
406+
auto one_half = FloatImm(DataType::Float(32), 0.5f);
407+
auto b = FloatImm(DataType::Float(32), 127.0f);
408408

409409
return compute(
410410
_x->shape,

include/tvm/topi/nn.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,12 @@ inline tvm::te::Tensor pad(
232232
if (pad_mode == "constant") {
233233
return tvm::if_then_else(
234234
foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); },
235-
const_true(1), sel),
235+
IntImm::Bool(true), sel),
236236
t(indices), pad_value);
237237
} else if (pad_mode == "edge" || pad_mode == "reflect") {
238238
return tvm::if_then_else(
239239
foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); },
240-
const_true(1), sel),
240+
IntImm::Bool(true), sel),
241241
t(indices), t(pad_idx));
242242
}
243243
}
@@ -534,7 +534,7 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data,
534534
<< padded_input << ")"
535535
<< " must be divisible by its block size (" << block_size << ")";
536536

537-
PrimExpr bs = IntImm(DataType::Int(64), block_shape[i - 1]);
537+
PrimExpr bs = IntImm::Int64(block_shape[i - 1]);
538538
r_shape.push_back(div(padded_shape[i], bs));
539539
r_shape.push_back(bs);
540540
block_shape_prod *= bs;
@@ -549,7 +549,7 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data,
549549
}
550550
o_shape.push_back(tvm::PrimExpr(batch) * block_shape_prod);
551551
for (size_t i = 1; i <= num_block_dims; i++) {
552-
PrimExpr bs = IntImm(DataType::Int(64), block_shape[i - 1]);
552+
PrimExpr bs = IntImm::Int64(block_shape[i - 1]);
553553
o_shape.push_back(div(padded_shape[i], bs));
554554
}
555555
// append remaining shape
@@ -595,7 +595,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data,
595595
int batch = static_cast<int>(GetConstInt(in_shape[0]));
596596

597597
for (size_t i = 0; i < num_block_dims; i++) {
598-
PrimExpr bs = IntImm(DataType::Int(64), block_shape[i]);
598+
PrimExpr bs = IntImm::Int64(block_shape[i]);
599599
r_shape.push_back(bs);
600600
block_shape_prod *= bs;
601601
}
@@ -614,7 +614,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data,
614614
ffi::Array<PrimExpr> r_p_shape;
615615
r_p_shape.push_back(batch / block_shape_prod);
616616
for (size_t i = 1; i <= num_block_dims; i++) {
617-
PrimExpr bs = IntImm(DataType::Int(64), block_shape[i - 1]);
617+
PrimExpr bs = IntImm::Int64(block_shape[i - 1]);
618618
r_p_shape.push_back(in_shape[i] * bs);
619619
}
620620
for (size_t i = num_block_dims + 1; i < num_input_dims; i++) {

include/tvm/topi/nn/bnn.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis,
7171
start_idx.push_back(i == static_cast<size_t>(axis) ? indices[i] * 32
7272
: static_cast<PrimExpr>(indices[i]));
7373
}
74-
auto packed = make_const(DataType::UInt(32), 0);
74+
PrimExpr packed = IntImm(DataType::UInt(32), 0);
7575
for (size_t j = 0; j < 32; ++j) {
7676
ffi::Array<PrimExpr> idx;
7777
for (size_t i = 0; i < n; ++i) {

include/tvm/topi/nn/group_norm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor&
126126

127127
auto temp_x = temp_x_x2[0];
128128
auto temp_x2 = temp_x_x2[1];
129-
auto reduce_extent = make_const(DataType::Float(32), 1);
129+
PrimExpr reduce_extent = FloatImm(DataType::Float(32), 1);
130130
for (auto axis : new_axes) {
131131
reduce_extent *= data_reshaped->shape[axis];
132132
}

0 commit comments

Comments
 (0)