Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,33 @@ class IntImm : public PrimExpr {
*/
TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span());

/*!
* \brief Construct a scalar boolean constant.
* \param value The boolean value.
* \param span The location of this object in the source code.
*/
static IntImm Bool(bool value, Span span = Span()) {
return IntImm(DataType::Bool(), value, span);
}

/*!
* \brief Construct a scalar int32 constant.
* \param value The integer value.
* \param span The location of this object in the source code.
*/
static IntImm Int32(int64_t value, Span span = Span()) {
return IntImm(DataType::Int(32), value, span);
}

/*!
* \brief Construct a scalar int64 constant.
* \param value The integer value.
* \param span The location of this object in the source code.
*/
static IntImm Int64(int64_t value, Span span = Span()) {
return IntImm(DataType::Int(64), value, span);
}

TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IntImm, PrimExpr, IntImmNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode);
};
Expand Down Expand Up @@ -636,7 +663,7 @@ struct TypeTraits<FloatImm> : public ObjectRefWithFallbackTraitsBase<FloatImm, d

// define automatic conversion from bool, int64_t, double to PrimExpr
TVM_FFI_INLINE PrimExpr TypeTraits<PrimExpr>::ConvertFallbackValue(StrictBool value) {
return IntImm(DataType::Bool(), value, Span());
return IntImm::Bool(value);
}

