@@ -816,7 +816,13 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) {
816816
817817/* !
818818 * \brief Make a const value with certain data type.
819- * \param t The target type.
819+ *
820+ * This helper is intended for generic overload sites where the dtype may be
821+ * integer, floating point, or vector-valued. When the dtype is statically known
822+ * to be a scalar integer or floating point type, callers should use IntImm or
823+ * FloatImm directly for more compact and efficient code.
824+ *
825+ * \param dtype The target type.
820826 * \param value The input value
821827 * \return the result expression.
822828 * \tparam ValueType The constant value type
@@ -825,32 +831,14 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) {
825831template <typename ValueType,
826832 typename = typename std::enable_if<std::is_standard_layout<ValueType>::value &&
827833 std::is_trivial<ValueType>::value>::type>
828- inline PrimExpr make_const (DataType t, ValueType value, Span span = Span());
829- /* !
830- * \brief Make a const zero expr.
831- * \param t The target type.
832- * \param span The location of this operation in the source.
833- * \return the result expression.
834- */
835- inline PrimExpr make_zero (DataType t, Span span = Span());
836- /* !
837- * \brief Make a constant true expression.
838- * \param lanes The number of lanes in the bool
839- * \param span The location of this operation in the source.
840- * \return The result expression.
841- */
842- inline PrimExpr const_true (int lanes = 1 , Span span = Span()) {
843- return make_const (DataType::Bool (lanes), 1 );
844- }
834+ inline PrimExpr make_const (DataType dtype, ValueType value, Span span = Span());
845835/* !
846- * \brief Make a constant false expression .
847- * \param lanes The number of lanes in the bool
836+ * \brief Make a constant handle value .
837+ * \param value The integer handle value.
848838 * \param span The location of this operation in the source.
849839 * \return The result expression.
850840 */
851- inline PrimExpr const_false (int lanes = 1 , Span span = Span()) {
852- return make_const (DataType::Bool (lanes), 0 );
853- }
841+ inline PrimExpr ConstHandle (int32_t value, Span span = Span());
854842/* !
855843 * \brief Get x as constant int expression.
856844 * \param x The expression
@@ -981,53 +969,52 @@ inline bool is_no_op(const tirx::Stmt& stmt) {
981969}
982970
983971template <typename ValueType>
984- inline PrimExpr MakeConstScalar (DataType t , ValueType value, Span span = Span()) {
985- if (t .is_int () || t .is_bool ()) return IntImm (t , static_cast <int64_t >(value), span);
986- if (t .is_uint ()) {
972+ inline PrimExpr MakeConstScalar (DataType dtype , ValueType value, Span span = Span()) {
973+ if (dtype .is_int () || dtype .is_bool ()) return IntImm (dtype , static_cast <int64_t >(value), span);
974+ if (dtype .is_uint ()) {
987975 // Use IntImm if it is a small integer
988976 uint64_t uval = static_cast <uint64_t >(value);
989977 if (value < static_cast <ValueType>(0 )) {
990978 TVM_FFI_THROW (InternalError) << " cannot make uint from negative value " << value;
991979 } else if (uval <= static_cast <uint64_t >(std::numeric_limits<int64_t >::max ())) {
992- return IntImm (t , static_cast <int64_t >(value), span);
980+ return IntImm (dtype , static_cast <int64_t >(value), span);
993981 } else {
994982 uint64_t mask = (static_cast <uint64_t >(1 ) << 32U ) - 1U ;
995983 uint64_t low = uval & mask;
996984 uint64_t high = uval >> 32U ;
997- return LargeUIntImm (t , static_cast <int64_t >(low), static_cast <int64_t >(high), span);
985+ return LargeUIntImm (dtype , static_cast <int64_t >(low), static_cast <int64_t >(high), span);
998986 }
999987 }
1000- if (t.is_float () || t.is_bfloat16 () || t.is_float8 () || t.is_float6 () || t.is_float4 ())
1001- return FloatImm (t, static_cast <double >(value), span);
1002- TVM_FFI_THROW (InternalError) << " cannot make const for type " << t;
988+ if (dtype.is_float () || dtype.is_bfloat16 () || dtype.is_float8 () || dtype.is_float6 () ||
989+ dtype.is_float4 ()) {
990+ return FloatImm (dtype, static_cast <double >(value), span);
991+ }
992+ TVM_FFI_THROW (InternalError) << " cannot make const for type " << dtype;
1003993 throw ;
1004994}
1005995
1006996template <>
1007- inline PrimExpr MakeConstScalar (DataType t , bool value, Span span) {
1008- return MakeConstScalar (t , static_cast <int >(value), span);
997+ inline PrimExpr MakeConstScalar (DataType dtype , bool value, Span span) {
998+ return MakeConstScalar (dtype , static_cast <int >(value), span);
1009999}
10101000
10111001template <typename ValueType, typename >
1012- inline PrimExpr make_const (DataType t , ValueType value, Span span) {
1013- if (t .is_scalar ()) {
1014- return MakeConstScalar (t , value, span);
1002+ inline PrimExpr make_const (DataType dtype , ValueType value, Span span) {
1003+ if (dtype .is_scalar ()) {
1004+ return MakeConstScalar (dtype , value, span);
10151005 } else {
1016- if (t .is_fixed_length_vector ()) {
1017- return tirx::Broadcast (MakeConstScalar (t .element_of (), value, span), t .lanes (), span);
1006+ if (dtype .is_fixed_length_vector ()) {
1007+ return tirx::Broadcast (MakeConstScalar (dtype .element_of (), value, span), dtype .lanes (), span);
10181008 } else {
1019- PrimExpr lanes =
1020- tirx::Mul ( tirx::Call ( DataType::Int ( 32 ), tirx::builtin::vscale (), {}), t .vscale_factor ());
1021- return tirx::Broadcast (MakeConstScalar (t .element_of (), value, span), lanes, span);
1009+ PrimExpr lanes = tirx::Mul ( tirx::Call ( DataType::Int ( 32 ), tirx::builtin::vscale (), {}),
1010+ dtype .vscale_factor ());
1011+ return tirx::Broadcast (MakeConstScalar (dtype .element_of (), value, span), lanes, span);
10221012 }
10231013 }
10241014}
10251015
1026- inline PrimExpr make_zero (DataType t, Span span) {
1027- if (t.is_handle ()) {
1028- return reinterpret (t, make_const (DataType::UInt (64 ), 0 , span));
1029- }
1030- return make_const (t, 0 , span);
1016+ inline PrimExpr ConstHandle (int32_t value, Span span) {
1017+ return reinterpret (DataType::Handle (), IntImm (DataType::UInt (64 ), value, span));
10311018}
10321019
10331020} // namespace tirx
@@ -1049,7 +1036,7 @@ inline PrimExpr make_zero(DataType t, Span span) {
10491036 return Name (a, tirx::make_const (a.dtype (), b)); \
10501037 } \
10511038 inline PrimExpr Name (const PrimExpr& a, double b) { \
1052- return Name (a, tirx::make_const (DataType::Float (64 ), b)); \
1039+ return Name (a, FloatImm (DataType::Float (64 ), b)); \
10531040 }
10541041
10551042#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED (Name ) \
@@ -1066,7 +1053,7 @@ inline PrimExpr make_zero(DataType t, Span span) {
10661053 return Name (a, tirx::make_const (a.dtype (), b), span); \
10671054 } \
10681055 inline PrimExpr Name (const PrimExpr& a, double b, Span span = Span()) { \
1069- return Name (a, tirx::make_const (DataType::Float (64 ), b), span); \
1056+ return Name (a, FloatImm (DataType::Float (64 ), b), span); \
10701057 }
10711058
10721059#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD (Name ) \
0 commit comments