@@ -816,7 +816,14 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) {
816816
817817/* !
818818 * \brief Make a const value with certain data type.
819- * \param t The target type.
819+ *
820+ * Prefer direct IntImm or FloatImm construction when dtype is known to be
821+ * scalar integer or floating point. This makes the compiled code more compact
822+ * and efficient. Keep make_const for generic overload cases where dtype can be
823+ * integer, floating point, or vector-valued and the caller needs its
824+ * scalar/vector dispatch.
825+ *
826+ * \param dtype The target type.
820827 * \param value The input value
821828 * \return the result expression.
822829 * \tparam ValueType The constant value type
@@ -825,32 +832,14 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) {
825832template <typename ValueType,
826833 typename = typename std::enable_if<std::is_standard_layout<ValueType>::value &&
827834 std::is_trivial<ValueType>::value>::type>
828- inline PrimExpr make_const (DataType t, ValueType value, Span span = Span());
829- /* !
830- * \brief Make a const zero expr.
831- * \param t The target type.
832- * \param span The location of this operation in the source.
833- * \return the result expression.
834- */
835- inline PrimExpr make_zero (DataType t, Span span = Span());
836- /* !
837- * \brief Make a constant true expression.
838- * \param lanes The number of lanes in the bool
839- * \param span The location of this operation in the source.
840- * \return The result expression.
841- */
842- inline PrimExpr const_true (int lanes = 1 , Span span = Span()) {
843- return make_const (DataType::Bool (lanes), 1 );
844- }
835+ inline PrimExpr make_const (DataType dtype, ValueType value, Span span = Span());
845836/* !
846- * \brief Make a constant false expression .
847- * \param lanes The number of lanes in the bool
837+ * \brief Make a constant handle value .
838+ * \param value The integer payload to reinterpret as a handle.
848839 * \param span The location of this operation in the source.
849840 * \return The result expression.
850841 */
851- inline PrimExpr const_false (int lanes = 1 , Span span = Span()) {
852- return make_const (DataType::Bool (lanes), 0 );
853- }
842+ inline PrimExpr ConstHandle (int64_t value, Span span = Span());
854843/* !
855844 * \brief Get x as constant int expression.
856845 * \param x The expression
@@ -981,53 +970,52 @@ inline bool is_no_op(const tirx::Stmt& stmt) {
981970}
982971
983972template <typename ValueType>
984- inline PrimExpr MakeConstScalar (DataType t , ValueType value, Span span = Span()) {
985- if (t .is_int () || t .is_bool ()) return IntImm (t , static_cast <int64_t >(value), span);
986- if (t .is_uint ()) {
973+ inline PrimExpr MakeConstScalar (DataType dtype , ValueType value, Span span = Span()) {
974+ if (dtype .is_int () || dtype .is_bool ()) return IntImm (dtype , static_cast <int64_t >(value), span);
975+ if (dtype .is_uint ()) {
987976 // Use IntImm if it is a small integer
988977 uint64_t uval = static_cast <uint64_t >(value);
989978 if (value < static_cast <ValueType>(0 )) {
990979 TVM_FFI_THROW (InternalError) << " cannot make uint from negative value " << value;
991980 } else if (uval <= static_cast <uint64_t >(std::numeric_limits<int64_t >::max ())) {
992- return IntImm (t , static_cast <int64_t >(value), span);
981+ return IntImm (dtype , static_cast <int64_t >(value), span);
993982 } else {
994983 uint64_t mask = (static_cast <uint64_t >(1 ) << 32U ) - 1U ;
995984 uint64_t low = uval & mask;
996985 uint64_t high = uval >> 32U ;
997- return LargeUIntImm (t , static_cast <int64_t >(low), static_cast <int64_t >(high), span);
986+ return LargeUIntImm (dtype , static_cast <int64_t >(low), static_cast <int64_t >(high), span);
998987 }
999988 }
1000- if (t.is_float () || t.is_bfloat16 () || t.is_float8 () || t.is_float6 () || t.is_float4 ())
1001- return FloatImm (t, static_cast <double >(value), span);
1002- TVM_FFI_THROW (InternalError) << " cannot make const for type " << t;
989+ if (dtype.is_float () || dtype.is_bfloat16 () || dtype.is_float8 () || dtype.is_float6 () ||
990+ dtype.is_float4 ()) {
991+ return FloatImm (dtype, static_cast <double >(value), span);
992+ }
993+ TVM_FFI_THROW (InternalError) << " cannot make const for type " << dtype;
1003994 throw ;
1004995}
1005996
1006997template <>
1007- inline PrimExpr MakeConstScalar (DataType t , bool value, Span span) {
1008- return MakeConstScalar (t , static_cast <int >(value), span);
998+ inline PrimExpr MakeConstScalar (DataType dtype , bool value, Span span) {
999+ return MakeConstScalar (dtype , static_cast <int >(value), span);
10091000}
10101001
10111002template <typename ValueType, typename >
1012- inline PrimExpr make_const (DataType t , ValueType value, Span span) {
1013- if (t .is_scalar ()) {
1014- return MakeConstScalar (t , value, span);
1003+ inline PrimExpr make_const (DataType dtype , ValueType value, Span span) {
1004+ if (dtype .is_scalar ()) {
1005+ return MakeConstScalar (dtype , value, span);
10151006 } else {
1016- if (t .is_fixed_length_vector ()) {
1017- return tirx::Broadcast (MakeConstScalar (t .element_of (), value, span), t .lanes (), span);
1007+ if (dtype .is_fixed_length_vector ()) {
1008+ return tirx::Broadcast (MakeConstScalar (dtype .element_of (), value, span), dtype .lanes (), span);
10181009 } else {
1019- PrimExpr lanes =
1020- tirx::Mul ( tirx::Call ( DataType::Int ( 32 ), tirx::builtin::vscale (), {}), t .vscale_factor ());
1021- return tirx::Broadcast (MakeConstScalar (t .element_of (), value, span), lanes, span);
1010+ PrimExpr lanes = tirx::Mul ( tirx::Call ( DataType::Int ( 32 ), tirx::builtin::vscale (), {}),
1011+ dtype .vscale_factor ());
1012+ return tirx::Broadcast (MakeConstScalar (dtype .element_of (), value, span), lanes, span);
10221013 }
10231014 }
10241015}
10251016
1026- inline PrimExpr make_zero (DataType t, Span span) {
1027- if (t.is_handle ()) {
1028- return reinterpret (t, make_const (DataType::UInt (64 ), 0 , span));
1029- }
1030- return make_const (t, 0 , span);
1017+ inline PrimExpr ConstHandle (int64_t value, Span span) {
1018+ return reinterpret (DataType::Handle (), IntImm (DataType::UInt (64 ), value, span));
10311019}
10321020
10331021} // namespace tirx
@@ -1049,7 +1037,7 @@ inline PrimExpr make_zero(DataType t, Span span) {
10491037 return Name (a, tirx::make_const (a.dtype (), b)); \
10501038 } \
10511039 inline PrimExpr Name (const PrimExpr& a, double b) { \
1052- return Name (a, tirx::make_const (DataType::Float (64 ), b)); \
1040+ return Name (a, FloatImm (DataType::Float (64 ), b)); \
10531041 }
10541042
10551043#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED (Name ) \
@@ -1066,7 +1054,7 @@ inline PrimExpr make_zero(DataType t, Span span) {
10661054 return Name (a, tirx::make_const (a.dtype (), b), span); \
10671055 } \
10681056 inline PrimExpr Name (const PrimExpr& a, double b, Span span = Span()) { \
1069- return Name (a, tirx::make_const (DataType::Float (64 ), b), span); \
1057+ return Name (a, FloatImm (DataType::Float (64 ), b), span); \
10701058 }
10711059
10721060#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD (Name ) \
0 commit comments