TVM_FFI_INLINE PrimExpr TypeTraits<PrimExpr>::ConvertFallbackValue(int64_t value) {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/node_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace tvm {
* return prefix + "IntImm"
* });
*
* Expr x = make_const(1);
* Expr x = MakeConst(1);
* Expr y = x + x;
* // dispatch to IntImm, outputs "MyIntImm"
* LOG(INFO) << tostr(x, "My");
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,15 +277,15 @@ class LiteralDoc : public ExprDoc {
* \param p The object path
*/
static LiteralDoc Int(int64_t v, const ffi::Optional<AccessPath>& p) {
return LiteralDoc(IntImm(DataType::Int(64), v), p);
return LiteralDoc(IntImm::Int64(v), p);
}
/*!
* \brief Create a LiteralDoc to represent boolean.
* \param v The boolean value.
* \param p The object path
*/
static LiteralDoc Boolean(bool v, const ffi::Optional<AccessPath>& p) {
return LiteralDoc(IntImm(DataType::Bool(), v), p);
return LiteralDoc(IntImm::Bool(v), p);
}
/*!
* \brief Create a LiteralDoc to represent float.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tirx/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class Buffer : public ffi::ObjectRef {
* \param input_extent The extent of ptr.
*/
TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(),
int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0),
int content_lanes = 1, PrimExpr offset = IntImm::Int32(0),
ffi::Optional<PrimExpr> input_extent = std::nullopt) const;
/*!
* \brief Create an Expr that does a vector load at begin index.
Expand Down
108 changes: 48 additions & 60 deletions include/tvm/tirx/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ namespace tvm {
// Most common operators can be overloaded by argument type(PrimExpr).
// So we put them under the root namespace.
//
// We put more developer oriented APIs -- make_const and is_const under tirx
// We put more developer oriented APIs -- MakeConst and is_const under tirx
// as they are more specific to the tirx namespace.

/*!
Expand Down Expand Up @@ -816,7 +816,14 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) {

/*!
* \brief Make a const value with certain data type.
* \param t The target type.
*
* Prefer direct IntImm or FloatImm construction when dtype is known to be
* scalar integer or floating point. This makes the compiled code more compact
* and efficient. Keep MakeConst for generic overload cases where dtype can be
* integer, floating point, or vector-valued and the caller needs its
* scalar/vector dispatch.
*
* \param dtype The target type.
* \param value The input value
* \return the result expression.
* \tparam ValueType The constant value type
Expand All @@ -825,32 +832,14 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) {
template <typename ValueType,
typename = typename std::enable_if<std::is_standard_layout<ValueType>::value &&
std::is_trivial<ValueType>::value>::type>
inline PrimExpr make_const(DataType t, ValueType value, Span span = Span());
/*!
* \brief Make a const zero expr.
* \param t The target type.
* \param span The location of this operation in the source.
* \return the result expression.
*/
inline PrimExpr make_zero(DataType t, Span span = Span());
/*!
* \brief Make a constant true expression.
* \param lanes The number of lanes in the bool
* \param span The location of this operation in the source.
* \return The result expression.
*/
inline PrimExpr const_true(int lanes = 1, Span span = Span()) {
return make_const(DataType::Bool(lanes), 1);
}
inline PrimExpr MakeConst(DataType dtype, ValueType value, Span span = Span());
/*!
* \brief Make a constant false expression.
* \param lanes The number of lanes in the bool
* \brief Make a constant handle value.
* \param value The integer payload to reinterpret as a handle.
* \param span The location of this operation in the source.
* \return The result expression.
*/
inline PrimExpr const_false(int lanes = 1, Span span = Span()) {
return make_const(DataType::Bool(lanes), 0);
}
inline PrimExpr ConstHandle(int64_t value, Span span = Span());
/*!
* \brief Get x as constant int expression.
* \param x The expression
Expand Down Expand Up @@ -981,53 +970,52 @@ inline bool is_no_op(const tirx::Stmt& stmt) {
}

template <typename ValueType>
inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) {
if (t.is_int() || t.is_bool()) return IntImm(t, static_cast<int64_t>(value), span);
if (t.is_uint()) {
inline PrimExpr MakeConstScalar(DataType dtype, ValueType value, Span span = Span()) {
if (dtype.is_int() || dtype.is_bool()) return IntImm(dtype, static_cast<int64_t>(value), span);
if (dtype.is_uint()) {
// Use IntImm if it is a small integer
uint64_t uval = static_cast<uint64_t>(value);
if (value < static_cast<ValueType>(0)) {
TVM_FFI_THROW(InternalError) << "cannot make uint from negative value " << value;
} else if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
return IntImm(t, static_cast<int64_t>(value), span);
return IntImm(dtype, static_cast<int64_t>(value), span);
} else {
uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U;
uint64_t low = uval & mask;
uint64_t high = uval >> 32U;
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
return LargeUIntImm(dtype, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
}
}
if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float6() || t.is_float4())
return FloatImm(t, static_cast<double>(value), span);
TVM_FFI_THROW(InternalError) << "cannot make const for type " << t;
if (dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float6() ||
dtype.is_float4()) {
return FloatImm(dtype, static_cast<double>(value), span);
}
TVM_FFI_THROW(InternalError) << "cannot make const for type " << dtype;
throw;
}

template <>
inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) {
return MakeConstScalar(t, static_cast<int>(value), span);
inline PrimExpr MakeConstScalar(DataType dtype, bool value, Span span) {
return MakeConstScalar(dtype, static_cast<int>(value), span);
}

template <typename ValueType, typename>
inline PrimExpr make_const(DataType t, ValueType value, Span span) {
if (t.is_scalar()) {
return MakeConstScalar(t, value, span);
inline PrimExpr MakeConst(DataType dtype, ValueType value, Span span) {
if (dtype.is_scalar()) {
return MakeConstScalar(dtype, value, span);
} else {
if (t.is_fixed_length_vector()) {
return tirx::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span);
if (dtype.is_fixed_length_vector()) {
return tirx::Broadcast(MakeConstScalar(dtype.element_of(), value, span), dtype.lanes(), span);
} else {
PrimExpr lanes =
tirx::Mul(tirx::Call(DataType::Int(32), tirx::builtin::vscale(), {}), t.vscale_factor());
return tirx::Broadcast(MakeConstScalar(t.element_of(), value, span), lanes, span);
PrimExpr lanes = tirx::Mul(tirx::Call(DataType::Int(32), tirx::builtin::vscale(), {}),
dtype.vscale_factor());
return tirx::Broadcast(MakeConstScalar(dtype.element_of(), value, span), lanes, span);
}
}
}

inline PrimExpr make_zero(DataType t, Span span) {
if (t.is_handle()) {
return reinterpret(t, make_const(DataType::UInt(64), 0, span));
}
return make_const(t, 0, span);
inline PrimExpr ConstHandle(int64_t value, Span span) {
return reinterpret(DataType::Handle(), IntImm(DataType::UInt(64), value, span));
}

} // namespace tirx
Expand All @@ -1043,13 +1031,13 @@ inline PrimExpr make_zero(DataType t, Span span) {
inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \
inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \
inline PrimExpr Name(int a, const PrimExpr& b) { \
return Name(tirx::make_const(b.dtype(), a), b); \
return Name(tirx::MakeConst(b.dtype(), a), b); \
} \
inline PrimExpr Name(const PrimExpr& a, int b) { \
return Name(a, tirx::make_const(a.dtype(), b)); \
return Name(a, tirx::MakeConst(a.dtype(), b)); \
} \
inline PrimExpr Name(const PrimExpr& a, double b) { \
return Name(a, tirx::make_const(DataType::Float(64), b)); \
return Name(a, FloatImm(DataType::Float(64), b)); \
}

#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name) \
Expand All @@ -1060,13 +1048,13 @@ inline PrimExpr make_zero(DataType t, Span span) {
return Name(PrimExpr(a), b, span); \
} \
inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \
return Name(tirx::make_const(b.dtype(), a), b, span); \
return Name(tirx::MakeConst(b.dtype(), a), b, span); \
} \
inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \
return Name(a, tirx::make_const(a.dtype(), b), span); \
return Name(a, tirx::MakeConst(a.dtype(), b), span); \
} \
inline PrimExpr Name(const PrimExpr& a, double b, Span span = Span()) { \
return Name(a, tirx::make_const(DataType::Float(64), b), span); \
return Name(a, FloatImm(DataType::Float(64), b), span); \
}

#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
Expand All @@ -1081,18 +1069,18 @@ inline PrimExpr make_zero(DataType t, Span span) {
return Name(PrimExpr(a), b, span); \
}

#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
inline PrimExpr Name(const PrimExpr& a, int b) { \
return Name(a, tirx::make_const(a.dtype(), b)); \
} \
inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tirx::make_const(b.dtype(), a), b); }
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
inline PrimExpr Name(const PrimExpr& a, int b) { \
return Name(a, tirx::MakeConst(a.dtype(), b)); \
} \
inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tirx::MakeConst(b.dtype(), a), b); }

#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \
inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \
return Name(a, tirx::make_const(a.dtype(), b), span); \
return Name(a, tirx::MakeConst(a.dtype(), b), span); \
} \
inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \
return Name(tirx::make_const(b.dtype(), a), b, span); \
return Name(tirx::MakeConst(b.dtype(), a), b, span); \
}

TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+);
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/topi/detail/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ inline tvm::ffi::Array<tvm::PrimExpr> InputIndexFromBroadcast(
// Only inject 0 here if we have not yet reached the dimension of I
// (i.e. this must be a 1)
if (!found && (ovars.size() - i) <= expected_dims) {
ivars.push_back(tvm::tirx::make_zero(ovars[i].dtype()));
ivars.push_back(tvm::IntImm(ovars[i].dtype(), 0));
}
}
TVM_FFI_ICHECK(expected_dims == ivars.size());
Expand Down
13 changes: 6 additions & 7 deletions include/tvm/topi/detail/extern.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,12 @@ inline PrimExpr pack_buffer(Buffer buf) {
} else {
strides = 0;
}
ffi::Array<PrimExpr> pack_args{
buf->data,
shape,
strides,
make_const(DataType::Int(32), static_cast<int64_t>(buf->shape.size())),
make_const(buf->dtype, 0),
buf->elem_offset};
ffi::Array<PrimExpr> pack_args{buf->data,
shape,
strides,
IntImm::Int32(static_cast<int64_t>(buf->shape.size())),
MakeConst(buf->dtype, 0),
buf->elem_offset};
return tvm::tirx::Call(DataType::Handle(), tvm::tirx::builtin::tvm_stack_make_array(), pack_args);
}

Expand Down
4 changes: 2 additions & 2 deletions include/tvm/topi/detail/strided_slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ inline ffi::Array<PrimExpr> StridedSliceCanonicalizeBegin(const ffi::Array<PrimE
if (ishape[ax]->IsInstance<tvm::IntImmNode>()) {
int64_t dim_i = GetConstInt(ishape[ax]);
int64_t begin_i = CanonicalizeIndex(begin[i], dim_i, strides[i]);
begin_expr.push_back(make_const(dtype, begin_i));
begin_expr.push_back(MakeConst(dtype, begin_i));
} else {
auto idim = ishape[ax];
auto b_expr = make_const(dtype, begin[i]);
auto b_expr = MakeConst(dtype, begin[i]);
PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr;
auto s = strides[i];
if (s < 0) {
Expand Down
Loading
Loading