@@ -52,7 +52,7 @@ namespace tvm {
5252// Most common operators can be overloaded by argument type(PrimExpr).
5353// So we put them under the root namespace.
5454//
55- // We put more developer oriented APIs -- make_const and is_const under tirx
55+ // We put more developer oriented APIs -- MakeConst and is_const under tirx
5656// as they are more specific to the tirx namespace.
5757
5858/* !
@@ -816,7 +816,14 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) {
816816
817817/* !
818818 * \brief Make a const value with certain data type.
819- * \param t The target type.
819+ *
820+ * Prefer direct IntImm or FloatImm construction when dtype is known to be
821+ * scalar integer or floating point. This makes the compiled code more compact
822+ * and efficient. Keep MakeConst for generic overload cases where dtype can be
823+ * integer, floating point, or vector-valued and the caller needs its
824+ * scalar/vector dispatch.
825+ *
826+ * \param dtype The target type.
820827 * \param value The input value
821828 * \return the result expression.
822829 * \tparam ValueType The constant value type
@@ -825,32 +832,14 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) {
825832template <typename ValueType,
826833 typename = typename std::enable_if<std::is_standard_layout<ValueType>::value &&
827834 std::is_trivial<ValueType>::value>::type>
828- inline PrimExpr make_const (DataType t, ValueType value, Span span = Span());
829- /* !
830- * \brief Make a const zero expr.
831- * \param t The target type.
832- * \param span The location of this operation in the source.
833- * \return the result expression.
834- */
835- inline PrimExpr make_zero (DataType t, Span span = Span());
836- /* !
837- * \brief Make a constant true expression.
838- * \param lanes The number of lanes in the bool
839- * \param span The location of this operation in the source.
840- * \return The result expression.
841- */
842- inline PrimExpr const_true (int lanes = 1 , Span span = Span()) {
843- return make_const (DataType::Bool (lanes), 1 );
844- }
835+ inline PrimExpr MakeConst (DataType dtype, ValueType value, Span span = Span());
845836/* !
846- * \brief Make a constant false expression .
847- * \param lanes The number of lanes in the bool
837+ * \brief Make a constant handle value .
838+ * \param value The integer payload to reinterpret as a handle.
848839 * \param span The location of this operation in the source.
849840 * \return The result expression.
850841 */
851- inline PrimExpr const_false (int lanes = 1 , Span span = Span()) {
852- return make_const (DataType::Bool (lanes), 0 );
853- }
842+ inline PrimExpr ConstHandle (int64_t value, Span span = Span());
854843/* !
855844 * \brief Get x as constant int expression.
856845 * \param x The expression
@@ -981,53 +970,52 @@ inline bool is_no_op(const tirx::Stmt& stmt) {
981970}
982971
983972template <typename ValueType>
984- inline PrimExpr MakeConstScalar (DataType t , ValueType value, Span span = Span()) {
985- if (t .is_int () || t .is_bool ()) return IntImm (t , static_cast <int64_t >(value), span);
986- if (t .is_uint ()) {
973+ inline PrimExpr MakeConstScalar (DataType dtype , ValueType value, Span span = Span()) {
974+ if (dtype .is_int () || dtype .is_bool ()) return IntImm (dtype , static_cast <int64_t >(value), span);
975+ if (dtype .is_uint ()) {
987976 // Use IntImm if it is a small integer
988977 uint64_t uval = static_cast <uint64_t >(value);
989978 if (value < static_cast <ValueType>(0 )) {
990979 TVM_FFI_THROW (InternalError) << " cannot make uint from negative value " << value;
991980 } else if (uval <= static_cast <uint64_t >(std::numeric_limits<int64_t >::max ())) {
992- return IntImm (t , static_cast <int64_t >(value), span);
981+ return IntImm (dtype , static_cast <int64_t >(value), span);
993982 } else {
994983 uint64_t mask = (static_cast <uint64_t >(1 ) << 32U ) - 1U ;
995984 uint64_t low = uval & mask;
996985 uint64_t high = uval >> 32U ;
997- return LargeUIntImm (t , static_cast <int64_t >(low), static_cast <int64_t >(high), span);
986+ return LargeUIntImm (dtype , static_cast <int64_t >(low), static_cast <int64_t >(high), span);
998987 }
999988 }
1000- if (t.is_float () || t.is_bfloat16 () || t.is_float8 () || t.is_float6 () || t.is_float4 ())
1001- return FloatImm (t, static_cast <double >(value), span);
1002- TVM_FFI_THROW (InternalError) << " cannot make const for type " << t;
989+ if (dtype.is_float () || dtype.is_bfloat16 () || dtype.is_float8 () || dtype.is_float6 () ||
990+ dtype.is_float4 ()) {
991+ return FloatImm (dtype, static_cast <double >(value), span);
992+ }
993+ TVM_FFI_THROW (InternalError) << " cannot make const for type " << dtype;
1003994 throw ;
1004995}
1005996
1006997template <>
1007- inline PrimExpr MakeConstScalar (DataType t , bool value, Span span) {
1008- return MakeConstScalar (t , static_cast <int >(value), span);
998+ inline PrimExpr MakeConstScalar (DataType dtype , bool value, Span span) {
999+ return MakeConstScalar (dtype , static_cast <int >(value), span);
10091000}
10101001
10111002template <typename ValueType, typename >
1012- inline PrimExpr make_const (DataType t , ValueType value, Span span) {
1013- if (t .is_scalar ()) {
1014- return MakeConstScalar (t , value, span);
1003+ inline PrimExpr MakeConst (DataType dtype , ValueType value, Span span) {
1004+ if (dtype .is_scalar ()) {
1005+ return MakeConstScalar (dtype , value, span);
10151006 } else {
1016- if (t .is_fixed_length_vector ()) {
1017- return tirx::Broadcast (MakeConstScalar (t .element_of (), value, span), t .lanes (), span);
1007+ if (dtype .is_fixed_length_vector ()) {
1008+ return tirx::Broadcast (MakeConstScalar (dtype .element_of (), value, span), dtype .lanes (), span);
10181009 } else {
1019- PrimExpr lanes =
1020- tirx::Mul ( tirx::Call ( DataType::Int ( 32 ), tirx::builtin::vscale (), {}), t .vscale_factor ());
1021- return tirx::Broadcast (MakeConstScalar (t .element_of (), value, span), lanes, span);
1010+ PrimExpr lanes = tirx::Mul ( tirx::Call ( DataType::Int ( 32 ), tirx::builtin::vscale (), {}),
1011+ dtype .vscale_factor ());
1012+ return tirx::Broadcast (MakeConstScalar (dtype .element_of (), value, span), lanes, span);
10221013 }
10231014 }
10241015}
10251016
1026- inline PrimExpr make_zero (DataType t, Span span) {
1027- if (t.is_handle ()) {
1028- return reinterpret (t, make_const (DataType::UInt (64 ), 0 , span));
1029- }
1030- return make_const (t, 0 , span);
1017+ inline PrimExpr ConstHandle (int64_t value, Span span) {
1018+ return reinterpret (DataType::Handle (), IntImm (DataType::UInt (64 ), value, span));
10311019}
10321020
10331021} // namespace tirx
@@ -1043,13 +1031,13 @@ inline PrimExpr make_zero(DataType t, Span span) {
10431031 inline PrimExpr Name (const PrimExpr& a, float b) { return Name (a, PrimExpr (b)); } \
10441032 inline PrimExpr Name (float a, const PrimExpr& b) { return Name (PrimExpr (a), b); } \
10451033 inline PrimExpr Name (int a, const PrimExpr& b) { \
1046- return Name (tirx::make_const (b.dtype (), a), b); \
1034+ return Name (tirx::MakeConst (b.dtype (), a), b); \
10471035 } \
10481036 inline PrimExpr Name (const PrimExpr& a, int b) { \
1049- return Name (a, tirx::make_const (a.dtype (), b)); \
1037+ return Name (a, tirx::MakeConst (a.dtype (), b)); \
10501038 } \
10511039 inline PrimExpr Name (const PrimExpr& a, double b) { \
1052- return Name (a, tirx::make_const (DataType::Float (64 ), b)); \
1040+ return Name (a, FloatImm (DataType::Float (64 ), b)); \
10531041 }
10541042
10551043#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED (Name ) \
@@ -1060,13 +1048,13 @@ inline PrimExpr make_zero(DataType t, Span span) {
10601048 return Name (PrimExpr (a), b, span); \
10611049 } \
10621050 inline PrimExpr Name (int a, const PrimExpr& b, Span span = Span()) { \
1063- return Name (tirx::make_const (b.dtype (), a), b, span); \
1051+ return Name (tirx::MakeConst (b.dtype (), a), b, span); \
10641052 } \
10651053 inline PrimExpr Name (const PrimExpr& a, int b, Span span = Span()) { \
1066- return Name (a, tirx::make_const (a.dtype (), b), span); \
1054+ return Name (a, tirx::MakeConst (a.dtype (), b), span); \
10671055 } \
10681056 inline PrimExpr Name (const PrimExpr& a, double b, Span span = Span()) { \
1069- return Name (a, tirx::make_const (DataType::Float (64 ), b), span); \
1057+ return Name (a, FloatImm (DataType::Float (64 ), b), span); \
10701058 }
10711059
10721060#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD (Name ) \
@@ -1081,18 +1069,18 @@ inline PrimExpr make_zero(DataType t, Span span) {
10811069 return Name (PrimExpr (a), b, span); \
10821070 }
10831071
1084- #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD (Name ) \
1085- inline PrimExpr Name (const PrimExpr& a, int b) { \
1086- return Name (a, tirx::make_const (a.dtype (), b)); \
1087- } \
1088- inline PrimExpr Name (int a, const PrimExpr& b) { return Name (tirx::make_const (b.dtype (), a), b); }
1072+ #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD (Name ) \
1073+ inline PrimExpr Name (const PrimExpr& a, int b) { \
1074+ return Name (a, tirx::MakeConst (a.dtype (), b)); \
1075+ } \
1076+ inline PrimExpr Name (int a, const PrimExpr& b) { return Name (tirx::MakeConst (b.dtype (), a), b); }
10891077
10901078#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED (Name ) \
10911079 inline PrimExpr Name (const PrimExpr& a, int b, Span span = Span()) { \
1092- return Name (a, tirx::make_const (a.dtype (), b), span); \
1080+ return Name (a, tirx::MakeConst (a.dtype (), b), span); \
10931081 } \
10941082 inline PrimExpr Name (int a, const PrimExpr& b, Span span = Span()) { \
1095- return Name (tirx::make_const (b.dtype (), a), b, span); \
1083+ return Name (tirx::MakeConst (b.dtype (), a), b, span); \
10961084 }
10971085
10981086TVM_DEFINE_ASSIGN_OP_OVERLOAD (operator +=, operator +);
0 commit comments