diff --git a/include/tvm/ir/base_expr.h b/include/tvm/ir/base_expr.h new file mode 100644 index 000000000000..0a844bb3ba8e --- /dev/null +++ b/include/tvm/ir/base_expr.h @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/base_expr.h + * \brief Base expression and primitive type nodes. + */ +#ifndef TVM_IR_BASE_EXPR_H_ +#define TVM_IR_BASE_EXPR_H_ + +#include +#include +#include +#include + +#include +#include + +namespace tvm { + +/*! + * \brief Type is the base type of all types. + * + * TVM's type system contains following subclasses: + * + * - PrimType: type of primitive type values used in the low-level IR. + * - FuncType: type of a function. + * - TensorType: type of certain Tensor values in the expression. + * + * There are also advanced types to support generic(polymorphic types). + * \sa Type + */ +class TypeNode : public ffi::Object { + public: + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + // span do not participate in structural equal and hash. + refl::ObjectDef().def_ro("span", &TypeNode::span, refl::DefaultValue(Span()), + refl::AttachFieldFlag::SEqHashIgnore()); + } + + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; + + static constexpr const uint32_t _type_child_slots = 14; + TVM_FFI_DECLARE_OBJECT_INFO("ir.Type", TypeNode, ffi::Object); +}; + +/*! + * \brief Managed reference to TypeNode. + * \sa TypeNode + */ +class Type : public ffi::ObjectRef { + public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Type, ffi::ObjectRef, TypeNode); +}; + +/*! + * \brief Primitive data types used in the low-level IR. + * + * PrimType represents POD-values and handles that are + * not automatically managed by the runtime. + * + * \sa PrimType + */ +class PrimTypeNode final : public TypeNode { + public: + /*! + * \brief The raw DLPack dtype represented by this primitive type. + */ + DLDataType dtype; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("dtype", &PrimTypeNode::dtype); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.PrimType", PrimTypeNode, TypeNode); +}; + +/* + * \brief Managed reference to PrimTypeNode. + * \sa PrimTypeNode + */ +class PrimType final : public Type { + public: + /*! + * \brief Construct from a raw DLPack dtype. + * \param dtype The corresponding DLPack dtype. + */ + TVM_DLL explicit PrimType(DLDataType dtype); + + /*! + * \brief Construct from DLPack dtype fields. + * \param code The DLPack dtype code. + * \param bits The scalar bit width. + * \param lanes The fixed lane count. + */ + TVM_DLL PrimType(DLDataTypeCode code, int bits, int lanes = 1); + + /*! \brief Construct a signed integer type with fixed lanes. */ + TVM_DLL static PrimType Int(int bits, int lanes = 1); + /*! \brief Construct an unsigned integer type with fixed lanes. */ + TVM_DLL static PrimType UInt(int bits, int lanes = 1); + /*! \brief Construct a floating-point type with fixed lanes. */ + TVM_DLL static PrimType Float(int bits, int lanes = 1); + /*! \brief Construct a bfloat type with fixed lanes. */ + TVM_DLL static PrimType BFloat(int bits, int lanes = 1); + /*! \brief Construct a boolean type with fixed lanes. */ + TVM_DLL static PrimType Bool(int lanes = 1); + /*! \brief Construct an opaque handle type. */ + TVM_DLL static PrimType Handle(int bits = 64, int lanes = 1); + /*! \brief Construct the void sentinel type, encoded as handle(0, 0). */ + TVM_DLL static PrimType Void(); + /*! + * \brief Construct a scalable vector type. + * \param code The DLPack dtype code. + * \param bits The scalar bit width. + * \param lanes The positive vscale factor to encode in the DLPack lane field. + */ + TVM_DLL static PrimType ScalableVector(DLDataTypeCode code, int bits, int lanes); + + /*! \return The DLPack dtype code. */ + TVM_FFI_INLINE DLDataTypeCode code() const { + return static_cast(static_cast(get()->dtype.code)); + } + + /*! \return The scalar bit width. */ + TVM_FFI_INLINE int32_t bits() const { return get()->dtype.bits; } + + /*! + * \return The fixed lane count. + * \note Throws on scalable vector types, where the encoded lane field stores a vscale factor. + */ + TVM_FFI_INLINE int32_t lanes() const { + int16_t encoded_lanes = static_cast(get()->dtype.lanes); + if (TVM_FFI_PREDICT_FALSE(encoded_lanes < 0)) { + TVM_FFI_THROW(InternalError) + << "Can't fetch the lanes of a scalable vector at a compile time."; + } + return encoded_lanes; + } + + /*! + * \brief Check the scalar element code and bit width. + * \note Lane count and scalable-vector encoding are intentionally ignored. + */ + TVM_FFI_INLINE bool MatchesElementType(DLDataTypeCode code, int bits) const { + DLDataType dtype = get()->dtype; + return dtype.code == static_cast(code) && dtype.bits == bits; + } + + /*! + * \brief Check whether the dtype code matches any of the provided DLPack codes. + * \note Bit width and lanes are intentionally ignored. + */ + template + TVM_FFI_INLINE bool MatchesCode(Codes... codes) const { + uint8_t dtype_code = get()->dtype.code; + return ((dtype_code == static_cast(codes)) || ...); + } + + /*! \brief Whether this type is a scalar, excluding fixed and scalable vectors. */ + TVM_FFI_INLINE bool IsScalar() const { + int16_t encoded_lanes = static_cast(get()->dtype.lanes); + return encoded_lanes == 1; + } + + /*! \brief Whether this type is the void sentinel `handle(0, 0)`. */ + TVM_FFI_INLINE bool IsVoid() const { + DLDataType dtype = get()->dtype; + return dtype.code == static_cast(DLDataTypeCode::kDLOpaqueHandle) && dtype.bits == 0 && + static_cast(dtype.lanes) == 0; + } + + /*! \brief Whether this type is an opaque handle, excluding the void sentinel. */ + TVM_FFI_INLINE bool IsHandle() const { + return this->code() == DLDataTypeCode::kDLOpaqueHandle && !this->IsVoid(); + } + + /*! \brief Whether this type is a scalable vector. */ + TVM_FFI_INLINE bool IsScalableVector() const { + return static_cast(get()->dtype.lanes) < -1; + } + + /*! \brief Whether this type is a fixed-length vector. */ + TVM_FFI_INLINE bool IsFixedLengthVector() const { + return static_cast(get()->dtype.lanes) > 1; + } + + /*! + * \brief Return the number of bytes needed to store one value of this type. + * + * This uses the same packed sub-byte dtype sizing rule as runtime tensors. + * Scalable vector types have no compile-time storage size and are rejected. + */ + TVM_DLL size_t StorageBytes() const; + + /*! \brief Return the same type with a different dtype code, preserving bits and lanes. */ + TVM_FFI_INLINE PrimType WithCode(DLDataTypeCode code) const { + DLDataType dtype = get()->dtype; + int16_t encoded_lanes = static_cast(dtype.lanes); + if (encoded_lanes < -1) { + return ScalableVector(code, dtype.bits, -encoded_lanes); + } + return PrimType(code, dtype.bits, encoded_lanes); + } + + /*! \brief Return the same type with a different scalar bit width, preserving code and lanes. */ + TVM_FFI_INLINE PrimType WithBits(int bits) const { + DLDataType dtype = get()->dtype; + int16_t encoded_lanes = static_cast(dtype.lanes); + if (encoded_lanes < -1) { + return ScalableVector(this->code(), bits, -encoded_lanes); + } + return PrimType(this->code(), bits, encoded_lanes); + } + + /*! \brief Return the same scalar element type with a fixed lane count. */ + TVM_FFI_INLINE PrimType WithLanes(int lanes) const { + return PrimType(this->code(), this->bits(), lanes); + } + + /*! \return The vscale factor encoded in a scalable vector type. */ + TVM_FFI_INLINE int32_t VScaleFactor() const { + int16_t encoded_lanes = static_cast(get()->dtype.lanes); + if (encoded_lanes >= -1) { + TVM_FFI_THROW(InternalError) << "A fixed length vector doesn't have a vscale factor."; + } + return -encoded_lanes; + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PrimType, Type, PrimTypeNode); +}; + +inline bool operator==(const PrimType& lhs, const PrimType& rhs) { + return lhs->dtype == rhs->dtype; +} + +inline bool operator!=(const PrimType& lhs, const PrimType& rhs) { return !(lhs == rhs); } + +/*! + * \brief Base type of all the expressions. + * \sa Expr + */ +class BaseExprNode : public ffi::Object { + public: + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; + + /*! + * \brief The deduced or annotated type of the expression. + * + * This field is intentionally nullable because type information may + * be populated by later analysis passes instead of expression + * constructors. + */ + mutable Type ty; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + // span and ty do not participate in structural equal and hash. + refl::ObjectDef() + .def_ro("span", &BaseExprNode::span, refl::DefaultValue(Span()), + refl::AttachFieldFlag::SEqHashIgnore()) + .def_ro("ty", &BaseExprNode::ty, refl::DefaultValue(Type()), + refl::AttachFieldFlag::SEqHashIgnore()); + } + + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; + + static constexpr const uint32_t _type_child_slots = 64; + TVM_FFI_DECLARE_OBJECT_INFO("ir.BaseExpr", BaseExprNode, ffi::Object); +}; + +/*! + * \brief Managed reference to BaseExprNode. + * \sa BaseExprNode + */ +class BaseExpr : public ffi::ObjectRef { + public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BaseExpr, ffi::ObjectRef, BaseExprNode); +}; + +namespace ffi { +template <> +inline constexpr bool use_default_type_traits_v = false; + +template <> +struct TypeTraits : public ObjectRefWithFallbackTraitsBase { + TVM_FFI_INLINE static PrimType ConvertFallbackValue(DLDataType dtype) { return PrimType(dtype); } +}; +} // namespace ffi + +} // namespace tvm + +#endif // TVM_IR_BASE_EXPR_H_ diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index b81e4c2feda7..70e1ffeb480c 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -24,12 +24,13 @@ #ifndef TVM_IR_EXPR_H_ #define TVM_IR_EXPR_H_ +#include #include #include #include +#include #include #include -#include #include #include @@ -54,82 +55,6 @@ class VirtualDevice; * There are also advanced types to support generic(polymorphic types). * \sa Type */ -class TypeNode : public ffi::Object { - public: - /*! - * \brief Span that points to the original source code. - * Reserved debug information. - */ - mutable Span span; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - // span do not participate in structural equal and hash. - refl::ObjectDef().def_ro("span", &TypeNode::span, refl::DefaultValue(Span()), - refl::AttachFieldFlag::SEqHashIgnore()); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - - static constexpr const uint32_t _type_child_slots = 14; - TVM_FFI_DECLARE_OBJECT_INFO("ir.Type", TypeNode, ffi::Object); -}; - -/*! - * \brief Managed reference to TypeNode. - * \sa TypeNode - */ -class Type : public ffi::ObjectRef { - public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Type, ffi::ObjectRef, TypeNode); -}; - -/*! - * \brief Base type of all the expressions. - * \sa Expr - */ -class BaseExprNode : public ffi::Object { - public: - /*! - * \brief Span that points to the original source code. - * Reserved debug information. - */ - mutable Span span; - - /*! - * \brief The deduced or annotated type of the expression. - * - * This field is intentionally nullable because type information may - * be populated by later analysis passes instead of expression - * constructors. - */ - mutable Type ty; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - // span and ty do not participate in structural equal and hash. - refl::ObjectDef() - .def_ro("span", &BaseExprNode::span, refl::DefaultValue(Span()), - refl::AttachFieldFlag::SEqHashIgnore()) - .def_ro("ty", &BaseExprNode::ty, refl::DefaultValue(Type()), - refl::AttachFieldFlag::SEqHashIgnore()); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - - static constexpr const uint32_t _type_child_slots = 64; - TVM_FFI_DECLARE_OBJECT_INFO("ir.BaseExpr", BaseExprNode, ffi::Object); -}; - -/*! - * \brief Managed reference to BaseExprNode. - * \sa BaseExprNode - */ -class BaseExpr : public ffi::ObjectRef { - public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BaseExpr, ffi::ObjectRef, BaseExprNode); -}; - /*! * \brief Base node of all primitive expressions. * @@ -144,25 +69,16 @@ class BaseExpr : public ffi::ObjectRef { */ class PrimExprNode : public BaseExprNode { public: - /*! - * \brief The runtime data type of the primitive expression. - * - * runtime::DataType(dtype) provides coarse grained type information - * during compile time and runtime. It is eagerly built in - * PrimExpr expression construction and can be used for - * quick type checking. - * - * dtype is sufficient to decide the Type of the PrimExpr - * when it corresponds to POD value types such as i32. - * - * When dtype is DataType::Handle(), the expression could corresponds to - * a more fine-grained Type, and we can get the type by running lazy type inference. - */ - DataType dtype; + /*! \return the primitive type of this expression node. */ + PrimType ty() const { + TVM_FFI_DCHECK(this->BaseExprNode::ty.defined()); + TVM_FFI_DCHECK(this->BaseExprNode::ty->IsInstance()); + return ffi::GetRef(static_cast(this->BaseExprNode::ty.get())); + } static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("dtype", &PrimExprNode::dtype); + refl::ObjectDef(); } static constexpr const uint32_t _type_child_slots = 40; @@ -186,8 +102,13 @@ class PrimExpr : public BaseExpr { */ TVM_DLL PrimExpr(float value); // NOLINT(*) - /*! \return the data type of this expression. */ - DataType dtype() const { return static_cast(get())->dtype; } + /*! \return the primitive type of this expression. */ + PrimType ty() const { + const auto* node = static_cast(get()); + TVM_FFI_DCHECK(node->BaseExprNode::ty.defined()); + TVM_FFI_DCHECK(node->BaseExprNode::ty->IsInstance()); + return ffi::GetRef(static_cast(node->BaseExprNode::ty.get())); + } TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimExpr, BaseExpr, PrimExprNode); @@ -554,11 +475,11 @@ class IntImm : public PrimExpr { public: /*! * \brief Constructor. - * \param dtype The data type of the value. + * \param value_ty The primitive type of the value. * \param value The internal value. * \param span The location of this object in the source code. */ - TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span()); + TVM_DLL IntImm(PrimType value_ty, int64_t value, Span span = Span()); /*! * \brief Construct a scalar boolean constant. @@ -566,7 +487,7 @@ class IntImm : public PrimExpr { * \param span The location of this object in the source code. */ static IntImm Bool(bool value, Span span = Span()) { - return IntImm(DataType::Bool(), value, span); + return IntImm(PrimType::Bool(), value, span); } /*! @@ -575,7 +496,7 @@ class IntImm : public PrimExpr { * \param span The location of this object in the source code. */ static IntImm Int32(int64_t value, Span span = Span()) { - return IntImm(DataType::Int(32), value, span); + return IntImm(PrimType::Int(32), value, span); } /*! @@ -584,7 +505,7 @@ class IntImm : public PrimExpr { * \param span The location of this object in the source code. */ static IntImm Int64(int64_t value, Span span = Span()) { - return IntImm(DataType::Int(64), value, span); + return IntImm(PrimType::Int(64), value, span); } TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IntImm, PrimExpr, IntImmNode); @@ -616,11 +537,11 @@ class FloatImm : public PrimExpr { public: /*! * \brief Constructor. - * \param dtype The data type of the value. + * \param value_ty The primitive type of the value. * \param value The internal value. * \param span The location in the source code. */ - TVM_DLL FloatImm(DataType dtype, double value, Span span = Span()); + TVM_DLL FloatImm(PrimType value_ty, double value, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloatImm, PrimExpr, FloatImmNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode); @@ -688,11 +609,11 @@ inline constexpr bool use_default_type_traits_v = false; template <> struct TypeTraits : public ObjectRefWithFallbackTraitsBase { TVM_FFI_INLINE static IntImm ConvertFallbackValue(int64_t value) { - auto dtype = + auto value_ty = (value > std::numeric_limits::max() || value < std::numeric_limits::min()) - ? DataType::Int(64) - : DataType::Int(32); - return IntImm(dtype, value); + ? PrimType::Int(64) + : PrimType::Int(32); + return IntImm(value_ty, value); } }; @@ -702,7 +623,7 @@ inline constexpr bool use_default_type_traits_v = false; template <> struct TypeTraits : public ObjectRefWithFallbackTraitsBase { TVM_FFI_INLINE static FloatImm ConvertFallbackValue(double value) { - return FloatImm(runtime::DataType::Float(32), value); + return FloatImm(PrimType::Float(32), value); } }; diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 9c56d0376405..f63b5d261500 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -26,21 +26,19 @@ * * This file contains types that are common across IR variants. * - * ## Relation between Type and runtime::DataType + * ## Relation between Type and DLPack dtype * - * Besides Type, we also store a dtype field in the low-level PrimExpr. - * runtime::DataType(dtype) provides coarse grained type information - * during compile time and runtime. It is eagerly built in - * low-level expression construction and can be used for - * quick type checking in the low-level IR. - * For example, when an Expr's dtype is int32, - * we know for sure that its type is also int32. + * PrimExpr stores a PrimType in its `ty` field, backed by a DLPack + * `DLDataType`. This provides coarse grained scalar/vector element type + * information during compile time and runtime. It is eagerly built in + * low-level expression construction and can be used for quick type checking + * in the low-level IR. For example, when an Expr's dtype is int32, we know + * for sure that its PrimType is also int32. * * On the other hand, Type provides more fine grained information. - * For example, a low level expression can have DataType::Handle() as - * its dtype and MemRef[float32] as its type. - * Types are usually lazily constructed via type checking, - * so they may not readily be available during IR construction. + * For example, a low level expression can have a handle dtype while a + * node-specific type annotation records a + * PointerType to a float32 element. * * The unified Type serves as a common bridge across IR dialects. * For example, we require all the functions to have a type signature, @@ -49,55 +47,16 @@ #ifndef TVM_IR_TYPE_H_ #define TVM_IR_TYPE_H_ -#include #include +#include #include -#include +#include #include -#include #include namespace tvm { -/*! - * \brief Primitive data types used in the low-level IR. - * - * PrimType represents POD-values and handles that are - * not automatically managed by the runtime. - * - * \sa PrimType - */ -class PrimTypeNode : public TypeNode { - public: - /*! - * \brief The corresponding dtype field. - */ - runtime::DataType dtype; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("dtype", &PrimTypeNode::dtype); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.PrimType", PrimTypeNode, TypeNode); -}; - -/* - * \brief Managed reference to PrimTypeNode. - * \sa PrimTypeNode - */ -class PrimType : public Type { - public: - /*! - * \brief Constructor - * \param dtype The corresponding dtype. - * \param span The span - */ - TVM_DLL explicit PrimType(runtime::DataType dtype, Span span = Span()); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PrimType, Type, PrimTypeNode); -}; - /*! * \brief Low-level raw pointer type. * diff --git a/include/tvm/relax/attrs/create.h b/include/tvm/relax/attrs/create.h index 14a3402f2503..76ef219a862c 100644 --- a/include/tvm/relax/attrs/create.h +++ b/include/tvm/relax/attrs/create.h @@ -31,7 +31,7 @@ namespace relax { /*! \brief Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operators */ struct InitAttrs : public AttrsNode { - DataType dtype; + DLDataType dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/attrs/datatype.h b/include/tvm/relax/attrs/datatype.h index f67223edb546..aeac65e64484 100644 --- a/include/tvm/relax/attrs/datatype.h +++ b/include/tvm/relax/attrs/datatype.h @@ -31,7 +31,7 @@ namespace relax { /*! \brief Attributes used in astype operator */ struct AstypeAttrs : public AttrsNode { - DataType dtype; + DLDataType dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -42,7 +42,7 @@ struct AstypeAttrs : public AttrsNode { /*! \brief Attributes used in wrap_param operator */ struct WrapParamAttrs : public AttrsNode { - DataType dtype; + DLDataType dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/attrs/image.h b/include/tvm/relax/attrs/image.h index c9a720374036..8f512f28e55f 100644 --- a/include/tvm/relax/attrs/image.h +++ b/include/tvm/relax/attrs/image.h @@ -39,7 +39,7 @@ struct Resize2DAttrs : public AttrsNode { double cubic_alpha; int cubic_exclude; double extrapolation_value; - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -88,7 +88,7 @@ struct Resize3DAttrs : public AttrsNode { double cubic_alpha; int cubic_exclude; double extrapolation_value; - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/attrs/linear_algebra.h b/include/tvm/relax/attrs/linear_algebra.h index 817885edb871..19a5982bfe12 100644 --- a/include/tvm/relax/attrs/linear_algebra.h +++ b/include/tvm/relax/attrs/linear_algebra.h @@ -31,7 +31,7 @@ namespace relax { /*! \brief Attributes for matmul operator */ struct MatmulAttrs : public AttrsNode { - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 52d9c40d742d..aa3c0f4736f0 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -38,7 +38,7 @@ struct Conv1DAttrs : public AttrsNode { ffi::String data_layout; ffi::String kernel_layout; ffi::String out_layout; - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -82,7 +82,7 @@ struct Conv2DAttrs : public AttrsNode { ffi::String data_layout; ffi::String kernel_layout; ffi::String out_layout; - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -128,7 +128,7 @@ struct Conv3DAttrs : public AttrsNode { ffi::String data_layout; ffi::String kernel_layout; ffi::String out_layout; - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -177,7 +177,7 @@ struct Conv1DTransposeAttrs : public AttrsNode { ffi::String data_layout; ffi::String kernel_layout; ffi::String out_layout; - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -226,7 +226,7 @@ struct Conv2DTransposeAttrs : public AttrsNode { ffi::String data_layout; ffi::String kernel_layout; ffi::String out_layout; - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -277,7 +277,7 @@ struct Conv3DTransposeAttrs : public AttrsNode { ffi::String data_layout; ffi::String kernel_layout; ffi::String out_layout; - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/attrs/qdq.h b/include/tvm/relax/attrs/qdq.h index 83ec2223c3c7..be95b9e7b8ed 100644 --- a/include/tvm/relax/attrs/qdq.h +++ b/include/tvm/relax/attrs/qdq.h @@ -31,7 +31,7 @@ namespace relax { /*! \brief Attributes for relax.quantize/relax.dequantize operator */ struct QuantizeAttrs : public AttrsNode { - DataType out_dtype; + DLDataType out_dtype; int axis; static void RegisterReflection() { diff --git a/include/tvm/relax/attrs/sampling.h b/include/tvm/relax/attrs/sampling.h index 11bbfb6eba31..07b7de25e553 100644 --- a/include/tvm/relax/attrs/sampling.h +++ b/include/tvm/relax/attrs/sampling.h @@ -31,13 +31,13 @@ namespace relax { /*! \brief Attributes used in multinomial_from_uniform operator */ struct MultinomialFromUniformAttrs : public AttrsNode { - DataType dtype; + DLDataType dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro( "dtype", &MultinomialFromUniformAttrs::dtype, "Data type of the output indices.", - refl::DefaultValue(DataType::Int(64))); + refl::DefaultValue((DLDataType{kDLInt, 64, 1}))); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MultinomialFromUniformAttrs", MultinomialFromUniformAttrs, AttrsNode); diff --git a/include/tvm/relax/attrs/sorting.h b/include/tvm/relax/attrs/sorting.h index e8bf65d55a43..ef21bf9a637e 100644 --- a/include/tvm/relax/attrs/sorting.h +++ b/include/tvm/relax/attrs/sorting.h @@ -54,7 +54,7 @@ struct SortAttrs : public AttrsNode { struct ArgsortAttrs : public AttrsNode { int axis; bool descending; - DataType dtype; + DLDataType dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -68,7 +68,7 @@ struct ArgsortAttrs : public AttrsNode { "If it is not specified, it defaults to the ascending order.", refl::DefaultValue(false)) .def_ro("dtype", &ArgsortAttrs::dtype, "DType of the output indices.", - refl::DefaultValue(DataType::Void())); + refl::DefaultValue((DLDataType{kDLOpaqueHandle, 0, 0}))); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ArgsortAttrs", ArgsortAttrs, AttrsNode); }; // struct ArgsortAttrs @@ -79,7 +79,7 @@ struct TopKAttrs : public AttrsNode { int axis; bool largest; ffi::String ret_type; - DataType dtype; + DLDataType dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -98,7 +98,7 @@ struct TopKAttrs : public AttrsNode { "By default, return the largest k elements.", refl::DefaultValue(true)) .def_ro("dtype", &TopKAttrs::dtype, "Data type of the output indices.", - refl::DefaultValue(DataType::Void())); + refl::DefaultValue((DLDataType{kDLOpaqueHandle, 0, 0}))); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TopKAttrs", TopKAttrs, AttrsNode); }; // struct TopKAttrs diff --git a/include/tvm/relax/attrs/statistical.h b/include/tvm/relax/attrs/statistical.h index 66996c802cc3..a815e0e07e51 100644 --- a/include/tvm/relax/attrs/statistical.h +++ b/include/tvm/relax/attrs/statistical.h @@ -50,7 +50,7 @@ struct StatisticalAttrs : public AttrsNode { /*! \brief Attributes used in scan operators like cumsum, cumprod */ struct ScanopAttrs : public AttrsNode { ffi::Optional axis; - DataType dtype; + DLDataType dtype; bool exclusive = false; static void RegisterReflection() { diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 27894da3addd..0511395f8a67 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -116,8 +116,8 @@ class DFPattern : public ffi::ObjectRef { TVM_DLL AttrPattern HasAttr(const ffi::Map& attrs) const; /*! \brief Syntatic Sugar for creating a TypePattern */ TVM_DLL TypePattern HasType(const Type& ty) const; - /*! \brief Syntatic Sugar for creating a DataTypePattern with a DataType */ - TVM_DLL DataTypePattern HasDtype(const DataType& dtype) const; + /*! \brief Syntatic Sugar for creating a DataTypePattern with a dtype */ + TVM_DLL DataTypePattern HasDtype(DLDataType dtype) const; /*! \brief Syntatic Sugar for creating a DataTypePattern with a data type's name */ TVM_DLL DataTypePattern HasDtype(const std::string& dtype) const; /*! \brief Syntatic Sugar for creating a ShapePattern */ @@ -860,7 +860,7 @@ class SameShapeConstraint : public DFConstraint { class DataTypePatternNode : public DFPatternNode { public: DFPattern pattern; /*!< The root pattern to match */ - DataType dtype; /*!< The data type to match */ + DLDataType dtype; /*!< The data type to match */ static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -878,7 +878,7 @@ class DataTypePatternNode : public DFPatternNode { */ class DataTypePattern : public DFPattern { public: - TVM_DLL DataTypePattern(DFPattern pattern, DataType dtype); + TVM_DLL DataTypePattern(DFPattern pattern, DLDataType dtype); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataTypePattern, DFPattern, DataTypePatternNode); }; diff --git a/include/tvm/relax/distributed/global_info.h b/include/tvm/relax/distributed/global_info.h index 62ff904fc1a4..0347ec3b85a8 100644 --- a/include/tvm/relax/distributed/global_info.h +++ b/include/tvm/relax/distributed/global_info.h @@ -25,6 +25,7 @@ #ifndef TVM_RELAX_DISTRIBUTED_GLOBAL_INFO_H_ #define TVM_RELAX_DISTRIBUTED_GLOBAL_INFO_H_ +#include #include #include namespace tvm { diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 937091255b6f..0b75bf27a7d2 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -471,7 +471,7 @@ class StringImm : public LeafExpr { class DataTypeImmNode : public LeafExprNode { public: /*! \brief The data value. */ - DataType value; + DLDataType value; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -491,7 +491,7 @@ class DataTypeImm : public LeafExpr { * \param value The value input. * \param span The source span of the expression. */ - TVM_DLL explicit DataTypeImm(DataType value, Span span = Span()); + TVM_DLL explicit DataTypeImm(DLDataType value, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataTypeImm, LeafExpr, DataTypeImmNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DataTypeImmNode); diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index d0d0d1bb5441..5c757ba15161 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -663,9 +663,8 @@ TVM_DLL Pass DataflowUseInplaceCalls(); * * \note Mainly operates within dataflow blocks. ConvertToDataflow may need to be called first. */ -TVM_DLL Pass -ToMixedPrecision(const DataType& out_dtype, - ffi::Optional> fp16_input_names = std::nullopt); +TVM_DLL Pass ToMixedPrecision( + DLDataType out_dtype, ffi::Optional> fp16_input_names = std::nullopt); /*! * \brief Rewrite a Relax module for executing with CUDA graph. This pass identifies diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h index 9c27b627a7d6..a77a3cc66c38 100644 --- a/include/tvm/relax/type.h +++ b/include/tvm/relax/type.h @@ -124,7 +124,7 @@ class ShapeTypeNode : public TypeNode { * \brief The number of dimension of the shape, can be unknown. * \sa kUnknownNDim */ - int ndim; + int ndim{kUnknownNDim}; /*! \return Whether the type contains unknown ndim. */ bool IsUnknownNdim() const { return ndim == kUnknownNDim; } @@ -174,19 +174,19 @@ class TensorTypeNode : public TypeNode { * is expected to be executed. */ ffi::Optional vdevice; - /*! \brief The content data type, use void to denote the dtype is unknown. */ - DataType dtype; + /*! \brief The content dtype, use void to denote the dtype is unknown. */ + tvm::PrimType dtype{DLDataType{kDLOpaqueHandle, 0, 0}}; /*! * \brief The number of dimension of the tensor, can be unknown. * \sa kUnknownNDim */ - int ndim; + int ndim{kUnknownNDim}; /*! \return Whether the type contains unknown ndim. */ bool IsUnknownNdim() const { return ndim == kUnknownNDim; } /*! \return Whether the type contains unknown dtype. */ - bool IsUnknownDtype() const { return dtype.is_void(); } + bool IsUnknownDtype() const { return dtype->dtype == DLDataType{kDLOpaqueHandle, 0, 0}; } /*! \return Shape if it is known. */ ffi::Optional> GetShape() const { @@ -230,7 +230,7 @@ class TensorType : public Type { * * \note shape must already be normalized. */ - TVM_DLL TensorType(Expr shape, DataType dtype, ffi::Optional vdevice = std::nullopt, + TVM_DLL TensorType(Expr shape, tvm::PrimType dtype, ffi::Optional vdevice = std::nullopt, Span span = Span()); /*! @@ -240,7 +240,7 @@ class TensorType : public Type { * \param vdevice The virtual device. * \param span The span of the AST. */ - TVM_DLL TensorType(DataType dtype, int ndim, ffi::Optional vdevice = std::nullopt, + TVM_DLL TensorType(tvm::PrimType dtype, int ndim, ffi::Optional vdevice = std::nullopt, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TensorType, Type, TensorTypeNode); diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h deleted file mode 100644 index 9f230cac824e..000000000000 --- a/include/tvm/runtime/data_type.h +++ /dev/null @@ -1,522 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file tvm/runtime/data_type.h - * \brief Primitive runtime data type. - */ -// Acknowledgement: DataType structure design originates from Halide. -#ifndef TVM_RUNTIME_DATA_TYPE_H_ -#define TVM_RUNTIME_DATA_TYPE_H_ - -#include -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace runtime { - -/*! - * \brief Runtime primitive data type. - * - * This class is a thin wrapper of DLDataType. - * We also make use of DataType in compiler to store quick hint - */ -class DataType { - public: - /*! - * \brief Type code for the DataType. - * - * DLPack consistency: - * 1) kInt is consistent with kDLInt - * 2) kUInt is consistent with kDLUInt - * 3) kFloat is consistent with kDLFloat - */ - enum TypeCode { - kInt = kDLInt, - kUInt = kDLUInt, - kFloat = kDLFloat, - kHandle = kDLOpaqueHandle, - kBFloat = kDLBfloat, - kBool = kDLBool, - kFloat8_e3m4 = kDLFloat8_e3m4, - kFloat8_e4m3 = kDLFloat8_e4m3, - kFloat8_e4m3b11fnuz = kDLFloat8_e4m3b11fnuz, - kFloat8_e4m3fn = kDLFloat8_e4m3fn, - kFloat8_e4m3fnuz = kDLFloat8_e4m3fnuz, - kFloat8_e5m2 = kDLFloat8_e5m2, - kFloat8_e5m2fnuz = kDLFloat8_e5m2fnuz, - kFloat8_e8m0fnu = kDLFloat8_e8m0fnu, - kFloat6_e2m3fn = kDLFloat6_e2m3fn, - kFloat6_e3m2fn = kDLFloat6_e3m2fn, - kFloat4_e2m1fn = kDLFloat4_e2m1fn, - kCustomBegin = 129 - }; - /*! \brief default constructor */ - DataType() { data_ = DataType::Void(); } - /*! - * \brief Constructor - * \param dtype The DLDataType - */ - explicit DataType(DLDataType dtype) : data_(dtype) {} - /*! - * \brief Constructor - * \param code The type code. - * \param bits The number of bits in the type. - * \param lanes The number of lanes. - * \param is_scalable Whether the data type is scalable. - */ - DataType(int code, int bits, int lanes, bool is_scalable = false) { - data_.code = static_cast(code); - data_.bits = static_cast(bits); - if (is_scalable) { - TVM_FFI_ICHECK(lanes > 1) << "Invalid value for vscale factor" << lanes; - } - data_.lanes = is_scalable ? static_cast(-lanes) : static_cast(lanes); - if (code == kBFloat) { - TVM_FFI_ICHECK_EQ(bits, 16); - } - if (code == kFloat8_e3m4 || code == kFloat8_e4m3 || code == kFloat8_e4m3b11fnuz || - code == kFloat8_e4m3fn || code == kFloat8_e4m3fnuz || code == kFloat8_e5m2 || - code == kFloat8_e5m2fnuz || code == kFloat8_e8m0fnu) { - TVM_FFI_ICHECK_EQ(bits, 8); - } - if (code == kFloat6_e2m3fn || code == kFloat6_e3m2fn) { - TVM_FFI_ICHECK_EQ(bits, 6); - } - if (code == kFloat4_e2m1fn) { - TVM_FFI_ICHECK_EQ(bits, 4); - } - } - /*! \return The type code. */ - int code() const { return static_cast(data_.code); } - /*! \return number of bits in the data. */ - int bits() const { return static_cast(data_.bits); } - /*! \return number of bytes to store each scalar. */ - int bytes() const { return (bits() + 7) / 8; } - /*! \return number of lanes in the data. */ - int lanes() const { - int lanes_as_int = static_cast(data_.lanes); - if (lanes_as_int < 0) { - TVM_FFI_THROW(InternalError) - << "Can't fetch the lanes of a scalable vector at a compile time."; - } - return lanes_as_int; - } - /*! \return the integer multiplier of vscale in a scalable vector. */ - int vscale_factor() const { - int lanes_as_int = static_cast(data_.lanes); - if (lanes_as_int >= -1) { - TVM_FFI_THROW(InternalError) << "A fixed length vector doesn't have a vscale factor."; - } - return -lanes_as_int; - } - /*! \return get vscale factor or lanes depending on scalability of the vector. */ - int get_lanes_or_vscale_factor() const { - return is_scalable_vector() ? vscale_factor() : lanes(); - } - /*! \return whether type is a scalar type. */ - bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; } - /*! \return whether type is a bool type. */ - bool is_bool() const { return code() == DataType::kBool; } - /*! \return whether type can be used in a predicate expression. */ - bool is_predicate_dtype() const { return is_bool() || (is_uint() && bits() == 1); } - /*! \return whether type is a float type. */ - bool is_float() const { return code() == DataType::kFloat; } - /*! \return whether type is a bfloat type. */ - bool is_bfloat() const { return code() == DataType::kBFloat; } - /*! \return whether type is any 8-bit custom Float8 variant. */ - bool is_float8() const { - return bits() == 8 && - (code() == DataType::kFloat8_e3m4 || code() == DataType::kFloat8_e4m3 || - code() == DataType::kFloat8_e4m3b11fnuz || code() == DataType::kFloat8_e4m3fn || - code() == DataType::kFloat8_e4m3fnuz || code() == DataType::kFloat8_e5m2 || - code() == DataType::kFloat8_e5m2fnuz || code() == DataType::kFloat8_e8m0fnu); - } - /*! \return whether type is any 6-bit custom Float6 variant. */ - bool is_float6() const { - return bits() == 6 && - (code() == DataType::kFloat6_e2m3fn || code() == DataType::kFloat6_e3m2fn); - } - /*! \return whether type is the 4-bit custom Float4_e2m1fn variant. */ - bool is_float4() const { return bits() == 4 && code() == DataType::kFloat4_e2m1fn; } - /*! \return whether type is Float8E3M4. */ - bool is_float8_e3m4() const { return bits() == 8 && code() == DataType::kFloat8_e3m4; } - /*! \return whether type is Float8E4M3. */ - bool is_float8_e4m3() const { return bits() == 8 && code() == DataType::kFloat8_e4m3; } - /*! \return whether type is Float8E4M3B11FNUZ. */ - bool is_float8_e4m3b11fnuz() const { - return bits() == 8 && code() == DataType::kFloat8_e4m3b11fnuz; - } - /*! \return whether type is Float8E4M3FN. */ - bool is_float8_e4m3fn() const { return bits() == 8 && code() == DataType::kFloat8_e4m3fn; } - /*! \return whether type is Float8E4M3FNUZ. */ - bool is_float8_e4m3fnuz() const { return bits() == 8 && code() == DataType::kFloat8_e4m3fnuz; } - /*! \return whether type is Float8E5M2. */ - bool is_float8_e5m2() const { return bits() == 8 && code() == DataType::kFloat8_e5m2; } - /*! \return whether type is Float8E5M2FNUZ. */ - bool is_float8_e5m2fnuz() const { return bits() == 8 && code() == DataType::kFloat8_e5m2fnuz; } - /*! \return whether type is Float8E8M0FNU. */ - bool is_float8_e8m0fnu() const { return bits() == 8 && code() == DataType::kFloat8_e8m0fnu; } - /*! \return whether type is Float6E2M3FN. */ - bool is_float6_e2m3fn() const { return bits() == 6 && code() == DataType::kFloat6_e2m3fn; } - /*! \return whether type is Float6E3M2FN. */ - bool is_float6_e3m2fn() const { return bits() == 6 && code() == DataType::kFloat6_e3m2fn; } - /*! \return whether type is Float4E2M1FN. */ - bool is_float4_e2m1fn() const { return bits() == 4 && code() == DataType::kFloat4_e2m1fn; } - /*! \return whether type is a float16 type. */ - bool is_float16() const { return is_float() && bits() == 16; } - /*! \return whether type is a bfloat16 type. */ - bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; } - /*! \return whether type is an int type. */ - bool is_int() const { return code() == DataType::kInt; } - /*! \return whether type is an uint type. */ - bool is_uint() const { return code() == DataType::kUInt; } - /*! \return whether type is a handle type. */ - bool is_handle() const { return code() == DataType::kHandle && !is_void(); } - /*! \return whether type is a vector type. */ - bool is_scalable_or_fixed_length_vector() const { - int encoded_lanes = static_cast(data_.lanes); - return (encoded_lanes < -1) || (1 < encoded_lanes); - } - /*! \return Whether the type is a fixed length vector. */ - bool is_fixed_length_vector() const { return static_cast(data_.lanes) > 1; } - /*! \return Whether the type is a scalable vector. */ - bool is_scalable_vector() const { return static_cast(data_.lanes) < -1; } - /*! \return whether type is a vector type. */ - bool is_vector() const { return lanes() > 1; } - /*! \return whether type is a bool vector type. */ - bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && is_bool(); } - /*! \return whether type is a Void type. */ - bool is_void() const { - return code() == DataType::kHandle && bits() == 0 && static_cast(data_.lanes) == 0; - } - /*! - * \brief Create a new data type by change lanes to a specified value. - * \param lanes The target number of lanes. - * \return the result type. - */ - DataType with_lanes(int lanes) const { return DataType(data_.code, data_.bits, lanes); } - /*! - * \brief Create a new scalable vector data type by changing the vscale multiplier to a specified - * value. We'll use the data_.lanes field for this value. \param vscale_factor The vscale - * multiplier. \return A copy of the old DataType with the number of scalable lanes. - */ - DataType with_scalable_vscale_factor(int vscale_factor) const { - return DataType(data_.code, data_.bits, -vscale_factor); - } - /*! - * \brief Create a new data type by change bits to a specified value. - * \param bits The target number of bits. - * \return the result type. - */ - DataType with_bits(int bits) const { return DataType(data_.code, bits, data_.lanes); } - /*! - * \brief Get the scalar version of the type. - * \return the result type. - */ - DataType element_of() const { return with_lanes(1); } - /*! - * \brief Assignment operator. - */ - DataType& operator=(const DataType& rhs) { - if (this == &rhs) { - return *this; - } - data_ = rhs.data_; - return *this; - } - /*! - * \brief Equal comparator. - * \param other The data type to compare against. - * \return The comparison result. - */ - bool operator==(const DataType& other) const { - return data_.code == other.data_.code && data_.bits == other.data_.bits && - data_.lanes == other.data_.lanes; - } - /*! - * \brief NotEqual comparator. - * \param other The data type to compare against. - * \return The comparison result. - */ - bool operator!=(const DataType& other) const { return !operator==(other); } - /*! - * \brief Converter to DLDataType - * \return the result. - */ - operator DLDataType() const { return data_; } - - /*! - * \brief Construct an int type. - * \param bits The number of bits in the type. - * \param lanes The number of lanes. - * \return The constructed data type. - */ - static DataType Int(int bits, int lanes = 1) { return DataType(kDLInt, bits, lanes); } - /*! - * \brief Construct an uint type. - * \param bits The number of bits in the type. - * \param lanes The number of lanes. - * \param is_scalable Whether the data type is scalable. - * \return The constructed data type. - */ - static DataType UInt(int bits, int lanes = 1, bool is_scalable = false) { - return DataType(kDLUInt, bits, lanes, is_scalable); - } - /*! - * \brief Construct an float type. - * \param bits The number of bits in the type. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); } - /*! - * \brief Construct an bfloat type. - * \param bits The number of bits in the type. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); } - /*! - * \brief Construct float8 e3m4 datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float8E3M4(int lanes = 1) { return DataType(kFloat8_e3m4, 8, lanes); } - - /*! - * \brief Construct float8 e4m3 datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float8E4M3(int lanes = 1) { return DataType(kFloat8_e4m3, 8, lanes); } - - /*! - * \brief Construct float8 e4m3b11fnuz datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float8E4M3B11FNUZ(int lanes = 1) { - return DataType(kFloat8_e4m3b11fnuz, 8, lanes); - } - - /*! - * \brief Construct float8 e4m3fn datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float8E4M3FN(int lanes = 1) { return DataType(kFloat8_e4m3fn, 8, lanes); } - - /*! - * \brief Construct float8 e4m3fnuz datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float8E4M3FNUZ(int lanes = 1) { return DataType(kFloat8_e4m3fnuz, 8, lanes); } - - /*! - * \brief Construct float8 e5m2 datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float8E5M2(int lanes = 1) { return DataType(kFloat8_e5m2, 8, lanes); } - - /*! - * \brief Construct float8 e5m2fnuz datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float8E5M2FNUZ(int lanes = 1) { return DataType(kFloat8_e5m2fnuz, 8, lanes); } - - /*! - * \brief Construct float8 e8m0fnu datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float8E8M0FNU(int lanes = 1) { return DataType(kFloat8_e8m0fnu, 8, lanes); } - - /*! - * \brief Construct float6 e2m3fn datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float6E2M3FN(int lanes = 1) { return DataType(kFloat6_e2m3fn, 6, lanes); } - - /*! - * \brief Construct float6 e3m2fn datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float6E3M2FN(int lanes = 1) { return DataType(kFloat6_e3m2fn, 6, lanes); } - - /*! - * \brief Construct float4 e2m1fn datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float4E2M1FN(int lanes = 1) { return DataType(kFloat4_e2m1fn, 4, lanes); } - /*! - * \brief Construct a bool type. - * \param lanes The number of lanes. - * \param is_scalable Whether the data type is scalable. - * \return The constructed data type. - */ - static DataType Bool(int lanes = 1, bool is_scalable = false) { - return DataType(kDLBool, 8, lanes, is_scalable); - } - /*! - * \brief Construct a handle type. - * \param bits The number of bits in the type. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Handle(int bits = 64, int lanes = 1) { return DataType(kHandle, bits, lanes); } - /*! - * \brief Construct a Void type. - * \return The constructed data type. - */ - static DataType Void() { return DataType(kHandle, 0, 0); } - /*! - * \brief Get the corresponding type of TVMShapeIndex. - * \return The type of TVM shape index. - */ - static DataType ShapeIndex() { - if (std::is_signed::value) { - return DataType::Int(sizeof(ffi::Shape::index_type) * 8); - } else { - return DataType::UInt(sizeof(ffi::Shape::index_type) * 8); - } - } - - private: - DLDataType data_; -}; - -/*! - * \brief Get the number of bytes needed in a vector. - * \param dtype The data type. - * \return Number of bytes needed. - */ -inline int GetVectorBytes(DataType dtype) { - int data_bits = dtype.bits() * dtype.lanes(); - // allow bool to exist - if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) || - dtype == DataType::Int(1) || dtype == DataType::Float4E2M1FN() || - dtype == DataType::Float6E2M3FN() || dtype == DataType::Float6E3M2FN()) { - return 1; - } - TVM_FFI_ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes"; - return data_bits / 8; -} - -/*! - * \brief Check whether type matches the given spec. - * \param t The type - * \param code The type code. - * \param bits The number of bits to be matched. - * \param lanes The number of lanes in the type. - */ -inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) { - return t.code == code && t.bits == bits && t.lanes == lanes; -} -/*! - * \brief Check whether two types are equal . - * \param lhs The left operand. - * \param rhs The right operand. - */ -inline bool TypeEqual(DLDataType lhs, DLDataType rhs) { - return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes; -} - -inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) - return os << dtype.operator DLDataType(); -} -} // namespace runtime - -using DataType = runtime::DataType; - -namespace ffi { - -// runtime::DataType -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDataType; - - TVM_FFI_INLINE static void CopyToAnyView(const runtime::DataType& src, TVMFFIAny* result) { - // clear padding part to ensure the equality check can always check the v_uint64 part - result->v_uint64 = 0; - result->zero_padding = 0; - result->type_index = TypeIndex::kTVMFFIDataType; - result->v_dtype = src; - } - - TVM_FFI_INLINE static void MoveToAny(runtime::DataType src, TVMFFIAny* result) { - // clear padding part to ensure the equality check can always check the v_uint64 part - result->v_uint64 = 0; - result->zero_padding = 0; - result->type_index = TypeIndex::kTVMFFIDataType; - result->v_dtype = src; - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - auto opt_dtype = TypeTraits::TryCastFromAnyView(src); - if (opt_dtype) { - return runtime::DataType(opt_dtype.value()); - } - return std::nullopt; - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return TypeTraits::CheckAnyStrict(src); - } - - TVM_FFI_INLINE static runtime::DataType CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return runtime::DataType(TypeTraits::CopyFromAnyViewAfterCheck(src)); - } - - TVM_FFI_INLINE static std::string TypeStr() { return ffi::StaticTypeKey::kTVMFFIDataType; } - - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":")" + std::string(ffi::StaticTypeKey::kTVMFFIDataType) + R"("})"; - } -}; - -} // namespace ffi -} // namespace tvm - -namespace std { -template <> -struct hash { - inline int cantor_pairing_function(int a, int b) const { return (a + b) * (a + b + 1) / 2 + b; } - std::size_t operator()(tvm::DataType const& dtype) const { - int a = dtype.code(); - int b = dtype.bits(); - int c = dtype.lanes(); - int d = cantor_pairing_function(a, b); - return cantor_pairing_function(c, d); - } -}; -} // namespace std - -#endif // TVM_RUNTIME_DATA_TYPE_H_ diff --git a/include/tvm/runtime/disco/builtin.h b/include/tvm/runtime/disco/builtin.h index a9487c866acc..9d66a09507c5 100644 --- a/include/tvm/runtime/disco/builtin.h +++ b/include/tvm/runtime/disco/builtin.h @@ -19,8 +19,8 @@ #ifndef TVM_RUNTIME_DISCO_BUILTIN_H_ #define TVM_RUNTIME_DISCO_BUILTIN_H_ +#include #include -#include #include #include @@ -70,7 +70,7 @@ TVM_RUNTIME_DLL ffi::Module LoadVMModule(std::string path, ffi::Optional * \param device The device the Tensor is created on. If None, use the thread local default device * \return The Tensor created */ -TVM_RUNTIME_DLL Tensor DiscoEmptyTensor(ffi::Shape shape, DataType dtype, +TVM_RUNTIME_DLL Tensor DiscoEmptyTensor(ffi::Shape shape, DLDataType dtype, ffi::Optional device); /*! * \brief Perform an allreduce operation using the underlying communication library diff --git a/include/tvm/runtime/tensor.h b/include/tvm/runtime/tensor.h index d3497c8ff78f..cb93c4abd741 100644 --- a/include/tvm/runtime/tensor.h +++ b/include/tvm/runtime/tensor.h @@ -26,10 +26,10 @@ #include #include +#include #include #include #include -#include #include #include #include @@ -59,7 +59,7 @@ class Tensor : public tvm::ffi::Tensor { Tensor(const ffi::Tensor& other) : tvm::ffi::Tensor(other) {} // NOLINT(*) ffi::ShapeView Shape() const { return this->shape(); } - runtime::DataType DataType() const { return runtime::DataType(this->dtype()); } + DLDataType DataType() const { return this->dtype(); } // DLPack handling static Tensor FromDLPack(DLManagedTensor* tensor) { diff --git a/include/tvm/runtime/vm/bytecode.h b/include/tvm/runtime/vm/bytecode.h index 0f1927e0cbcb..ea246da5d354 100644 --- a/include/tvm/runtime/vm/bytecode.h +++ b/include/tvm/runtime/vm/bytecode.h @@ -24,8 +24,8 @@ #ifndef TVM_RUNTIME_VM_BYTECODE_H_ #define TVM_RUNTIME_VM_BYTECODE_H_ +#include #include -#include #include #include diff --git a/include/tvm/runtime/vm/tensor_cache_support.h b/include/tvm/runtime/vm/tensor_cache_support.h index ea997f0755bd..b112043c376f 100644 --- a/include/tvm/runtime/vm/tensor_cache_support.h +++ b/include/tvm/runtime/vm/tensor_cache_support.h @@ -54,7 +54,7 @@ struct TensorCacheMetadata { /*! \brief Shape of the parameter */ ffi::Shape shape; /*! \brief Data type of the parameter */ - DataType dtype; + DLDataType dtype; /*! \brief Format of the parameter */ std::string format; /*! \brief Number of bytes */ diff --git a/include/tvm/s_tir/data_layout.h b/include/tvm/s_tir/data_layout.h index 48836c5a53d5..ee6d51832dba 100644 --- a/include/tvm/s_tir/data_layout.h +++ b/include/tvm/s_tir/data_layout.h @@ -140,10 +140,10 @@ class SLayout : public ffi::ObjectRef { * the corresponding lower case with factor size * indicates the split dimension. * return undefined layout if "__undef__" is passed. - * \param dtype The dtype of generated axes vars in the returned layout. + * \param index_ty The type of generated axes vars in the returned layout. * It is required to be integer type. */ - TVM_DLL SLayout(const std::string& name, DataType dtype = DataType::Int(32)); // NOLINT(*) + TVM_DLL SLayout(const std::string& name, PrimType index_ty = PrimType::Int(32)); // NOLINT(*) /*! * \brief access the internal node container diff --git a/include/tvm/s_tir/meta_schedule/arg_info.h b/include/tvm/s_tir/meta_schedule/arg_info.h index 463e73b0e246..a346a73dd441 100644 --- a/include/tvm/s_tir/meta_schedule/arg_info.h +++ b/include/tvm/s_tir/meta_schedule/arg_info.h @@ -20,9 +20,9 @@ #define TVM_S_TIR_META_SCHEDULE_ARG_INFO_H_ #include +#include #include #include -#include #include namespace tvm { @@ -77,7 +77,7 @@ class ArgInfo : public ffi::ObjectRef { class TensorInfoNode : public ArgInfoNode { public: /*! \brief The data type of the tensor. */ - runtime::DataType dtype; + DLDataType dtype; /*! \brief The shape of the tensor. */ ffi::Shape shape; @@ -104,7 +104,7 @@ class TensorInfo : public ArgInfo { * \param dtype The data type of the tensor argument. * \param shape The shape tuple of the tensor argument. */ - TVM_DLL explicit TensorInfo(runtime::DataType dtype, ffi::Shape shape); + TVM_DLL explicit TensorInfo(DLDataType dtype, ffi::Shape shape); /*! * \brief Parse the argument information from a JSON object. * \param json_obj The json object to parse. diff --git a/include/tvm/script/printer/config.h b/include/tvm/script/printer/config.h index beea4042470c..e0ed32d38094 100644 --- a/include/tvm/script/printer/config.h +++ b/include/tvm/script/printer/config.h @@ -30,10 +30,11 @@ #include #include #include +#include #include #include #include -#include +#include #include @@ -53,15 +54,15 @@ class PrinterConfigNode : public ffi::Object { */ ffi::String module_alias = "cls"; /*! \brief Default buffer dtype */ - DataType buffer_dtype = DataType::Float(32); + DLDataType buffer_dtype = DLDataType{kDLFloat, 32, 1}; /*! \brief Default data type of integer literals */ - DataType int_dtype = DataType::Int(32); + DLDataType int_dtype = DLDataType{kDLInt, 32, 1}; /*! * \brief Default data type of float literals. Right now we always print out the explicit type * of floating point values, so setting it to Void means we do not print without the * T.float32/T.float64 wrapper. */ - DataType float_dtype = DataType::Void(); + DLDataType float_dtype = DLDataType{kDLOpaqueHandle, 0, 0}; /*! \brief Whether or not to verbose print expressions. */ bool verbose_expr = false; /*! \brief Number of spaces used for indentation*/ diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 2389c1b50d15..bc90e5365734 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -19,10 +19,11 @@ #ifndef TVM_SCRIPT_PRINTER_DOC_H_ #define TVM_SCRIPT_PRINTER_DOC_H_ +#include #include #include #include -#include +#include #include #include @@ -293,7 +294,7 @@ class LiteralDoc : public ExprDoc { * \param p The object path */ static LiteralDoc Float(double v, const ffi::Optional& p) { - return LiteralDoc(FloatImm(DataType::Float(64), v), p); + return LiteralDoc(FloatImm(PrimType::Float(64), v), p); } /*! * \brief Create a LiteralDoc to represent string. @@ -308,8 +309,9 @@ class LiteralDoc : public ExprDoc { * \param v The string value. * \param p The object path */ - static LiteralDoc DataType(const runtime::DataType& v, const ffi::Optional& p) { - std::string dtype = v.is_void() ? "void" : ffi::DLDataTypeToString(v); + static LiteralDoc DataType(DLDataType v, const ffi::Optional& p) { + std::string dtype = + v == DLDataType{kDLOpaqueHandle, 0, 0} ? "void" : ffi::DLDataTypeToString(v); return LiteralDoc::Str(dtype, p); } /*! diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index 98249c6f30bd..e9c82265ff27 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -333,7 +333,7 @@ inline TDoc IRDocsifierNode::AsDoc(const Any& value, const AccessPath& path) con return LiteralDoc::Str(string_value, path).as_or_throw(); } case ffi::TypeIndex::kTVMFFIDataType: - return LiteralDoc::DataType(value.as().value(), path).as_or_throw(); + return LiteralDoc::DataType(value.as().value(), path).as_or_throw(); case ffi::TypeIndex::kTVMFFIDevice: return LiteralDoc::Device(value.as().value(), path).as_or_throw(); default: { diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index c9d35a77fe99..ba5267a8ce85 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -34,6 +34,7 @@ #include #include +#include #include namespace tvm { @@ -67,11 +68,11 @@ class TVM_DLL OperationNode : public ffi::Object { /*! \return number of outputs */ virtual int num_outputs() const = 0; /*! - * \brief Get data type. i-th output tensor. + * \brief Get the primitive element type of the i-th output tensor. * \param i The output index. - * \return type of i-th output. + * \return primitive element type of i-th output. */ - virtual DataType output_dtype(size_t i) const = 0; + virtual PrimType output_dtype(size_t i) const = 0; /*! * \brief Get shape of i-th output tensor. * \param i The output index. @@ -101,11 +102,11 @@ class PlaceholderOpNode : public OperationNode { public: /*! \brief The shape of the input */ ffi::Array shape; - /*! \brief The data type of the input. */ - DataType dtype; + /*! \brief The dtype of the input. */ + PrimType dtype{DLDataType{kDLOpaqueHandle, 0, 0}}; // override behavior. int num_outputs() const final; - DataType output_dtype(size_t i) const final; + PrimType output_dtype(size_t i) const final; ffi::Array output_shape(size_t i) const final; ffi::Array InputTensors() const final; @@ -124,7 +125,9 @@ class PlaceholderOpNode : public OperationNode { */ class PlaceholderOp : public Operation { public: - TVM_DLL PlaceholderOp(std::string name, ffi::Array shape, DataType dtype); + TVM_DLL PlaceholderOp(std::string name, ffi::Array shape, PrimType dtype); + PlaceholderOp(std::string name, ffi::Array shape, DLDataType dtype) + : PlaceholderOp(std::move(name), std::move(shape), PrimType(dtype)) {} TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PlaceholderOp, Operation, PlaceholderOpNode); }; @@ -162,7 +165,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { ComputeOpNode() {} // override functions int num_outputs() const final; - DataType output_dtype(size_t i) const final; + PrimType output_dtype(size_t i) const final; ffi::Array InputTensors() const final; static void RegisterReflection() { @@ -217,7 +220,7 @@ class ScanOpNode : public OperationNode { ScanOpNode() {} // override behavior. int num_outputs() const final; - DataType output_dtype(size_t i) const final; + PrimType output_dtype(size_t i) const final; ffi::Array output_shape(size_t i) const final; ffi::Array InputTensors() const final; @@ -266,7 +269,7 @@ class ExternOpNode : public OperationNode { ExternOpNode() {} // override functions int num_outputs() const final; - DataType output_dtype(size_t i) const final; + PrimType output_dtype(size_t i) const final; ffi::Array output_shape(size_t i) const final; ffi::Array InputTensors() const final; @@ -299,7 +302,7 @@ class ExternOp : public Operation { * \param name_hint The name hint for the expression * \param t The type of the expression */ -TVM_DLL Var var(std::string name_hint, DataType t = DataType::Int(32)); +TVM_DLL Var var(std::string name_hint, PrimType t = PrimType::Int(32)); /*! * \brief Create a new IterVar that represents an axis in thread. @@ -329,9 +332,14 @@ using FBatchCompute = std::function(const ffi::Array& * \param dtype the data type of the tensor. * \param name The name of the Tensor. */ -TVM_DLL Tensor placeholder(ffi::Array shape, DataType dtype = DataType::Float(32), +TVM_DLL Tensor placeholder(ffi::Array shape, PrimType dtype = PrimType::Float(32), std::string name = "placeholder"); +inline Tensor placeholder(ffi::Array shape, DLDataType dtype, + std::string name = "placeholder") { + return placeholder(std::move(shape), PrimType(dtype), std::move(name)); +} + /*! * \brief Construct a new tensor by computing over shape, * using the computation rule: result_tensor[axis] = fcompute(axis) diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index ed07a35fb2da..760d308623f8 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -71,8 +71,8 @@ class TensorNode : public DataProducerNode { public: /*! \brief The shape of the tensor */ ffi::Array shape; - /*! \brief data type in the content of the tensor */ - DataType dtype; + /*! \brief dtype in the content of the tensor */ + PrimType dtype{DLDataType{kDLOpaqueHandle, 0, 0}}; /*! \brief the source operation, can be None */ Operation op; /*! \brief the output index from source operation */ @@ -82,7 +82,7 @@ class TensorNode : public DataProducerNode { ffi::Array GetShape() const final { return shape; } - DataType GetDataType() const final { return dtype; } + PrimType GetDataType() const final { return dtype; } TVM_DLL PrimExpr ToPrimExpr() const final; @@ -108,7 +108,9 @@ class Tensor : public DataProducer { inline PrimExpr IndexTensor(ffi::Array indices, bool support_negative_indices) const; public: - TVM_DLL Tensor(ffi::Array shape, DataType dtype, Operation op, int value_index); + TVM_DLL Tensor(ffi::Array shape, PrimType dtype, Operation op, int value_index); + Tensor(ffi::Array shape, DLDataType dtype, Operation op, int value_index) + : Tensor(std::move(shape), PrimType(dtype), std::move(op), value_index) {} /*! * \brief check if two tensors equals each other. * \param other tensor to be checked. diff --git a/include/tvm/tirx/buffer.h b/include/tvm/tirx/buffer.h index 1456787d688b..71d4c974dbb8 100644 --- a/include/tvm/tirx/buffer.h +++ b/include/tvm/tirx/buffer.h @@ -40,11 +40,20 @@ namespace tirx { #define TVM_INDEX_DEFAULT_I64 1 #endif /*! \brief if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32 */ -inline DataType DefaultIndexType() { +inline PrimType DefaultIndexPrimType() { #if TVM_INDEX_DEFAULT_I64 - return DataType::Int(64); + static const PrimType default_index_ty = PrimType::Int(64); #else - return DataType::Int(32); + static const PrimType default_index_ty = PrimType::Int(32); +#endif + return default_index_ty; +} + +inline DLDataType DefaultIndexType() { +#if TVM_INDEX_DEFAULT_I64 + return DLDataType{kDLInt, 64, 1}; +#else + return DLDataType{kDLInt, 32, 1}; #endif } @@ -67,8 +76,8 @@ class BufferNode : public ffi::Object { * \sa data_alignment The alignment of data in bytes. */ Var data; - /*! \brief data type in the content of the tensor */ - DataType dtype; + /*! \brief dtype in the content of the tensor */ + PrimType dtype{DLDataType{kDLOpaqueHandle, 0, 0}}; /*! \brief The type of the buffer prior to flattening * * This contains the shape as it is accessed by @@ -147,10 +156,13 @@ class BufferNode : public ffi::Object { } /*! \return preferred index type for this buffer node */ - DataType DefaultIndexType() const { - return shape.size() != 0 ? shape[0].dtype() : tvm::tirx::DefaultIndexType(); + DLDataType DefaultIndexType() const { + return shape.size() != 0 ? shape[0].ty()->dtype : tvm::tirx::DefaultIndexType(); } + /*! \return primitive element type for compiler-side uses. */ + PrimType ElementType() const { return dtype; } + /*! \brief Determine the offset in the buffer of the given index. * * Returns the buffer offset, in number of elements of type dtype, @@ -176,11 +188,19 @@ class Buffer : public ffi::ObjectRef { public: // User can specify data_alignment and offset_factor to be 0 // A default value will be picked. - TVM_DLL Buffer(Var data, DataType dtype, ffi::Array shape, ffi::Array strides, + TVM_DLL Buffer(Var data, PrimType dtype, ffi::Array shape, ffi::Array strides, PrimExpr elem_offset, ffi::String name, int data_alignment, int offset_factor, BufferType buffer_type, ffi::Array axis_separators = {}, Span span = Span(), ffi::Optional layout = std::nullopt, ffi::Array allocated_addr = {}); + Buffer(Var data, DLDataType dtype, ffi::Array shape, ffi::Array strides, + PrimExpr elem_offset, ffi::String name, int data_alignment, int offset_factor, + BufferType buffer_type, ffi::Array axis_separators = {}, Span span = Span(), + ffi::Optional layout = std::nullopt, ffi::Array allocated_addr = {}) + : Buffer(std::move(data), PrimType(dtype), std::move(shape), std::move(strides), + std::move(elem_offset), std::move(name), data_alignment, offset_factor, buffer_type, + std::move(axis_separators), std::move(span), std::move(layout), + std::move(allocated_addr)) {} /*! * \brief Return a new buffer that is equivalent with current one @@ -205,7 +225,7 @@ class Buffer : public ffi::ObjectRef { * \param offset The offset of ptr. * \param input_extent The extent of ptr. */ - TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(), + TVM_DLL PrimExpr access_ptr(int access_mask, PrimType ptr_type = PrimType::Handle(), int content_lanes = 1, PrimExpr offset = IntImm::Int32(0), ffi::Optional input_extent = std::nullopt) const; /*! @@ -215,7 +235,7 @@ class Buffer : public ffi::ObjectRef { * \param predicate A vector mask of boolean values indicating which lanes of a vector are to be * loaded. The number lanes of the mask must be equal to the number of lanes in being loaded. */ - TVM_DLL PrimExpr vload(ffi::Array begin, DataType dtype, + TVM_DLL PrimExpr vload(ffi::Array begin, PrimType dtype, ffi::Optional predicate = std::nullopt) const; /*! * \brief Create a Stmt that does a vector store at begin index. @@ -267,7 +287,11 @@ class Buffer : public ffi::ObjectRef { /*! * \brief Return a new buffer with the dtype. */ - TVM_DLL Buffer with_dtype(DataType dtype) const; + TVM_DLL Buffer with_dtype(PrimType dtype) const; + Buffer with_dtype(DLDataType dtype) const { return with_dtype(PrimType(dtype)); } + + /*! \return primitive element type for compiler-side uses. */ + PrimType ElementType() const { return (*this)->ElementType(); } /*! * \brief Return a new buffer with the data. @@ -289,11 +313,20 @@ class Buffer : public ffi::ObjectRef { * \return The created buffer. * \sa Buffer for complete constructor. */ -TVM_DLL Buffer decl_buffer(ffi::Array shape, DataType dtype = DataType::Float(32), +TVM_DLL Buffer decl_buffer(ffi::Array shape, + DLDataType dtype = DLDataType{kDLFloat, 32, 1}, ffi::String name = "buffer", ffi::String storage_scope = "", ffi::Optional> axis_separators = std::nullopt, Span span = Span()); +inline Buffer decl_buffer(ffi::Array shape, PrimType dtype, ffi::String name = "buffer", + ffi::String storage_scope = "", + ffi::Optional> axis_separators = std::nullopt, + Span span = Span()) { + return decl_buffer(std::move(shape), dtype->dtype, std::move(name), std::move(storage_scope), + std::move(axis_separators), std::move(span)); +} + /*! * \brief Base node for data producers. * @@ -316,10 +349,10 @@ class DataProducerNode : public PrimExprConvertibleNode { */ virtual ffi::Array GetShape() const = 0; /*! - * \brief Get the data type of the result. - * \return The data type. + * \brief Get the raw element dtype of the result. + * \return The raw dtype. */ - virtual DataType GetDataType() const = 0; + virtual PrimType GetDataType() const = 0; /*! * \brief Get the name hint of the data producer. * \return The data type. @@ -350,10 +383,18 @@ class DataProducer : public PrimExprConvertible { * \param compact If the statement has already bound to a compact buffer. * \param memory_scope memory scope of the buffer */ -TVM_DLL tirx::Buffer BufferWithOffsetAlignment(ffi::Array shape, DataType dtype, +TVM_DLL tirx::Buffer BufferWithOffsetAlignment(ffi::Array shape, DLDataType dtype, std::string name, int data_alignment, int offset_factor, bool compact, std::string memory_scope = ""); + +inline tirx::Buffer BufferWithOffsetAlignment(ffi::Array shape, PrimType dtype, + std::string name, int data_alignment, + int offset_factor, bool compact, + std::string memory_scope = "") { + return BufferWithOffsetAlignment(std::move(shape), dtype->dtype, std::move(name), data_alignment, + offset_factor, compact, std::move(memory_scope)); +} } // namespace tirx } // namespace tvm #endif // TVM_TIR_BUFFER_H_ diff --git a/include/tvm/tirx/expr.h b/include/tvm/tirx/expr.h index cd51108b0d23..bf4c9004e84d 100644 --- a/include/tvm/tirx/expr.h +++ b/include/tvm/tirx/expr.h @@ -27,13 +27,13 @@ #include #include +#include #include #include #include #include #include #include -#include #include #include @@ -96,7 +96,7 @@ class CastNode : public PrimExprNode { */ class Cast : public PrimExpr { public: - TVM_DLL Cast(DataType dtype, PrimExpr value, Span span = Span()); + TVM_DLL Cast(PrimType value_ty, PrimExpr value, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Cast, PrimExpr, CastNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode); }; @@ -752,9 +752,9 @@ class CallNode : public PrimExprNode { */ class Call : public PrimExpr { public: - TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array args, Attrs attrs = Attrs(), + TVM_DLL Call(PrimType ret_ty, RelaxExpr op, ffi::Array args, Attrs attrs = Attrs(), Span span = Span()); - TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array args, Span span); + TVM_DLL Call(PrimType ret_ty, RelaxExpr op, ffi::Array args, Span span); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, PrimExpr, CallNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); }; diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h index 416aff73ee29..be827b9ef534 100644 --- a/include/tvm/tirx/op.h +++ b/include/tvm/tirx/op.h @@ -39,6 +39,7 @@ #include #include #include +#include namespace tvm { @@ -58,34 +59,36 @@ namespace tvm { /*! * \brief Get the type of the expression under the unified type system. * - * This function could return a more refined type than - * the runtime type provided by expr->dtype + * This function could return a more refined type than the runtime dtype + * implied by PrimExpr::ty(). * * \param expr The input parameter. * \return The result type. * - * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType. + * \sa tvm/ir/type.h for discussion about the relation between Type and DLPack dtype. */ TVM_DLL Type GetType(const PrimExpr& expr); /*! - * \brief Get the type corresponding to DataType - * \param dtype The data type + * \brief Get the type corresponding to a runtime DLPack dtype. + * \param dtype The runtime dtype. * \return The result type * - * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType. + * \sa tvm/ir/type.h for discussion about the relation between Type and DLPack dtype. */ -TVM_DLL Type GetTypeFromRuntimeDataType(const DataType& dtype); +TVM_DLL Type GetTypeFromRuntimeDataType(DLDataType dtype); /*! - * \brief Get the implied DataType for storing values with type during runtime. + * \brief Get the implied DLPack dtype for storing values with type during runtime. * * \param type The input type. - * \return The result runtime::DataType. + * \return The result DLPack dtype. * - * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType. + * \sa tvm/ir/type.h for discussion about the relation between Type and DLPack dtype. */ -TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type); +TVM_DLL DLDataType GetRuntimeDLDataType(const Type& type); + +inline DLDataType GetRuntimeDataType(const Type& type) { return GetRuntimeDLDataType(type); } /*! * \brief Return the value. @@ -120,27 +123,27 @@ TVM_DLL PrimExpr break_loop(Span span = Span()); /*! * Query the maximum possible value of dtype. - * \param dtype The data type. + * \param dtype The primitive type. * \param span The location of this operation in the source. * \return the maximum possible value in this format. */ -TVM_DLL PrimExpr max_value(const DataType& dtype, Span span = Span()); +TVM_DLL PrimExpr max_value(PrimType dtype, Span span = Span()); /*! * Query the minimum possible value of dtype. - * \param dtype The data type. + * \param dtype The primitive type. * \param span The location of this operation in the source. * \return the minimum possible value in this format. */ -TVM_DLL PrimExpr min_value(const DataType& dtype, Span span = Span()); +TVM_DLL PrimExpr min_value(PrimType dtype, Span span = Span()); /*! * Get the value of infinity. - * \param dtype The data type. + * \param dtype The primitive type. * \param span The location of this operation in the source. * \return the infinity value in this format. */ -TVM_DLL PrimExpr infinity(const DataType& dtype, Span span = Span()); +TVM_DLL PrimExpr infinity(PrimType dtype, Span span = Span()); /*! * \brief cast value to type. @@ -151,7 +154,7 @@ TVM_DLL PrimExpr infinity(const DataType& dtype, Span span = Span()); * \return The result expression. * \note This function may return value if the type is the same. */ -TVM_DLL PrimExpr cast(const DataType& t, PrimExpr value, Span span = Span()); +TVM_DLL PrimExpr cast(PrimType t, PrimExpr value, Span span = Span()); /*! * \brief perform reinterpret cast value to type. * @@ -161,7 +164,7 @@ TVM_DLL PrimExpr cast(const DataType& t, PrimExpr value, Span span = Span()); * \return The result expression. * \note This function may return value if the type is the same. */ -TVM_DLL PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span = Span()); +TVM_DLL PrimExpr reinterpret(PrimType t, PrimExpr value, Span span = Span()); /*! * \brief add operator * @@ -691,13 +694,13 @@ TVM_DLL PrimExpr trunc(PrimExpr x, Span span = Span()); /*! * \brief Construct a large uint constant by its low 32 bits and high 32bits. - * \param dtype The final data type. + * \param value_ty The final primitive type. * \param low The lower 32 bits. * \param high The higher 32 bits. * \param span The location of this operation in the source. * \return The constructed expression. */ -TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span span = Span()); +TVM_DLL PrimExpr LargeUIntImm(PrimType value_ty, int64_t low, int64_t high, Span span = Span()); /*! * \brief Execute a multiplication between two Q-numbers x and y @@ -731,29 +734,35 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s */ TVM_DLL PrimExpr fast_erf_float_expr(PrimExpr arg, int bits); -inline void CheckMathUnaryOpInputDType(const char* op_name, DataType dtype) { - TVM_FFI_CHECK(dtype.is_float() || dtype.is_bfloat16(), TypeError) +inline void CheckMathUnaryOpInputDType(const char* op_name, const PrimType& dtype) { + TVM_FFI_CHECK(dtype.code() == DLDataTypeCode::kDLFloat || + dtype.MatchesElementType(DLDataTypeCode::kDLBfloat, 16), + TypeError) << "tirx." << op_name << " only supports floating-point inputs, but got " << dtype; } // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckInputDType) \ - inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ - static const Op op = Op::Get("tirx." #OpName); \ - CheckInputDType(#OpName, x.dtype()); \ - if (x.dtype().is_bfloat16()) { \ - DataType bf16_dtype = x.dtype(); \ - DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \ - PrimExpr x_fp32 = tirx::Cast(fp32_dtype, {x}, span); \ - PrimExpr result_fp32 = tirx::Call(fp32_dtype, op, {x_fp32}, {}, span); \ - return tirx::Cast(bf16_dtype, {result_fp32}, span); \ - } else { \ - return tirx::Call(x.dtype(), op, {x}, {}, span); \ - } \ +#define TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckInputDType) \ + inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ + static const Op op = Op::Get("tirx." #OpName); \ + PrimType x_ty = x.ty(); \ + CheckInputDType(#OpName, x_ty); \ + if (x_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { \ + PrimType bf16_ty = x_ty; \ + PrimType f32_ty = \ + x_ty.IsScalableVector() \ + ? PrimType::ScalableVector(DLDataTypeCode::kDLFloat, 32, x_ty.VScaleFactor()) \ + : PrimType::Float(32, x_ty.lanes()); \ + PrimExpr x_fp32 = tirx::Cast(f32_ty, x, span); \ + PrimExpr result_fp32 = tirx::Call(f32_ty, op, {x_fp32}, {}, span); \ + return tirx::Cast(bf16_ty, result_fp32, span); \ + } else { \ + return tirx::Call(x_ty, op, {x}, {}, span); \ + } \ } #define TVM_DECLARE_INTRIN_UNARY(OpName) \ - TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, [](const char*, DataType) {}) + TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, [](const char*, const PrimType&) {}) #define TVM_DECLARE_FLOAT_INTRIN_UNARY(OpName) \ TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckMathUnaryOpInputDType) @@ -787,7 +796,7 @@ TVM_DECLARE_INTRIN_UNARY(clz); #define TVM_DECLARE_INTRIN_BINARY(OpName) \ inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \ static const Op op = Op::Get("tirx." #OpName); \ - return tirx::Call(x.dtype(), op, {x, y}, {}, span); \ + return tirx::Call(x.ty(), op, {x, y}, {}, span); \ } TVM_DECLARE_INTRIN_BINARY(atan2); @@ -804,7 +813,7 @@ namespace tirx { * \param element_type The corresponding element type. * \return The check results */ -inline bool IsPointerType(const Type& type, const DataType& element_type) { +inline bool IsPointerType(const Type& type, DLDataType element_type) { if (!type.defined()) return false; if (const auto* ptr_type = type.as()) { if (const auto* prim_type = ptr_type->element_type.as()) { @@ -832,7 +841,7 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) { template ::value && std::is_trivial::value>::type> -inline PrimExpr MakeConst(DataType dtype, ValueType value, Span span = Span()); +inline PrimExpr MakeConst(PrimType dtype, ValueType value, Span span = Span()); /*! * \brief Make a constant handle value. * \param value The integer payload to reinterpret as a handle. @@ -970,9 +979,12 @@ inline bool is_no_op(const tirx::Stmt& stmt) { } template -inline PrimExpr MakeConstScalar(DataType dtype, ValueType value, Span span = Span()) { - if (dtype.is_int() || dtype.is_bool()) return IntImm(dtype, static_cast(value), span); - if (dtype.is_uint()) { +inline PrimExpr MakeConstScalar(PrimType dtype, ValueType value, Span span = Span()) { + DLDataTypeCode code = dtype.code(); + if (code == DLDataTypeCode::kDLInt || code == DLDataTypeCode::kDLBool) { + return IntImm(dtype, static_cast(value), span); + } + if (code == DLDataTypeCode::kDLUInt) { // Use IntImm if it is a small integer uint64_t uval = static_cast(value); if (value < static_cast(0)) { @@ -986,8 +998,13 @@ inline PrimExpr MakeConstScalar(DataType dtype, ValueType value, Span span = Spa return LargeUIntImm(dtype, static_cast(low), static_cast(high), span); } } - if (dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float6() || - dtype.is_float4()) { + if (dtype.MatchesCode(DLDataTypeCode::kDLFloat, DLDataTypeCode::kDLFloat8_e3m4, + DLDataTypeCode::kDLFloat8_e4m3, DLDataTypeCode::kDLFloat8_e4m3b11fnuz, + DLDataTypeCode::kDLFloat8_e4m3fn, DLDataTypeCode::kDLFloat8_e4m3fnuz, + DLDataTypeCode::kDLFloat8_e5m2, DLDataTypeCode::kDLFloat8_e5m2fnuz, + DLDataTypeCode::kDLFloat8_e8m0fnu, DLDataTypeCode::kDLFloat6_e2m3fn, + DLDataTypeCode::kDLFloat6_e3m2fn, DLDataTypeCode::kDLFloat4_e2m1fn) || + dtype.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { return FloatImm(dtype, static_cast(value), span); } TVM_FFI_THROW(InternalError) << "cannot make const for type " << dtype; @@ -995,27 +1012,26 @@ inline PrimExpr MakeConstScalar(DataType dtype, ValueType value, Span span = Spa } template <> -inline PrimExpr MakeConstScalar(DataType dtype, bool value, Span span) { +inline PrimExpr MakeConstScalar(PrimType dtype, bool value, Span span) { return MakeConstScalar(dtype, static_cast(value), span); } template -inline PrimExpr MakeConst(DataType dtype, ValueType value, Span span) { - if (dtype.is_scalar()) { +inline PrimExpr MakeConst(PrimType dtype, ValueType value, Span span) { + if (!dtype.IsScalableVector() && !dtype.IsFixedLengthVector()) { return MakeConstScalar(dtype, value, span); - } else { - if (dtype.is_fixed_length_vector()) { - return tirx::Broadcast(MakeConstScalar(dtype.element_of(), value, span), dtype.lanes(), span); - } else { - PrimExpr lanes = tirx::Mul(tirx::Call(DataType::Int(32), tirx::builtin::vscale(), {}), - dtype.vscale_factor()); - return tirx::Broadcast(MakeConstScalar(dtype.element_of(), value, span), lanes, span); - } } + PrimType elem_ty = dtype.WithLanes(1); + if (dtype.IsFixedLengthVector()) { + return tirx::Broadcast(MakeConstScalar(elem_ty, value, span), dtype.lanes(), span); + } + PrimExpr lanes = + tirx::Mul(tirx::Call(PrimType::Int(32), tirx::builtin::vscale(), {}), dtype.VScaleFactor()); + return tirx::Broadcast(MakeConstScalar(elem_ty, value, span), lanes, span); } inline PrimExpr ConstHandle(int64_t value, Span span) { - return reinterpret(DataType::Handle(), IntImm(DataType::UInt(64), value, span)); + return reinterpret(PrimType::Handle(), IntImm(PrimType::UInt(64), value, span)); } } // namespace tirx @@ -1027,17 +1043,13 @@ inline PrimExpr ConstHandle(int64_t value, Span span) { return a; \ } -#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \ - inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \ - inline PrimExpr Name(int a, const PrimExpr& b) { \ - return Name(tirx::MakeConst(b.dtype(), a), b); \ - } \ - inline PrimExpr Name(const PrimExpr& a, int b) { \ - return Name(a, tirx::MakeConst(a.dtype(), b)); \ - } \ - inline PrimExpr Name(const PrimExpr& a, double b) { \ - return Name(a, FloatImm(DataType::Float(64), b)); \ +#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \ + inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \ + inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tirx::MakeConst(b.ty(), a), b); } \ + inline PrimExpr Name(const PrimExpr& a, int b) { return Name(a, tirx::MakeConst(a.ty(), b)); } \ + inline PrimExpr Name(const PrimExpr& a, double b) { \ + return Name(a, FloatImm(PrimType::Float(64), b)); \ } #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name) \ @@ -1048,13 +1060,13 @@ inline PrimExpr ConstHandle(int64_t value, Span span) { return Name(PrimExpr(a), b, span); \ } \ inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \ - return Name(tirx::MakeConst(b.dtype(), a), b, span); \ + return Name(tirx::MakeConst(b.ty(), a), b, span); \ } \ inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \ - return Name(a, tirx::MakeConst(a.dtype(), b), span); \ + return Name(a, tirx::MakeConst(a.ty(), b), span); \ } \ inline PrimExpr Name(const PrimExpr& a, double b, Span span = Span()) { \ - return Name(a, FloatImm(DataType::Float(64), b), span); \ + return Name(a, FloatImm(PrimType::Float(64), b), span); \ } #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ @@ -1069,18 +1081,16 @@ inline PrimExpr ConstHandle(int64_t value, Span span) { return Name(PrimExpr(a), b, span); \ } -#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, int b) { \ - return Name(a, tirx::MakeConst(a.dtype(), b)); \ - } \ - inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tirx::MakeConst(b.dtype(), a), b); } +#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, int b) { return Name(a, tirx::MakeConst(a.ty(), b)); } \ + inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tirx::MakeConst(b.ty(), a), b); } #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \ inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \ - return Name(a, tirx::MakeConst(a.dtype(), b), span); \ + return Name(a, tirx::MakeConst(a.ty(), b), span); \ } \ inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \ - return Name(tirx::MakeConst(b.dtype(), a), b, span); \ + return Name(tirx::MakeConst(b.ty(), a), b, span); \ } TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+); diff --git a/include/tvm/tirx/script/builder/ir.h b/include/tvm/tirx/script/builder/ir.h index ad18d7ac4001..684653134a55 100644 --- a/include/tvm/tirx/script/builder/ir.h +++ b/include/tvm/tirx/script/builder/ir.h @@ -57,7 +57,7 @@ using tvm::tirx::Var; * \param axis_separators The separators between input axes when generating flattened output axes. * \return The declared buffer. */ -Buffer BufferDecl(ffi::Array shape, DataType dtype, ffi::String buffer_name, +Buffer BufferDecl(ffi::Array shape, PrimType dtype, ffi::String buffer_name, ffi::Optional data, ffi::Optional> strides, ffi::Optional elem_offset, ffi::String storage_scope, int align, int offset_factor, ffi::String buffer_type, @@ -122,7 +122,7 @@ Type FuncRet(Type ret_type); * \return The matched buffer. */ Buffer MatchBuffer(ffi::ObjectRef param, ffi::Array shape, - DataType dtype = DataType::Float(32), ffi::Optional data = std::nullopt, + PrimType dtype = PrimType::Float(32), ffi::Optional data = std::nullopt, ffi::Array strides = {}, PrimExpr elem_offset = PrimExpr(), ffi::String storage_scope = "global", int align = -1, int offset_factor = 0, ffi::String buffer_type = "default", @@ -197,7 +197,7 @@ void BlockAttrs(ffi::Map attrs); * T.prim_func(tirx=True). */ ffi::Variant SBlockAllocBuffer( - ffi::Array shape, DataType dtype = DataType::Float(32), + ffi::Array shape, PrimType dtype = PrimType::Float(32), ffi::Optional data = std::nullopt, ffi::Array strides = {}, PrimExpr elem_offset = PrimExpr(), ffi::String storage_scope = "", int align = -1, int offset_factor = 0, ffi::String buffer_type = "default", @@ -213,7 +213,7 @@ namespace axis { * \param dtype The data type of the iteration variable. * \return The iteration variable. */ -Var Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); +Var Spatial(Range dom, PrimExpr binding, PrimType dtype = PrimType::Int(32)); /*! * \brief The reduced block axis defining function. @@ -222,7 +222,7 @@ Var Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); * \param dtype The data type of the iteration variable. * \return The iteration variable. */ -Var Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); +Var Reduce(Range dom, PrimExpr binding, PrimType dtype = PrimType::Int(32)); /*! * \brief The scanning block axis defining function. @@ -231,7 +231,7 @@ Var Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); * \param dtype The data type of the iteration variable. * \return The iteration variable. */ -Var Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); +Var Scan(Range dom, PrimExpr binding, PrimType dtype = PrimType::Int(32)); /*! * \brief The opaque block axis defining function. @@ -240,7 +240,7 @@ Var Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); * \param dtype The data type of the iteration variable. * \return The iteration variable. */ -Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); +Var Opaque(Range dom, PrimExpr binding, PrimType dtype = PrimType::Int(32)); /*! * \brief The block axis remapping function. @@ -250,7 +250,7 @@ Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); * \return The iteration variables. */ ffi::Array Remap(ffi::String kinds, ffi::Array bindings, - DataType dtype = DataType::Int(32)); + PrimType dtype = PrimType::Int(32)); } // namespace axis @@ -412,7 +412,7 @@ ElseFrame Else(); * \param layout The layout of the buffer. * \return The declaration frame. */ -DeclBufferFrame DeclBuffer(ffi::Array shape, DataType dtype, ffi::String buffer_name, +DeclBufferFrame DeclBuffer(ffi::Array shape, PrimType dtype, ffi::String buffer_name, ffi::Optional data, ffi::Optional> strides, ffi::Optional elem_offset, ffi::String storage_scope, int align, int offset_factor, ffi::String buffer_type, @@ -428,7 +428,7 @@ DeclBufferFrame DeclBuffer(ffi::Array shape, DataType dtype, ffi::Stri * \param annotations Optional annotations for the allocation. * \return The allocated buffer. */ -Buffer AllocBuffer(ffi::Array shape, DataType dtype = DataType::Float(32), +Buffer AllocBuffer(ffi::Array shape, PrimType dtype = PrimType::Float(32), ffi::String storage_scope = "global", ffi::Optional> annotations = std::nullopt); @@ -465,7 +465,7 @@ ComposeOpFrame ComposeOp(ffi::Map workspace, * \param dtype The data type of the variable. * \return The result variable which gets bound to the thread env. */ -Var EnvThread(ffi::String thread_tag, DataType dtype = DataType::Int(32)); +Var EnvThread(ffi::String thread_tag, PrimType dtype = PrimType::Int(32)); /*! * \brief Store data in a buffer. @@ -494,21 +494,20 @@ void Evaluate(PrimExpr value); * \param is_size_var Whether the pointer is a size var. * * \param is_unknown_type Used to distinguish between - * `PrimType(DataType::Handle())` and - * `PointerType(PrimType(DataType::Void()))`. If true, resolve dtype + * `PrimType::Handle()` and `PointerType(PrimType(DLDataType{kDLOpaqueHandle, 0, 0}))`. + * If true, resolve dtype * of `Void()` as `PrimType`, and if false resolve dtype of `Void()` * as a `PointerType`. * * \return The pointer. */ -inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), - ffi::String storage_scope = "global", bool is_size_var = false, - bool is_unknown_type = false) { +inline Var Handle(PrimType dtype = PrimType::Handle(), ffi::String storage_scope = "global", + bool is_size_var = false, bool is_unknown_type = false) { Type type_annotation{nullptr}; if (is_unknown_type && storage_scope == "global") { - type_annotation = PrimType(runtime::DataType::Handle()); + type_annotation = PrimType::Handle(); } else { - type_annotation = PointerType(PrimType(dtype), storage_scope); + type_annotation = PointerType(dtype, storage_scope); } return is_size_var ? tvm::tirx::SizeVar("", type_annotation) : tvm::tirx::Var("", type_annotation); @@ -519,67 +518,67 @@ inline Var TensorMap() { return tvm::tirx::Var("", PointerType(TensorMapType())) #define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ inline PrimExpr FuncName(ffi::Optional expr = std::nullopt, \ bool is_size_var = false) { \ - DataType dtype = DType; \ + PrimType dtype(DType); \ return expr.defined() \ ? tvm::cast(dtype, expr.value()) \ : (is_size_var ? tvm::tirx::SizeVar("", dtype) : tvm::tirx::Var("", dtype)); \ } -#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##8, FDType(8)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##16, FDType(16)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##32, FDType(32)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##64, FDType(64)); - -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(BFloat, DataType::BFloat); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Float, DataType::Float); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, DataType::UInt); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int); - -#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(FuncName, FDType, Size) \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x2, FDType(Size, 2)) \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x4, FDType(Size, 4)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x8, FDType(Size, 8)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x16, FDType(Size, 16)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x32, FDType(Size, 32)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x64, FDType(Size, 64)); - -#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(DType, FDType) \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##8, FDType, 8); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##16, FDType, 16); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, FDType, 32); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, FDType, 64); - -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(BFloat, DataType::BFloat); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int); - -#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, FDType) \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType, FDType(1)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x2, FDType(2)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x4, FDType(4)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x8, FDType(8)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x16, FDType(16)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64)); - -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E3M4, DataType::Float8E3M4); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3, DataType::Float8E4M3); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3B11FNUZ, DataType::Float8E4M3B11FNUZ); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FN, DataType::Float8E4M3FN); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FNUZ, DataType::Float8E4M3FNUZ); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2, DataType::Float8E5M2); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2FNUZ, DataType::Float8E5M2FNUZ); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E8M0FNU, DataType::Float8E8M0FNU); - -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E2M3FN, DataType::Float6E2M3FN); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E3M2FN, DataType::Float6E3M2FN); - -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1FN, DataType::Float4E2M1FN); - -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool()); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void()); +#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, Code) \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##8, (DLDataType{Code, 8, 1})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##16, (DLDataType{Code, 16, 1})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##32, (DLDataType{Code, 32, 1})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##64, (DLDataType{Code, 64, 1})); + +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(BFloat, kDLBfloat); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Float, kDLFloat); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, kDLUInt); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, kDLInt); + +#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(FuncName, Code, Size) \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x2, (DLDataType{Code, Size, 2})) \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x4, (DLDataType{Code, Size, 4})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x8, (DLDataType{Code, Size, 8})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x16, (DLDataType{Code, Size, 16})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x32, (DLDataType{Code, Size, 32})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x64, (DLDataType{Code, Size, 64})); + +#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(DType, Code) \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##8, Code, 8); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##16, Code, 16); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, Code, 32); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, Code, 64); + +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(BFloat, kDLBfloat); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, kDLFloat); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, kDLUInt); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, kDLInt); + +#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, Code, Bits) \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType, (DLDataType{Code, Bits, 1})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x2, (DLDataType{Code, Bits, 2})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x4, (DLDataType{Code, Bits, 4})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x8, (DLDataType{Code, Bits, 8})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x16, (DLDataType{Code, Bits, 16})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, (DLDataType{Code, Bits, 32})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, (DLDataType{Code, Bits, 64})); + +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E3M4, kDLFloat8_e3m4, 8); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3, kDLFloat8_e4m3, 8); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3B11FNUZ, kDLFloat8_e4m3b11fnuz, 8); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FN, kDLFloat8_e4m3fn, 8); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FNUZ, kDLFloat8_e4m3fnuz, 8); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2, kDLFloat8_e5m2, 8); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2FNUZ, kDLFloat8_e5m2fnuz, 8); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E8M0FNU, kDLFloat8_e8m0fnu, 8); + +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E2M3FN, kDLFloat6_e2m3fn, 6); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E3M2FN, kDLFloat6_e3m2fn, 6); + +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1FN, kDLFloat4_e2m1fn, 4); + +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(Boolean, (DLDataType{kDLBool, 8, 1})); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(Void, (DLDataType{kDLOpaqueHandle, 0, 0})); #undef TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST diff --git a/include/tvm/tirx/stmt.h b/include/tvm/tirx/stmt.h index 1ed4d5acac54..7eb004f8cf25 100644 --- a/include/tvm/tirx/stmt.h +++ b/include/tvm/tirx/stmt.h @@ -1282,7 +1282,7 @@ inline bool IsPragmaKey(const std::string& attr_key) { * \param span The location of this object in the source code. * \return Expr a expression with dtype. */ -TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span()); +TVM_DLL PrimExpr TypeAnnotation(PrimType dtype, Span span = Span()); // overload printing of for type. TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind); diff --git a/include/tvm/tirx/var.h b/include/tvm/tirx/var.h index 8c536ef0d668..3a4746a3f6a2 100644 --- a/include/tvm/tirx/var.h +++ b/include/tvm/tirx/var.h @@ -24,9 +24,9 @@ #ifndef TVM_TIR_VAR_H_ #define TVM_TIR_VAR_H_ +#include #include #include -#include #include #include @@ -57,7 +57,7 @@ class VarNode : public PrimExprNode { * * It is an optional field that provides a refined type of the variable than dtype. * - * \sa tvm/ir/type.h for discussion of relations between runtime::DataType and Type. + * \sa tvm/ir/type.h for discussion of relations between DLPack dtype and Type. */ Type type_annotation; @@ -84,7 +84,7 @@ class Var : public PrimExpr { * \param dtype data type * \param span The location of this object in the source code. */ - TVM_DLL explicit Var(ffi::String name_hint = "v", DataType dtype = DataType::Int(32), + TVM_DLL explicit Var(ffi::String name_hint = "v", PrimType dtype = PrimType::Int(32), Span span = Span()); /*! * \brief Constructor which provides a more detailed type annotation. @@ -110,7 +110,7 @@ class Var : public PrimExpr { * \param dtype The specified dtype * \return The new variable */ - TVM_DLL Var copy_with_dtype(DataType dtype) const; + TVM_DLL Var copy_with_dtype(PrimType dtype) const; /*! * \brief Get pointer to the internal value. @@ -150,7 +150,7 @@ class SizeVar : public Var { * \param t data type * \param span The location of this object in the source code. */ - TVM_DLL explicit SizeVar(ffi::String name_hint = "s", DataType t = DataType::Int(32), + TVM_DLL explicit SizeVar(ffi::String name_hint = "s", PrimType t = PrimType::Int(32), Span span = Span()); /*! * \brief Constructor which provides a more detailed type annotation. diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h index b0c6ac8f6722..26bf7c100ca5 100644 --- a/include/tvm/topi/broadcast.h +++ b/include/tvm/topi/broadcast.h @@ -252,7 +252,8 @@ TOPI_DEFINE_BCAST_OP(divide, { return div(a, b); }); * \return The result. */ TOPI_DEFINE_BCAST_OP(floor_divide, { - if (a.dtype().is_int() || a.dtype().is_uint()) { + PrimType a_ty = a.ty(); + if (a_ty.code() == DLDataTypeCode::kDLInt || a_ty.code() == DLDataTypeCode::kDLUInt) { return floordiv(a, b); } else { return floor(div(a, b)); @@ -287,7 +288,8 @@ TOPI_DEFINE_BCAST_OP(log_add_exp, { return logaddexp(a, b); }); * \return The result. */ TOPI_DEFINE_BCAST_OP(trunc_divide, { - if (a.dtype().is_int() || a.dtype().is_uint()) { + PrimType a_ty = a.ty(); + if (a_ty.code() == DLDataTypeCode::kDLInt || a_ty.code() == DLDataTypeCode::kDLUInt) { return truncdiv(a, b); } else { return trunc(div(a, b)); @@ -319,7 +321,8 @@ TOPI_DEFINE_BCAST_OP(mod, { return truncmod(a, b); }); * \return The result. */ TOPI_DEFINE_BCAST_OP(floor_mod, { - if (a.dtype().is_int() || a.dtype().is_uint()) { + PrimType a_ty = a.ty(); + if (a_ty.code() == DLDataTypeCode::kDLInt || a_ty.code() == DLDataTypeCode::kDLUInt) { return floormod(a, b); } else { return a - floor_divide(a, b) * b; @@ -338,7 +341,8 @@ TOPI_DEFINE_BCAST_OP(floor_mod, { * \return The result. */ TOPI_DEFINE_BCAST_OP(trunc_mod, { - if (a.dtype().is_int() || a.dtype().is_uint()) { + PrimType a_ty = a.ty(); + if (a_ty.code() == DLDataTypeCode::kDLInt || a_ty.code() == DLDataTypeCode::kDLUInt) { return truncmod(a, b); } else { return a - trunc_divide(a, b) * b; diff --git a/include/tvm/topi/contrib/cublas.h b/include/tvm/topi/contrib/cublas.h index 3590b7a54458..18ad4320f489 100644 --- a/include/tvm/topi/contrib/cublas.h +++ b/include/tvm/topi/contrib/cublas.h @@ -48,7 +48,7 @@ inline Tensor cublas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, b auto m = transb ? rhs->shape[0] : rhs->shape[1]; return make_extern( - {{n, m}}, {lhs->dtype}, {lhs, rhs}, + {{n, m}}, {lhs->GetDataType()}, {lhs, rhs}, [&](ffi::Array ins, ffi::Array outs) { return call_packed({StringImm("tvm.contrib.cublas.matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); @@ -73,7 +73,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool tra auto m = transb ? rhs->shape[1] : rhs->shape[2]; return make_extern( - {{b, n, m}}, {lhs->dtype}, {lhs, rhs}, + {{b, n, m}}, {lhs->GetDataType()}, {lhs, rhs}, [&](ffi::Array ins, ffi::Array outs) { return call_packed({StringImm("tvm.contrib.cublas.batch_matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); diff --git a/include/tvm/topi/detail/broadcast.h b/include/tvm/topi/detail/broadcast.h index c9dce9eb7489..7c990c5c6e1a 100644 --- a/include/tvm/topi/detail/broadcast.h +++ b/include/tvm/topi/detail/broadcast.h @@ -42,10 +42,10 @@ struct BroadcastHelper { std::deque vars2; }; -static inline DataType CommonType(DataType type1, DataType type2) { - TVM_FFI_ICHECK(type1.is_scalar() && type2.is_scalar()); +static inline PrimType CommonType(const PrimType& type1, const PrimType& type2) { + TVM_FFI_ICHECK(type1.IsScalar() && type2.IsScalar()); TVM_FFI_ICHECK(type1.code() == type2.code()); - return DataType(type1.code(), std::max(type1.bits(), type2.bits()), /*lanes=*/1); + return type1.bits() < type2.bits() ? type1.WithBits(type2.bits()) : type1; } inline BroadcastHelper BroadcastShape(const tvm::ffi::Array& shape1, @@ -56,15 +56,15 @@ inline BroadcastHelper BroadcastShape(const tvm::ffi::Array& shap tvm::PrimExpr one(1); int i; - auto cast_if_needed = [](DataType to_type, PrimExpr expr) { - return to_type != expr.dtype() ? cast(to_type, expr) : expr; + auto cast_if_needed = [](PrimType to_type, PrimExpr expr) { + return to_type == expr.ty() ? expr : cast(to_type, expr); }; for (i = 1; i <= std::min(s1_size, s2_size); ++i) { // TODO(@icemelon9): Need to revisit this part const IntImmNode* static_size1 = shape1[s1_size - i].as(); const IntImmNode* static_size2 = shape2[s2_size - i].as(); - DataType common_type = CommonType(shape1[s1_size - i].dtype(), shape2[s2_size - i].dtype()); + PrimType common_type = CommonType(shape1[s1_size - i].ty(), shape2[s2_size - i].ty()); bh.all_vars.push_front(tvm::tirx::Var("dim", common_type)); if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) { @@ -104,7 +104,7 @@ inline BroadcastHelper BroadcastShape(const tvm::ffi::Array& shap auto& shape = (s1_size > s2_size) ? shape1 : shape2; auto& vars = (s1_size > s2_size) ? bh.vars1 : bh.vars2; for (; i <= max_size; ++i) { - bh.all_vars.push_front(tvm::tirx::Var("v", shape[max_size - 1].dtype())); + bh.all_vars.push_front(tvm::tirx::Var("v", shape[max_size - 1].ty())); bh.common_shape.push_front(shape[max_size - i]); vars.push_front(bh.all_vars[0]); } @@ -130,7 +130,7 @@ inline tvm::ffi::Array InputIndexFromBroadcast( // Only inject 0 here if we have not yet reached the dimension of I // (i.e. this must be a 1) if (!found && (ovars.size() - i) <= expected_dims) { - ivars.push_back(tvm::IntImm(ovars[i].dtype(), 0)); + ivars.push_back(tvm::IntImm(ovars[i].ty(), 0)); } } TVM_FFI_ICHECK(expected_dims == ivars.size()); diff --git a/include/tvm/topi/detail/extern.h b/include/tvm/topi/detail/extern.h index 161d5291c38e..b0ce2d713bee 100644 --- a/include/tvm/topi/detail/extern.h +++ b/include/tvm/topi/detail/extern.h @@ -28,6 +28,7 @@ #include #include +#include #include namespace tvm { @@ -61,7 +62,7 @@ using FExtern = std::function, ffi::Array)>; * element of out_types. */ inline ffi::Array make_extern(const ffi::Array>& out_shapes, - const std::vector& out_types, + const std::vector& out_types, const ffi::Array& inputs, FExtern fextern, std::string name, std::string tag, ::tvm::ffi::Map attrs) { @@ -100,10 +101,10 @@ inline ffi::Array make_extern(const ffi::Array>& ou inline PrimExpr pack_buffer(Buffer buf) { TVM_FFI_ICHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element"; auto shape = - tvm::tirx::Call(DataType::Handle(), tvm::tirx::builtin::tvm_stack_make_shape(), buf->shape); + tvm::tirx::Call(PrimType::Handle(), tvm::tirx::builtin::tvm_stack_make_shape(), buf->shape); PrimExpr strides; if (buf->strides.size() > 0) { - strides = tvm::tirx::Call(DataType::Handle(), tvm::tirx::builtin::tvm_stack_make_shape(), + strides = tvm::tirx::Call(PrimType::Handle(), tvm::tirx::builtin::tvm_stack_make_shape(), buf->strides); } else { strides = 0; @@ -112,9 +113,9 @@ inline PrimExpr pack_buffer(Buffer buf) { shape, strides, IntImm::Int32(static_cast(buf->shape.size())), - MakeConst(buf->dtype, 0), + MakeConst(PrimType(buf->dtype), 0), buf->elem_offset}; - return tvm::tirx::Call(DataType::Handle(), tvm::tirx::builtin::tvm_stack_make_array(), pack_args); + return tvm::tirx::Call(PrimType::Handle(), tvm::tirx::builtin::tvm_stack_make_array(), pack_args); } /*! @@ -127,7 +128,7 @@ inline PrimExpr pack_buffer(Buffer buf) { * \return An expression representing the invocation */ inline PrimExpr call_packed(ffi::Array args) { - return tvm::tirx::Call(DataType::Int(32), tvm::tirx::builtin::tvm_call_packed(), args); + return tvm::tirx::Call(PrimType::Int(32), tvm::tirx::builtin::tvm_call_packed(), args); } } // namespace detail diff --git a/include/tvm/topi/detail/strided_slice.h b/include/tvm/topi/detail/strided_slice.h index 19ee79a2086f..95ab3a38cbc0 100644 --- a/include/tvm/topi/detail/strided_slice.h +++ b/include/tvm/topi/detail/strided_slice.h @@ -91,7 +91,7 @@ inline ffi::Array StridedSliceCanonicalizeBegin(const ffi::Array& begin, const std::vector& strides, const ffi::Array& axes, - DataType dtype, + PrimType dtype, std::string slice_mode = "end") { ffi::Array begin_expr; for (size_t i = 0; i < axes.size(); ++i) { @@ -140,9 +140,9 @@ inline ffi::Array StridedSliceOutputShape( static_cast((interval + std::abs(strides[i]) - 1) / std::abs(strides[i])); TVM_FFI_ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) << ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i; - out_shape.Set(ax, cast(out_shape[i].dtype(), PrimExpr(slice_size))); + out_shape.Set(ax, cast(out_shape[i].ty(), PrimExpr(slice_size))); } else { - out_shape.Set(ax, tvm::tirx::Var("dim", out_shape[i]->dtype)); + out_shape.Set(ax, tvm::tirx::Var("dim", out_shape[i].ty())); } } diff --git a/include/tvm/topi/detail/tensor_utils.h b/include/tvm/topi/detail/tensor_utils.h index d67ad6359434..82649cd0b387 100644 --- a/include/tvm/topi/detail/tensor_utils.h +++ b/include/tvm/topi/detail/tensor_utils.h @@ -70,10 +70,10 @@ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const ffi::Arraydtype, -9.0), minimum(MakeConst(in->dtype, 9.0), in)); + PrimType input_type = in->GetDataType(); + auto x = maximum(MakeConst(input_type, -9.0), minimum(MakeConst(input_type, 9.0), in)); // The monomial coefficients of the numerator polynomial (odd). - auto alpha_1 = MakeConst(in->dtype, 4.89352455891786e-03); - auto alpha_3 = MakeConst(in->dtype, 6.37261928875436e-04); - auto alpha_5 = MakeConst(in->dtype, 1.48572235717979e-05); - auto alpha_7 = MakeConst(in->dtype, 5.12229709037114e-08); - auto alpha_9 = MakeConst(in->dtype, -8.60467152213735e-11); - auto alpha_11 = MakeConst(in->dtype, 2.00018790482477e-13); - auto alpha_13 = MakeConst(in->dtype, -2.76076847742355e-16); + auto alpha_1 = MakeConst(input_type, 4.89352455891786e-03); + auto alpha_3 = MakeConst(input_type, 6.37261928875436e-04); + auto alpha_5 = MakeConst(input_type, 1.48572235717979e-05); + auto alpha_7 = MakeConst(input_type, 5.12229709037114e-08); + auto alpha_9 = MakeConst(input_type, -8.60467152213735e-11); + auto alpha_11 = MakeConst(input_type, 2.00018790482477e-13); + auto alpha_13 = MakeConst(input_type, -2.76076847742355e-16); // The monomial coefficients of the denominator polynomial (even). - auto beta_0 = MakeConst(in->dtype, 4.89352518554385e-03); - auto beta_2 = MakeConst(in->dtype, 2.26843463243900e-03); - auto beta_4 = MakeConst(in->dtype, 1.18534705686654e-04); - auto beta_6 = MakeConst(in->dtype, 1.19825839466702e-06); + auto beta_0 = MakeConst(input_type, 4.89352518554385e-03); + auto beta_2 = MakeConst(input_type, 2.26843463243900e-03); + auto beta_4 = MakeConst(input_type, 1.18534705686654e-04); + auto beta_6 = MakeConst(input_type, 1.19825839466702e-06); return compute( x->shape, @@ -130,7 +131,7 @@ inline Tensor fast_tanh_float(const Tensor& in, std::string name, std::string ta */ inline Tensor fast_tanh(const Tensor& x, std::string name = "T_fast_tanh", std::string tag = kElementWise) { - if (x->dtype == DataType::Float(32)) { + if (x->GetDataType().MatchesElementType(DLDataTypeCode::kDLFloat, 32)) { // invoke fast_tanh_float implementation return fast_tanh_float(x, name, tag); } else { @@ -209,9 +210,10 @@ inline Tensor sign(const Tensor& x, std::string name = "T_sign", std::string tag return compute( x->shape, [&](const ffi::Array& i) { - PrimExpr zero = MakeConst(x->dtype, 0); - PrimExpr one = MakeConst(x->dtype, 1); - PrimExpr minus_one = MakeConst(x->dtype, -1); + PrimType x_type(x->GetDataType()); + PrimExpr zero = MakeConst(x_type, 0); + PrimExpr one = MakeConst(x_type, 1); + PrimExpr minus_one = MakeConst(x_type, -1); auto s1 = tvm::tirx::Select((x(i) < zero), minus_one, zero); auto s2 = tvm::tirx::Select((x(i) > zero), one, s1); return s2; @@ -232,7 +234,7 @@ inline Tensor rsqrt(const Tensor& x, std::string name = "tensor", std::string ta return compute( x->shape, [&](const ffi::Array& i) { - PrimExpr one = MakeConst(x->dtype, 1); + PrimExpr one = MakeConst(x->GetDataType(), 1); return one / tvm::sqrt(x(i)); }, name, tag); @@ -255,8 +257,9 @@ inline Tensor clip(const Tensor& x, const PrimExpr& a_min, const PrimExpr& a_max return compute( x->shape, [&](const ffi::Array& i) { - auto min_val = tvm::cast(x->dtype, a_min); - auto max_val = tvm::cast(x->dtype, a_max); + PrimType x_type(x->GetDataType()); + auto min_val = tvm::cast(x_type, a_min); + auto max_val = tvm::cast(x_type, a_max); return tvm::max(tvm::min(x(i), max_val), min_val); // NOLINT(*) }, name, tag); @@ -274,16 +277,24 @@ inline Tensor clip(const Tensor& x, const PrimExpr& a_min, const PrimExpr& a_max * * \return A Tensor whose op member is the cast operation */ -inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast", +inline Tensor cast(const Tensor& x, PrimType type, std::string name, std::string tag); + +inline Tensor cast(const Tensor& x, DLDataType type, std::string name = "T_cast", + std::string tag = kElementWise) { + return cast(x, PrimType(type), std::move(name), std::move(tag)); +} + +inline Tensor cast(const Tensor& x, PrimType type, std::string name = "T_cast", std::string tag = kElementWise) { return compute( x->shape, [&](const ffi::Array& i) -> PrimExpr { auto expr = x(i); - if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) { - if (expr.dtype().lanes() == type.lanes()) { + PrimType expr_ty = expr.ty(); + if (expr_ty.MatchesElementType(type.code(), type.bits())) { + if (expr_ty.lanes() == type.lanes()) { return expr; - } else if (expr.dtype().lanes() == 1 && type.is_vector()) { + } else if (expr_ty.lanes() == 1 && type.IsFixedLengthVector()) { return tvm::tirx::Broadcast(expr, type.lanes()); } } @@ -303,7 +314,14 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast", * * \return A Tensor whose op member is the reinterpret operation */ -inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "tensor", +inline Tensor reinterpret(const Tensor& x, PrimType type, std::string name, std::string tag); + +inline Tensor reinterpret(const Tensor& x, DLDataType type, std::string name = "tensor", + std::string tag = kElementWise) { + return reinterpret(x, PrimType(type), std::move(name), std::move(tag)); +} + +inline Tensor reinterpret(const Tensor& x, PrimType type, std::string name = "tensor", std::string tag = kElementWise) { return compute( x->shape, [&](const ffi::Array& i) { return reinterpret(type, x(i)); }, name, tag); @@ -344,7 +362,15 @@ inline Tensor elemwise_sum(const ffi::Array& xs, std::string name = "T_e * * \return A Tensor whose op member is the full operation */ -inline Tensor full(const ffi::Array& shape, DataType dtype, const PrimExpr fill_value, +inline Tensor full(const ffi::Array& shape, PrimType dtype, const PrimExpr fill_value, + std::string name, std::string tag); + +inline Tensor full(const ffi::Array& shape, DLDataType dtype, const PrimExpr fill_value, + std::string name = "T_full", std::string tag = kElementWise) { + return full(shape, PrimType(dtype), fill_value, std::move(name), std::move(tag)); +} + +inline Tensor full(const ffi::Array& shape, PrimType dtype, const PrimExpr fill_value, std::string name = "T_full", std::string tag = kElementWise) { PrimExpr ev = cast(dtype, fill_value); if (!ev.defined()) { @@ -366,7 +392,7 @@ inline Tensor full(const ffi::Array& shape, DataType dtype, const Prim */ inline Tensor full_like(const Tensor& x, const PrimExpr fill_value, std::string name = "T_full_like", std::string tag = kElementWise) { - PrimExpr ev = cast(x->dtype, fill_value); + PrimExpr ev = cast(x->GetDataType(), fill_value); return compute(x->shape, [&](const ffi::Array& i) { return ev; }, name, tag); } @@ -392,19 +418,17 @@ inline Tensor full_like(const Tensor& x, const PrimExpr fill_value, * y = exp(f) = 1 + 2 * P(x**2)/(Q(x**2) - P(x**2)) */ inline Tensor fast_exp_float32(const Tensor& _x, std::string name, std::string tag) { - auto x_hi = FloatImm(DataType::Float(32), 88.3762626647950f); - auto x_lo = FloatImm(DataType::Float(32), -88.3762626647949f); - auto log2e = FloatImm(DataType::Float(32), 1.44269504088896341f); - auto ln2 = FloatImm(DataType::Float(32), 0.6931471805599453f); - PrimExpr p[6] = {FloatImm(DataType::Float(32), 1.9875691500E-4f), - FloatImm(DataType::Float(32), 1.3981999507E-3f), - FloatImm(DataType::Float(32), 8.3334519073E-3f), - FloatImm(DataType::Float(32), 4.1665795894E-2f), - FloatImm(DataType::Float(32), 1.6666665459E-1f), - FloatImm(DataType::Float(32), 5.0000001201E-1f)}; - auto one = FloatImm(DataType::Float(32), 1.0f); - auto one_half = FloatImm(DataType::Float(32), 0.5f); - auto b = FloatImm(DataType::Float(32), 127.0f); + PrimType f32_ty = PrimType::Float(32); + auto x_hi = FloatImm(f32_ty, 88.3762626647950f); + auto x_lo = FloatImm(f32_ty, -88.3762626647949f); + auto log2e = FloatImm(f32_ty, 1.44269504088896341f); + auto ln2 = FloatImm(f32_ty, 0.6931471805599453f); + PrimExpr p[6] = {FloatImm(f32_ty, 1.9875691500E-4f), FloatImm(f32_ty, 1.3981999507E-3f), + FloatImm(f32_ty, 8.3334519073E-3f), FloatImm(f32_ty, 4.1665795894E-2f), + FloatImm(f32_ty, 1.6666665459E-1f), FloatImm(f32_ty, 5.0000001201E-1f)}; + auto one = FloatImm(f32_ty, 1.0f); + auto one_half = FloatImm(f32_ty, 0.5f); + auto b = FloatImm(f32_ty, 127.0f); return compute( _x->shape, @@ -419,7 +443,7 @@ inline Tensor fast_exp_float32(const Tensor& _x, std::string name, std::string t (((((p[0] * f + p[1]) * f + p[2]) * f + p[3]) * f + p[4]) * f + p[5]) * f * f + f + one; // Return 2^m * exp(r). auto ef = - tvm::reinterpret(DataType::Float(32), ::tvm::cast(DataType::Int(32), n + b) << 23); + tvm::reinterpret(PrimType::Float(32), ::tvm::cast(PrimType::Int(32), n + b) << 23); return ::tvm::max(ef * y, _x(i)); // NOLINT(*) }, name, tag); @@ -437,7 +461,7 @@ inline Tensor fast_exp_float32(const Tensor& _x, std::string name, std::string t */ inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp", std::string tag = kElementWise) { - if (x->dtype == DataType::Float(32)) { + if (x->GetDataType().MatchesElementType(DLDataTypeCode::kDLFloat, 32)) { auto ret = fast_exp_float32(x, name, tag); return ret; } else { @@ -474,10 +498,11 @@ inline Tensor fast_erf_float16(const Tensor& data, std::string name, std::string */ inline Tensor fast_erf(const Tensor& x, std::string name = "T_fast_erf", std::string tag = kElementWise) { - if (x->dtype == DataType::Float(32)) { + PrimType x_type(x->GetDataType()); + if (x_type.MatchesElementType(DLDataTypeCode::kDLFloat, 32)) { auto ret = fast_erf_float32(x, name, tag); return ret; - } else if (x->dtype == DataType::Float(16)) { + } else if (x_type.MatchesElementType(DLDataTypeCode::kDLFloat, 16)) { auto ret = fast_erf_float16(x, name, tag); return ret; } else { diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 0a448620dae3..b864bfe53ea3 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -57,7 +57,7 @@ inline tvm::te::Tensor relu(const tvm::te::Tensor& t, T threshold = static_cast< return tvm::te::compute( t->shape, [&](const tvm::ffi::Array& i) { - auto threshold_const = tvm::tirx::MakeConst(t->dtype, threshold); + auto threshold_const = tvm::tirx::MakeConst(tvm::PrimType(t->dtype), threshold); return tvm::max(t(i), threshold_const); }, name, tag); @@ -80,7 +80,7 @@ inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, double alpha = 0.1, t->shape, [&](const tvm::ffi::Array& i) { auto value = t(i); - auto calpha = tvm::tirx::MakeConst(value.dtype(), alpha); + auto calpha = tvm::tirx::MakeConst(value.ty(), alpha); return tvm::tirx::Select(value > 0, value, value * calpha); }, name, tag); @@ -171,10 +171,10 @@ inline tvm::te::Tensor pad( tvm::ffi::Array pad_after_int32; for (const auto& ele : pad_before) { - pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); + pad_before_int32.push_back(tvm::cast(tvm::PrimType::Int(32), ele)); } for (const auto& ele : pad_after) { - pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); + pad_after_int32.push_back(tvm::cast(tvm::PrimType::Int(32), ele)); } tvm::ffi::Array output_shape; @@ -194,7 +194,7 @@ inline tvm::te::Tensor pad( } if (!pad_value.defined()) { - pad_value = tvm::tirx::MakeConst(t->dtype, 0); + pad_value = tvm::tirx::MakeConst(tvm::PrimType(t->dtype), 0); } auto l = [&](tvm::ffi::Array ovars) { @@ -495,19 +495,19 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, tvm::ffi::Array pad_after_int32; // pad size for batch dimension is 0 - pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0)); - pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0)); + pad_before_int32.push_back(tvm::cast(tvm::PrimType::Int(32), 0)); + pad_after_int32.push_back(tvm::cast(tvm::PrimType::Int(32), 0)); // insert pad sizes given for spatial dimensions for (const auto& ele : pad_before) { - pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); + pad_before_int32.push_back(tvm::cast(tvm::PrimType::Int(32), ele)); } for (const auto& ele : pad_after) { - pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); + pad_after_int32.push_back(tvm::cast(tvm::PrimType::Int(32), ele)); } // pad the input with paddings provided if (!pad_value.defined()) { - pad_value = tvm::tirx::MakeConst(data->dtype, 0); + pad_value = tvm::tirx::MakeConst(tvm::PrimType(data->dtype), 0); } padded_t = pad(data, pad_before_int32, pad_after_int32, pad_value); @@ -629,9 +629,9 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, // Crop the start and end of dimensions of out ffi::Array> begin_idx, end_idx; ffi::Array strides; - DataType index_dtype = DataType::Int(64); + PrimType index_ty = PrimType::Int(64); for (size_t i = 0; i < r_p_shape.size(); ++i) { - strides.push_back(IntImm(index_dtype, 1)); + strides.push_back(IntImm(index_ty, 1)); if (i > 0 && i <= num_block_dims) { // prepare begin and end index for spatial dimensions int64_t begin_i = GetConstInt(crop_begin_list[i - 1]); @@ -640,12 +640,12 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, TVM_FFI_ICHECK_GT(out_i, (begin_i + end_i)) << "Incorrect crop sizes for (" << i << ")th dim, can not crop more than" << " output size" << out_i << " vs " << (begin_i + end_i); - begin_idx.push_back(IntImm(index_dtype, begin_i)); - end_idx.push_back(IntImm(index_dtype, out_i - end_i)); + begin_idx.push_back(IntImm(index_ty, begin_i)); + end_idx.push_back(IntImm(index_ty, out_i - end_i)); } else { // ignore the batch and remaining dimension - begin_idx.push_back(IntImm(index_dtype, 0)); - end_idx.push_back(IntImm(index_dtype, GetConstInt(r_p_shape[i]))); + begin_idx.push_back(IntImm(index_ty, 0)); + end_idx.push_back(IntImm(index_ty, GetConstInt(r_p_shape[i]))); } } @@ -677,7 +677,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T [&](const tvm::ffi::Array& target_indices) { auto c = targets(); return tvm::tirx::Select(c != ignore_index, -predictions(c) * weights(c), - tvm::tirx::MakeConst(predictions->dtype, 0)); + tvm::tirx::MakeConst(tvm::PrimType(predictions->dtype), 0)); }, name, tag); if (reduction == "mean") { @@ -686,7 +686,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T [&](const tvm::ffi::Array& target_indices) { auto c = targets(); return tvm::tirx::Select(c != ignore_index, weights(c), - tvm::tirx::MakeConst(predictions->dtype, 0)); + tvm::tirx::MakeConst(tvm::PrimType(predictions->dtype), 0)); }, name, tag); return topi::divide(T, W); @@ -705,7 +705,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T pred_indices.push_back(target_indices[i]); // indices for multidimensional loss } return tvm::tirx::Select(c != ignore_index, -predictions(pred_indices) * weights(c), - tvm::tirx::MakeConst(predictions->dtype, 0)); + tvm::tirx::MakeConst(tvm::PrimType(predictions->dtype), 0)); }, name, tag); TVM_FFI_ICHECK(T->shape.size() != 0); @@ -715,7 +715,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T [&](const tvm::ffi::Array& target_indices) { auto c = targets(target_indices); return tvm::tirx::Select(c != ignore_index, weights(c), - tvm::tirx::MakeConst(predictions->dtype, 0)); + tvm::tirx::MakeConst(tvm::PrimType(predictions->dtype), 0)); }, name, tag); return topi::divide(topi::sum(T, tvm::ffi::Array(nullptr)), diff --git a/include/tvm/topi/nn/bnn.h b/include/tvm/topi/nn/bnn.h index 5faed879c005..56a6f3aaa815 100644 --- a/include/tvm/topi/nn/bnn.h +++ b/include/tvm/topi/nn/bnn.h @@ -71,14 +71,14 @@ inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis, start_idx.push_back(i == static_cast(axis) ? indices[i] * 32 : static_cast(indices[i])); } - PrimExpr packed = IntImm(DataType::UInt(32), 0); + PrimExpr packed = IntImm(PrimType::UInt(32), 0); for (size_t j = 0; j < 32; ++j) { ffi::Array idx; for (size_t i = 0; i < n; ++i) { idx.push_back(i == static_cast(axis) ? start_idx[i] + static_cast(j) : start_idx[i]); } - auto sign = tvm::cast(DataType::UInt(32), data(idx) >= 0); + auto sign = tvm::cast(PrimType::UInt(32), data(idx) >= 0); packed = (packed | sign); if (j == 31) { return packed; @@ -101,8 +101,8 @@ inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis, inline tvm::te::Tensor binary_dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight) { TVM_FFI_ICHECK_EQ(data->shape.size(), 2) << "binary_dense requires 2-D data"; TVM_FFI_ICHECK_EQ(weight->shape.size(), 2) << "binary_dense requires 2-D weight"; - TVM_FFI_ICHECK_EQ(data->dtype, DataType::UInt(32)) << "binary_dense requires uint32 data"; - TVM_FFI_ICHECK_EQ(weight->dtype, DataType::UInt(32)) << "binary_dense requires uint32 weight"; + TVM_FFI_ICHECK_EQ(data->dtype, PrimType::UInt(32)) << "binary_dense requires uint32 data"; + TVM_FFI_ICHECK_EQ(weight->dtype, PrimType::UInt(32)) << "binary_dense requires uint32 weight"; auto batch = data->shape[0]; auto in_dim = data->shape[1]; diff --git a/include/tvm/topi/nn/dense.h b/include/tvm/topi/nn/dense.h index be0030cd40d5..2c7b2330505e 100644 --- a/include/tvm/topi/nn/dense.h +++ b/include/tvm/topi/nn/dense.h @@ -46,7 +46,7 @@ using namespace tvm::te; * \return Tensor with shape [batch, out_dim] */ inline tvm::te::Tensor dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight, - const tvm::te::Tensor& bias, const DataType& out_dtype) { + const tvm::te::Tensor& bias, const PrimType& out_dtype) { TVM_FFI_ICHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data"; TVM_FFI_ICHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight"; if (bias.defined()) { diff --git a/include/tvm/topi/nn/dilate.h b/include/tvm/topi/nn/dilate.h index 0c8ea395c701..f45543eda337 100644 --- a/include/tvm/topi/nn/dilate.h +++ b/include/tvm/topi/nn/dilate.h @@ -95,7 +95,7 @@ inline Tensor dilate(const Tensor& x, ffi::Array strides, double dilat if (not_zero.size() > 0) { auto all_not_zero = all(not_zero); return tvm::if_then_else(all_not_zero, x(index_tuple), - MakeConst(x->dtype, dilation_value)); + MakeConst(PrimType(x->dtype), dilation_value)); } return x(index_tuple); }, diff --git a/include/tvm/topi/nn/group_norm.h b/include/tvm/topi/nn/group_norm.h index 4962587a9396..7a778dea8ce5 100644 --- a/include/tvm/topi/nn/group_norm.h +++ b/include/tvm/topi/nn/group_norm.h @@ -45,9 +45,9 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& const auto& beta_type = beta.defined() ? beta->dtype : data_type; TVM_FFI_ICHECK(data_type == gamma_type && data_type == beta_type) << "group_norm: data, gamma and beta must have the same type"; - TVM_FFI_ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) + TVM_FFI_ICHECK(data_type == PrimType::Float(32) || data_type == PrimType::Float(16)) << "group_norm: only support float32 and float16 for now"; - bool is_float16 = data_type == DataType::Float(16); + bool is_float16 = data_type == PrimType::Float(16); // reshape data C -> G, C/G int ndim = data->shape.size(); channel_axis = GetRealAxis(static_cast(ndim), ffi::Array({channel_axis}))[0]; @@ -65,7 +65,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& } Tensor data_reshaped; if (is_float16) { - data_reshaped = cast(reshape(data, new_shape), DataType::Float(32)); + data_reshaped = cast(reshape(data, new_shape), PrimType::Float(32)); } else { data_reshaped = reshape(data, new_shape); } @@ -126,7 +126,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& auto temp_x = temp_x_x2[0]; auto temp_x2 = temp_x_x2[1]; - PrimExpr reduce_extent = FloatImm(DataType::Float(32), 1); + PrimExpr reduce_extent = FloatImm(PrimType::Float(32), 1); for (auto axis : new_axes) { reduce_extent *= data_reshaped->shape[axis]; } @@ -142,10 +142,10 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& gamma_indices = {indices[channel_axis], indices[channel_axis + 1]}; auto mean = temp_x(non_reduce_indices) / reduce_extent; auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean; - PrimExpr group_norm = - (data_reshaped(indices) - mean) * tvm::rsqrt(var + MakeConst(data->dtype, epsilon)); + PrimExpr group_norm = (data_reshaped(indices) - mean) * + tvm::rsqrt(var + MakeConst(PrimType(data->dtype), epsilon)); if (is_float16) { - group_norm = Cast(DataType::Float(16), group_norm); + group_norm = Cast(PrimType::Float(16), group_norm); } if (gamma.defined()) { group_norm = topi::multiply(group_norm, gamma_reshaped(gamma_indices)); diff --git a/include/tvm/topi/nn/instance_norm.h b/include/tvm/topi/nn/instance_norm.h index 60361e8bc681..e246d97a59df 100644 --- a/include/tvm/topi/nn/instance_norm.h +++ b/include/tvm/topi/nn/instance_norm.h @@ -58,9 +58,9 @@ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tenso const auto& beta_type = beta.defined() ? beta->dtype : data_type; TVM_FFI_ICHECK(data_type == gamma_type && data_type == beta_type) << "instance_norm: data, gamma and beta must have the same type"; - TVM_FFI_ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) + TVM_FFI_ICHECK(data_type == PrimType::Float(32) || data_type == PrimType::Float(16)) << "instance_norm: only support float32 and float16 for now"; - bool is_float16 = data_type == DataType::Float(16); + bool is_float16 = data_type == PrimType::Float(16); // sum x and x^2 auto ndim = data->shape.size(); TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; @@ -69,9 +69,10 @@ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tenso auto target_shape = MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/true); auto func = MakeTupleSumReducer(); + PrimType f32_ty = PrimType::Float(32); - auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func, - &data](const ffi::Array& indices) { + auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func, &data, + f32_ty](const ffi::Array& indices) { ffi::Array eval_range; int arg_counter = 0; int red_counter = 0; @@ -86,15 +87,14 @@ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tenso arg_counter++; } } - auto square = [is_float16](const PrimExpr& x) { + auto square = [is_float16, f32_ty](const PrimExpr& x) { if (is_float16) { - return Cast(DataType::Float(32), x) * Cast(DataType::Float(32), x); + return Cast(f32_ty, x) * Cast(f32_ty, x); } return x * x; }; if (is_float16) { - return func({Cast(DataType::Float(32), data(eval_range)), square(data(eval_range))}, - reduce_axes, nullptr); + return func({Cast(f32_ty, data(eval_range)), square(data(eval_range))}, reduce_axes, nullptr); } else { return func({data(eval_range), square(data(eval_range))}, reduce_axes, nullptr); } @@ -106,7 +106,7 @@ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tenso auto temp_x = temp_x_x2[0]; auto temp_x2 = temp_x_x2[1]; - auto reduce_extent = MakeConst(data->dtype, 1); + auto reduce_extent = MakeConst(PrimType(data->dtype), 1); for (int i : real_axis) { reduce_extent *= data->shape[i]; } @@ -124,9 +124,9 @@ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tenso channel = indices[channel_axis]; auto mean = temp_x(non_reduce_indices) / reduce_extent; auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean; - auto instance_norm = (data(indices) - mean) * tvm::rsqrt(var + MakeConst(var->dtype, epsilon)); + auto instance_norm = (data(indices) - mean) * tvm::rsqrt(var + MakeConst(var.ty(), epsilon)); if (is_float16) { - instance_norm = Cast(DataType::Float(16), instance_norm); + instance_norm = Cast(PrimType::Float(16), instance_norm); } instance_norm = topi::multiply(instance_norm, gamma(channel)); if (beta.defined()) { diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h index fb8155ef654a..8a995d7b91fe 100644 --- a/include/tvm/topi/nn/layer_norm.h +++ b/include/tvm/topi/nn/layer_norm.h @@ -57,9 +57,9 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& const auto& beta_type = beta.defined() ? beta->dtype : data_type; TVM_FFI_ICHECK(data_type == gamma_type && data_type == beta_type) << "layer_norm: data, gamma and beta must have the same type"; - TVM_FFI_ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) + TVM_FFI_ICHECK(data_type == PrimType::Float(32) || data_type == PrimType::Float(16)) << "layer_norm: only support float32 and float16 for now"; - bool is_float16 = data_type == DataType::Float(16); + bool is_float16 = data_type == PrimType::Float(16); // Two-pass algorithm for improved numerical stability: // pass1: mean = E[x] // pass2: var = E[(x - mean)^2] @@ -69,6 +69,7 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& auto reduce_axes = MakeReduceAxes(real_axis, data); auto target_shape = MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/false); + PrimType f32_ty = PrimType::Float(32); auto make_eval_range = [&real_axis, &reduce_axes, ndim](const ffi::Array& non_reduce_indices) { @@ -91,17 +92,17 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& Tensor temp_sum = te::compute( target_shape, - [is_float16, &data, &reduce_axes, &make_eval_range](const ffi::Array& indices) { + [is_float16, &data, &reduce_axes, &make_eval_range, f32_ty](const ffi::Array& indices) { auto eval_range = make_eval_range(indices); PrimExpr x = data(eval_range); if (is_float16) { - x = Cast(DataType::Float(32), x); + x = Cast(f32_ty, x); } return sum(x, reduce_axes); }, data->op->name + "_sum", kCommReduce); - DataType reduce_dtype = is_float16 ? DataType::Float(32) : data->dtype; + PrimType reduce_dtype = is_float16 ? PrimType::Float(32) : PrimType(data->dtype); PrimExpr reduce_extent = MakeConst(reduce_dtype, 1); for (int i : real_axis) { reduce_extent *= data->shape[i]; @@ -115,12 +116,12 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& Tensor temp_var_sum = te::compute( target_shape, - [is_float16, &data, &reduce_axes, &make_eval_range, - &temp_mean](const ffi::Array& indices) { + [is_float16, &data, &reduce_axes, &make_eval_range, &temp_mean, + f32_ty](const ffi::Array& indices) { auto eval_range = make_eval_range(indices); PrimExpr x = data(eval_range); if (is_float16) { - x = Cast(DataType::Float(32), x); + x = Cast(f32_ty, x); } PrimExpr diff = x - temp_mean(indices); return sum(diff * diff, reduce_axes); @@ -138,9 +139,9 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& } auto mean = temp_mean(non_reduce_indices); auto var = temp_var_sum(non_reduce_indices) / reduce_extent; - auto layer_norm = (data(indices) - mean) * rsqrt(var + MakeConst(var->dtype, epsilon)); + auto layer_norm = (data(indices) - mean) * rsqrt(var + MakeConst(var.ty(), epsilon)); if (is_float16) { - layer_norm = Cast(DataType::Float(16), layer_norm); + layer_norm = Cast(PrimType::Float(16), layer_norm); } layer_norm = topi::multiply(layer_norm, gamma(reduce_indices)); if (beta.defined()) { diff --git a/include/tvm/topi/nn/local_response_norm.h b/include/tvm/topi/nn/local_response_norm.h index 7407448f88c5..4f411076387d 100644 --- a/include/tvm/topi/nn/local_response_norm.h +++ b/include/tvm/topi/nn/local_response_norm.h @@ -55,7 +55,8 @@ inline Tensor lrn(const Tensor& data, int size, int axis = 1, float alpha = 0.00 TVM_FFI_ICHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input"; TVM_FFI_ICHECK_EQ(size % 2, 1) << "size should be odd number"; TVM_FFI_ICHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC"; - TVM_FFI_ICHECK(data->dtype.is_float()) << "datatype should be float"; + // LRN only requires a floating-point element kind; lane encoding is irrelevant here. + TVM_FFI_ICHECK_EQ(data->dtype.code(), DLDataTypeCode::kDLFloat) << "datatype should be float"; auto input_shape = data->shape; ffi::Array pad_before{0, 0, 0, 0}; ffi::Array pad_after{0, 0, 0, 0}; @@ -79,9 +80,9 @@ inline Tensor lrn(const Tensor& data, int size, int axis = 1, float alpha = 0.00 }, "tensor", "sqr_sum"); } - PrimExpr alpha_imm = tvm::te::MakeConst(data->dtype, alpha); - PrimExpr beta_imm = tvm::te::MakeConst(data->dtype, beta); - PrimExpr bias_imm = tvm::te::MakeConst(data->dtype, bias); + PrimExpr alpha_imm = tvm::te::MakeConst(PrimType(data->dtype), alpha); + PrimExpr beta_imm = tvm::te::MakeConst(PrimType(data->dtype), beta); + PrimExpr bias_imm = tvm::te::MakeConst(PrimType(data->dtype), bias); auto sqrt_sum_up = tvm::te::compute( input_shape, [&](Var i, Var j, Var k, Var l) { diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h index e8410d8add22..91b10e7d8df9 100644 --- a/include/tvm/topi/nn/pooling.h +++ b/include/tvm/topi/nn/pooling.h @@ -117,7 +117,8 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww"); auto argmax = MakeArgmaxReducer(); - auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; + auto pad_x = + do_pad ? pad(x, pad_before, pad_after, tvm::min_value(PrimType(x->dtype)), "pad_temp") : x; auto mp_argmax = tvm::te::compute( out_shape, @@ -145,17 +146,17 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww); PrimExpr out_idx_lower_h = tirx::Select( - pad_inds[height_axis] < kernel_height, IntImm(pad_inds[height_axis].dtype(), 0), + pad_inds[height_axis] < kernel_height, IntImm(pad_inds[height_axis].ty(), 0), (pad_inds[height_axis] - kernel_height) / stride_height + 1); PrimExpr out_idx_lower_w = tirx::Select( - pad_inds[width_axis] < kernel_width, IntImm(pad_inds[width_axis].dtype(), 0), + pad_inds[width_axis] < kernel_width, IntImm(pad_inds[width_axis].ty(), 0), (pad_inds[width_axis] - kernel_width) / stride_width + 1); return tvm::sum( tvm::if_then_else(tirx::And(tirx::And(out_idx[height_axis] >= out_idx_lower_h, out_idx[width_axis] >= out_idx_lower_w), mp_inds(out_idx) == idx), - out_grad(out_idx), MakeConst(x->dtype, 0)), + out_grad(out_idx), MakeConst(PrimType(x->dtype), 0)), {windowh, windoww}); }, "T_pool_grad", "pool_grad_max"); @@ -176,10 +177,10 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww)); PrimExpr out_idx_lower_h = - tirx::Select(pad_h_idx < kernel_height, IntImm(pad_h_idx.dtype(), 0), + tirx::Select(pad_h_idx < kernel_height, IntImm(pad_h_idx.ty(), 0), (pad_h_idx - kernel_height) / stride_height + 1); PrimExpr out_idx_lower_w = - tirx::Select(pad_w_idx < kernel_width, IntImm(pad_w_idx.dtype(), 0), + tirx::Select(pad_w_idx < kernel_width, IntImm(pad_w_idx.ty(), 0), (pad_w_idx - kernel_width) / stride_width + 1); PrimExpr divide_factor; // number of pooled elements @@ -191,16 +192,17 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, PrimExpr h_end = min(h_start + kernel_height, height); PrimExpr w_end = min(w_start + kernel_width, width); - h_start = max(h_start, IntImm(h_start.dtype(), 0)); - w_start = max(w_start, IntImm(w_start.dtype(), 0)); - divide_factor = max((h_end - h_start) * (w_end - w_start), MakeConst(h_end.dtype(), 1)); + h_start = max(h_start, IntImm(h_start.ty(), 0)); + w_start = max(w_start, IntImm(w_start.ty(), 0)); + divide_factor = max((h_end - h_start) * (w_end - w_start), MakeConst(h_end.ty(), 1)); } return tvm::sum( tvm::if_then_else(tirx::And(tirx::And(out_idx[height_axis] >= out_idx_lower_h, out_idx[height_axis] < out_height), tirx::And(out_idx[width_axis] >= out_idx_lower_w, out_idx[width_axis] < out_width)), - out_grad(out_idx) / divide_factor, MakeConst(out_grad->dtype, 0)), + out_grad(out_idx) / divide_factor, + MakeConst(PrimType(out_grad->dtype), 0)), {windowh, windoww}); }, "T_pool_grad", "pool_grad_avg"); @@ -384,9 +386,9 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const ffi::Array& ou ffi::Array reduce_axes; std::tie(indices, reduce_axes) = get_iter_vars(output, false); - PrimExpr divide_factor = tvm::cast(x->dtype, 1); + PrimExpr divide_factor = tvm::cast(PrimType(x->dtype), 1); for (size_t i = 0; i < n_dim; ++i) { - divide_factor *= tvm::cast(DataType::Int(32), reduce_axes[i]->dom->extent); + divide_factor *= tvm::cast(PrimType::Int(32), reduce_axes[i]->dom->extent); } return div(pool_sum(indices), divide_factor); @@ -582,7 +584,8 @@ inline Tensor pool_impl_nd(const Tensor& x, const ffi::Array& kernel_s ffi::Map attrs; if (pool_type == kMaxPool) { - auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; + auto temp = + do_pad ? pad(x, pad_before, pad_after, tvm::min_value(PrimType(x->dtype)), "pad_temp") : x; attrs.Set("schedule_rule", tvm::ffi::String("meta_schedule.pool_max")); return tvm::te::compute( out_shape, @@ -657,7 +660,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const ffi::Array& kernel_s // number that represents the number of steps along the dilated kernel to reach a // non-padded value. Otherwise this should be 0. PrimExpr jumps_to_non_pad = (dilation[i] - 1 - start[i]) / dilation[i]; - jumps_to_non_pad = max(jumps_to_non_pad, IntImm(jumps_to_non_pad.dtype(), 0)); + jumps_to_non_pad = max(jumps_to_non_pad, IntImm(jumps_to_non_pad.ty(), 0)); end[i] = min(end[i], data_shape[ii] - 1); num_el *= (end[i] - (start[i] + dilation[i] * jumps_to_non_pad)) / dilation[i] + 1; diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h index 294d82054e3e..29f46918a754 100644 --- a/include/tvm/topi/nn/rms_norm.h +++ b/include/tvm/topi/nn/rms_norm.h @@ -54,8 +54,8 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const ffi::Arra const auto& weight_type = weight.defined() ? weight->dtype : data_type; TVM_FFI_ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type"; - const auto& data_fp32 = cast(data, DataType::Float(32)); - const auto& weight_fp32 = cast(weight, DataType::Float(32)); + const auto& data_fp32 = cast(data, PrimType::Float(32)); + const auto& weight_fp32 = cast(weight, PrimType::Float(32)); auto square = multiply(data_fp32, data_fp32); auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true); @@ -63,7 +63,7 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const ffi::Arra auto ndim = data_fp32->shape.size(); TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); - auto reduce_extent = MakeConst(data_fp32->dtype, 1); + auto reduce_extent = MakeConst(PrimType(data_fp32->dtype), 1); for (int i : real_axis) { reduce_extent *= data_fp32->shape[i]; } @@ -74,8 +74,8 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const ffi::Arra non_reduce_indices.push_back(indices[i]); } } - auto output = - tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + MakeConst(data_type, epsilon)); + auto output = tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + + MakeConst(PrimType(data_type), epsilon)); return output; }; auto rsqrt_shape = ffi::Array(); diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index e6b4c5af1dea..fbea4a57eabf 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -259,7 +259,7 @@ inline Tensor CommReduceIdx(const Tensor& data, const ffi::Optional(ffi::Array lhs, ffi::Array rhs)>; /*! \brief An initializer function for a reduction */ -using FIdentity = std::function(std::vector types)>; +using FIdentity = std::function(std::vector types)>; /*! * \brief Create a commutative reducer for a reduction @@ -275,10 +275,10 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, return [fcombine, fidentity, name](ffi::Array exprs, const ffi::Array& axis, PrimExpr* condition) { ffi::Array lhs, rhs; - std::vector dtypes; + std::vector dtypes; for (size_t i = 0; i < exprs.size(); ++i) { - auto dtype = exprs[i].dtype(); + PrimType dtype = exprs[i].ty(); dtypes.push_back(dtype); lhs.push_back(var(name + "_lhs_" + std::to_string(i), dtype)); rhs.push_back(var(name + "_rhs_" + std::to_string(i), dtype)); @@ -330,7 +330,8 @@ inline PrimExpr ProdOp(PrimExpr source, ffi::Array axis, ffi::Array>& axis, bool keepdims = false, bool atleast1d = false) { - if (data->dtype.is_bool()) { + // Reduction dispatch only depends on boolean element kind; lane encoding is irrelevant here. + if (data->dtype.code() == DLDataTypeCode::kDLBool) { return CommReduce(data, axis, tvm::any, keepdims, atleast1d); } else { return CommReduce(data, axis, tvm::sum, keepdims, atleast1d); @@ -477,7 +478,7 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { result.push_back(tvm::tirx::Select(is_smaller, lhs[1], rhs[1])); // val return result; }; - auto fidentity = [&](std::vector types) { + auto fidentity = [&](std::vector types) { ffi::Array result; result.push_back(tvm::tirx::MakeConst(types[0], -1)); // idx result.push_back(tvm::max_value(types[1])); // val @@ -539,7 +540,7 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { result.push_back(tvm::tirx::Select(is_bigger, lhs[1], rhs[1])); // val return result; }; - auto fidentity = [&](std::vector types) { + auto fidentity = [&](std::vector types) { ffi::Array result; result.push_back(tvm::tirx::MakeConst(types[0], -1)); // idx result.push_back(tvm::min_value(types[1])); // val @@ -601,7 +602,7 @@ inline FCommReduce MakeTupleSumReducer() { } return result; }; - auto fidentity = [](std::vector types) { + auto fidentity = [](std::vector types) { ffi::Array result; for (size_t i = 0; i < types.size(); ++i) { result.push_back(tvm::tirx::MakeConst(types[i], 0)); diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index e216cf86ced4..f2ede7af8aa0 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -44,8 +44,8 @@ #include #include +#include "tvm/ffi/dtype.h" #include "tvm/ir/expr.h" -#include "tvm/runtime/data_type.h" #include "tvm/tirx/expr.h" #include "tvm/tirx/op.h" #include "tvm/tirx/var.h" @@ -338,7 +338,8 @@ inline Tensor reshape(const Tensor& x, ffi::Array newshape, // If either the input shape or the target shape contains a zero, return an empty tensor. if (is_empty_shape(target_shape) || is_empty_shape(x->shape)) { return compute( - target_shape, [&](const ffi::Array& indices) { return tvm::cast(x->dtype, 0); }, name, + target_shape, + [&](const ffi::Array& indices) { return tvm::cast(PrimType(x->dtype), 0); }, name, tag); } else { return compute( @@ -679,7 +680,7 @@ inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stri if (index->IsInstance() && extent->IsInstance() && stride->IsInstance()) { return tvm::IntImm( - tvm::DataType::Int(64), + tvm::PrimType::Int(64), StaticCanonicalizeIndex(GetConstInt(index), GetConstInt(extent), GetConstInt(stride))); } return DynamicCanonicalizeIndex(index, extent, stride); @@ -835,14 +836,14 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b bool assume_inbound = true, std::string name = "T_strided_slice_dynamic", std::string tag = topi::kInjective) { - DataType index_dtype = begin->shape[0]->dtype; + PrimType index_ty = begin->shape[0].ty(); const int64_t num_dynamic_axes = begin->shape[0].as()->value; TVM_FFI_ICHECK_EQ(end->shape[0].as()->value, num_dynamic_axes); TVM_FFI_ICHECK_EQ(strides->shape[0].as()->value, num_dynamic_axes); ffi::Array begin_expr, end_expr, strides_expr; for (int64_t i = 0; i < num_dynamic_axes; ++i) { - auto ind = MakeConst(index_dtype, i); + auto ind = MakeConst(index_ty, i); begin_expr.push_back(begin(ind)); end_expr.push_back(end(ind)); strides_expr.push_back(strides(ind)); @@ -874,10 +875,10 @@ inline ffi::Array StridedSliceOutputShape(const ffi::Array& axes.size() == strides.size()); std::vector begin_vec, end_vec, strides_vec; std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); - DataType index_dtype = - (begin.size() > 0 && begin[0].defined()) ? begin[0].value()->dtype : DataType::Int(64); + PrimType index_ty = + (begin.size() > 0 && begin[0].defined()) ? begin[0].value().ty() : PrimType::Int(64); auto begin_canonicalized = - StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes, index_dtype, slice_mode); + StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes, index_ty, slice_mode); return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode, begin_canonicalized, true); } @@ -924,10 +925,10 @@ inline Tensor strided_slice_with_axes( std::vector begin_vec, end_vec, strides_vec; std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); - DataType index_dtype = - (begin.size() > 0 && begin[0].defined()) ? begin[0].value()->dtype : DataType::Int(64); + PrimType index_ty = + (begin.size() > 0 && begin[0].defined()) ? begin[0].value().ty() : PrimType::Int(64); auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, normalized_axes, - index_dtype, slice_mode); + index_ty, slice_mode); auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec, normalized_axes, slice_mode, begin_expr); @@ -938,7 +939,7 @@ inline Tensor strided_slice_with_axes( for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]); for (size_t i = 0; i < normalized_axes.size(); ++i) { int64_t ax = normalized_axes[i]; - auto stride = MakeConst(strides[i]->dtype, strides_vec[i]); + auto stride = MakeConst(strides[i]->ty(), strides_vec[i]); PrimExpr ind = indices[ax] * stride + begin_expr[i]; real_indices.Set(ax, ind); } @@ -972,11 +973,11 @@ inline Tensor strided_slice(const Tensor& x, const ffi::Array> end_full(end); ffi::Array strides_full(strides); - DataType index_dtype = - (begin.size() > 0 && begin[0].defined()) ? begin[0].value()->dtype : DataType::Int(64); - const IntImm one = IntImm(index_dtype, 1); - const IntImm zero = IntImm(index_dtype, 0); - const IntImm max_range = max_value(index_dtype).as_or_throw(); + PrimType index_ty = + (begin.size() > 0 && begin[0].defined()) ? begin[0].value().ty() : PrimType::Int(64); + const IntImm one = IntImm(index_ty, 1); + const IntImm zero = IntImm(index_ty, 0); + const IntImm max_range = max_value(index_ty).as_or_throw(); for (size_t i = strides.size(); i < src_tensor_dim; ++i) { strides_full.push_back(one); @@ -1073,7 +1074,8 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, [&](const ffi::Array& out_index) { auto idx = tvm::if_then_else( indices(out_index) < 0 || indices(out_index) >= a_size, - tvm::FloatImm(a->dtype, std::numeric_limits::quiet_NaN()), indices(out_index)); + tvm::FloatImm(tvm::PrimType(a->dtype), std::numeric_limits::quiet_NaN()), + indices(out_index)); return a(UnravelIndex(idx, a_shape)); }, name, tag); @@ -1116,9 +1118,9 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub auto tid = out_index[axis]; auto bid = out_index[1 - axis]; len_index.push_back(bid); - PrimExpr ret = - tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index), - tvm::tirx::MakeConst(data->dtype, mask_value), data(out_index)); + PrimExpr ret = tvm::if_then_else( + tvm::cast(PrimType(valid_length->dtype), tid) >= valid_length(len_index), + tvm::tirx::MakeConst(PrimType(data->dtype), mask_value), data(out_index)); return ret; }, name, tag); @@ -1293,7 +1295,7 @@ inline Tensor take(const Tensor& a, ffi::Variant indices, int PrimExpr in_bounds = idx >= 0 && idx < axis_dim; return tvm::if_then_else( in_bounds, a(real_indices), - tvm::tirx::MakeConst(a->dtype, std::numeric_limits::quiet_NaN())); + tvm::tirx::MakeConst(PrimType(a->dtype), std::numeric_limits::quiet_NaN())); }, name, tag); } else { // mode == "wrap" @@ -1443,8 +1445,8 @@ inline Tensor tile(const Tensor& x, ffi::Array reps, std::string name = if (is_empty_shape(new_shape)) { return compute( - new_shape, [&](const ffi::Array& indices) { return tvm::cast(x->dtype, 0); }, name, - tag); + new_shape, [&](const ffi::Array& indices) { return tvm::cast(PrimType(x->dtype), 0); }, + name, tag); } else { return compute( new_shape, @@ -1478,8 +1480,8 @@ inline Tensor dyn_tile(const Tensor& x, ffi::Array new_shape, size_t r size_t ndim = x->shape.size(); if (is_empty_shape(new_shape)) { return compute( - new_shape, [&](const ffi::Array& indices) { return tvm::cast(x->dtype, 0); }, name, - tag); + new_shape, [&](const ffi::Array& indices) { return tvm::cast(PrimType(x->dtype), 0); }, + name, tag); } else { return compute( new_shape, @@ -1526,7 +1528,9 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, size_t indices_dim_i = static_cast(GetConstInt(indices->shape[axis])); TVM_FFI_ICHECK_GE(indices_dim_i, 1); } - TVM_FFI_ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()); + // Index tensors are validated by integer element kind; vector lane encoding is irrelevant here. + PrimType indices_ty = indices->dtype; + TVM_FFI_ICHECK(indices_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)); ffi::Array out_shape; for (size_t i = 0; i < ndim_i; ++i) { @@ -1593,10 +1597,13 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim } for (size_t i = 0; i < indices_dim0; ++i) { indices_position.Set(0, IntImm::Int32(i)); - if (indices->dtype.is_int() || indices->dtype.is_uint()) { + // Index tensors are validated by integer element kind; vector lane encoding is + // irrelevant for choosing whether an index cast is needed. + PrimType indices_ty = indices->dtype; + if (indices_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) { real_indices.push_back(indices(indices_position)); } else { - real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position))); + real_indices.push_back(tvm::cast(tvm::PrimType::Int(32), indices(indices_position))); } } if (real_indices.size() == ndim_d) { @@ -1740,10 +1747,15 @@ inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, ffi::ArrayCanProveGreaterEqual(step, 1)) { // fast path for integer arange when step is positive num_elem = tvm::floordiv((stop - start + step - 1), step); @@ -1752,8 +1764,8 @@ inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr num_elem = tvm::floordiv((start - stop - step - 1), -step); } else { // fallback path for non-integer or step of unknown sign - num_elem = tvm::cast(DefaultIndexType(), - tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step)); + num_elem = tvm::cast(PrimType(DefaultIndexType()), + tvm::ceil(tvm::cast(tvm::PrimType::Float(32), stop - start) / step)); } num_elem = analyzer->Simplify(num_elem); @@ -1845,7 +1857,8 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, for (size_t i = 0; i < src.ndim(); ++i) { in_range = in_range && (src_indices[i] < src->shape[i]); } - return if_then_else(in_range, src(src_indices), tvm::cast(src->dtype, PrimExpr(0))); + return if_then_else(in_range, src(src_indices), + tvm::cast(PrimType(src->dtype), PrimExpr(0))); }, name, tag, attrs); } @@ -1960,7 +1973,7 @@ inline Tensor meta_schedule_layout_transform( ffi::Array iter_domain; iter_domain.reserve(src->shape.size()); for (const PrimExpr& e : src->shape) { - iter_domain.push_back(Range::FromMinExtent(IntImm(e->dtype, 0), e)); + iter_domain.push_back(Range::FromMinExtent(IntImm(e.ty(), 0), e)); } ffi::Array post_transform_shape = index_map->MapShape(src->shape, analyzer); return compute( @@ -1980,7 +1993,7 @@ inline Tensor meta_schedule_layout_transform( * \param tag output tensor tag. * \return Tensor of input shape. */ -inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = "T_shape", +inline Tensor shape(const Tensor& src, PrimType dtype, const std::string name = "T_shape", const std::string tag = kInjective) { int ndim = static_cast(src->shape.size()); ffi::Array out_shape{ndim}; @@ -1997,6 +2010,11 @@ inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = name, tag); } +inline Tensor shape(const Tensor& src, DLDataType dtype, const std::string name = "T_shape", + const std::string tag = kInjective) { + return shape(src, PrimType(dtype), name, tag); +} + /*! * \brief Get the size of input tensor. * \param src the input tensor. @@ -2005,7 +2023,7 @@ inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = * \param tag output tensor tag. * \return Tensor of input shape. */ -inline te::Tensor tensor_size(const te::Tensor& src, const DataType& dtype, +inline te::Tensor tensor_size(const te::Tensor& src, PrimType dtype, const std::string& name = "tensor_size", const std::string& tag = kInjective) { int ndim = static_cast(src->shape.size()); @@ -2022,6 +2040,12 @@ inline te::Tensor tensor_size(const te::Tensor& src, const DataType& dtype, name, tag); } +inline te::Tensor tensor_size(const te::Tensor& src, DLDataType dtype, + const std::string& name = "tensor_size", + const std::string& tag = kInjective) { + return tensor_size(src, PrimType(dtype), name, tag); +} + /*! * \brief Returns a one-hot tensor where the locations repsented by indices take value on_value, other locations take value off_value. @@ -2037,7 +2061,7 @@ inline te::Tensor tensor_size(const te::Tensor& src, const DataType& dtype, * \return one-hot tensor. */ inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value, - int depth, int axis, const DataType& dtype, + int depth, int axis, PrimType dtype, ffi::Array oshape = ffi::Array(), const std::string name = "T_one_hot", const std::string tag = kInjective) { int true_axis = (axis == -1) ? indices->shape.size() : axis; @@ -2073,6 +2097,14 @@ inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const Prim name, tag); } +inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value, + int depth, int axis, DLDataType dtype, + ffi::Array oshape = ffi::Array(), + const std::string name = "T_one_hot", const std::string tag = kInjective) { + return one_hot(indices, on_value, off_value, depth, axis, PrimType(dtype), std::move(oshape), + name, tag); +} + /*! * \brief Get a dense tensor. * \param sparse_indices sparse_indices[i] contains sparse_values[i] will be placed. @@ -2088,7 +2120,9 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const PrimExpr& default_value, const std::string name = "T_sparse_to_dense", const std::string tag = kInjective) { - TVM_FFI_ICHECK(sparse_indices->dtype.is_int()) << "sparse_indices only accepts integer values"; + // Sparse indices are validated by signed integer element kind; lane encoding is irrelevant here. + TVM_FFI_ICHECK_EQ(sparse_indices->dtype.code(), DLDataTypeCode::kDLInt) + << "sparse_indices only accepts integer values"; TVM_FFI_ICHECK_LE(sparse_indices->shape.size(), 3) << "sparse_indices tensor should be 0D, 1D, or 2D only"; TVM_FFI_ICHECK_LE(sparse_values->shape.size(), 2) diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 4fbebeddd0f5..dd463150fd51 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -43,7 +43,18 @@ class PrimExpr(BaseExpr): optimizations and integer analysis. """ - dtype: str + @property + def dtype(self): + """Compatibility alias for the runtime dtype of scalar PrimExpr. + + New code should inspect ``expr.ty`` directly. For scalar primitive + expressions, use ``expr.ty.dtype``. + """ + if self.ty is None: + return None + if hasattr(self.ty, "dtype"): + return self.ty.dtype + return "handle" @tvm_ffi.register_object("ir.RelaxExpr") diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index 567ebafa2d5c..96548439d70e 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -53,6 +53,35 @@ class PrimType(Type): def __init__(self, dtype): self.__init_handle_by_constructor__(_ffi_api.PrimType, dtype) + def __eq__(self, other): + if isinstance(other, str): + return self.dtype == other + return super().__eq__(other) + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + dtype = self.dtype + return hash((dtype.type_code, dtype.bits, dtype.lanes)) + + def __str__(self): + return str(self.dtype) + + def matches_code(self, *codes) -> bool: + """Return whether this type has any of the given DLPack dtype codes.""" + type_code = self.dtype.type_code + return any(type_code == int(code) for code in codes) + + def matches_element_type(self, code, bits: int) -> bool: + """Return whether this type has the given scalar element code and bits.""" + dtype = self.dtype + return dtype.type_code == int(code) and dtype.bits == bits + + def is_scalar(self) -> bool: + """Return whether this type has exactly one fixed lane.""" + return self.dtype.lanes == 1 + @tvm_ffi.register_object("ir.PointerType") class PointerType(Type): diff --git a/python/tvm/relax/frontend/nn/extern.py b/python/tvm/relax/frontend/nn/extern.py index 9c8efce690f1..6c7f3dc72c9f 100644 --- a/python/tvm/relax/frontend/nn/extern.py +++ b/python/tvm/relax/frontend/nn/extern.py @@ -145,7 +145,7 @@ def shape_dtype_inference(a, b): // those headers are guaranteed to be available #include - #include + #include #include namespace { diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index f987f48d4251..b9ab88da0b43 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -29,6 +29,7 @@ import tvm_ffi from tvm import relax, tirx +from tvm.runtime import DataTypeCode class BaseFXGraphImporter(metaclass=abc.ABCMeta): @@ -566,7 +567,7 @@ def _pow(self, node: fx.Node) -> relax.Var: if ( isinstance(lhs, relax.Expr) and isinstance(lhs.ty, relax.TensorType) - and "int" in lhs.ty.dtype + and lhs.ty.dtype.matches_code(DataTypeCode.INT, DataTypeCode.UINT) and isinstance(rhs, int) and not isinstance(rhs, bool) and rhs >= 0 @@ -1607,7 +1608,7 @@ def transpose_and_reshape_back(tensor): if attn_mask is not None: attn_mask = self.env[attn_mask] msg = "Only a float mask is supported for the attn_mask input." - assert "float" in attn_mask.ty.dtype, msg + assert attn_mask.ty.dtype.matches_code(DataTypeCode.FLOAT, DataTypeCode.BFLOAT), msg attention_output = self.block_builder.emit( relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) diff --git a/python/tvm/relax/op/_op_gradient.py b/python/tvm/relax/op/_op_gradient.py index 5aae26b75f20..809a2c19ee9e 100644 --- a/python/tvm/relax/op/_op_gradient.py +++ b/python/tvm/relax/op/_op_gradient.py @@ -22,6 +22,7 @@ from tvm import relax from tvm.arith import Analyzer +from tvm.ir import PrimType from tvm.relax.type import ShapeType from ...tirx import PrimExpr @@ -81,6 +82,8 @@ def _get_dtype(expr: Expr) -> str: raise RuntimeError( f"Get the dtype of {expr} failed. Please normalize it first and ensure it is a Tensor." ) from error + if isinstance(dtype, PrimType): + dtype = dtype.dtype return dtype diff --git a/python/tvm/relax/op/create.py b/python/tvm/relax/op/create.py index 9d28ed92f9c5..1bbeeee8f272 100644 --- a/python/tvm/relax/op/create.py +++ b/python/tvm/relax/op/create.py @@ -17,6 +17,7 @@ """Creation operators.""" from tvm import DataType, DataTypeCode +from tvm.ir import PrimType from tvm.ir.expr import PrimExpr from ..expr import Expr, PrimValue, ShapeExpr @@ -267,7 +268,12 @@ def is_int(expr): return True if isinstance(expr, PrimValue): expr = expr.value - return isinstance(expr, PrimExpr) and DataType(expr.dtype).type_code == DataTypeCode.INT # type: ignore + if isinstance(expr, PrimExpr): + dtype = expr.dtype # type: ignore + if isinstance(dtype, PrimType): + dtype = dtype.dtype + return DataType(dtype).type_code == DataTypeCode.INT + return False if dtype is None: args = (start, end, step) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 4b787c265bc3..43a2bd400351 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -19,6 +19,7 @@ from collections.abc import Callable from tvm.ir.expr import PrimExpr +from tvm.runtime import DataTypeCode from tvm.tirx import FloatImm, IndexMap, IntImm from ..expr import Expr, PrimValue, ShapeExpr @@ -151,10 +152,12 @@ def layout_transform( if pad_value is None: pass elif not isinstance(pad_value, PrimValue): - if "int" in x_dtype and isinstance(pad_value, int): - pad_value = IntImm(x_dtype, pad_value) - elif "float" in x_dtype and (isinstance(pad_value, int | float)): - pad_value = FloatImm(x_dtype, float(pad_value)) + if x_dtype.matches_code(DataTypeCode.INT, DataTypeCode.UINT) and isinstance(pad_value, int): + pad_value = IntImm(x_dtype.dtype, pad_value) + elif x_dtype.matches_code(DataTypeCode.FLOAT, DataTypeCode.BFLOAT) and ( + isinstance(pad_value, int | float) + ): + pad_value = FloatImm(x_dtype.dtype, float(pad_value)) pad_value = PrimValue(pad_value) if axis_separators is None: diff --git a/python/tvm/relax/transform/legalize_ops/common.py b/python/tvm/relax/transform/legalize_ops/common.py index 1b7d1179a521..f464c248e363 100644 --- a/python/tvm/relax/transform/legalize_ops/common.py +++ b/python/tvm/relax/transform/legalize_ops/common.py @@ -20,6 +20,7 @@ import tvm from tvm import te +from tvm.runtime import DataTypeCode from tvm.tirx import FloatImm, IntImm from ...block_builder import BlockBuilder @@ -38,9 +39,6 @@ LegalizeFunc = Callable[[BlockBuilder, Call], Expr] -##################### Utilities ##################### - - def _try_convert_to_scalar_const( expr: Expr, python_native: bool = False ) -> Expr | FloatImm | IntImm | bool | float | int: @@ -69,13 +67,14 @@ def _try_convert_to_scalar_const( # get the value of the scalar constant value = expr.data.numpy()[()].item() dtype = expr.ty.dtype + dtype_str = str(dtype.dtype) if python_native: return value # preserve the data type of the constant - if dtype.startswith("float"): - return tvm.tirx.FloatImm(dtype, value) - elif dtype.startswith("int") or dtype.startswith("uint") or dtype.startswith("bool"): - return tvm.tirx.IntImm(dtype, value) + if dtype.matches_code(DataTypeCode.FLOAT, DataTypeCode.BFLOAT): + return tvm.tirx.FloatImm(dtype_str, value) + elif dtype.matches_code(DataTypeCode.INT, DataTypeCode.UINT, DataTypeCode.BOOL): + return tvm.tirx.IntImm(dtype_str, value) return expr diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index f0cc8977d4ef..a59b1f9fe52e 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -19,7 +19,7 @@ """Default legalization function for manipulate operators.""" import tvm -from tvm import relax, s_tir, te, tirx, topi +from tvm import DataTypeCode, relax, s_tir, te, tirx, topi from tvm.relax.op.base import call_tir from tvm.relax.type import TensorType from tvm.relax.utils import gen_call_tir_inputs @@ -337,7 +337,7 @@ def set_axis_sep(axis_sep: list, sch: s_tir.schedule, buffer_type: str): if pad_value is not None: pad_value = pad_value.value else: - if "int" in call.args[0].ty.dtype: + if call.args[0].ty.dtype.matches_code(DataTypeCode.INT, DataTypeCode.UINT): pad_value = 0 else: pad_value = 0.0 diff --git a/python/tvm/relax/transform/legalize_ops/qdq.py b/python/tvm/relax/transform/legalize_ops/qdq.py index aa86f6fca2c3..7a825e300e40 100644 --- a/python/tvm/relax/transform/legalize_ops/qdq.py +++ b/python/tvm/relax/transform/legalize_ops/qdq.py @@ -19,6 +19,7 @@ import tvm from tvm import te, tirx +from tvm.runtime import DataTypeCode from ...block_builder import BlockBuilder from ...expr import Call, Expr @@ -140,7 +141,11 @@ def dequantize_compute(*indices): zp_value = zp[(0,) * len(zp.shape)] else: zp_value = zp[indices[axis]] - dtype = "float32" if "float" in data.dtype else "int32" + dtype = ( + "float32" + if data.dtype.matches_code(DataTypeCode.FLOAT, DataTypeCode.BFLOAT) + else "int32" + ) sub = te.subtract(data[indices].astype(dtype), zp_value) out = te.multiply(sub, scale_value.astype("float32")) if out_dtype == "float32": diff --git a/python/tvm/relax/type.py b/python/tvm/relax/type.py index ad8f469826ef..305f01750306 100644 --- a/python/tvm/relax/type.py +++ b/python/tvm/relax/type.py @@ -21,7 +21,7 @@ import tvm_ffi from tvm_ffi import Array -from tvm.ir import EnvFunc, PrimExpr, Span, TupleType, VDevice +from tvm.ir import EnvFunc, PrimExpr, PrimType, Span, TupleType, VDevice from . import _ffi_api from .expr import Expr, ShapeExpr, Type @@ -92,7 +92,7 @@ class TensorType(Type): """ shape: Expr | None - dtype: str + dtype: PrimType vdevice: VDevice | None ndim: int span: Span @@ -100,13 +100,15 @@ class TensorType(Type): def __init__( self, shape: Expr | None | list[PrimExpr] = None, - dtype: str = "float32", + dtype: str | PrimType | None = "float32", vdevice: VDevice | None | str = None, ndim: int = -1, span: Span = None, ) -> None: if isinstance(shape, list | tuple | Array): shape = ShapeExpr(shape) + if dtype is not None and not isinstance(dtype, PrimType): + dtype = PrimType(dtype) self.__init_handle_by_constructor__( _ffi_api.TensorType, shape, diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index 51c8805f9445..505613d0372e 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -66,5 +66,9 @@ def const(value, dtype=None, span=None): if dtype is None: dtype = _scalar_type_inference(value) if dtype == "uint64" and value >= (1 << 63): - return _ffi_node_api.LargeUIntImm(dtype, value & ((1 << 32) - 1), value >> 32, span) + from tvm.ir import PrimType # pylint: disable=import-outside-toplevel + + return _ffi_node_api.LargeUIntImm( + PrimType(dtype), value & ((1 << 32) - 1), value >> 32, span + ) return _ffi_node_api._const(value, dtype, span) diff --git a/python/tvm/s_tir/schedule/schedule.py b/python/tvm/s_tir/schedule/schedule.py index 7f191df98d84..25b81239189d 100644 --- a/python/tvm/s_tir/schedule/schedule.py +++ b/python/tvm/s_tir/schedule/schedule.py @@ -24,7 +24,7 @@ from tvm.error import register_error from tvm.ir import GlobalVar, IRModule, PrimExpr -from tvm.runtime import Object +from tvm.runtime import DataTypeCode, Object from tvm.tirx import Buffer, FloatImm, For, IntImm, PrimFunc, SBlock from tvm.tirx.function import IndexMap @@ -3465,10 +3465,14 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> # buffer's type. If the default `tvm.runtime.convert` # behavior is applied, these would be converted to # int32/float32, which may not match the buffer's type. - if "int" in buffer_obj.dtype and isinstance(pad_value, int): - pad_value = IntImm(buffer_obj.dtype, pad_value) - elif "float" in buffer_obj.dtype and isinstance(pad_value, float): - pad_value = FloatImm(buffer_obj.dtype, pad_value) + if buffer_obj.dtype.matches_code(DataTypeCode.INT, DataTypeCode.UINT) and isinstance( + pad_value, int + ): + pad_value = IntImm(buffer_obj.dtype.dtype, pad_value) + elif buffer_obj.dtype.matches_code(DataTypeCode.FLOAT, DataTypeCode.BFLOAT) and ( + isinstance(pad_value, float) + ): + pad_value = FloatImm(buffer_obj.dtype.dtype, pad_value) pad_value = IndexMap.from_func( lambda *indices: pad_value, ndim=len(index_map.final_indices), diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 4d38292b9b56..0461e56ec984 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -396,7 +396,11 @@ def _eval_if_exp(self, fields: dict[str, Any]) -> Any: orelse = self._eval_expr(fields["orelse"]) if isinstance(test, bool): return body if test else orelse - elif isinstance(test, tvm.tirx.PrimExpr) and test.dtype == "bool": + elif ( + isinstance(test, tvm.tirx.PrimExpr) + and isinstance(test.ty, tvm.ir.PrimType) + and test.ty.matches_code(tvm.DataTypeCode.BOOL) + ): return tvm.tirx.op.if_then_else(test, body, orelse) else: raise TypeError(f"Expected Python bool or TIR bool, but got {type(test)}") diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 531915c6798a..b7238cf07eda 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -19,6 +19,7 @@ # pylint: disable=invalid-name import tvm_ffi +from tvm.ir import PrimType from tvm.runtime import Object, ObjectConvertible from tvm.tirx import DataProducer from tvm.tirx import expr as _expr @@ -49,6 +50,10 @@ def dtype(self): """Data content of the tensor.""" return self.tensor.dtype + def expr_ty(self): + """Compile-time element type of the tensor.""" + return self.tensor.expr_ty() + @tvm_ffi.register_object("te.Tensor") class Tensor(DataProducer, _expr.ExprOp): @@ -86,6 +91,15 @@ def ndim(self): """Dimension of the tensor.""" return len(self.shape) + @property + def dtype(self): + """Data content of the tensor.""" + return PrimType(_ffi_api.TensorDType(self)) + + def expr_ty(self): + """Compile-time element type of the tensor.""" + return self.dtype + @property def name(self): op = self.op diff --git a/python/tvm/tirx/buffer.py b/python/tvm/tirx/buffer.py index 4caf154547fa..43023b4c3cb9 100644 --- a/python/tvm/tirx/buffer.py +++ b/python/tvm/tirx/buffer.py @@ -352,7 +352,7 @@ def _infer_shape(shape): shape = args assert all( isinstance(arg, int) - or (isinstance(arg, PrimExpr) and arg.dtype in ["int32", "int64"]) + or (isinstance(arg, PrimExpr) and arg.ty.dtype in ["int32", "int64"]) for arg in shape ), "shape must be a list of integers or PrimExprs with dtype int32 or int64" # Safely get optional keyword arguments @@ -462,7 +462,7 @@ def permute(self, *dims) -> "Buffer": def __getitem__(self, indices): from ..arith import Analyzer # pylint: disable=import-outside-toplevel - from .expr import BufferLoad, Ramp, const # pylint: disable=import-outside-toplevel + from .expr import BufferLoad, Ramp # pylint: disable=import-outside-toplevel from .stmt import BufferRegion # pylint: disable=import-outside-toplevel if not isinstance(indices, tuple | list): @@ -483,7 +483,8 @@ def __getitem__(self, indices): else: region.append( Range.from_min_extent( - index, const(1, index.dtype) if isinstance(index, PrimExpr) else 1 + index, + tvm.tirx.expr.IntImm(index.ty, 1) if isinstance(index, PrimExpr) else 1, ) ) if has_implicit_slice: @@ -499,7 +500,7 @@ def __getitem__(self, indices): step = 1 if index.step is None else index.step # We should ensure the dtype of start is the same with that of step. if isinstance(start, tvm.tirx.expr.PrimExpr) and isinstance(step, int): - step = tvm.tirx.expr.IntImm(start.dtype, step) + step = tvm.tirx.expr.IntImm(start.ty, step) lanes = analyzer.simplify((stop - start + step - 1) // step) if lanes == 1: expr_indices.append(start) @@ -540,11 +541,11 @@ def decl_buffer( layout = TileLayout(S[tuple(shape)]) if shape else None if offset_factor != 0 and elem_offset is None: - shape_dtype = shape[0].dtype if shape and hasattr(shape[0], "dtype") else "int32" - elem_offset = Var(f"{name}_elem_offset", shape_dtype) + shape_ty = shape[0].ty if shape and isinstance(shape[0], PrimExpr) else "int32" + elem_offset = Var(f"{name}_elem_offset", shape_ty) if data is None: # Bool is represented as uint1 in the IR, but stored as int8 - storage_type = PrimType(dtype) + storage_type = dtype if isinstance(dtype, PrimType) else PrimType(dtype) storage_type = PrimType("int8") if storage_type.dtype == "bool" else storage_type data = Var(name, PointerType(storage_type, scope), span) return _ffi_api.Buffer( # type: ignore diff --git a/python/tvm/tirx/expr.py b/python/tvm/tirx/expr.py index a97171e436ae..ec744acf5093 100644 --- a/python/tvm/tirx/expr.py +++ b/python/tvm/tirx/expr.py @@ -34,7 +34,7 @@ from tvm import ir from tvm.ir import Op, PrimExpr from tvm.ir.base import Span -from tvm.runtime import DataType, DataTypeCode, Object, ObjectConvertible, Scriptable, const +from tvm.runtime import DataTypeCode, Object, ObjectConvertible, Scriptable, const from . import _ffi_api from . import generic as _generic @@ -56,13 +56,17 @@ def div_ambiguity_error() -> RuntimeError: def _dtype_is_int(value): if isinstance(value, int): return True - return isinstance(value, ExprOp) and DataType(value.dtype).type_code == DataTypeCode.INT # type: ignore + if isinstance(value, ExprOp): + return value.expr_ty().matches_code(DataTypeCode.INT) + return False def _dtype_is_float(value): if isinstance(value, float): return True - return isinstance(value, ExprOp) and DataType(value.dtype).type_code == DataTypeCode.FLOAT # type: ignore + if isinstance(value, ExprOp): + return value.expr_ty().matches_code(DataTypeCode.FLOAT) + return False class ExprOp: @@ -70,6 +74,13 @@ class ExprOp: # TODO(tkonolige): use inspect to add source information to these objects + def expr_ty(self) -> ir.PrimType: + """Return the compile-time primitive type for expression operators.""" + ty = getattr(self, "ty", None) + if isinstance(ty, ir.PrimType): + return ty + raise TypeError(f"Cannot determine PrimType for {type(self).__name__}") + def __add__(self, other: PrimExpr) -> PrimExpr: return _generic.add(self, other) @@ -121,7 +132,7 @@ def __rmod__(self, other: PrimExpr) -> PrimExpr: return _ffi_api._OpFloorMod(other, self, None) # type: ignore def __neg__(self) -> PrimExpr: - neg_one = const(-1, self.dtype) # type: ignore + neg_one = const(-1, self.expr_ty().dtype) return self.__mul__(neg_one) def __lshift__(self, other: PrimExpr) -> PrimExpr: @@ -204,7 +215,7 @@ def equal(self, other: PrimExpr, span: Span | None = None) -> bool: """ return _ffi_api._OpEQ(self, other, span) # type: ignore - def astype(self, dtype: str, span: Span | None = None) -> PrimExpr: + def astype(self, dtype: str | ir.PrimType, span: Span | None = None) -> PrimExpr: """Cast the expression to other type. Parameters @@ -259,6 +270,10 @@ def asobject(self) -> PrimExpr: """Convert object.""" return _ffi_api._OpEQ(self.a, self.b, self.span) # type: ignore + def expr_ty(self) -> ir.PrimType: + """Compile-time type of the equality result.""" + return ir.PrimType("bool") + def __repr__(self) -> str: return f"EqualOp({self.a!r}, {self.b!r})" @@ -299,6 +314,10 @@ def asobject(self) -> PrimExpr: """Convert object.""" return _ffi_api._OpNE(self.a, self.b, self.span) # type: ignore + def expr_ty(self) -> ir.PrimType: + """Compile-time type of the inequality result.""" + return ir.PrimType("bool") + def __repr__(self) -> str: return f"NotEqualOp({self.a!r}, {self.b!r})" @@ -458,12 +477,10 @@ def __init__( raise TypeError("dom need to be Range") name = var if var is not None else "iter" - dtype = "int32" if dom is None else dom.extent.dtype + dtype = "int32" if dom is None else dom.extent.ty var = Var(name, dtype=dtype, span=span) if not isinstance(var, Var) else var if dom is not None: - assert var.dtype == dom.extent.dtype, ( - "IterVar's Var dtype must match its domain's extent's dtype" - ) + assert var.ty == dom.extent.ty, "IterVar's Var type must match its domain's extent type" self.__init_handle_by_constructor__( _ffi_api.IterVar, dom, @@ -473,6 +490,10 @@ def __init__( span, # type: ignore ) + def expr_ty(self) -> ir.PrimType: + """Compile-time type of the iteration variable.""" + return self.var.ty + @tvm_ffi.register_object("tirx.CommReducer") class CommReducer(Object, Scriptable): @@ -595,7 +616,9 @@ class FloatImm(ConstExpr): value: float - def __init__(self, dtype: str, value: float, span: Span | None = None) -> None: + def __init__(self, dtype: str | ir.PrimType, value: float, span: Span | None = None) -> None: + if isinstance(dtype, ir.PrimType): + dtype = dtype.dtype self.__init_handle_by_constructor__( tvm.ir._ffi_api.FloatImm, dtype, @@ -625,7 +648,9 @@ class IntImm(ConstExpr): value: int - def __init__(self, dtype: str, value: int, span: Span | None = None) -> None: + def __init__(self, dtype: str | ir.PrimType, value: int, span: Span | None = None) -> None: + if isinstance(dtype, ir.PrimType): + dtype = dtype.dtype self.__init_handle_by_constructor__( tvm.ir._ffi_api.IntImm, dtype, @@ -702,7 +727,9 @@ class Cast(PrimExprWithOp): value: PrimExpr - def __init__(self, dtype, value, span: Span | None = None) -> None: + def __init__(self, dtype: str | ir.PrimType, value, span: Span | None = None) -> None: + if isinstance(dtype, ir.PrimType): + dtype = dtype.dtype self.__init_handle_by_constructor__(_ffi_api.Cast, dtype, value, span) # type: ignore @@ -1313,7 +1340,7 @@ class Call(PrimExprWithOp): def __init__( self, - dtype: str, + dtype: str | ir.PrimType, op: Op | str, args: list[PrimExpr], attrs: ir.Attrs | dict | None = None, @@ -1332,6 +1359,8 @@ def __init__( op = Op.get(op) if isinstance(attrs, dict): attrs = ir.make_node("ir.DictAttrs", **attrs) + if not isinstance(dtype, ir.PrimType): + dtype = ir.PrimType(dtype) if attrs: self.__init_handle_by_constructor__( # type: ignore _ffi_api.CallWithAttrs, dtype, op, args, attrs, span diff --git a/python/tvm/tirx/layout.py b/python/tvm/tirx/layout.py index 29a19d746dee..11d1e140ae16 100644 --- a/python/tvm/tirx/layout.py +++ b/python/tvm/tirx/layout.py @@ -332,10 +332,10 @@ def _get_default_strides(data: list[int | PrimExpr], stride: int = 1) -> tuple: # produce for int64-shaped buffers (otherwise the last stride stays a # Python ``int`` -> int32 IntImm and breaks structural-equal). for t in data: - if isinstance(t, PrimExpr) and t.dtype != "int32": + if isinstance(t, PrimExpr) and t.ty.dtype != "int32": from .expr import IntImm # pylint: disable=import-outside-toplevel - stride = IntImm(t.dtype, stride) + stride = IntImm(t.ty, stride) break res = list() for t in reversed(data): diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py index a7a2889c444b..9a54e915bb0b 100644 --- a/python/tvm/tirx/op.py +++ b/python/tvm/tirx/op.py @@ -31,7 +31,7 @@ from . import _ffi_api from .buffer import Buffer -from .expr import BufferLoad, Call, CommReducer, IntImm, PrimExprWithOp, Var +from .expr import BufferLoad, Call, CommReducer, ExprOp, IntImm, PrimExprWithOp, Var tir = tirx # alias for backward compat with upstream tir.convert() calls @@ -57,6 +57,24 @@ def _canonical_device_intrin_name(func_name: str) -> str: return func_name +def _primexpr_ty(expr): + """Return the runtime primitive type of an expression.""" + ty = getattr(expr, "ty", None) + if isinstance(ty, tvm.ir.PrimType): + return ty + if isinstance(expr, ExprOp): + return expr.expr_ty() + raise TypeError(f"Cannot determine PrimExpr type for {type(expr).__name__}") + + +def _primexpr_dtype(expr): + """Return the runtime dtype of a primitive expression without using PrimExpr.dtype.""" + ty = _primexpr_ty(expr) + if not isinstance(ty, tvm.ir.PrimType): + raise TypeError(f"Expected PrimType for {type(expr).__name__}, but got {ty}") + return ty.dtype + + def _pack_buffer(buf, span=None): """Build intrinsics that packs the buffer.""" shape = Call("handle", "tirx.tvm_stack_make_shape", buf.shape, span=span) @@ -187,7 +205,7 @@ def call_cpacked(*args, span=None): return Call("int32", Op.get("tirx.tvm_call_cpacked"), call_args, span=span) -def call_intrin(dtype, func_name, *args, attrs=None, span=None): +def call_intrin(dtype: str | tvm.ir.PrimType, func_name, *args, attrs=None, span=None): """Build expression by calling an intrinsic function. Intrinsics can be overloaded with multiple data types via @@ -272,8 +290,9 @@ def call_extern(dtype, func_name, *args, span=None): def _require_float_arg(op_name, x): x = tirx.convert(x) - if "float" not in x.dtype and "bfloat" not in x.dtype: - raise TypeError(f"tirx.{op_name} only supports floating-point inputs, but got {x.dtype}") + dtype = _primexpr_dtype(x) + if "float" not in dtype and "bfloat" not in dtype: + raise TypeError(f"tirx.{op_name} only supports floating-point inputs, but got {dtype}") return x @@ -476,8 +495,8 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args): dtype = "void" if global_var.ty is not None: ret_ty = global_var.ty.ret - if hasattr(ret_ty, "dtype"): - dtype = ret_ty.dtype + if isinstance(ret_ty, tvm.ir.PrimType): + dtype = ret_ty return Call(dtype=dtype, op=global_var, args=args) @@ -680,7 +699,7 @@ def tvm_thread_invariant(cond): The call expression. """ assert isinstance(cond, PrimExpr) - return call_intrin(cond.dtype, "tirx.tvm_thread_invariant", cond) + return call_intrin(_primexpr_ty(cond), "tirx.tvm_thread_invariant", cond) def tvm_storage_sync(storage_scope, is_load=False, num_blocks=-1): @@ -742,7 +761,9 @@ def tvm_warp_shuffle(mask, value, warp_id, width, warp_size): call : PrimExpr The call expression. """ - return call_intrin(value.dtype, "tirx.tvm_warp_shuffle", mask, value, warp_id, width, warp_size) + return call_intrin( + _primexpr_ty(value), "tirx.tvm_warp_shuffle", mask, value, warp_id, width, warp_size + ) def tvm_warp_shuffle_up(mask, value, offset, width, warp_size): @@ -768,7 +789,7 @@ def tvm_warp_shuffle_up(mask, value, offset, width, warp_size): The call expression. """ return call_intrin( - value.dtype, "tirx.tvm_warp_shuffle_up", mask, value, offset, width, warp_size + _primexpr_ty(value), "tirx.tvm_warp_shuffle_up", mask, value, offset, width, warp_size ) @@ -795,7 +816,7 @@ def tvm_warp_shuffle_down(mask, value, offset, width, warp_size): The call expression. """ return call_intrin( - value.dtype, "tirx.tvm_warp_shuffle_down", mask, value, offset, width, warp_size + _primexpr_ty(value), "tirx.tvm_warp_shuffle_down", mask, value, offset, width, warp_size ) @@ -821,7 +842,7 @@ def tvm_warp_shuffle_xor(mask, value, lane_mask, width, warp_size): The call expression. """ return call_intrin( - value.dtype, "tirx.tvm_warp_shuffle_xor", mask, value, lane_mask, width, warp_size + _primexpr_ty(value), "tirx.tvm_warp_shuffle_xor", mask, value, lane_mask, width, warp_size ) @@ -1208,7 +1229,8 @@ def trace(args, trace_action="tvm.default_trace_action"): raise Exception("tvm.tirx.trace consumes the args as list type") call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] call_args.insert(0, trace_action) - return tvm.tirx.Call(args[-1].dtype, Op.get("tirx.tvm_call_trace_packed"), call_args) + dtype = _primexpr_ty(args[-1]) if isinstance(args[-1], PrimExpr) else args[-1].dtype + return tvm.tirx.Call(dtype, Op.get("tirx.tvm_call_trace_packed"), call_args) def min_value(dtype, span=None): @@ -1304,7 +1326,7 @@ def exp(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.exp", x) + return call_intrin(_primexpr_ty(x), "tirx.exp", x) def exp2(x): @@ -1321,7 +1343,7 @@ def exp2(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.exp2", x) + return call_intrin(_primexpr_ty(x), "tirx.exp2", x) def exp10(x): @@ -1338,7 +1360,7 @@ def exp10(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.exp10", x) + return call_intrin(_primexpr_ty(x), "tirx.exp10", x) def fma(x, y, z): @@ -1363,7 +1385,7 @@ def fma(x, y, z): x = tir.convert(x) y = tir.convert(y) z = tir.convert(z) - return call_intrin(x.dtype, "tirx.fma", x, y, z) + return call_intrin(_primexpr_ty(x), "tirx.fma", x, y, z) def erf(x): @@ -1380,7 +1402,7 @@ def erf(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.erf", x) + return call_intrin(_primexpr_ty(x), "tirx.erf", x) def tanh(x): @@ -1397,7 +1419,7 @@ def tanh(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.tanh", x) + return call_intrin(_primexpr_ty(x), "tirx.tanh", x) def sigmoid(x): @@ -1414,7 +1436,7 @@ def sigmoid(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.sigmoid", x) + return call_intrin(_primexpr_ty(x), "tirx.sigmoid", x) def log(x): @@ -1431,7 +1453,7 @@ def log(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.log", x) + return call_intrin(_primexpr_ty(x), "tirx.log", x) def log2(x): @@ -1448,7 +1470,7 @@ def log2(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.log2", x) + return call_intrin(_primexpr_ty(x), "tirx.log2", x) def log10(x): @@ -1465,7 +1487,7 @@ def log10(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.log10", x) + return call_intrin(_primexpr_ty(x), "tirx.log10", x) def log1p(x): @@ -1482,7 +1504,7 @@ def log1p(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.log1p", x) + return call_intrin(_primexpr_ty(x), "tirx.log1p", x) def tan(x): @@ -1499,7 +1521,7 @@ def tan(x): The result. """ x = _require_float_arg("tan", x) - return call_intrin(x.dtype, "tirx.tan", x) + return call_intrin(_primexpr_ty(x), "tirx.tan", x) def cos(x): @@ -1516,7 +1538,7 @@ def cos(x): The result. """ x = _require_float_arg("cos", x) - return call_intrin(x.dtype, "tirx.cos", x) + return call_intrin(_primexpr_ty(x), "tirx.cos", x) def cosh(x): @@ -1533,7 +1555,7 @@ def cosh(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.cosh", x) + return call_intrin(_primexpr_ty(x), "tirx.cosh", x) def acos(x): @@ -1550,7 +1572,7 @@ def acos(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.acos", x) + return call_intrin(_primexpr_ty(x), "tirx.acos", x) def acosh(x): @@ -1567,7 +1589,7 @@ def acosh(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.acosh", x) + return call_intrin(_primexpr_ty(x), "tirx.acosh", x) def sin(x): @@ -1584,7 +1606,7 @@ def sin(x): The result. """ x = _require_float_arg("sin", x) - return call_intrin(x.dtype, "tirx.sin", x) + return call_intrin(_primexpr_ty(x), "tirx.sin", x) def sinh(x): @@ -1601,7 +1623,7 @@ def sinh(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.sinh", x) + return call_intrin(_primexpr_ty(x), "tirx.sinh", x) def asin(x): @@ -1618,7 +1640,7 @@ def asin(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.asin", x) + return call_intrin(_primexpr_ty(x), "tirx.asin", x) def asinh(x): @@ -1635,7 +1657,7 @@ def asinh(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.asinh", x) + return call_intrin(_primexpr_ty(x), "tirx.asinh", x) def atan(x): @@ -1652,7 +1674,7 @@ def atan(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.atan", x) + return call_intrin(_primexpr_ty(x), "tirx.atan", x) def atanh(x): @@ -1669,7 +1691,7 @@ def atanh(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.atanh", x) + return call_intrin(_primexpr_ty(x), "tirx.atanh", x) def atan2(x1, x2): @@ -1690,7 +1712,7 @@ def atan2(x1, x2): """ x1 = tir.convert(x1) x2 = tir.convert(x2) - return call_intrin(x1.dtype, "tirx.atan2", x1, x2) + return call_intrin(_primexpr_ty(x1), "tirx.atan2", x1, x2) def sqrt(x): @@ -1707,7 +1729,7 @@ def sqrt(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.sqrt", x) + return call_intrin(_primexpr_ty(x), "tirx.sqrt", x) def rsqrt(x): @@ -1724,7 +1746,7 @@ def rsqrt(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.rsqrt", x) + return call_intrin(_primexpr_ty(x), "tirx.rsqrt", x) def clz(x): @@ -1971,7 +1993,7 @@ def nextafter(x1, x2): """ x1 = tir.convert(x1) x2 = tir.convert(x2) - return call_intrin(x1.dtype, "tirx.nextafter", x1, x2) # type: ignore + return call_intrin(_primexpr_ty(x1), "tirx.nextafter", x1, x2) # type: ignore def hypot(x1, x2): @@ -1992,7 +2014,7 @@ def hypot(x1, x2): """ x1 = tir.convert(x1) x2 = tir.convert(x2) - return call_intrin(x1.dtype, "tirx.hypot", x1, x2) # type: ignore + return call_intrin(_primexpr_ty(x1), "tirx.hypot", x1, x2) # type: ignore def copysign(x1, x2): @@ -2013,7 +2035,7 @@ def copysign(x1, x2): """ x1 = tir.convert(x1) x2 = tir.convert(x2) - return call_intrin(x1.dtype, "tirx.copysign", x1, x2) # type: ignore + return call_intrin(_primexpr_ty(x1), "tirx.copysign", x1, x2) # type: ignore def ldexp(x1, x2): @@ -2034,7 +2056,7 @@ def ldexp(x1, x2): """ x1 = tir.convert(x1) x2 = tir.convert(x2) - return call_intrin(x1.dtype, "tirx.ldexp", x1, x2) # type: ignore + return call_intrin(_primexpr_ty(x1), "tirx.ldexp", x1, x2) # type: ignore def likely(cond, span=None): @@ -2086,7 +2108,7 @@ def selector(var, pred, span=None): active domain for which ``pred`` is true. It is intended for compiler metadata and should not survive to executable codegen. """ - return call_intrin(var.dtype, "tirx.selector", var, pred, span=span) + return call_intrin(_primexpr_ty(var), "tirx.selector", var, pred, span=span) def isnan(x, span=None): @@ -2223,7 +2245,7 @@ def popcount(x): The result. """ x = tir.convert(x) - return call_intrin(x.dtype, "tirx.popcount", x) + return call_intrin(_primexpr_ty(x), "tirx.popcount", x) def q_multiply_shift(x, y, q, s): @@ -2356,7 +2378,7 @@ def fmod(x, y): """ x = tir.convert(x) y = tir.convert(y) - return call_intrin(x.dtype, "tirx.fmod", x, y) + return call_intrin(_primexpr_ty(x), "tirx.fmod", x, y) def if_then_else(cond, t, f, span=None): @@ -2667,7 +2689,7 @@ def _make_reduce(expr, axis, where=None, init=None): rhs = [] dtypes = [] for i in range(size): - dtype = expr[i].dtype + dtype = _primexpr_dtype(expr[i]) dtypes.append(dtype) lname = code.co_varnames[0] + "_" + str(i) lhs.append(Var(lname, dtype)) @@ -2680,7 +2702,7 @@ def _make_reduce(expr, axis, where=None, init=None): else: assert isinstance(expr, tvm.ir.PrimExpr) size = 1 - dtype = expr.dtype + dtype = _primexpr_dtype(expr) lvar = Var(code.co_varnames[0], dtype) rvar = Var(code.co_varnames[1], dtype) result = [fcombine(lvar, rvar)] diff --git a/python/tvm/tirx/script/builder/external_kernel.py b/python/tvm/tirx/script/builder/external_kernel.py index d56ed9ea0384..68e597d3f8ff 100644 --- a/python/tvm/tirx/script/builder/external_kernel.py +++ b/python/tvm/tirx/script/builder/external_kernel.py @@ -28,6 +28,7 @@ from tvm import __version__ as tvm_version from tvm import tirx +from tvm.ir import PrimExpr from tvm.runtime import Module, const from tvm.support import nvcc @@ -136,8 +137,10 @@ def compile_to_device_module( # pylint: disable=arguments-differ "threadIdx.y", "threadIdx.z", ][: len(grid[1])] - runtime_args = [arg if hasattr(arg, "dtype") else const(arg) for arg in args] - kernel_arg_types = [arg.dtype for arg in runtime_args] + runtime_args = [arg if isinstance(arg, PrimExpr) else const(arg) for arg in args] + kernel_arg_types = [ + str(arg.ty.dtype) if isinstance(arg, PrimExpr) else arg.dtype for arg in runtime_args + ] runtime_args = runtime_args + list(grid[0]) + list(grid[1]) # Reuse compilation path from SourceModule diff --git a/python/tvm/tirx/script/builder/ir.py b/python/tvm/tirx/script/builder/ir.py index 2c18c61136b8..12db12aa99db 100644 --- a/python/tvm/tirx/script/builder/ir.py +++ b/python/tvm/tirx/script/builder/ir.py @@ -520,7 +520,7 @@ def match_buffer( raise ValueError("Shape must be specified when binding input param") shape = (shape,) if isinstance(shape, PrimExpr | Integral) else shape if strides is not None: - idx_dtype = shape[0].dtype if isinstance(shape[0], PrimExpr) else "int32" + idx_dtype = shape[0].ty if isinstance(shape[0], PrimExpr) else "int32" strides = [Var(s, idx_dtype) if isinstance(s, str) else s for s in strides] else: strides = [] @@ -1012,8 +1012,8 @@ def _as_range(dom: ir.Range | list[PrimExpr]) -> ir.Range: if isinstance(extent, tir.IntImm): return ir.Range.from_min_extent(dom[0], extent) return ir.Range(dom[0], dom[1]) - if hasattr(dom, "dtype"): - return ir.Range(IntImm(dom.dtype, 0), dom) + if isinstance(dom, PrimExpr): + return ir.Range(IntImm(dom.ty, 0), dom) return ir.Range(0, dom) @@ -1204,8 +1204,8 @@ def serial( annotations["disable_unroll"] = True if stop is None: stop = start - if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + if isinstance(start, PrimExpr): + start = IntImm(start.ty, 0) else: start = 0 return _ffi_api.Serial(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member @@ -1241,8 +1241,8 @@ def parallel( """ if stop is None: stop = start - if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + if isinstance(start, PrimExpr): + start = IntImm(start.ty, 0) else: start = 0 return _ffi_api.Parallel(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member @@ -1278,8 +1278,8 @@ def vectorized( """ if stop is None: stop = start - if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + if isinstance(start, PrimExpr): + start = IntImm(start.ty, 0) else: start = 0 return _ffi_api.Vectorized(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member @@ -1315,8 +1315,8 @@ def unroll( """ if stop is None: stop = start - if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + if isinstance(start, PrimExpr): + start = IntImm(start.ty, 0) else: start = 0 return _ffi_api.Unroll(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member @@ -1355,14 +1355,14 @@ def thread_binding( raise ValueError("Thread cannot be None for thread_binding") thread = stop stop = start - if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + if isinstance(start, PrimExpr): + start = IntImm(start.ty, 0) else: start = 0 elif stop is None: stop = start - if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + if isinstance(start, PrimExpr): + start = IntImm(start.ty, 0) else: start = 0 return _ffi_api.ThreadBinding( # type: ignore[attr-defined] # pylint: disable=no-member @@ -1502,7 +1502,8 @@ def as_var(self, rhs_dtype=None): else: raise TypeError(f"Invalid type for T.let: {self.type_spec}") elif rhs_dtype is not None: - return Var("", ir.PrimType(rhs_dtype)) + rhs_ty = rhs_dtype if isinstance(rhs_dtype, Type) else ir.PrimType(rhs_dtype) + return Var("", rhs_ty) else: raise TypeError("T.let requires either a type or an RHS value") @@ -2799,7 +2800,7 @@ def comm_reducer(combiner: Callable, identity: list[PrimExpr]) -> CommReducer: if isinstance(i, int): args.append(Var(name, "int32")) else: - args.append(Var(name, i.dtype)) + args.append(Var(name, i.ty)) res = combiner(*args) if not isinstance(res, tuple): res = (res,) @@ -2986,19 +2987,19 @@ class WebGPUNamespace: def subgroup_shuffle(var, lane): if isinstance(var, Buffer): var = var[0] - return _tir_op.call_intrin(var.dtype, "tirx.webgpu.subgroup_shuffle", var, lane) + return _tir_op.call_intrin(var.ty, "tirx.webgpu.subgroup_shuffle", var, lane) @staticmethod def subgroup_shuffle_up(var, delta): if isinstance(var, Buffer): var = var[0] - return _tir_op.call_intrin(var.dtype, "tirx.webgpu.subgroup_shuffle_up", var, delta) + return _tir_op.call_intrin(var.ty, "tirx.webgpu.subgroup_shuffle_up", var, delta) @staticmethod def subgroup_shuffle_down(var, delta): if isinstance(var, Buffer): var = var[0] - return _tir_op.call_intrin(var.dtype, "tirx.webgpu.subgroup_shuffle_down", var, delta) + return _tir_op.call_intrin(var.ty, "tirx.webgpu.subgroup_shuffle_down", var, delta) webgpu = WebGPUNamespace() diff --git a/python/tvm/tirx/script/parser/operation.py b/python/tvm/tirx/script/parser/operation.py index dac8f06ebf80..4f362b7d3acf 100644 --- a/python/tvm/tirx/script/parser/operation.py +++ b/python/tvm/tirx/script/parser/operation.py @@ -17,7 +17,8 @@ """The tirx expression operation registration""" from tvm import tirx -from tvm.runtime import DataType, DataTypeCode +from tvm.ir import PrimType +from tvm.runtime import DataTypeCode from tvm.script.parser._core import OpMethod, doc, register_op from tvm.tirx import IntImm from tvm.tirx.expr import FloatImm @@ -26,12 +27,20 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name ty._dispatch_type = ty # pylint: disable=protected-access + def _expr_ty(expr): + ty = expr.ty if isinstance(expr, tirx.PrimExpr) else None + if not isinstance(ty, PrimType): + ty = expr.expr_ty() + if not isinstance(ty, PrimType): + raise TypeError(f"Expected a PrimType expression, but got {ty}") + return ty + def _and(a, b): if isinstance(a, bool): a = IntImm("bool", a) if isinstance(b, bool): b = IntImm("bool", b) - if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1: + if not _expr_ty(a).is_scalar() or not _expr_ty(b).is_scalar(): return a & b else: return tirx.And(a, b) @@ -41,58 +50,56 @@ def _or(a, b): a = IntImm("bool", a) if isinstance(b, bool): b = IntImm("bool", b) - if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1: + if not _expr_ty(a).is_scalar() or not _expr_ty(b).is_scalar(): return a | b else: return tirx.Or(a, b) - def _get_type_str(dtype: str): - if DataType(dtype).lanes == 1: - return dtype - index = dtype.find("x") - return dtype[0:index] + def _get_type_str(ty: PrimType): + dtype_str = str(ty.dtype) + if ty.is_scalar(): + return dtype_str + index = dtype_str.find("x") + return dtype_str[0:index] def _auto_broadcast(a, b, op): if isinstance(a, int): - if hasattr(b, "dtype"): - if ( - DataType(b.dtype).type_code == DataTypeCode.INT - or DataType(b.dtype).type_code == DataTypeCode.UINT - or DataType(b.dtype).type_code == DataTypeCode.BOOL - ): - a = IntImm(_get_type_str(b.dtype), a) - elif DataType(b.dtype).type_code == DataTypeCode.FLOAT: - a = FloatImm(_get_type_str(b.dtype), a) + if isinstance(b, tirx.PrimExpr) or hasattr(b, "expr_ty"): + b_ty = _expr_ty(b) + if b_ty.matches_code(DataTypeCode.INT, DataTypeCode.UINT, DataTypeCode.BOOL): + a = IntImm(_get_type_str(b_ty), a) + elif b_ty.matches_code(DataTypeCode.FLOAT): + a = FloatImm(_get_type_str(b_ty), a) elif isinstance(b, float): a = FloatImm("float32", a) else: a = IntImm("int32", a) elif isinstance(a, float): - if DataType(b.dtype).type_code == DataTypeCode.FLOAT: - a = FloatImm(_get_type_str(b.dtype), a) + b_ty = _expr_ty(b) + if b_ty.matches_code(DataTypeCode.FLOAT): + a = FloatImm(_get_type_str(b_ty), a) else: a = FloatImm("float32", a) assert isinstance(a, tirx.PrimExpr), "Operand should be a PrimExpr." if isinstance(b, int): - if ( - DataType(a.dtype).type_code == DataTypeCode.INT - or DataType(a.dtype).type_code == DataTypeCode.UINT - or DataType(a.dtype).type_code == DataTypeCode.BOOL - ): - b = IntImm(_get_type_str(a.dtype), b) - elif DataType(a.dtype).type_code == DataTypeCode.FLOAT: - b = FloatImm(_get_type_str(a.dtype), b) + a_ty = _expr_ty(a) + if a_ty.matches_code(DataTypeCode.INT, DataTypeCode.UINT, DataTypeCode.BOOL): + b = IntImm(_get_type_str(a_ty), b) + elif a_ty.matches_code(DataTypeCode.FLOAT): + b = FloatImm(_get_type_str(a_ty), b) elif isinstance(b, float): - b = FloatImm(_get_type_str(a.dtype), b) + b = FloatImm(_get_type_str(_expr_ty(a)), b) - if DataType(a.dtype).lanes == DataType(b.dtype).lanes: + a_ty = _expr_ty(a) + b_ty = _expr_ty(b) + if a_ty.dtype.lanes == b_ty.dtype.lanes: return op(a, b) - elif DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes: - broadcast_a = tirx.Broadcast(a, DataType(b.dtype).lanes) + elif a_ty.is_scalar() and a_ty.dtype.lanes != b_ty.dtype.lanes: + broadcast_a = tirx.Broadcast(a, b_ty.dtype.lanes) return op(broadcast_a, b) - elif DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes: - broadcast_b = tirx.Broadcast(b, DataType(a.dtype).lanes) + elif b_ty.is_scalar() and a_ty.dtype.lanes != b_ty.dtype.lanes: + broadcast_b = tirx.Broadcast(b, a_ty.dtype.lanes) return op(a, broadcast_b) else: raise TypeError("do not know how to deal with it.") diff --git a/python/tvm/tirx/script/parser/parser.py b/python/tvm/tirx/script/parser/parser.py index 54c18db374d8..b2f2b30063a8 100644 --- a/python/tvm/tirx/script/parser/parser.py +++ b/python/tvm/tirx/script/parser/parser.py @@ -225,13 +225,13 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - value = tvm.tirx.const(value) if not isinstance(value, tvm.tirx.StringImm): # x = expr -> scalar (auto-typed from value) - scalar = T.local_scalar(dtype=str(value.dtype)) + scalar = T.local_scalar(dtype=str(value.ty.dtype)) IRBuilder.name(var_name, scalar.scalar.buffer) T.buffer_store(scalar.scalar.buffer, value, [0]) return scalar.scalar else: # StringImm: x = expr -> immutable Bind var - ann_var = tvm.tirx.Var(var_name, value.dtype) + ann_var = tvm.tirx.Var(var_name, value.ty) IRBuilder.name(var_name, ann_var) T.Bind(value, var=ann_var) return ann_var @@ -539,7 +539,7 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: if raw_ann.type_spec is not None: ann_var = raw_ann.as_var() else: - ann_var = raw_ann.as_var(rhs_dtype=rhs.dtype) + ann_var = raw_ann.as_var(rhs_dtype=rhs.ty) if not isinstance(ann_var, Var): self.report_error(node.annotation, "Annotation should resolve to Var") self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value) @@ -619,7 +619,7 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: if node.returns is not None: ret_type = self.eval_expr(node.returns) if callable(ret_type): - ret_type = PrimType(ret_type().dtype) + ret_type = ret_type().ty T.func_ret(ret_type) with self.with_dispatch_token("tirx"): # TODO: handle different types of arguments: @@ -888,7 +888,7 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar if node.returns is not None: ret_type = self.eval_expr(node.returns) if callable(ret_type): - ret_type = PrimType(ret_type().dtype) + ret_type = ret_type().ty arg_annotations = [] for arg in node.args.args: diff --git a/python/tvm/tirx/stmt.py b/python/tvm/tirx/stmt.py index 532bf35b254a..543ff99fed66 100644 --- a/python/tvm/tirx/stmt.py +++ b/python/tvm/tirx/stmt.py @@ -35,7 +35,7 @@ from tvm.ir import Op, PrimExpr, Range, Span from tvm.runtime import Object, Scriptable, const -from tvm.tirx import FloatImm +from tvm.tirx import FloatImm, IntImm from . import _ffi_api from .buffer import Buffer @@ -656,7 +656,7 @@ def __getitem__(self, indices): new_min = old_range.min + index new_region.append( Range.from_min_extent( - new_min, const(1, index.dtype) if isinstance(index, PrimExpr) else 1 + new_min, IntImm(index.ty, 1) if isinstance(index, PrimExpr) else 1 ) ) # Fill remaining dimensions with their original ranges diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py index d3e8991c85c7..6088c4baa800 100644 --- a/python/tvm/topi/math.py +++ b/python/tvm/topi/math.py @@ -18,7 +18,7 @@ # pylint: disable=redefined-builtin,unused-argument import tvm -from tvm import DataType, DataTypeCode, te +from tvm import DataTypeCode, te from tvm.tirx import PrimExpr from . import cpp, tag @@ -26,11 +26,15 @@ def _require_float_tensor(op_name, x): - if DataType(x.dtype).type_code not in (DataTypeCode.FLOAT, DataTypeCode.BFLOAT): + if not x.dtype.matches_code(DataTypeCode.FLOAT, DataTypeCode.BFLOAT): raise TypeError(f"topi.{op_name} only supports floating-point inputs, but got {x.dtype}") return x +def _is_integer_tensor(x): + return x.dtype.matches_code(DataTypeCode.INT, DataTypeCode.UINT) + + @tvm.te.tag_scope(tag=tag.ELEMWISE) def identity(x): """Take identity of input x. @@ -478,7 +482,7 @@ def log(x): y : tvm.te.Tensor The result. """ - if x.dtype.startswith("int"): + if x.dtype.matches_code(DataTypeCode.INT): x = te.compute(x.shape, lambda *i: x(*i).astype("float32")) return te.compute(x.shape, lambda *i: te.log(x(*i)), tag=tag.ELEMWISE) @@ -496,7 +500,7 @@ def log2(x): y : tvm.te.Tensor The result. """ - if x.dtype.startswith("int"): + if x.dtype.matches_code(DataTypeCode.INT): x = te.compute(x.shape, lambda *i: x(*i).astype("float32")) return te.compute(x.shape, lambda *i: te.log2(x(*i)), tag=tag.ELEMWISE) @@ -514,7 +518,7 @@ def log10(x): y : tvm.te.Tensor The result. """ - if x.dtype.startswith("int"): + if x.dtype.matches_code(DataTypeCode.INT): x = te.compute(x.shape, lambda *i: x(*i).astype("float32")) return te.compute(x.shape, lambda *i: te.log10(x(*i)), tag=tag.ELEMWISE) @@ -533,7 +537,7 @@ def sqrt(x): y : tvm.te.Tensor The result. """ - if x.dtype.startswith("int"): + if x.dtype.matches_code(DataTypeCode.INT): x = te.compute(x.shape, lambda *i: x(*i).astype("float32")) return te.compute(x.shape, lambda *i: te.sqrt(x(*i))) @@ -552,7 +556,7 @@ def rsqrt(x): y : tvm.te.Tensor The result. """ - if x.dtype.startswith("int"): + if x.dtype.matches_code(DataTypeCode.INT): x = te.compute(x.shape, lambda *i: x(*i).astype("float32")) return te.compute(x.shape, lambda *i: te.rsqrt(x(*i))) @@ -798,7 +802,7 @@ def fast_exp(x): y : tvm.te.Tensor The result. """ - if x.dtype.startswith("int") or x.dtype.startswith("uint"): + if _is_integer_tensor(x): x = cast(x, "float32") return cpp.fast_exp(x, x.dtype, tag.ELEMWISE) @@ -816,7 +820,7 @@ def fast_tanh(x): y : tvm.te.Tensor The result. """ - if x.dtype.startswith("int") or x.dtype.startswith("uint"): + if _is_integer_tensor(x): x = cast(x, "float32") return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE) @@ -855,24 +859,26 @@ def ceil_log2(x): if not isinstance(x, tvm.tirx.PrimExpr): x = tvm.tirx.const(x) - if "float" in x.dtype: + if x.ty.matches_code(DataTypeCode.FLOAT, DataTypeCode.BFLOAT): return tvm.tirx.ceil(tvm.tirx.log2(x)) target = tvm.target.Target.current() - if "vulkan" in target.kind.name: - clz = tvm.tirx.clz(x) - bits = int(x.dtype[-2:]) - res = tvm.tirx.if_then_else(x & (x - 1) == 0, bits - clz - 1, bits - clz) - if res.dtype != x.dtype: - return cast(res, x.dtype) - return res - - if "adreno" in str(target.attrs.get("device", "")) or target.kind.name in [ - "metal", - "rocm", - "webgpu", - ]: - return cast(tvm.tirx.ceil(tvm.tirx.log2(cast(x, "float32"))), x.dtype) + if target is not None: + target_name = target.kind.name + if "vulkan" in target_name: + clz = tvm.tirx.clz(x) + bits = x.ty.dtype.bits + res = tvm.tirx.if_then_else(x & (x - 1) == 0, bits - clz - 1, bits - clz) + if res.dtype != x.dtype: + return cast(res, x.dtype) + return res + + if "adreno" in str(target.attrs.get("device", "")) or target_name in [ + "metal", + "rocm", + "webgpu", + ]: + return cast(tvm.tirx.ceil(tvm.tirx.log2(cast(x, "float32"))), x.dtype) return cast(tvm.tirx.ceil(tvm.tirx.log2(cast(x, "float64"))), x.dtype) diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index bf5b86599854..de35577c4d85 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -18,7 +18,7 @@ # ruff: noqa: E741 """ScatterND operator""" -from tvm import te, tirx # hide redefinition of min and max +from tvm import DataTypeCode, te, tirx # hide redefinition of min and max from tvm.arith.analyzer import Analyzer from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import tirx as T @@ -49,7 +49,7 @@ def _verify_scatter_nd_inputs(data, indices, updates): f"of out_shape[{i}] ({data.shape[i]})." ) - assert "int" in indices.dtype, ( + assert indices.dtype.matches_code(DataTypeCode.INT, DataTypeCode.UINT), ( f"Indices must be a tensor of integers, but its elements are {indices.dtype}." ) diff --git a/python/tvm/topi/sort.py b/python/tvm/topi/sort.py index 81821e462dcf..846573db5036 100644 --- a/python/tvm/topi/sort.py +++ b/python/tvm/topi/sort.py @@ -110,7 +110,7 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): f = tvm.compile(s, [data, out], "llvm") dev = tvm.cpu() tvm_data = tvm.runtime.tensor(np_data, dev) - tvm_out = tvm.runtime.tensor(np.zeros(dshape, dtype=data.dtype), dev) + tvm_out = tvm.runtime.tensor(np.zeros(dshape, dtype=data.dtype.dtype), dev) f(tvm_data, tvm_out) """ data_buf = tvm.tirx.decl_buffer( diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index fc59f891e1bf..94eb8788846b 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -73,7 +73,8 @@ void AnalyzerObj::Bind(const Var& var, const Range& range, bool allow_override) void AnalyzerObj::MarkGlobalNonNegValue(const PrimExpr& value) { // decompose value as symbol * scale + offset int64_t offset = 0; - PrimExpr symbol_scale = tirx::MakeConst(value.dtype(), 0); + PrimType value_ty = value.ty(); + PrimExpr symbol_scale = tirx::MakeConst(value_ty, 0); auto fcollect_sum = [&](PrimExpr val, int sign) { if (const auto* intimm = val.as()) { @@ -90,7 +91,7 @@ void AnalyzerObj::MarkGlobalNonNegValue(const PrimExpr& value) { // split out the symbol and non-symbolic part int64_t cscale = 1; - PrimExpr symbol = tirx::MakeConst(value.dtype(), 1); + PrimExpr symbol = tirx::MakeConst(value_ty, 1); auto fcollect_prod = [&](PrimExpr val) { if (const auto* intimm = val.as()) { cscale *= intimm->value; @@ -110,7 +111,7 @@ void AnalyzerObj::MarkGlobalNonNegValue(const PrimExpr& value) { Var var = ffi::GetRef(var_ptr); // skip non-index type, keep it to be compatible // with any_dim that do not represent any value - if (!IsIndexType(var.dtype())) return; + if (!IsIndexTypedExpr(var)) return; bool allow_override = true; // mark the constant bound is sufficient // we cannot mark interval set as that will cause relaxation of the var @@ -169,7 +170,7 @@ bool AnalyzerObj::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) { const auto* clhs = lhs.as(); const auto* crhs = rhs.as(); if (clhs && crhs) return clhs->value == crhs->value; - if (lhs->dtype.is_handle() || rhs->dtype.is_handle()) { + if (lhs->ty().IsHandle() || rhs->ty().IsHandle()) { return lhs.same_as(rhs); } return CanProve(lhs - rhs == 0); @@ -189,7 +190,7 @@ bool AnalyzerObj::CanProveLessEqualThanSymbolicShapeValue(const PrimExpr& lhs, } }; UnpackReduction(shape, fcollect); - PrimExpr const_shape_bound = IntImm(shape.dtype(), std::abs(cscale)); + PrimExpr const_shape_bound = IntImm(shape.ty(), std::abs(cscale)); if (this->CanProve(lhs <= const_shape_bound, ProofStrength::kSymbolicBound)) return true; return false; } diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index 475a687cd462..01d50da56e41 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -96,7 +96,8 @@ class BoundDeducer : public ExprFunctor { void VisitExprDefault_(const ffi::Object* op) final { success_ = false; } SignType GetSignType(const PrimExpr& e) { - if (e.dtype().is_uint()) { + PrimType e_ty = e.ty(); + if (e_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { return kPositive; } return expr_map_[e].GetSignType(); diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 12344cffd1d8..17a6ba022e2b 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -83,14 +83,14 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) { * \param analyzer The analyzer * \return whether value fits in dtype */ -bool CastIsSafe(DataType dtype, PrimExpr value, AnalyzerObj* analyzer) { - if (!IsIndexType(dtype)) { +bool CastIsSafe(PrimType dtype, PrimExpr value, AnalyzerObj* analyzer) { + if (!IsIndexType(dtype->dtype)) { return false; } ConstIntBound bound = analyzer->const_int_bound(value); int64_t ubound = max_value(dtype).as_or_throw()->value; int64_t lbound = min_value(dtype).as_or_throw()->value; - if (value.dtype().bits() <= dtype.bits() || // upcast is safe + if (value.ty().bits() <= dtype.bits() || // upcast is safe (bound->max_value <= ubound && bound->min_value >= lbound)) { return true; } @@ -128,7 +128,7 @@ class SplitExprNode : public CanonicalExprNode { PrimExpr NormalizeWithScale(int64_t sscale) const { PrimExpr res = this->index; - DataType dtype = this->dtype; + PrimType dtype = this->ty(); if (this->scale == 0) { return IntImm(dtype, 0); } @@ -140,7 +140,7 @@ class SplitExprNode : public CanonicalExprNode { } sscale *= this->scale; if (sscale != 1) { - TVM_FFI_ICHECK(!dtype.is_uint() || sscale > 0); + TVM_FFI_ICHECK(dtype.code() != DLDataTypeCode::kDLUInt || sscale > 0); res = res * MakeConst(dtype, sscale); } return res; @@ -156,12 +156,12 @@ class SplitExprNode : public CanonicalExprNode { * \param analyzer The analyzer * \return whether the cast can be safely pushed to children */ - bool CanPushCastToChildren(DataType dtype, AnalyzerObj* analyzer) const { + bool CanPushCastToChildren(PrimType dtype, AnalyzerObj* analyzer) const { // cast(dtype, index % upper_factor / lower_factor * scale) == // cast(dtype, index) % upper_factor / lower_factor * scale // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of // its intermediate results fit in the range of dtype - if (dtype.bits() >= this->dtype.bits()) { + if (dtype.bits() >= this->ty().bits()) { return true; // upcast is safe } PrimExpr res = this->index; @@ -172,20 +172,20 @@ class SplitExprNode : public CanonicalExprNode { return false; } if (this->upper_factor != SplitExprNode::kPosInf) { - res = ModImpl(res, MakeConst(this->dtype, this->upper_factor), div_mode); + res = ModImpl(res, MakeConst(this->ty(), this->upper_factor), div_mode); if (!CastIsSafe(dtype, res, analyzer)) { return false; } } if (this->lower_factor != 1) { - res = DivImpl(res, MakeConst(this->dtype, this->lower_factor), div_mode); + res = DivImpl(res, MakeConst(this->ty(), this->lower_factor), div_mode); if (!CastIsSafe(dtype, res, analyzer)) { return false; } } if (this->scale != 1) { - TVM_FFI_ICHECK(!this->dtype.is_uint() || this->scale > 0); - res = res * MakeConst(this->dtype, this->scale); + TVM_FFI_ICHECK(this->ty().code() != DLDataTypeCode::kDLUInt || this->scale > 0); + res = res * MakeConst(this->ty(), this->scale); if (!CastIsSafe(dtype, res, analyzer)) { return false; } @@ -197,9 +197,9 @@ class SplitExprNode : public CanonicalExprNode { * \brief self = cast(dtype, self) * \param dtype The target datatype */ - void PushCastToChildren(DataType dtype) { + void PushCastToChildren(PrimType dtype) { this->index = cast(dtype, this->index); - this->dtype = dtype; + this->BaseExprNode::ty = dtype; } inline bool IndexEqual(const SplitExpr& other) const; @@ -252,9 +252,9 @@ class SumExprNode : public CanonicalExprNode { PrimExpr Normalize() const final { // quick path 1. if (this->args.size() == 0) { - return MakeConst(this->dtype, this->base); + return MakeConst(this->ty(), this->base); } - return Normalize_(this->dtype, SimplifySplitExprs(args), base); + return Normalize_(this->ty(), SimplifySplitExprs(args), base); } /*! * \brief Whether self is divisible by scale. @@ -334,14 +334,14 @@ class SumExprNode : public CanonicalExprNode { * \param analyzer The analyzer * \return whether the cast can be safely pushed to children */ - bool CanPushCastToChildren(DataType dtype, AnalyzerObj* analyzer) const { + bool CanPushCastToChildren(PrimType dtype, AnalyzerObj* analyzer) const { bool is_min_value = dtype.bits() == 64 ? base == std::numeric_limits::lowest() : base == -(1LL << (dtype.bits() - 1)); // cast(dtype, arg_1 + arg_2 + ... arg_n) == // cast(dtype, arg_1) + ... + cast(dtype, arg_n) // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of // its intermediate results fit in the range of dtype - if (dtype.bits() >= this->dtype.bits()) { + if (dtype.bits() >= this->ty().bits()) { return true; // upcast is safe } PrimExpr res = IntImm(dtype, 0); @@ -386,11 +386,11 @@ class SumExprNode : public CanonicalExprNode { * \brief self = cast(dtype, self) * \param dtype The target datatype */ - void PushCastToChildren(DataType dtype) { + void PushCastToChildren(PrimType dtype) { for (auto& arg : args) { arg.CopyOnWrite()->PushCastToChildren(dtype); } - this->dtype = dtype; + this->BaseExprNode::ty = dtype; } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.SumExpr", SumExprNode, CanonicalExprNode); @@ -496,7 +496,7 @@ class SumExprNode : public CanonicalExprNode { std::stable_sort(args.begin(), args.end(), fcompare); return args; } - static PrimExpr Normalize_(DataType dtype, const std::vector& args, int64_t base) { + static PrimExpr Normalize_(PrimType dtype, const std::vector& args, int64_t base) { bool is_min_value = dtype.bits() == 64 ? base == std::numeric_limits::lowest() : base == -(1LL << (dtype.bits() - 1)); // Positive scales first @@ -648,7 +648,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { expr = op->Normalize(); } ffi::ObjectPtr n = ffi::make_object(); - n->dtype = expr.dtype(); + n->BaseExprNode::ty = expr.ty(); n->index = std::move(expr); n->div_mode = kTruncDiv; return SplitExpr(n); @@ -685,7 +685,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { return op.value(); } ffi::ObjectPtr n = ffi::make_object(); - n->dtype = expr.dtype(); + n->BaseExprNode::ty = expr.ty(); if (const auto* op = expr.as()) { n->base = op->value; return SumExpr(n); @@ -699,7 +699,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { }; PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const AddNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Rewriter::VisitExpr_(op); } // normalize @@ -723,7 +723,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const AddNode* op) { } PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Rewriter::VisitExpr_(op); } // normalize @@ -747,7 +747,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) { } PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Rewriter::VisitExpr_(op); } // normalize @@ -794,8 +794,8 @@ void CanonicalSimplifier::Impl::SeparateDivisibleParts(const SumExprNode* psum, SumExpr* out_non_divisible) { auto divisible = ffi::make_object(); auto non_divisible = ffi::make_object(); - divisible->dtype = psum->dtype; - non_divisible->dtype = psum->dtype; + divisible->BaseExprNode::ty = psum->ty(); + non_divisible->BaseExprNode::ty = psum->ty(); if (psum->base % coeff == 0) { divisible->base = psum->base; @@ -834,11 +834,11 @@ SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval, return lhs; } else if (lhs->upper_factor <= (lhs->lower_factor * scaled_cval)) { // (x % c1) / c2 => 0 when c2 >= c1 - return ToSplitExpr(IntImm(lhs.dtype(), 0)); + return ToSplitExpr(IntImm(lhs.ty(), 0)); } else { // move the upper_factor modular into index. lhs.CopyOnWrite()->index = - ModImpl(lhs->index, MakeConst(lhs.dtype(), lhs->upper_factor), div_mode); + ModImpl(lhs->index, MakeConst(lhs.ty(), lhs->upper_factor), div_mode); lhs.CopyOnWrite()->upper_factor = SplitExprNode::kPosInf; lhs.CopyOnWrite()->scale = 1; lhs.CopyOnWrite()->lower_factor *= scaled_cval; @@ -862,8 +862,9 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs, if (prhs->as()) return false; // collect lhs products and try to eliminate by matching them to prod in rhs ffi::Array> lhs_prods; - PrimExpr new_rhs = MakeConst(prhs->dtype(), 1); - PrimExpr new_common_scale = MakeConst(prhs->dtype(), 1); + PrimType rhs_ty = prhs->ty(); + PrimExpr new_rhs = MakeConst(rhs_ty, 1); + PrimExpr new_common_scale = MakeConst(rhs_ty, 1); int64_t lhs_cscale = 1, rhs_cscale = 1; int num_elimination = 0; @@ -905,18 +906,19 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs, if (num_elimination == 0 && cscale_gcd == 1) return false; // construct prod via canonical form - PrimExpr new_lhs = MakeConst(plhs->dtype(), 1); + PrimType lhs_ty = plhs->ty(); + PrimExpr new_lhs = MakeConst(lhs_ty, 1); for (ffi::Optional val : lhs_prods) { if (val.defined()) new_lhs = new_lhs * val.value(); } - *plhs = new_lhs * MakeConst(plhs->dtype(), lhs_cscale); - *prhs = new_rhs * MakeConst(prhs->dtype(), rhs_cscale); - *common_scale = new_common_scale * MakeConst(prhs->dtype(), cscale_gcd); + *plhs = new_lhs * MakeConst(lhs_ty, lhs_cscale); + *prhs = new_rhs * MakeConst(rhs_ty, rhs_cscale); + *common_scale = new_common_scale * MakeConst(rhs_ty, cscale_gcd); return true; } PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Rewriter::VisitExpr_(op); } @@ -958,7 +960,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { // if a >= 0 && a < cval, then result == 0 auto cbound = analyzer_->const_int_bound(Normalize(a)); if (cbound->min_value >= 0 && cbound->max_value < cval) { - return IntImm(a.dtype(), 0); + return IntImm(a.ty(), 0); } } return SplitDivConst(ToSplitExpr(std::move(a)), cval, kTruncDiv); @@ -980,7 +982,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { } PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Rewriter::VisitExpr_(op); } PrimExpr a = this->CanonicalMutate(op->a); @@ -1019,7 +1021,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { // if a >= 0 && a < cval, then result == 0 auto cbound = analyzer_->const_int_bound(Normalize(a)); if (cbound->min_value >= 0 && cbound->max_value < cval) { - return IntImm(a.dtype(), 0); + return IntImm(a.ty(), 0); } } // Identity: floordiv(floormod(index, m*n), n) = floormod(floordiv(index, n), m) @@ -1049,7 +1051,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { } // Apply floormod(floordiv_result, m) to complete the identity PrimExpr div_result = Normalize(lhs); - return this->VisitExpr(floormod(div_result, MakeConst(a.dtype(), new_mod))); + return this->VisitExpr(floormod(div_result, MakeConst(a.ty(), new_mod))); } } } @@ -1095,8 +1097,8 @@ SplitExpr CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval, // Perhaps there are more chances in simplifying the index // Do a recursive call to simplify the mod with the new factor. if (new_upper_factor < lhs->upper_factor && lhs->upper_factor != SplitExprNode::kPosInf) { - auto updated = ToSplitExpr(this->VisitExpr( - ModImpl(lhs->index, MakeConst(lhs.dtype(), new_upper_factor), div_mode))); + auto updated = ToSplitExpr( + this->VisitExpr(ModImpl(lhs->index, MakeConst(lhs.ty(), new_upper_factor), div_mode))); // re-apply the lower_factor if (lhs->lower_factor != 1) { auto ret = SplitDivConst(updated, lhs->lower_factor, div_mode); @@ -1126,7 +1128,7 @@ SplitExpr CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval, } PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Rewriter::VisitExpr_(op); } // normalize @@ -1144,7 +1146,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { SumExpr lhs, extra; SeparateDivisibleParts(psum, cval, &lhs, &extra); if (extra->IsZero()) { - return IntImm(a.dtype(), 0); + return IntImm(a.ty(), 0); } // both lhs and extra are non-negative if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) && @@ -1200,7 +1202,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { } PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Rewriter::VisitExpr_(op); } // normalize @@ -1362,7 +1364,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) { } PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Rewriter::VisitExpr_(op); } // normalize @@ -1370,15 +1372,15 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) { // PushCastToChildren if (value.as()) { SumExpr se = value.as_or_throw(); - if (se->CanPushCastToChildren(op->dtype, analyzer_)) { - se.CopyOnWrite()->PushCastToChildren(op->dtype); + if (se->CanPushCastToChildren(op->ty(), analyzer_)) { + se.CopyOnWrite()->PushCastToChildren(op->ty()); return se; } } if (value.as()) { SplitExpr se = value.as_or_throw(); - if (se->CanPushCastToChildren(op->dtype, analyzer_)) { - se.CopyOnWrite()->PushCastToChildren(op->dtype); + if (se->CanPushCastToChildren(op->ty(), analyzer_)) { + se.CopyOnWrite()->PushCastToChildren(op->ty()); return se; } } @@ -1411,8 +1413,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) { } SumExpr divisible, extra; SeparateDivisibleParts(lhs, gcd, &divisible, &extra); - DataType dtype = divisible->dtype; - TVM_FFI_ICHECK(extra->dtype == dtype); + PrimType dtype = divisible->ty(); + TVM_FFI_ICHECK(extra->ty()->dtype == dtype->dtype); PrimExpr normal_extra = extra->Normalize(); if (this->analyzer_->CanProve(normal_extra < MakeConst(dtype, gcd)) && this->analyzer_->CanProve(normal_extra >= IntImm(dtype, 0))) { diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index fb1055660e3b..4793538316a3 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -72,18 +72,29 @@ inline ffi::Optional TryConstFold(PrimExpr a); * \param type The type to represent index. * \return the checked result. */ -inline bool IsIndexType(const DataType& type) { - return type.is_int() && !type.is_scalable_or_fixed_length_vector() && - (type.bits() == 32 || type.bits() == 64); +inline bool IsIndexType(DLDataType type) { + return type.code == static_cast(DLDataTypeCode::kDLInt) && + (type.bits == 32 || type.bits == 64) && type.lanes == 1; +} + +inline bool IsIndexTypedExpr(const PrimExprNode* expr) { + TVM_FFI_DCHECK(expr != nullptr); + TVM_FFI_DCHECK(expr->BaseExprNode::ty.defined()); + const auto* prim_ty = expr->BaseExprNode::ty.as(); + TVM_FFI_DCHECK(prim_ty != nullptr); + return IsIndexType(prim_ty->dtype); +} + +inline bool IsIndexTypedExpr(const PrimExpr& expr) { + return IsIndexTypedExpr(static_cast(expr.get())); } /*! \brief Helper to get const folding result repr in int64. */ -inline int64_t GetFoldResultInt64Repr(int64_t x, const DataType& dtype) { +inline int64_t GetFoldResultInt64Repr(int64_t x, const PrimType& dtype) { if (dtype.bits() < 64) { x &= (1LL << dtype.bits()) - 1; } - if (dtype.is_int()) { - // get sign extended value of integer with specified bits + if (dtype.MatchesCode(DLDataTypeCode::kDLInt)) { int64_t m = 1LL << (dtype.bits() - 1); x = (x ^ m) - m; } @@ -118,32 +129,30 @@ inline double GetFoldResultDoubleRepr(float x) { const FloatImmNode* fb = b.as(); \ BODY; -#define TVM_INDEX_CONST_PROPAGATION(BODY) \ - const IntImmNode* pa = a.as(); \ - const IntImmNode* pb = b.as(); \ - const DataType& ta = a.dtype(); \ - const DataType& tb = b.dtype(); \ - if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \ - BODY; \ +#define TVM_INDEX_CONST_PROPAGATION(BODY) \ + const IntImmNode* pa = a.as(); \ + const IntImmNode* pb = b.as(); \ + if (arith::IsIndexTypedExpr(a) && arith::IsIndexTypedExpr(b)) { \ + BODY; \ } // specialization of constant folders. template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); + PrimType result_ty = a.ty(); if (pa && pb) { int64_t res = pa->value + pb->value; - return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); + return IntImm(result_ty, GetFoldResultInt64Repr(res, result_ty)); } if (pa && pa->value == 0) return b; if (pb && pb->value == 0) return a; if (fa && fb) { - if (rtype.bits() == 32) { - return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast(fa->value) + - static_cast(fb->value))); - } else if (rtype.bits() == 64) { - return FloatImm(rtype, fa->value + fb->value); + if (result_ty.bits() == 32) { + return FloatImm(result_ty, GetFoldResultDoubleRepr(static_cast(fa->value) + + static_cast(fb->value))); + } else if (result_ty.bits() == 64) { + return FloatImm(result_ty, fa->value + fb->value); } } if (fa && fa->value == 0) return b; @@ -155,22 +164,22 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - TVM_FFI_ICHECK(!((pa && pa->dtype.is_uint() && pa->value == 0U) && - (pb && pb->dtype.is_uint() && pb->value > 0U))) + TVM_FFI_ICHECK(!((pa && pa->ty().MatchesCode(DLDataTypeCode::kDLUInt) && pa->value == 0U) && + (pb && pb->ty().MatchesCode(DLDataTypeCode::kDLUInt) && pb->value > 0U))) << "Checked failed. Minuend 's value is 0U and it's dtype is uint " << "while Subtrahend's dtype is uint; which will cause a negative uint"; - const DataType& rtype = a.dtype(); + PrimType result_ty = a.ty(); if (pa && pb) { int64_t res = pa->value - pb->value; - return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); + return IntImm(result_ty, GetFoldResultInt64Repr(res, result_ty)); } if (pb && pb->value == 0) return a; if (fa && fb) { - if (rtype.bits() == 32) { - return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast(fa->value) - - static_cast(fb->value))); - } else if (rtype.bits() == 64) { - return FloatImm(rtype, fa->value - fb->value); + if (result_ty.bits() == 32) { + return FloatImm(result_ty, GetFoldResultDoubleRepr(static_cast(fa->value) - + static_cast(fb->value))); + } else if (result_ty.bits() == 64) { + return FloatImm(result_ty, fa->value - fb->value); } } if (fb && fb->value == 0) return a; @@ -181,10 +190,10 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); + PrimType result_ty = a.ty(); if (pa && pb) { int64_t res = pa->value * pb->value; - return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); + return IntImm(result_ty, GetFoldResultInt64Repr(res, result_ty)); } if (pa) { if (pa->value == 1) return b; @@ -195,11 +204,11 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { if (pb->value == 0) return b; } if (fa && fb) { - if (rtype.bits() == 32) { - return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast(fa->value) * - static_cast(fb->value))); - } else if (rtype.bits() == 64) { - return FloatImm(rtype, fa->value * fb->value); + if (result_ty.bits() == 32) { + return FloatImm(result_ty, GetFoldResultDoubleRepr(static_cast(fa->value) * + static_cast(fb->value))); + } else if (result_ty.bits() == 64) { + return FloatImm(result_ty, fa->value * fb->value); } } if (fa) { @@ -217,13 +226,13 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); + PrimType result_ty = a.ty(); if (pa && pb) { // due to division and mod can have different modes // NOTE: this will assumes truc div. TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; int64_t res = pa->value / pb->value; - return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); + return IntImm(result_ty, GetFoldResultInt64Repr(res, result_ty)); } if (pa) { if (pa->value == 0) return a; @@ -234,11 +243,11 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } if (fa && fb) { TVM_FFI_ICHECK_NE(fb->value, 0) << "Divide by zero"; - if (rtype.bits() == 32) { - return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast(fa->value) / - static_cast(fb->value))); - } else if (rtype.bits() == 64) { - return FloatImm(rtype, fa->value / fb->value); + if (result_ty.bits() == 32) { + return FloatImm(result_ty, GetFoldResultDoubleRepr(static_cast(fa->value) / + static_cast(fb->value))); + } else if (result_ty.bits() == 64) { + return FloatImm(result_ty, fa->value / fb->value); } } if (fa && fa->value == 0) return a; @@ -253,18 +262,18 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); + PrimType result_ty = a.ty(); if (pa && pb) { TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; int64_t res = pa->value % pb->value; - return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); + return IntImm(result_ty, GetFoldResultInt64Repr(res, result_ty)); } if (pa) { if (pa->value == 0) return a; } if (pb) { // MakeConst can handle both vector and scalar types. - if (pb->value == 1) return tirx::MakeConst(rtype, 0); + if (pb->value == 1) return tirx::MakeConst(result_ty, 0); TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; } }); @@ -274,11 +283,11 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); + PrimType result_ty = a.ty(); if (pa && pb) { TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; int64_t res = arith::floordiv(pa->value, pb->value); - return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); + return IntImm(result_ty, GetFoldResultInt64Repr(res, result_ty)); } if (pa) { if (pa->value == 0) return a; @@ -288,11 +297,12 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; } if (fa && fb && fb->value != 0) { - if (rtype.bits() == 32) { - return FloatImm(rtype, GetFoldResultDoubleRepr(std::floor(static_cast(fa->value) / - static_cast(fb->value)))); - } else if (rtype.bits() == 64) { - return FloatImm(rtype, std::floor(fa->value / fb->value)); + if (result_ty.bits() == 32) { + return FloatImm(result_ty, + GetFoldResultDoubleRepr(std::floor(static_cast(fa->value) / + static_cast(fb->value)))); + } else if (result_ty.bits() == 64) { + return FloatImm(result_ty, std::floor(fa->value / fb->value)); } else { return std::nullopt; } @@ -309,18 +319,18 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); + PrimType result_ty = a.ty(); if (pa && pb) { TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; int64_t res = arith::floormod(pa->value, pb->value); - return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); + return IntImm(result_ty, GetFoldResultInt64Repr(res, result_ty)); } if (pa) { if (pa->value == 0) return a; } if (pb) { // MakeConst can handle both vector and scalar types. - if (pb->value == 1) return tirx::MakeConst(rtype, 0); + if (pb->value == 1) return tirx::MakeConst(result_ty, 0); TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; } }); @@ -330,9 +340,9 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); - if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value)); + PrimType result_ty = a.ty(); + if (pa && pb) return IntImm(result_ty, std::min(pa->value, pb->value)); + if (fa && fb) return FloatImm(result_ty, std::min(fa->value, fb->value)); }); if (a.same_as(b)) return a; return std::nullopt; @@ -341,9 +351,9 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); - if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value)); + PrimType result_ty = a.ty(); + if (pa && pb) return IntImm(result_ty, std::max(pa->value, pb->value)); + if (fa && fb) return FloatImm(result_ty, std::max(fa->value, fb->value)); }); if (a.same_as(b)) return a; return std::nullopt; diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 4d700564ea05..3e8087af0eff 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -151,7 +151,7 @@ class ConstIntBoundAnalyzer::Impl // Override visitor behaviors Entry VisitExprDefault_(const ffi::Object* op) final { - return Everything(static_cast(op)->dtype); + return Everything(static_cast(op)->ty()); } Entry VisitExpr(const PrimExpr& expr) final { @@ -167,7 +167,7 @@ class ConstIntBoundAnalyzer::Impl if (bound_) { auto val = bound_->find(expr); if (val != bound_->end()) { - auto everything = Everything(expr->dtype); + auto everything = Everything(expr->ty()); TVM_FFI_ICHECK( (val->second->min_value == res.min_value && val->second->max_value == res.max_value) || (val->second->min_value == everything.min_value && @@ -203,7 +203,7 @@ class ConstIntBoundAnalyzer::Impl a = VisitExpr(op->value); } - Entry b = Everything(op->dtype); + Entry b = Everything(op->ty()); return Intersect(a, b); } @@ -263,7 +263,7 @@ class ConstIntBoundAnalyzer::Impl Entry VisitExpr_(const DivNode* op) final { Entry a = VisitExpr(op->a); Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); - return HandleDivision(a, b, op->dtype, InfAwareDiv); + return HandleDivision(a, b, op->ty(), InfAwareDiv); } Entry VisitExpr_(const ModNode* op) final { @@ -312,14 +312,14 @@ class ConstIntBoundAnalyzer::Impl TVM_FFI_ICHECK(!b.is_const(0)) << "mod by zero"; // mod by negative value is rare, // and we just use the simpliest rule. - return Everything(op->dtype); + return Everything(op->ty()); } } Entry VisitExpr_(const FloorDivNode* op) final { Entry a = VisitExpr(op->a); Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); - return HandleDivision(a, b, op->dtype, InfAwareFloorDiv); + return HandleDivision(a, b, op->ty(), InfAwareFloorDiv); } Entry VisitExpr_(const FloorModNode* op) final { @@ -385,7 +385,7 @@ class ConstIntBoundAnalyzer::Impl int64_t b_max_cap = InfAwareAdd(b.max_value, -1); return Intersect(MakeBound(std::min(static_cast(0), b_min_cap), std::max(static_cast(0), b_max_cap)), - Everything(op->dtype)); + Everything(op->ty())); } } @@ -424,7 +424,7 @@ class ConstIntBoundAnalyzer::Impl } else if (op->op.same_as(tirx::builtin::bitwise_and())) { return VisitBitwiseAnd(op); } else { - return Everything(op->dtype); + return Everything(op->ty()); } } @@ -434,7 +434,7 @@ class ConstIntBoundAnalyzer::Impl if (it != var_map_.end()) { return it->second; } else { - return Everything(op->dtype); + return Everything(op->ty()); } } @@ -456,7 +456,7 @@ class ConstIntBoundAnalyzer::Impl // If either operand can negative, we may run into undefined // behavior for some targets. In these cases, avoid making any // assumptions about the result. - return Everything(op->dtype); + return Everything(op->ty()); } return BinaryOpBoundary(a, b, InfAwareLeftShift); @@ -481,7 +481,7 @@ class ConstIntBoundAnalyzer::Impl if (a.min_value >= 0) { return MakeBound(0, a.max_value); } - return Everything(op->dtype); + return Everything(op->ty()); } } @@ -549,7 +549,7 @@ class ConstIntBoundAnalyzer::Impl * \return The result. */ template - static Entry HandleDivision(Entry a, Entry b, DataType dt, const F& op) { + static Entry HandleDivision(Entry a, Entry b, PrimType dt, const F& op) { // Here we have a / b. // The largest value of the division will be for the smallest (with // respect to the absolute value) value of b. If the range of b starts @@ -557,7 +557,7 @@ class ConstIntBoundAnalyzer::Impl // be closer to 0, because BinaryOpBoundary only checks end-points of // the domain ranges. // If the range of b contains 0, then some infinity will be involved - if (b.min_value <= 0 && 0 <= b.max_value && dt.is_int()) { + if (b.min_value <= 0 && 0 <= b.max_value && dt.code() == DLDataTypeCode::kDLInt) { Entry b_neg = b.min_value < 0 ? MakeBound(b.min_value, -1) : Everything(dt); Entry b_pos = b.max_value > 0 ? MakeBound(1, b.max_value) : Everything(dt); @@ -566,7 +566,7 @@ class ConstIntBoundAnalyzer::Impl return MakeBound(std::min(e_neg.min_value, e_pos.min_value), std::max(e_neg.max_value, e_pos.max_value)); - } else if (b.min_value == 0 && dt.is_uint()) { + } else if (b.min_value == 0 && dt.code() == DLDataTypeCode::kDLUInt) { // uints only have one sided bounds Entry assumed_b = MakeBound(1, b.max_value); return BinaryOpBoundary(a, assumed_b, op); @@ -727,16 +727,17 @@ class ConstIntBoundAnalyzer::Impl * \param dtype The data type. * \return Bound that represent everything dtype can represent. */ - static Entry Everything(DataType dtype) { - if (!dtype.is_int() && !dtype.is_uint() && !dtype.is_bool()) { + static Entry Everything(PrimType dtype) { + if (dtype.code() != DLDataTypeCode::kDLInt && dtype.code() != DLDataTypeCode::kDLUInt && + dtype.code() != DLDataTypeCode::kDLBool) { return MakeBound(kNegInf, kPosInf); } - if (dtype.is_bool()) { + if (dtype.code() == DLDataTypeCode::kDLBool) { return MakeBound(0, 1); } Entry ret; - int64_t vbits = dtype.bits() - static_cast(dtype.is_int()); - if (dtype.is_uint()) { + int64_t vbits = dtype.bits() - static_cast(dtype.code() == DLDataTypeCode::kDLInt); + if (dtype.code() == DLDataTypeCode::kDLUInt) { ret.min_value = 0; } else { if (vbits >= 63) { @@ -800,7 +801,7 @@ class ConstIntBoundAnalyzer::Impl static ffi::Optional FindCeilLog2Arg(const CastNode* op) { static const Op& ceil_op = Op::Get("tirx.ceil"); static const Op& log2_op = Op::Get("tirx.log2"); - if (op->dtype.is_int()) { + if (op->ty().code() == DLDataTypeCode::kDLInt) { if (auto as_call = op->value.as()) { if (as_call->op.same_as(ceil_op)) { PrimExpr ceil_arg = as_call->args[0]; diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index 5e77dca59405..f7e04ee0ebf5 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -54,10 +54,10 @@ class LinearEqDetector : public ExprFunctorbase.defined()) { - ret->base = IntImm(var_.dtype(), 0); + ret->base = IntImm(var_.ty(), 0); } if (!ret->coeff.defined()) { - ret->coeff = IntImm(var_.dtype(), 0); + ret->coeff = IntImm(var_.ty(), 0); } return true; } @@ -101,8 +101,8 @@ class LinearEqDetector : public ExprFunctordtype; - ret.coeff = MakeConst(DataType::Int(dtype.bits(), dtype.lanes()), 1); + PrimType dtype = op->ty(); + ret.coeff = MakeConst(PrimType::Int(dtype.bits(), dtype.lanes()), 1); } else { ret.base = e; } @@ -194,19 +194,21 @@ bool DetectClipBound(const PrimExpr& cond, bool is_eq = false; PrimExpr canonical; if (const LTNode* op = cond.as()) { - if (!op->a.dtype().is_int()) return false; - canonical = op->b - op->a - MakeConst(op->a.dtype(), 1); + PrimType a_ty = op->a.ty(); + if (a_ty.code() != DLDataTypeCode::kDLInt) return false; + canonical = op->b - op->a - MakeConst(a_ty, 1); } else if (const LENode* op = cond.as()) { - if (!op->a.dtype().is_int()) return false; + if (op->a.ty().code() != DLDataTypeCode::kDLInt) return false; canonical = op->b - op->a; } else if (const GTNode* op = cond.as()) { - if (!op->a.dtype().is_int()) return false; - canonical = op->a - op->b - MakeConst(op->a.dtype(), 1); + PrimType a_ty = op->a.ty(); + if (a_ty.code() != DLDataTypeCode::kDLInt) return false; + canonical = op->a - op->b - MakeConst(a_ty, 1); } else if (const GENode* op = cond.as()) { - if (!op->a.dtype().is_int()) return false; + if (op->a.ty().code() != DLDataTypeCode::kDLInt) return false; canonical = op->a - op->b; } else if (const EQNode* op = cond.as()) { - if (!op->a.dtype().is_int()) return false; + if (op->a.ty().code() != DLDataTypeCode::kDLInt) return false; canonical = op->a - op->b; is_eq = true; } else { diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 55db4fc774b6..b517324f378d 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -74,7 +74,8 @@ ffi::Array AsConditions(const ffi::Array& variables, IntGroupBounds::IntGroupBounds(PrimExpr coef, ffi::Array lower, ffi::Array equal, ffi::Array upper) { - TVM_FFI_ICHECK(coef.dtype().is_int() || coef.dtype().is_uint()) + PrimType coef_ty = coef.ty(); + TVM_FFI_ICHECK(coef_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) << "Coefficient in IntGroupBounds must be integers"; ffi::ObjectPtr node = ffi::make_object(); node->coef = std::move(coef); @@ -86,7 +87,7 @@ IntGroupBounds::IntGroupBounds(PrimExpr coef, ffi::Array lower, IntGroupBounds IntGroupBounds::FromRange(const Range& r) { Analyzer analyzer; - PrimExpr coef = tirx::MakeConst(r->min.dtype(), 1); + PrimExpr coef = tirx::MakeConst(r->min.ty(), 1); ffi::Array equal; ffi::Array lower; ffi::Array upper; @@ -232,7 +233,8 @@ IntConstraints::IntConstraints(ffi::Array variables, ffi::Map r } TVM_FFI_ICHECK(relations.defined()); for (const auto& var : variables) { - TVM_FFI_ICHECK(var.dtype().is_int() || var.dtype().is_uint()) + PrimType var_ty = var.ty(); + TVM_FFI_ICHECK(var_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) << "Variables in IntConstraints must be integers"; } node->variables = std::move(variables); diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index d7bf32442497..b3d111ffa7a8 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -50,8 +50,8 @@ using tirx::MakeConst; TVM_FFI_STATIC_INIT_BLOCK() { IntervalSetNode::RegisterReflection(); } -PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle()); -PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); +PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", PrimType::Handle()); +PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", PrimType::Handle()); IntervalSet::IntervalSet(PrimExpr min_value, PrimExpr max_value) { auto node = ffi::make_object(); @@ -72,8 +72,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { IntervalSet Intersect(AnalyzerObj* analyzer, IntervalSet a, IntervalSet b) { PrimExpr max_value = min(a->max_value, b->max_value); PrimExpr min_value = max(a->min_value, b->min_value); - if ((max_value.dtype().is_int() || max_value.dtype().is_uint()) && - (min_value.dtype().is_int() || min_value.dtype().is_uint()) && + PrimType max_ty = max_value.ty(); + PrimType min_ty = min_value.ty(); + if (max_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt) && + min_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt) && analyzer->CanProve(max_value < min_value)) { return IntervalSet::Empty(); } else { @@ -121,7 +123,7 @@ TVM_DECLARE_LOGICAL_OP(Not); */ template inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, IntervalSet b, const OpNode* op) { - DataType dtype = op->dtype; + PrimType dtype = op->ty(); if (a->IsSinglePoint() && b->IsSinglePoint()) { PrimExpr expr; if (auto res = TryConstFold(a->min_value, b->min_value)) { @@ -195,7 +197,7 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, Inte return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { using tirx::Select; - PrimExpr sign = b->min_value >= IntImm(b->min_value.dtype().element_of(), 0); + PrimExpr sign = b->min_value >= IntImm(b->min_value.ty().WithLanes(1), 0); PrimExpr e1 = a->min_value * b->min_value; PrimExpr e2 = a->max_value * b->min_value; return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); @@ -229,7 +231,7 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, Inte return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { using tirx::Select; - PrimExpr sign = b->min_value >= IntImm(b->min_value.dtype().element_of(), 0); + PrimExpr sign = b->min_value >= IntImm(b->min_value.ty().WithLanes(1), 0); PrimExpr e1 = a->min_value / b->min_value; PrimExpr e2 = a->max_value / b->min_value; return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); @@ -258,7 +260,7 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, Inte // is the case of our application. // TODO(tqchen): add bound constraints for a. if (analyzer->CanProveGreaterEqual(divisor, 0)) { - return IntervalSet(IntImm(divisor.dtype(), 0), divisor - 1); + return IntervalSet(IntImm(divisor.ty(), 0), divisor - 1); } else { PrimExpr bound = abs(divisor) - 1; return IntervalSet(-bound, bound); @@ -292,7 +294,7 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { using tirx::Select; - PrimExpr sign = b->min_value >= IntImm(b->min_value.dtype().element_of(), 0); + PrimExpr sign = b->min_value >= IntImm(b->min_value.ty().WithLanes(1), 0); PrimExpr e1 = floordiv(a->min_value, b->min_value); PrimExpr e2 = floordiv(a->max_value, b->min_value); return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); @@ -323,7 +325,7 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, auto qmin = a->HasLowerBound() ? floordiv(a->min_value, divisor) : neg_inf(); // We can compare +/- inf against each other, but cannot use // operator== between the symbolic limits and an integer. - bool compatible_dtypes = !(qmin.dtype().is_handle() ^ qmax.dtype().is_handle()); + bool compatible_dtypes = !(qmin.ty().IsHandle() ^ qmax.ty().IsHandle()); if (compatible_dtypes && analyzer->CanProve(qmax == qmin)) { auto tmax = a->max_value - divisor * qmin; auto tmin = a->min_value - divisor * qmin; @@ -348,12 +350,13 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, int64_t max_mod_result = max_quotient * gcd + (dividend_mod->base % gcd); if (max_mod_result >= 0 && max_mod_result < div_val) { - return IntervalSet(IntImm(op->dtype, 0), IntImm(op->dtype, max_mod_result)); + PrimType result_ty = ffi::GetRef(op).ty(); + return IntervalSet(IntImm(result_ty, 0), IntImm(result_ty, max_mod_result)); } } } } - return IntervalSet(IntImm(divisor.dtype(), 0), divisor - 1); + return IntervalSet(IntImm(divisor.ty(), 0), divisor - 1); } else { PrimExpr bound = abs(divisor) - 1; return IntervalSet(-bound, bound); @@ -522,7 +525,7 @@ class IntervalSetEvaluator : public ExprFunctor { IntervalSet base = Eval(op->base); PVar stride; if (stride.Match(op->stride)) { - DataType t = op->base.dtype(); + PrimType t = op->base.ty(); int64_t vstride = stride.Eval()->value; if (op->lanes->IsInstance()) { int lanes = static_cast(op->lanes.as_or_throw()->value); @@ -569,18 +572,19 @@ class IntervalSetEvaluator : public ExprFunctor { // short cut for the int set. if (value_set->min_value.same_as(value_set->max_value)) { if (value_set->IsEmpty()) return value_set; - return IntervalSet::SinglePoint(cast(op->dtype, value_set->min_value)); + return IntervalSet::SinglePoint(cast(op->ty(), value_set->min_value)); } PrimExpr min_value = - value_set->HasLowerBound() ? cast(op->dtype, value_set->min_value) : neg_inf(); + value_set->HasLowerBound() ? cast(op->ty(), value_set->min_value) : neg_inf(); PrimExpr max_value = - value_set->HasUpperBound() ? cast(op->dtype, value_set->max_value) : pos_inf(); + value_set->HasUpperBound() ? cast(op->ty(), value_set->max_value) : pos_inf(); return IntervalSet(min_value, max_value); } IntervalSet VisitExpr_(const BufferLoadNode* op) final { - if (!(op->dtype.is_int() || op->dtype.is_uint())) { - DLOG(WARNING) << "cannot evaluate set BufferLoad which loads from a " << op->dtype + PrimType op_ty = op->ty(); + if (!op_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) { + DLOG(WARNING) << "cannot evaluate set BufferLoad which loads from a " << op_ty->dtype << " buffer"; return IntervalSet::Everything(); } @@ -1048,7 +1052,7 @@ IntSet EvalSet(PrimExpr e, const ffi::Map& dom_map) { IntSet IntSet::Vector(PrimExpr x) { // short cut: simply get single point - if (!x.dtype().is_scalable_or_fixed_length_vector()) { + if (!x.ty().IsScalableVector() && !x.ty().IsFixedLengthVector()) { return IntSet::SinglePoint(x); } else { // vector case. @@ -1068,7 +1072,9 @@ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom IntSet EvalSet(Range r, const ffi::Map& dom_map) { Analyzer ana; - if ((r->min->dtype.is_int() || r->min->dtype.is_uint()) && ana->CanProveEqual(r->extent, 1)) { + PrimType min_ty = r->min.ty(); + if (min_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt) && + ana->CanProveEqual(r->extent, 1)) { return EvalSet(r->min, dom_map); } IntervalSetEvaluator m(ana.get(), dom_map); diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 8dcef7a75a80..d6a264288b16 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -54,7 +54,7 @@ void AppendFloorDivConstraints(const FloorDivNode* div, int64_t value, CompareKi int64_t divisor_value = 0; if (!TryGetIntImm(div->b, &divisor_value) || divisor_value <= 0) return; - DataType dtype = div->a.dtype(); + PrimType dtype = div->a.ty(); PrimExpr divisor = MakeConst(dtype, divisor_value); PrimExpr k = MakeConst(dtype, value); PrimExpr lo = k * divisor; @@ -117,7 +117,8 @@ void CollectDerivedConstraintFacts(const PrimExpr& condition, std::vector()) { if (call->op.same_as(tirx::builtin::bitwise_and()) && call->args.size() == 2 && - call->args[0].dtype().is_bool() && call->args[1].dtype().is_bool()) { + call->args[0].ty().MatchesElementType(DLDataTypeCode::kDLBool, 8) && + call->args[1].ty().MatchesElementType(DLDataTypeCode::kDLBool, 8)) { CollectDerivedConstraintFacts(call->args[0], out); CollectDerivedConstraintFacts(call->args[1], out); return; @@ -260,7 +261,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == tirx::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { IterVar iv = op->node.as_or_throw(); TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); - Range dom = Range::FromMinExtent(IntImm(op->value.dtype(), 0), op->value); + Range dom = Range::FromMinExtent(IntImm(op->value.ty(), 0), op->value); analyzer_->Bind(iv->var, dom); iter_vars_.Set(iv->var, dom); } @@ -313,7 +314,8 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { false_value.same_as(op->args[2])) { return ffi::GetRef(op); } else { - return Call(op->dtype, op->op, {cond, true_value, false_value}, op->attrs, op->span); + return Call(ffi::GetRef(op).ty(), op->op, {cond, true_value, false_value}, + op->attrs, op->span); } } return StmtExprMutator::VisitExpr_(op); diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc index ffe9c73bd6f2..0313dbfe4271 100644 --- a/src/arith/ir_visitor_with_analyzer.cc +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -79,7 +79,7 @@ void IRVisitorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == tirx::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { IterVar iv = op->node.as_or_throw(); TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); - analyzer_->Bind(iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), op->value)); + analyzer_->Bind(iv->var, Range::FromMinExtent(IntImm(op->value.ty(), 0), op->value)); } StmtExprVisitor::VisitStmt_(op); }); diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index c7f8819f944f..430a4ec5c839 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -66,8 +66,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { IterSplitExpr::IterSplitExpr(IterMark source) { auto n = ffi::make_object(); - auto one = MakeConst(source->source->dtype, 1); - n->dtype = source->source->dtype; + auto one = MakeConst(source->source.ty(), 1); + n->BaseExprNode::ty = source->source.ty(); n->source = std::move(source); n->extent = n->source->extent; n->lower_factor = one; @@ -77,8 +77,8 @@ IterSplitExpr::IterSplitExpr(IterMark source) { IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr scale) { auto n = ffi::make_object(); - auto one = MakeConst(source->source->dtype, 1); - n->dtype = source->source->dtype; + auto one = MakeConst(source->source.ty(), 1); + n->BaseExprNode::ty = source->source.ty(); n->source = std::move(source); n->extent = n->source->extent; n->lower_factor = one; @@ -89,7 +89,7 @@ IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr scale) { IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) { auto n = ffi::make_object(); - n->dtype = source->source->dtype; + n->BaseExprNode::ty = source->source.ty(); n->source = std::move(source); n->lower_factor = std::move(lower_factor); n->extent = std::move(extent); @@ -109,7 +109,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { IterSumExpr::IterSumExpr(ffi::Array args, PrimExpr base) { auto n = ffi::make_object(); - n->dtype = base->dtype; + n->BaseExprNode::ty = base.ty(); n->args = std::move(args); n->base = std::move(base); data_ = std::move(n); @@ -563,7 +563,7 @@ class IterMapRewriter : public ExprMutator { IterMapLevel check_level) { std::vector used(splits.size(), false); std::vector iters; - PrimExpr expected_lower_factor = MakeConst(mark->source->dtype, 1); + PrimExpr expected_lower_factor = MakeConst(mark->source.ty(), 1); for (size_t i = 0; i < splits.size(); ++i) { size_t j = 0; @@ -694,7 +694,7 @@ class IterMapRewriter : public ExprMutator { PrimExpr iter_min = mark_offset; PrimExpr iter_max = iter_min + mark->extent; // the delta of iter_min when it is updated when the lower bound predicate is present - PrimExpr iter_min_delta = IntImm(iter_min.dtype(), 0); + PrimExpr iter_min_delta = IntImm(iter_min.ty(), 0); if (predicate_induced_min.defined()) { iter_min_delta = max(predicate_induced_min.value(), iter_min) - iter_min; iter_min = max(predicate_induced_min.value(), iter_min); @@ -788,7 +788,7 @@ class IterMapRewriter : public ExprMutator { for (IterSplitExpr split : expr->args) { int64_t symbol_prod_count = 0; int64_t cscale = 1; - PrimExpr res = tirx::MakeConst(split.dtype(), 1); + PrimExpr res = tirx::MakeConst(split.ty(), 1); auto fcollect = [&](PrimExpr val) { if (const auto* intimm = val.as()) { cscale *= intimm->value; @@ -799,7 +799,7 @@ class IterMapRewriter : public ExprMutator { }; UnpackReduction(split->scale, fcollect); if (cscale != 1) { - res = res * tirx::MakeConst(res.dtype(), cscale); + res = res * tirx::MakeConst(res.ty(), cscale); } split.CopyOnWrite()->scale = res; items.emplace_back(Item{cscale, symbol_prod_count, split}); @@ -830,7 +830,7 @@ class IterMapRewriter : public ExprMutator { if (auto op = expr.as()) { return op.value(); } else if (auto op = expr.as()) { - return IterSumExpr({op.value()}, IntImm(expr->dtype, 0)); + return IterSumExpr({op.value()}, IntImm(expr.ty(), 0)); } else { TVM_FFI_ICHECK(!expr->IsInstance()); return IterSumExpr({}, expr); @@ -1103,8 +1103,8 @@ class IterMapRewriter : public ExprMutator { std::vector flattened_iters, grouped_iters; // check if it can be remapped into a fused pattern. - PrimExpr expected_extra_base = IntImm(expr.dtype(), 0); - PrimExpr tail_extent = IntImm(expr.dtype(), 0); + PrimExpr expected_extra_base = IntImm(expr.ty(), 0); + PrimExpr tail_extent = IntImm(expr.ty(), 0); PrimExpr expected_scale = base_scale; int first_possible_unit_extent_pos = FindFirstPossibleUnitExtentIndex(expr); @@ -1200,10 +1200,10 @@ class IterMapRewriter : public ExprMutator { IterSumExpr structured_form = expr, flattened_form = expr; flattened_form.CopyOnWrite()->args = ffi::Array(flattened_iters.rbegin(), flattened_iters.rend()); - flattened_form.CopyOnWrite()->base = IntImm(expr.dtype(), 0); + flattened_form.CopyOnWrite()->base = IntImm(expr.ty(), 0); structured_form.CopyOnWrite()->args = ffi::Array(grouped_iters.rbegin(), grouped_iters.rend()); - structured_form.CopyOnWrite()->base = IntImm(expr.dtype(), 0); + structured_form.CopyOnWrite()->base = IntImm(expr.ty(), 0); auto it = sum_fuse_map_.find(flattened_form); if (it != sum_fuse_map_.end()) { // old iter @@ -1245,7 +1245,7 @@ class IterMapRewriter : public ExprMutator { if (sign > 0) { lhs->args.push_back(rhs); } else { - rhs.CopyOnWrite()->scale = IntImm(rhs->scale.dtype(), 0) - rhs->scale; + rhs.CopyOnWrite()->scale = IntImm(rhs->scale.ty(), 0) - rhs->scale; lhs->args.push_back(rhs); } } @@ -1332,8 +1332,10 @@ bool MatchBoundConstraints(PrimExpr pred, ffi::Map* input_iters, PrimExpr lhs_expr = lhs.Eval(); PrimExpr rhs_expr = rhs.Eval(); // we only accept predicate of integers - if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) && - (rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) { + PrimType lhs_ty = lhs_expr.ty(); + PrimType rhs_ty = rhs_expr.ty(); + if (!((lhs_ty.code() == DLDataTypeCode::kDLInt || lhs_ty.code() == DLDataTypeCode::kDLUInt) && + (rhs_ty.code() == DLDataTypeCode::kDLInt || rhs_ty.code() == DLDataTypeCode::kDLUInt))) { return false; } // determine iter and bound, if we can not distinguish them simply, @@ -1563,7 +1565,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) { } PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Parent::VisitExpr_(op); } PrimExpr a = this->DirectMutate(op->a); @@ -1596,7 +1598,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) { } PrimExpr IterMapRewriter::VisitExpr_(const SubNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Parent::VisitExpr_(op); } @@ -1631,7 +1633,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const SubNode* op) { } PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Parent::VisitExpr_(op); } // normalize @@ -1677,7 +1679,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr original_dividend) { if (dividend->IsInstance()) { auto split = dividend.as_or_throw(); - return IterSumExpr({split}, IntImm(split.dtype(), 0)); + return IterSumExpr({split}, IntImm(split.ty(), 0)); } else if (dividend->IsInstance()) { auto sum = dividend.as_or_throw(); if (sum->args.empty()) { @@ -1880,12 +1882,12 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P } else if (CanProveDivisible(rhs, lhs->scale) && is_zero(base)) { // floordiv(x*c1, c1*c2) = floordiv(x, c2), c2=rhs/scale rhs = floordiv(rhs, lhs->scale); - lhs.CopyOnWrite()->scale = MakeConst(rhs->dtype, 1); + lhs.CopyOnWrite()->scale = MakeConst(rhs.ty(), 1); } else if (CanProveDivisible(rhs, lhs->scale) && CanProveDivisible(base, lhs->scale)) { // floordiv(x*c1 + y*c1, c1*c2) = floordiv(x+y, c2), c2=rhs/scale base = floordiv(base, lhs->scale); rhs = floordiv(rhs, lhs->scale); - lhs.CopyOnWrite()->scale = MakeConst(rhs->dtype, 1); + lhs.CopyOnWrite()->scale = MakeConst(rhs.ty(), 1); } else { // mark as unresolved. ErrorLogger(this) << "Cannot represent as IterMap: the numerator's scaling factor, " @@ -1931,7 +1933,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P new_split = IterSplitExpr(IterMark(padded, padded->extent), /* lower_factor = */ rhs, /* extent = */ analyzer_->Simplify(ceildiv(padded->extent, rhs)), - /* scale = */ MakeConst(rhs->dtype, 1)); + /* scale = */ MakeConst(rhs.ty(), 1)); } auto new_base = analyzer_->Simplify(floordiv(base - left_pad, rhs), 6); @@ -1944,7 +1946,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P } PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Parent::VisitExpr_(op); } @@ -1987,13 +1989,13 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, P if (is_one(rhs)) { // floormod(x, 1) = 0 - return IntImm(lhs->dtype, 0); + return IntImm(lhs.ty(), 0); } if (!is_one(lhs->scale)) { if (CanProveDivisible(lhs->scale, rhs) && CanProveDivisible(base, rhs)) { // floormod(x*c1*c2, c1) = 0 - return IntImm(lhs->dtype, 0); + return IntImm(lhs.ty(), 0); } else if (CanProveDivisible(rhs, lhs->scale) && is_zero(base)) { // floormod(x*c1, c1*c2) = (floormod(x, c2)) * c1, where c2 = rhs/scale rhs = floordiv(rhs, lhs->scale); @@ -2028,7 +2030,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, P } PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Parent::VisitExpr_(op); } @@ -2113,7 +2115,7 @@ class IterMapToExprNormalizer : public ExprMutator { // simplify trivial iters like `vi \in [0, 1)`, which can be useful for subsequent analysis // like tensorization. if (is_one(expr->extent) && !is_one(expr->source->extent)) { - return IntImm(expr->extent->dtype, 0); + return IntImm(expr->extent.ty(), 0); } return floordiv(source, expr->lower_factor) * expr->scale; } else { @@ -2255,13 +2257,13 @@ class SubspaceDivider { IterSplitExpr GetInnerAsSplit() const { return GetAsSplit(inner, inner_extent); } static DivisionResult Inner(const IterMapExpr& iter, const PrimExpr& extent) { - auto dtype = iter.dtype(); + PrimType dtype = iter.ty(); return DivisionResult(IterSumExpr({}, IntImm(dtype, 0)), IntImm(dtype, 1), iter, extent, Kind::kInner); } static DivisionResult Outer(const IterMapExpr& iter, const PrimExpr& extent) { - auto dtype = iter.dtype(); + PrimType dtype = iter.ty(); return DivisionResult(iter, extent, IterSumExpr({}, IntImm(dtype, 0)), IntImm(dtype, 1), Kind::kOuter); } @@ -2285,7 +2287,7 @@ class SubspaceDivider { // Divide an IterSumExpr DivisionResult DivideIterSumExpr(const IterSumExpr& expr, const PrimExpr& mark_extent) { - auto dtype = expr.dtype(); + PrimType dtype = expr.ty(); if (expr->args.empty()) { // base return DivisionResult(IterSumExpr({}, IntImm(dtype, 0)), IntImm(dtype, 1), @@ -2377,7 +2379,7 @@ class SubspaceDivider { // args are sorted from inner to outer static IterMark MarkFromArgsAndBase(const std::vector& args, PrimExpr base) { std::vector res; - PrimExpr extent = MakeConst(base.dtype(), 1); + PrimExpr extent = MakeConst(base.ty(), 1); for (const IterSplitExpr& it : args) { IterSplitExpr arg = it; arg.CopyOnWrite()->scale = extent; @@ -2431,7 +2433,7 @@ class SubspaceDivider { bool encountered_boundary = mark_division.IsOuter(); std::vector used(splits.size(), false); std::vector inner_iters, outer_iters; - PrimExpr expected_lower_factor = MakeConst(expr->source->source->dtype, 1); + PrimExpr expected_lower_factor = MakeConst(expr->source->source.ty(), 1); // find the boundary of outer and inner, like case 1 above for (size_t i = 0; i < splits.size(); ++i) { size_t j = 0; diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index bb1ebd54cca7..dda8e704cfed 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -199,7 +199,10 @@ class PVar : public Pattern> { // Store PVars by reference in the expression. using Nested = const PVar&; - void InitMatch_() const { filled_ = false; } + void InitMatch_() const { + value_ = nullptr; + filled_ = false; + } bool Match_(const T& value) const { if (!filled_) { @@ -207,7 +210,7 @@ class PVar : public Pattern> { filled_ = true; return true; } else { - return PEqualChecker()(value_, value); + return PEqualChecker()(value_.value(), value); } } @@ -223,14 +226,14 @@ class PVar : public Pattern> { T Eval() const { TVM_FFI_ICHECK(filled_); - return value_; + return value_.value(); } - T EvalOr(const T& default_value) const { return filled_ ? value_ : default_value; } + T EvalOr(const T& default_value) const { return filled_ ? value_.value() : default_value; } protected: /*! \brief The matched value */ - mutable T value_; + mutable ffi::Optional value_; /*! \brief whether the variable has been filled */ mutable bool filled_{false}; }; @@ -282,7 +285,7 @@ class PVarWithDataType : public PVarWithCheck, T> { public: explicit PVarWithDataType(const DType& dtype) : dtype_(dtype) {} - bool Match_(const T& value) const { return dtype_.Match_(value->dtype); } + bool Match_(const T& value) const { return dtype_.Match_(value.ty()); } protected: typename DType::Nested dtype_; @@ -291,15 +294,15 @@ class PVarWithDataType : public PVarWithCheck, T> { /*! * \brief Pattern variable container for data type with lanes. */ -class PVecDataType : public PVarWithCheck { +class PVecDataType : public PVarWithCheck { public: /*! \brief construct vector dtype placeholder with element type check */ - explicit PVecDataType(const DataType& elem_dtype) : elem_dtype_(elem_dtype) {} + explicit PVecDataType(PrimType elem_dtype) : elem_dtype_(elem_dtype) {} - bool Match_(const DataType& dtype) const { return dtype.code() == elem_dtype_.code(); } + bool Match_(PrimType dtype) const { return dtype.code() == elem_dtype_.code(); } protected: - DataType elem_dtype_; + PrimType elem_dtype_; }; /*! @@ -377,7 +380,7 @@ class PConstWithTypeLike : public Pattern> { } } - PrimExpr Eval() const { return tirx::MakeConst(ref_.Eval().dtype(), value_); } + PrimExpr Eval() const { return tirx::MakeConst(ref_.Eval().ty(), value_); } private: typename TA::Nested ref_; @@ -540,7 +543,7 @@ class PCastExpr : public Pattern> { bool Match_(const ffi::ObjectRef& node) const { if (const tirx::CastNode* ptr = node.as()) { - if (!dtype_.Match_(ptr->dtype)) return false; + if (!dtype_.Match_(ptr->ty())) return false; if (!value_.Match_(ptr->value)) return false; return true; } else { @@ -558,7 +561,7 @@ class PCastExpr : public Pattern> { /*! * \brief Construct a cast pattern. * - * \param dtype The target data type, can be PVar or PConst. + * \param dtype The target data type, can be PVar or PConst. * \param value The input type. * * \return The result pattern. @@ -780,7 +783,7 @@ class PCallExpr : public Pattern> { #define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinOpName) \ struct OpName { \ static PrimExpr Eval(ffi::Array args) { \ - return tirx::Call(args[0].dtype(), GetOp(), args); \ + return tirx::Call(args[0].ty(), GetOp(), args); \ } \ static const Op& GetOp() { return tirx::builtin::IntrinOpName(); } \ }; \ @@ -799,7 +802,7 @@ TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, bitwise_xor); #define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \ struct OpName { \ static PrimExpr Eval(ffi::Array args) { \ - return tirx::Call(args[0].dtype(), GetOp(), args); \ + return tirx::Call(args[0].ty(), GetOp(), args); \ } \ static const Op& GetOp() { return tirx::builtin::IntrinOpName(); } \ }; \ @@ -813,7 +816,7 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not); // if_then_else struct PIfThenElseOp { static PrimExpr Eval(ffi::Array args) { - return tirx::Call(args[1].dtype(), GetOp(), args); + return tirx::Call(args[1].ty(), GetOp(), args); } static const Op& GetOp() { return tirx::builtin::if_then_else(); } }; @@ -841,7 +844,7 @@ inline PCallExpr if_then_else(const Pattern // vscale struct PVscaleOp { - static PrimExpr Eval() { return tirx::Call(DataType::Int(32), GetOp(), {}); } + static PrimExpr Eval() { return tirx::Call(PrimType::Int(32), GetOp(), {}); } static const Op& GetOp() { return tirx::builtin::vscale(); } }; diff --git a/src/arith/product_normal_form.h b/src/arith/product_normal_form.h index 40d02c1952b7..79e040287fa7 100644 --- a/src/arith/product_normal_form.h +++ b/src/arith/product_normal_form.h @@ -79,7 +79,8 @@ inline void UnpackSum(const PrimExpr& value, FLeaf fleaf, int sign = 1) { */ inline PrimExpr MulAndNormalize(const PrimExpr& lhs, const PrimExpr& rhs) { int64_t cscale = 1; - PrimExpr res = tirx::MakeConst(lhs.dtype(), 1); + PrimType lhs_ty = lhs.ty(); + PrimExpr res = tirx::MakeConst(lhs_ty, 1); auto fcollect = [&](PrimExpr val) { if (const auto* intimm = val.as()) { cscale *= intimm->value; @@ -90,7 +91,7 @@ inline PrimExpr MulAndNormalize(const PrimExpr& lhs, const PrimExpr& rhs) { UnpackReduction(lhs, fcollect); UnpackReduction(rhs, fcollect); if (cscale != 1) { - res = res * tirx::MakeConst(res.dtype(), cscale); + res = res * tirx::MakeConst(res.ty(), cscale); } return res; } diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index fa3ba0b519d6..07ea2c7a7778 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -425,7 +425,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes), ramp(b1 + b2, s1 + s2, lanes)); TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), ramp(b1 + x, s1, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), ramp(x + b1, s1, lanes)); @@ -433,7 +433,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE_IF(x + broadcast(c4, lanes), x, c4.Eval()->value == 0.0f); } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { // Index rules // cancelation rules TVM_TRY_REWRITE((x - y) + y, x); @@ -535,7 +535,7 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c if (SideEffect(subconstraint) <= CallEffectKind::kPure) { literal_constraints_.push_back(subconstraint); PrimExpr negation; - if (subconstraint.dtype().is_bool()) { + if (subconstraint.ty().MatchesElementType(DLDataTypeCode::kDLBool, 8)) { // We could apply NormalizeBooleanOperators during // TryMatchLiteralConstraint, but that would require // performing a rewrite of each expression being checked. @@ -543,7 +543,7 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c // applied. negation = NormalizeBooleanOperators(Not(subconstraint)); } else { - negation = subconstraint == IntImm(subconstraint.dtype(), 0); + negation = subconstraint == IntImm(subconstraint.ty(), 0); } literal_constraints_.push_back(Not(negation)); } @@ -575,14 +575,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { PVar lanes; // Vector rules - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2, s1 - s2, lanes)); TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes), ramp(b1 - x, s1, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) - ramp(b1, s1, lanes), ramp(x - b1, 0 - s1, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) - broadcast(y, lanes), broadcast(x - y, lanes)); } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { // Index rules // cancelation rules TVM_TRY_REWRITE(matches_one_of((x + y) - y, (y + x) - y), x); @@ -765,7 +765,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x * y, lanes)); TVM_TRY_REWRITE(matches_one_of(ramp(b1, s1, lanes) * broadcast(x, lanes), broadcast(x, lanes) * ramp(b1, s1, lanes)), @@ -773,7 +773,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { TVM_TRY_REWRITE_IF(broadcast(c3, lanes) * x, broadcast(c3, lanes), c3.Eval()->value == 0.0f); } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { // constant simplification rule TVM_TRY_REWRITE((x + c1) * c2, x * c2 + c1 * c2); TVM_TRY_REWRITE((x * c1) * c2, x * (c1 * c2)); @@ -803,7 +803,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { PVar lanes; // Vector rules - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { // NOTE: use div as the pattern also works for float. TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)), broadcast(div(x, y), lanes)); // ramp / bcast @@ -827,7 +827,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { } } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { // Be-aware of the division rules: // We adopt the default C division uses truncation instead of floordiv. // This means most rules need to check non-negativeness of the operands. @@ -839,7 +839,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { if (truncdiv(c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - return MakeConst(op->dtype, truncdiv(c1val, c2val)); + return MakeConst(op->ty(), truncdiv(c1val, c2val)); } // while it is always true for trunc div @@ -957,7 +957,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { PVar lanes; // Vector rules - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(truncmod(broadcast(x, lanes), broadcast(y, lanes)), broadcast(truncmod(x, y), lanes)); @@ -994,7 +994,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { } } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { // Be-aware of the division rules: // We adopt the default C division uses truncation instead of floordiv. // This means most rules need to check non-negativeness of the operands. @@ -1019,7 +1019,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { // canonicalization: x % c == x % (-c) for truncated division // NOTE: trunc div required TVM_TRY_RECURSIVE_REWRITE_IF( - truncmod(x, c1), truncmod(x, PConst(MakeConst(op->dtype, -c1.Eval()->value))), + truncmod(x, c1), truncmod(x, PConst(MakeConst(op->ty(), -c1.Eval()->value))), c1.Eval()->value < 0); // try modular analysis @@ -1046,7 +1046,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PVar lanes; // Vector rules - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(floordiv(broadcast(x, lanes), broadcast(y, lanes)), broadcast(floordiv(x, y), lanes)); // ramp // bcast @@ -1077,7 +1077,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { } } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { // Be-aware of the division rules: this is floor division. TVM_TRY_REWRITE_IF(floordiv(floordiv(x, c1), c2), floordiv(x, c1 * c2), c1.Eval()->value > 0 && c2.Eval()->value > 0); @@ -1198,7 +1198,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { PVar lanes; // Vector rules - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(floormod(broadcast(x, lanes), broadcast(y, lanes)), broadcast(floormod(x, y), lanes)); @@ -1238,7 +1238,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { } } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { // Be-aware of the division rules: we use floordiv/floormod here TVM_TRY_REWRITE_IF(floormod(x * c1, c2), floormod(x * floormod(c1, c2), c2), c2.Eval()->value != 0); @@ -1314,12 +1314,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) { PVar lanes; // vector rule - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)), broadcast(min(x, y), lanes)); TVM_TRY_REWRITE(min(min(x, broadcast(y, lanes)), broadcast(z, lanes)), min(x, broadcast(min(y, z), lanes))); } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { TVM_TRY_REWRITE(min(x, x), x); // constant int bound @@ -1498,12 +1498,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { PVar lanes; // vector rule - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)), broadcast(max(x, y), lanes)); TVM_TRY_REWRITE(max(max(x, broadcast(y, lanes)), broadcast(z, lanes)), max(x, broadcast(max(y, z), lanes))); } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { TVM_TRY_REWRITE(max(x, x), x); // constant int bound @@ -1686,10 +1686,10 @@ ffi::Optional RewriteSimplifier::Impl::TryMatchLiteralConstraint( ExprDeepEqual expr_equal; for (const auto& constraint : literal_constraints_) { if (expr_equal(constraint, expr)) { - return MakeConst(expr->dtype, true); + return MakeConst(expr->ty(), true); } if (expr_equal(constraint, negation)) { - return MakeConst(expr->dtype, false); + return MakeConst(expr->ty(), false); } } return std::nullopt; @@ -1715,20 +1715,20 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { // Pattern var match IntImm PVar c1, c2; PVar lanes; - PConst ctrue(MakeConst(ret->dtype, true)); + PConst ctrue(MakeConst(ret->ty(), true)); // vector rule - if (ret->dtype.is_scalable_or_fixed_length_vector()) { + if (ret->ty().IsScalableVector() || ret->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), broadcast(x == y, lanes)); } - if (IsIndexType(ret->a.dtype())) { + if (IsIndexTypedExpr(ret->a)) { CompareResult result = TryCompare(ret->a, ret->b); if (result == CompareResult::kEQ) { - return MakeConst(ret->dtype, true); + return MakeConst(ret->ty(), true); } else if (result == CompareResult::kNE || result == CompareResult::kGT || result == CompareResult::kLT) { - return MakeConst(ret->dtype, false); + return MakeConst(ret->ty(), false); } TVM_TRY_REWRITE(c1 == x, x == c1); @@ -1758,13 +1758,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NENode* op) { if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); - if (IsIndexType(op->a.dtype())) { + if (IsIndexTypedExpr(op->a)) { CompareResult result = TryCompare(op->a, op->b); if (result == CompareResult::kNE || result == CompareResult::kGT || result == CompareResult::kLT) { - return MakeConst(op->dtype, true); + return MakeConst(op->ty(), true); } else if (result == CompareResult::kEQ) { - return MakeConst(op->dtype, false); + return MakeConst(op->ty(), false); } else if (result == CompareResult::kGE) { // Known: a >= b // @@ -1802,13 +1802,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LENode* op) { // (floordiv(A,B)b, op->a)))); - if (auto op = ret.as(); op && IsIndexType(op->a.dtype())) { + if (auto op = ret.as(); op && IsIndexTypedExpr(op->a)) { CompareResult result = TryCompare(op->a, op->b); if (result == CompareResult::kLE || result == CompareResult::kLT || result == CompareResult::kEQ) { - return MakeConst(op->dtype, true); + return MakeConst(op->ty(), true); } else if (result == CompareResult::kGT) { - return MakeConst(op->dtype, false); + return MakeConst(op->ty(), false); } else if (result == CompareResult::kNE) { // Known: a != b // @@ -1857,19 +1857,19 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { PVar lanes; // vector rule - if (ret->dtype.is_scalable_or_fixed_length_vector()) { + if (ret->ty().IsScalableVector() || ret->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes), broadcast(x < y, lanes)); TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes), broadcast(x < y, lanes)); } - if (IsIndexType(ret->a.dtype())) { + if (IsIndexTypedExpr(ret->a)) { CompareResult result = TryCompare(ret->a, ret->b); if (result == CompareResult::kLT) { - return MakeConst(ret->dtype, true); + return MakeConst(ret->ty(), true); } if (result == CompareResult::kEQ || result == CompareResult::kGT || result == CompareResult::kGE) { - return MakeConst(ret->dtype, false); + return MakeConst(ret->ty(), false); } // clang-format off @@ -1987,9 +1987,9 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { } else if (diff == 1) { return lhs <= rhs; } else if (diff < 0 && rhs_offset != 0) { - return lhs + MakeConst(lhs.dtype(), -diff) < rhs; + return lhs + MakeConst(lhs.ty(), -diff) < rhs; } else if (diff > 0 && lhs_offset != 0) { - return lhs < rhs + MakeConst(rhs.dtype(), diff); + return lhs < rhs + MakeConst(rhs.ty(), diff); } return std::nullopt; @@ -2024,7 +2024,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(Not ret) { // Pattern var to match any expression PVar x, y; PVar lanes; - if (ret->dtype.is_scalable_or_fixed_length_vector()) { + if (ret->ty().IsScalableVector() || ret->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes)); } @@ -2100,11 +2100,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { PVar c1, c2, c3; PVar lanes; - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes)); } - auto cfalse = PConst(MakeConst(op->dtype, false)); + auto cfalse = PConst(MakeConst(op->ty(), false)); TVM_TRY_REWRITE(x == y && x != y, cfalse); TVM_TRY_REWRITE(x != y && x == y, cfalse); TVM_TRY_REWRITE(x && !x, cfalse); @@ -2248,11 +2248,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { PVar c1, c2; PVar lanes; - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes)); } - auto ctrue = PConst(MakeConst(op->dtype, true)); + auto ctrue = PConst(MakeConst(op->ty(), true)); TVM_TRY_REWRITE(x == y || x != y, ctrue); TVM_TRY_REWRITE(x != y || x == y, ctrue); @@ -2319,12 +2319,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { static const Op& ceil_op = Op::Get("tirx.ceil"); static const Op& log2_op = Op::Get("tirx.log2"); static const Op& clz_op = Op::Get("tirx.clz"); + PrimType ret_ty = ffi::GetRef(op).ty(); if (op->op.same_as(ceil_op)) { PrimExpr ceil_arg = op->args[0]; if (auto arg_int = op->args[0].as()) { - return cast(op->dtype, IntImm(arg_int->dtype, arg_int->value)); + return cast(ret_ty, IntImm(ffi::GetRef(arg_int).ty(), arg_int->value)); } else if (auto arg_float = ceil_arg.as()) { - return cast(op->dtype, FloatImm(arg_float->dtype, std::ceil(arg_float->value))); + return cast(ret_ty, + FloatImm(ffi::GetRef(arg_float).ty(), std::ceil(arg_float->value))); } else if (auto arg_call = ceil_arg.as()) { // ceil(log2(cast(n,"float64"))) is used as the implementation of // topi.math.ceil_log2, and appears in iteration bounds. @@ -2334,17 +2336,17 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { // ceil(log2(n)) can be simplified, and should produce the // same integer result regardless of the target's rounding // conventions. - return FloatImm(op->dtype, std::ceil(std::log2(as_float->value))); + return FloatImm(ret_ty, std::ceil(std::log2(as_float->value))); } } } } else if (op->op.same_as(clz_op)) { if (const auto* arg_int = op->args[0].as()) { - int bits = arg_int->dtype.bits(); - if (arg_int->value == 0) return MakeConst(op->dtype, bits); + int bits = arg_int->ty().bits(); + if (arg_int->value == 0) return MakeConst(ret_ty, bits); for (int i = bits - 1; i >= 0; --i) { if ((int64_t(1) << i) & arg_int->value) { - return IntImm(op->dtype, bits - i - 1); + return IntImm(ret_ty, bits - i - 1); } } TVM_FFI_THROW(InternalError) << "Should not reach here"; @@ -2373,7 +2375,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { // Only check constant cases to avoid recursion if (is_const_number(inner_else_expr) && is_const_number(else_expr) && analyzer_->CanProve(inner_else_expr == else_expr)) { - return Call(op->dtype, op->op, {cond && inner_cond, inner_then_expr, else_expr}, op->attrs, + return Call(ret_ty, op->op, {cond && inner_cond, inner_then_expr, else_expr}, op->attrs, op->span); } } @@ -2384,7 +2386,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) { Var var = ffi::GetRef(op); - if (op->dtype == DataType::Bool()) { + PrimType op_ty = op->ty(); + if (op_ty.MatchesElementType(DLDataTypeCode::kDLBool, 8) && !op_ty.IsScalableVector() && + !op_ty.IsFixedLengthVector()) { if (auto match = TryMatchLiteralConstraint(var)) { return match.value(); } @@ -2400,7 +2404,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CastNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - return cast(op->dtype, op->value); + return cast(ret.ty(), op->value); } bool RewriteSimplifier::Impl::CanInlineLet(const LetNode* op) { diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 27144c674b9f..fd507ccdd658 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -24,9 +24,9 @@ #include #include #include +#include #include #include -#include #include #include #include @@ -133,10 +133,10 @@ void SmithNormalFormDiag(std::vector>* S, std::vector>* S, std::vector()) { name_hint += "_" + v_old->name_hint; } - Var v = Var(name_hint, V_inv_x[j].dtype()); + Var v = Var(name_hint, V_inv_x[j].ty()); solution_for_V_inv_x.push_back(v); new_vars.push_back(v); new_to_old_map.Set(v, to_old); @@ -403,12 +403,12 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol // The j-th variable is just a single value, don't create a tvm variable // S^{-1}_{nxm} Uy_{mxn} if (S[j][j] >= 0) { - PrimExpr a = tirx::MakeConst(Uy[j].dtype(), S[j][j]); + PrimExpr a = tirx::MakeConst(Uy[j].ty(), S[j][j]); solution_for_V_inv_x.push_back(analyzer_problem->Simplify(floordiv(Uy[j], a))); } else { // This is required because some simplifiers // have problems with dividing by negative numbers - PrimExpr a = tirx::MakeConst(Uy[j].dtype(), -S[j][j]); + PrimExpr a = tirx::MakeConst(Uy[j].ty(), -S[j][j]); solution_for_V_inv_x.push_back(analyzer_problem->Simplify(floordiv(-Uy[j], a))); } } @@ -416,9 +416,9 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol // V V^{-1} x = x for (size_t i = 0; i < num_vars; ++i) { - PrimExpr e = IntImm(system_to_solve->variables[i].dtype(), 0); + PrimExpr e = IntImm(system_to_solve->variables[i].ty(), 0); for (size_t j = 0; j < num_vars; ++j) { - e = e + tirx::MakeConst(e.dtype(), V[i][j]) * solution_for_V_inv_x[j]; + e = e + tirx::MakeConst(e.ty(), V[i][j]) * solution_for_V_inv_x[j]; } e = analyzer_problem->Simplify(e); old_to_new_map.Set(system_to_solve->variables[i], e); diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 80d064f71157..14b1affb9927 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -24,9 +24,9 @@ #include #include #include +#include #include #include -#include #include #include #include @@ -91,10 +91,12 @@ class NormalizeComparisons : public ExprMutator { template PrimExpr Make(const PrimExpr& a, const PrimExpr& b) { // rewrite LT to LE for ints - if (std::is_same::value && (a.dtype().is_int() || a.dtype().is_uint())) { - return LE(analyzer_->Simplify(a - b + 1), IntImm(a.dtype(), 0)); + PrimType a_ty = a.ty(); + if (std::is_same::value && + (a_ty.code() == DLDataTypeCode::kDLInt || a_ty.code() == DLDataTypeCode::kDLUInt)) { + return LE(analyzer_->Simplify(a - b + 1), IntImm(a.ty(), 0)); } - return T(analyzer_->Simplify(a - b), IntImm(a.dtype(), 0)); + return T(analyzer_->Simplify(a - b), IntImm(a.ty(), 0)); } arith::Analyzer analyzer_; }; @@ -248,11 +250,12 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t for (const auto& pos : coef_pos) { for (const auto& neg : coef_neg) { auto first_gcd = ExtendedEuclidean(pos.first, -neg.first, &gcd_x, &gcd_y); - PrimExpr c_pos = MakeConst(v.dtype(), neg.first / first_gcd); - PrimExpr c_neg = MakeConst(v.dtype(), pos.first / first_gcd); + PrimType v_ty = v.ty(); + PrimExpr c_pos = MakeConst(v_ty, neg.first / first_gcd); + PrimExpr c_neg = MakeConst(v_ty, pos.first / first_gcd); // eliminate the current variable PrimExpr new_lhs = c_neg * neg.second - c_pos * pos.second; - PrimExpr new_ineq = LE(new_lhs, IntImm(pos.second.dtype(), 0)); + PrimExpr new_ineq = LE(new_lhs, IntImm(pos.second.ty(), 0)); // we need rewrite_simplify -> canonical_simplify -> rewrite_simplify // to help simplify things like (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0 // with steps = 2 it's (y*2) - 10 <= 0 @@ -281,7 +284,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t lower_bounds.reserve(coef_neg.size()); for (const auto& pos : coef_pos) { - PrimExpr bound = MakeConst(v.dtype(), -coef_lcm / pos.first) * pos.second; + PrimExpr bound = MakeConst(v.ty(), -coef_lcm / pos.first) * pos.second; bound = analyzer->Simplify(bound, kSimplifyRewriteCanonicalRewrite); // Don't add if any of the existing bounds is better if (std::any_of(upper_bounds.begin(), upper_bounds.end(), @@ -302,7 +305,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t upper_bounds.push_back(bound); } for (const auto& neg : coef_neg) { - PrimExpr bound = MakeConst(v.dtype(), -coef_lcm / neg.first) * neg.second; + PrimExpr bound = MakeConst(v.ty(), -coef_lcm / neg.first) * neg.second; bound = analyzer->Simplify(bound, kSimplifyRewriteCanonicalRewrite); // Don't add if any of the existing bounds is better if (std::any_of(lower_bounds.begin(), lower_bounds.end(), @@ -330,7 +333,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t std::sort(equal_list.begin(), equal_list.end(), ExprLess()); // Write it to the result. - IntGroupBounds bnds(MakeConst(v.dtype(), coef_lcm), + IntGroupBounds bnds(MakeConst(v.ty(), coef_lcm), ffi::Array(lower_bounds.begin(), lower_bounds.end()), ffi::Array(equal_list.begin(), equal_list.end()), ffi::Array(upper_bounds.begin(), upper_bounds.end())); @@ -509,7 +512,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ analyzer->Simplify(var - Substitute(best_range->min, res_dst_to_src))); // Add the new var to the resulting axis - auto range = Range(IntImm(new_var.dtype(), 0), best_range->extent); + auto range = Range(IntImm(new_var.ty(), 0), best_range->extent); res_variables.push_back(new_var); res_ranges.Set(new_var, range); diff --git a/src/arith/transitive_comparison_analyzer.cc b/src/arith/transitive_comparison_analyzer.cc index 20fd05169f43..7b740d6229c2 100644 --- a/src/arith/transitive_comparison_analyzer.cc +++ b/src/arith/transitive_comparison_analyzer.cc @@ -615,7 +615,8 @@ CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs const PrimExpr& rhs_expr, bool propagate_inequalities) const { // Currently only supports integer checks - if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) { + if (!lhs_expr.ty().MatchesCode(DLDataTypeCode::kDLInt) || + !rhs_expr.ty().MatchesCode(DLDataTypeCode::kDLInt)) { return CompareResult::kUnknown; } diff --git a/src/arith/unwrap_vector_expr.cc b/src/arith/unwrap_vector_expr.cc index e9245c48a102..dfe7a3cf404b 100644 --- a/src/arith/unwrap_vector_expr.cc +++ b/src/arith/unwrap_vector_expr.cc @@ -58,14 +58,16 @@ class Scalarizer : public ExprMutator { } } PrimExpr VisitExpr_(const LetNode* op) final { - if (op->value.dtype().lanes() == 1) { + PrimType value_ty = op->value.ty(); + if (value_ty.lanes() == 1) { return ExprMutator::VisitExpr_(op); } auto it = let_var_remap_.find(op->var.get()); TVM_FFI_ICHECK(it == let_var_remap_.end()) << "Duplicate binding of variable " << op->var; - Var new_var(op->var->name_hint + "_scalar", op->var.dtype().element_of()); + PrimType var_ty = op->var.ty(); + Var new_var(op->var->name_hint + "_scalar", var_ty.WithLanes(1)); let_var_remap_[op->var.get()] = new_var; PrimExpr value = this->VisitExpr(op->value); diff --git a/src/arith/z3_prover.cc b/src/arith/z3_prover.cc index 604815c97955..9ceb156dead8 100644 --- a/src/arith/z3_prover.cc +++ b/src/arith/z3_prover.cc @@ -50,10 +50,10 @@ #include #include "tvm/ffi/cast.h" +#include "tvm/ffi/dtype.h" #include "tvm/ffi/object.h" #include "tvm/ffi/string.h" #include "tvm/ir/expr.h" -#include "tvm/runtime/data_type.h" #include "z3++.h" namespace tvm::arith { @@ -147,14 +147,14 @@ class Z3Prover::Impl : ExprFunctor { /// @brief Create a Free z3 expression from PrimExprNode z3::expr Create(const PrimExprNode* op) { auto ref = ffi::GetRef(op); - auto dtype = op->dtype; + PrimType dtype = op->ty(); std::string name = ns.GetNewName(ref); /// TVM max_val can't handle uint64 max correctly, so we special case it here - if (dtype.is_bool()) { + if (dtype.MatchesCode(DLDataTypeCode::kDLBool)) { return ctx->bool_const(name.c_str()); } else { z3::expr e = ctx->int_const(name.c_str()); - if (dtype.is_uint() && dtype.bits() == 64) { + if (dtype.MatchesCode(DLDataTypeCode::kDLUInt) && dtype.bits() == 64) { solver.add(ctx->int_val(0) <= e && e <= ctx->int_val((uint64_t)UINT64_MAX)); } else { auto min_val = min_value(dtype).as_or_throw()->value; @@ -249,7 +249,7 @@ class Z3Prover::Impl : ExprFunctor { // solver) must degrade to "cannot prove" instead of escaping to the caller. try { if (CheckTrivilBadCases(expr)) return false; - if (!IsValidDType(expr->dtype)) return false; + if (!IsValidType(expr.ty())) return false; z3::expr_vector constr(*ctx); constr.push_back(!ConvertBool(expr)); auto result = solver.check(constr); @@ -263,7 +263,7 @@ class Z3Prover::Impl : ExprFunctor { /// @brief Binded /// @brief Bind a variable to a value or a range void Bind(const Var& var, const PrimExpr& value, bool allow_override = false) { - if (!IsValidDType(var->dtype)) return; + if (!IsValidType(var.ty())) return; scope_stack_.back().push_back(Scope{Scope::BindValue, var, value}); // we add the binding whenever the value is pure, // because non-pure parts are handling by creating free variables in VisitExpr @@ -272,7 +272,7 @@ class Z3Prover::Impl : ExprFunctor { /// @brief Bind a variable to a range void Bind(const Var& var, const Range& range, bool allow_override = false) { - if (!IsValidDType(var->dtype)) return; + if (!IsValidType(var.ty())) return; scope_stack_.back().push_back( Scope{Scope::BindRange, var, PrimExpr(), range->min, range->extent}); // 1. Create a placeholder for the var, and save it in the memo @@ -427,7 +427,7 @@ class Z3Prover::Impl : ExprFunctor { * \return Number of satisfying values, -1 on error, -2 if min_consecutive constraint not met */ int64_t CountSatisfyingValues(const Var& var, int64_t max_count, int64_t min_consecutive = 1) { - if (!IsValidDType(var->dtype)) { + if (!IsValidType(var.ty())) { return -1; } @@ -550,12 +550,14 @@ class Z3Prover::Impl : ExprFunctor { } return e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance() || - (e->IsInstance() && !IsValidDType(e.as_or_throw()->value->dtype)); + (e->IsInstance() && !IsValidType(e.as_or_throw()->value.ty())); } /// @brief Check if the dtype is valid for z3 integer operations - static bool IsValidDType(const DataType& dtype) { - return (dtype.is_int() || dtype.is_uint() || dtype.is_bool()) && dtype.lanes() == 1; + static bool IsValidType(const PrimType& dtype) { + return dtype.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt, + DLDataTypeCode::kDLBool) && + dtype.lanes() == 1; } /// @brief Visit the expression and convert it into z3 integer expression @@ -581,7 +583,7 @@ class Z3Prover::Impl : ExprFunctor { /// @brief Helper function to visit binary arithmetic operations z3::expr VisitArith(Z3BinOp signed_op, const PrimExprNode* op, const PrimExpr& a, const PrimExpr& b) { - if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + if (IsValidType(a.ty()) && IsValidType(b.ty())) { return signed_op(VisitInt(a), VisitInt(b)); } else { return Create(op); @@ -589,14 +591,14 @@ class Z3Prover::Impl : ExprFunctor { } z3::expr VisitExpr_(const LetNode* op) override { - if (IsValidDType(op->var->dtype)) { + if (IsValidType(op->var.ty())) { memo_.emplace(op->var, VisitInt(op->value)); } return VisitExpr(op->body); } z3::expr VisitExpr_(const CastNode* op) override { // if the inner dtype is valid, we just visit it - if (IsValidDType(op->value->dtype) && IsValidDType(op->dtype)) { + if (IsValidType(op->value.ty()) && IsValidType(op->ty())) { return VisitInt(op->value); } else { // otherwise, we create a new free z3 variable @@ -696,7 +698,7 @@ class Z3Prover::Impl : ExprFunctor { } else if (op->op.same_as(tirx::builtin::shift_right())) { return VisitShiftOp(z3::ashr, op); } else if (op->op.same_as(tirx::builtin::if_then_else()) && op->args.size() == 3 && - IsValidDType(op->args[1]->dtype) && IsValidDType(op->args[2]->dtype)) { + IsValidType(op->args[1].ty()) && IsValidType(op->args[2].ty())) { // tir.if_then_else(cond, a, b) is a select-like ternary. return z3::ite(VisitBool(op->args[0]), VisitInt(op->args[1]), VisitInt(op->args[2])); } else { @@ -715,9 +717,9 @@ class Z3Prover::Impl : ExprFunctor { const PrimExpr& a = op->args[0]; const PrimExpr& b = op->args[1]; - unsigned bit_width = std::max(op->args[0].dtype().bits(), op->args[1].dtype().bits()); + unsigned bit_width = std::max(op->args[0].ty().bits(), op->args[1].ty().bits()); - if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + if (IsValidType(a.ty()) && IsValidType(b.ty())) { return z3::bv2int( op_func(z3::int2bv(bit_width, VisitInt(a)), z3::int2bv(bit_width, VisitInt(b))), true); } else { @@ -734,9 +736,9 @@ class Z3Prover::Impl : ExprFunctor { const PrimExpr& a = op->args[0]; - if (IsValidDType(a->dtype)) { + if (IsValidType(a.ty())) { // Cast integer to bit-vector, apply bitwise not, then cast back. - unsigned bit_width = a.dtype().bits(); + unsigned bit_width = a.ty().bits(); z3::expr a_int = VisitInt(a); z3::expr a_bv = z3::int2bv(bit_width, a_int); return z3::bv2int(~a_bv, true); @@ -756,7 +758,7 @@ class Z3Prover::Impl : ExprFunctor { const PrimExpr& b = op->args[1]; // Shift operations require integer types for both operands - if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + if (IsValidType(a.ty()) && IsValidType(b.ty())) { z3::expr a_expr = VisitInt(a); z3::expr b_expr = VisitInt(b); @@ -765,7 +767,7 @@ class Z3Prover::Impl : ExprFunctor { // matching push/pop in this path, so the assertion would permanently // poison the shared solver and make all subsequent unrelated proofs about // `b` unsound. - unsigned bit_width = std::max(a.dtype().bits(), b.dtype().bits()); + unsigned bit_width = std::max(a.ty().bits(), b.ty().bits()); z3::expr a_bv = z3::int2bv(bit_width, a_expr); z3::expr b_bv = z3::int2bv(bit_width, b_expr); diff --git a/src/backend/cuda/codegen/codegen_cuda.cc b/src/backend/cuda/codegen/codegen_cuda.cc index 0f2838014b28..0d70d9aef3fd 100644 --- a/src/backend/cuda/codegen/codegen_cuda.cc +++ b/src/backend/cuda/codegen/codegen_cuda.cc @@ -56,13 +56,32 @@ bool IsOp(const tirx::CallNode* call, const Op& compat_op, const char* canonical return op_node != nullptr && op_node->name == canonical_name; } +bool IsCUDAFloat8(DLDataTypeCode code) { + return code == DLDataTypeCode::kDLFloat8_e3m4 || code == DLDataTypeCode::kDLFloat8_e4m3 || + code == DLDataTypeCode::kDLFloat8_e4m3b11fnuz || + code == DLDataTypeCode::kDLFloat8_e4m3fn || code == DLDataTypeCode::kDLFloat8_e4m3fnuz || + code == DLDataTypeCode::kDLFloat8_e5m2 || code == DLDataTypeCode::kDLFloat8_e5m2fnuz || + code == DLDataTypeCode::kDLFloat8_e8m0fnu; +} + +bool IsCUDAFloat6(DLDataTypeCode code) { + return code == DLDataTypeCode::kDLFloat6_e2m3fn || code == DLDataTypeCode::kDLFloat6_e3m2fn; +} + +bool IsCUDAFloat4(DLDataTypeCode code) { return code == DLDataTypeCode::kDLFloat4_e2m1fn; } + +bool IsCUDAPackedFloat(DLDataTypeCode code) { + return IsCUDAFloat8(code) || IsCUDAFloat6(code) || IsCUDAFloat4(code); +} + } // namespace -std::string GetFP8Type(DataType type) { +std::string GetFP8Type(DLDataType type) { + PrimType type_ty(type); std::stringstream stream; - int32_t lanes = type.lanes(); + int32_t lanes = type_ty.lanes(); std::string vec; - if (type.is_scalar()) { + if (type_ty.IsScalar()) { vec = ""; } else if (lanes == 2) { vec = "x2"; @@ -78,11 +97,12 @@ std::string GetFP8Type(DataType type) { } stream << "__nv_fp8"; std::string suffix; - if (type.code() == DataType::kFloat8_e4m3fn) { + DLDataTypeCode code = type_ty.code(); + if (code == DLDataTypeCode::kDLFloat8_e4m3fn) { suffix = "_e4m3"; - } else if (type.code() == DataType::kFloat8_e5m2) { + } else if (code == DLDataTypeCode::kDLFloat8_e5m2) { suffix = "_e5m2"; - } else if (type.code() == DataType::kFloat8_e8m0fnu) { + } else if (code == DLDataTypeCode::kDLFloat8_e8m0fnu) { suffix = "_e8m0"; } else { TVM_FFI_THROW(InternalError) << "Unsupported FP8 type in CUDA codegen"; @@ -91,11 +111,12 @@ std::string GetFP8Type(DataType type) { return stream.str(); } -std::string GetFP6Type(DataType type) { +std::string GetFP6Type(DLDataType type) { + PrimType type_ty(type); std::stringstream stream; - int32_t lanes = type.lanes(); + int32_t lanes = type_ty.lanes(); std::string vec; - if (type.is_scalar()) { + if (type_ty.IsScalar()) { vec = ""; } else if (lanes == 2) { vec = "x2"; @@ -110,9 +131,10 @@ std::string GetFP6Type(DataType type) { } stream << "__nv_fp6"; std::string suffix; - if (type.code() == DataType::kFloat6_e2m3fn) { + DLDataTypeCode code = type_ty.code(); + if (code == DLDataTypeCode::kDLFloat6_e2m3fn) { suffix = "_e2m3"; - } else if (type.code() == DataType::kFloat6_e3m2fn) { + } else if (code == DLDataTypeCode::kDLFloat6_e3m2fn) { suffix = "_e3m2"; } else { TVM_FFI_THROW(InternalError) << "Unsupported FP6 type in CUDA codegen"; @@ -121,11 +143,12 @@ std::string GetFP6Type(DataType type) { return stream.str(); } -std::string GetFP4Type(DataType type) { +std::string GetFP4Type(DLDataType type) { + PrimType type_ty(type); std::stringstream stream; - int32_t lanes = type.lanes(); + int32_t lanes = type_ty.lanes(); std::string vec; - if (type.is_scalar()) { + if (type_ty.IsScalar()) { vec = ""; } else if (lanes == 2) { vec = "x2"; @@ -140,7 +163,8 @@ std::string GetFP4Type(DataType type) { } stream << "__nv_fp4"; std::string suffix; - if (type.code() == DataType::kFloat4_e2m1fn) { + DLDataTypeCode code = type_ty.code(); + if (code == DLDataTypeCode::kDLFloat4_e2m1fn) { suffix = "_e2m1"; } else { TVM_FFI_THROW(InternalError) << "Unsupported FP4 type in CUDA codegen"; @@ -299,31 +323,34 @@ void CodeGenCUDA::BindThreadIndex(const IterVar& iv) { ";\" : \"=r\"(ctaid) :);\n" " return ctaid;\n" "}\n"); - var_idmap_[iv->var.get()] = CastFromTo(func_name + "()", DataType::UInt(32), iv->var.dtype()); + var_idmap_[iv->var.get()] = + CastFromTo(func_name + "()", DLDataType{kDLUInt, 32, 1}, iv->var.ty()->dtype); } else { - var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); + var_idmap_[iv->var.get()] = + CastFromTo(iv->thread_tag, DLDataType{kDLUInt, 32, 1}, iv->var.ty()->dtype); } } -void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::PrintType(DLDataType raw_t, std::ostream& os) { // NOLINT(*) + PrimType t(raw_t); int lanes = t.lanes(); - if (t.is_handle()) { - TVM_FFI_ICHECK(t.is_scalar()) << "do not yet support vector types"; + if (t.IsHandle()) { + TVM_FFI_ICHECK(t.IsScalar()) << "do not yet support vector types"; os << "void*"; return; } - if (t.is_void()) { + if (t.IsVoid()) { os << "void"; return; } bool fail = false; - if (t.is_float()) { + if (t.code() == DLDataTypeCode::kDLFloat) { switch (t.bits()) { case 16: codegen_tags_.insert("fp16"); - if (t.is_scalar()) { + if (t.IsScalar()) { os << "half"; } else if (lanes <= 8) { TVM_FFI_ICHECK_EQ(lanes % 2, 0) << "Only support an even number of lanes for half type"; @@ -360,15 +387,15 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) fail = true; break; } - if (!fail && (t.is_scalar() || t.bits() == 16)) return; + if (!fail && (t.IsScalar() || t.bits() == 16)) return; if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) return; if (!fail && (lanes >= 2 && lanes <= 4)) { os << lanes; return; } - } else if (t.is_bfloat16()) { + } else if (t.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { codegen_tags_.insert("bf16"); - if (t.is_scalar()) { + if (t.IsScalar()) { os << "nv_bfloat16"; } else if (lanes <= 8) { TVM_FFI_ICHECK_EQ(lanes % 2, 0) << "only support even lane for bfloat16 type"; @@ -381,57 +408,65 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) fail = true; } if (!fail) return; - } else if (t.is_float8()) { + } else if (t.code() == DLDataTypeCode::kDLFloat8_e3m4 || + t.code() == DLDataTypeCode::kDLFloat8_e4m3 || + t.code() == DLDataTypeCode::kDLFloat8_e4m3b11fnuz || + t.code() == DLDataTypeCode::kDLFloat8_e4m3fn || + t.code() == DLDataTypeCode::kDLFloat8_e4m3fnuz || + t.code() == DLDataTypeCode::kDLFloat8_e5m2 || + t.code() == DLDataTypeCode::kDLFloat8_e5m2fnuz || + t.code() == DLDataTypeCode::kDLFloat8_e8m0fnu) { codegen_tags_.insert("fp8"); - if (t.lanes() <= 4) { - os << GetFP8Type(t); + if (lanes <= 4) { + os << GetFP8Type(raw_t); } else { - os << "uint" << t.lanes() / 4; + os << "uint" << lanes / 4; } return; - } else if (t.is_float6()) { + } else if (t.code() == DLDataTypeCode::kDLFloat6_e2m3fn || + t.code() == DLDataTypeCode::kDLFloat6_e3m2fn) { codegen_tags_.insert("fp6"); - if (t.lanes() <= 4) { - os << GetFP6Type(t); + if (lanes <= 4) { + os << GetFP6Type(raw_t); } else { fail = true; } return; - } else if (t.is_float4()) { + } else if (t.code() == DLDataTypeCode::kDLFloat4_e2m1fn) { codegen_tags_.insert("fp4"); - if (t.lanes() <= 4) { - os << GetFP4Type(t); + if (lanes <= 4) { + os << GetFP4Type(raw_t); } else { fail = true; } return; - } else if (t == DataType::Bool()) { + } else if (raw_t == DLDataType{kDLBool, 8, 1}) { os << "bool"; return; - } else if (t.is_vector_bool()) { + } else if (t.code() == DLDataTypeCode::kDLBool && lanes > 1) { // CUDA does not support bool vectors. // Use ushort vectors to represent instead. - int n = t.lanes(); + int n = lanes; if (n <= 4) { os << "ushort" << n; return; } - } else if (t.is_uint() || t.is_int()) { - if (t.is_uint()) { + } else if (t.MatchesCode(DLDataTypeCode::kDLUInt, DLDataTypeCode::kDLInt)) { + if (t.MatchesCode(DLDataTypeCode::kDLUInt)) { os << "u"; } switch (t.bits()) { case 1: { - if (t.is_scalar()) { + if (t.IsScalar()) { os << "int"; return; - } else if (t.lanes() == 8) { + } else if (lanes == 8) { os << "int8_t"; return; - } else if (t.lanes() == 16) { + } else if (lanes == 16) { os << "int16_t"; return; - } else if (t.lanes() == 32) { + } else if (lanes == 32) { os << "int"; return; } else { @@ -439,23 +474,23 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } } case 4: { - if (t.is_scalar()) { + if (t.IsScalar()) { os << "int"; return; - } else if (t.lanes() == 4) { + } else if (lanes == 4) { os << "int16_t"; return; - } else if (t.lanes() == 8) { + } else if (lanes == 8) { // directly 8 4-bit int in integer. os << "int"; return; - } else if (t.lanes() == 16) { + } else if (lanes == 16) { os << "int2"; return; - } else if (t.lanes() == 32) { + } else if (lanes == 32) { os << "int4"; return; - } else if (t.lanes() == 64) { + } else if (lanes == 64) { os << "int8"; return; } else { @@ -463,7 +498,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } } case 8: { - if (t.lanes() == 4) { + if (lanes == 4) { // directly 4 8 bit int in integer. codegen_tags_.insert("int8"); @@ -472,15 +507,15 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) // into 32-bit data. os << "int"; return; - } else if (t.lanes() == 8) { + } else if (lanes == 8) { codegen_tags_.insert("int8"); os << "int2"; return; - } else if (t.lanes() == 16) { + } else if (lanes == 16) { codegen_tags_.insert("int8"); os << "int4"; return; - } else if (!t.is_uint() && t.is_scalar()) { + } else if (!t.MatchesCode(DLDataTypeCode::kDLUInt) && t.IsScalar()) { os << "signed char"; break; } else { @@ -489,11 +524,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } } case 16: { - if (t.is_scalar()) { + if (t.IsScalar()) { os << "short"; - } else if (t.lanes() <= 4) { + } else if (lanes <= 4) { os << "short" << lanes; - } else if (t.lanes() <= 8) { + } else if (lanes <= 8) { // Emit CUDA code to access int16 vector elements. // // short4 is stored as int2 @@ -503,9 +538,8 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) // s4.z is emitted as *(short2*)(&(i2.y)).x // s4.w is emitted as *(short2*)(&(i2.y)).y // - TVM_FFI_ICHECK_EQ(t.lanes() % 2, 0) - << "only support even lane for shorT type with lanes > 4"; - os << "int" << t.lanes() / 2; + TVM_FFI_ICHECK_EQ(lanes % 2, 0) << "only support even lane for shorT type with lanes > 4"; + os << "int" << lanes / 2; } else { fail = true; } @@ -515,11 +549,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) break; } case 32: { - if (t.is_scalar()) { + if (t.IsScalar()) { os << "int"; - } else if (t.lanes() <= 4) { - os << "int" << t.lanes(); - } else if (t.lanes() <= 8) { + } else if (lanes <= 4) { + os << "int" << lanes; + } else if (lanes <= 8) { // Emit CUDA code to access int32 vector elements for 4 < lanes <= 8. // // int8 is stored as longlong4 @@ -538,13 +572,13 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) break; } case 64: { - if (t.is_scalar()) { + if (t.IsScalar()) { os << "int64_t"; - } else if (t.lanes() == 2) { + } else if (lanes == 2) { os << "longlong2"; - } else if (t.lanes() == 3) { + } else if (lanes == 3) { os << "longlong3"; - } else if (t.lanes() == 4) { + } else if (lanes == 4) { os << "longlong4"; } return; @@ -561,15 +595,16 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) return; } } - TVM_FFI_THROW(InternalError) << "Cannot convert type " << t << " to CUDA type"; + TVM_FFI_THROW(InternalError) << "Cannot convert type " << ffi::DLDataTypeToString(raw_t) + << " to CUDA type"; } -void CodeGenCUDA::PrintVecConstructor(DataType t, std::ostream& os) { +void CodeGenCUDA::PrintVecConstructor(DLDataType t, std::ostream& os) { os << "make_"; PrintType(t, os); } -void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, +void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DLDataType t, PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) // Declare the result. std::string sret = name_supply_->FreshName("_"); @@ -579,22 +614,22 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l int ssa_scope = BeginScope(); { // Unpack into individual ops. - std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); - std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); + std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.ty()->dtype); + std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.ty()->dtype); - for (int i = 0, lanes = t.lanes(); i < lanes; ++i) { + for (int i = 0, lanes = PrimType(t).lanes(); i < lanes; ++i) { std::ostringstream value_temp; if (isalpha(op[0])) { value_temp << op << "("; - PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); + PrintVecElemLoad(vlhs, lhs.ty()->dtype, i, value_temp); value_temp << ", "; - PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); + PrintVecElemLoad(vrhs, rhs.ty()->dtype, i, value_temp); value_temp << ")"; } else { value_temp << "("; - PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); + PrintVecElemLoad(vlhs, lhs.ty()->dtype, i, value_temp); value_temp << op; - PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); + PrintVecElemLoad(vrhs, rhs.ty()->dtype, i, value_temp); value_temp << ")"; } PrintVecElemStore(sret, t, i, value_temp.str()); @@ -604,55 +639,58 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l os << sret; } -void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, +void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DLDataType t, int i, std::ostream& os) { // NOLINT(*) - if (t.is_scalar()) { + PrimType t_ty(t); + int lanes = t_ty.lanes(); + if (t_ty.IsScalar()) { os << vec; return; } static const char access[] = {'x', 'y', 'z', 'w'}; - TVM_FFI_ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); - if (t.bits() == 8 && (t.is_int() || t.is_uint())) { - std::string type_name = t.is_int() ? "signed char" : "unsigned char"; - if (t.lanes() == 2 || t.lanes() == 3) { - os << vec << "." << access[i % t.lanes()]; + TVM_FFI_ICHECK(i >= 0 && i < (t.bits == 8 ? 16 : (t.bits == 16 || t.bits == 32) ? 8 : 4)); + if (t.bits == 8 && (t_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt))) { + std::string type_name = + t_ty.MatchesCode(DLDataTypeCode::kDLInt) ? "signed char" : "unsigned char"; + if (lanes == 2 || lanes == 3) { + os << vec << "." << access[i % lanes]; } else { - std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); + std::string ac = lanes == 4 ? vec : (vec + "." + access[i / 4]); os << "(reinterpret_cast(&(" << ac << "))[" << (i % 4) << "])"; } - } else if (t.is_float16()) { - if (t.lanes() <= 4) { + } else if (t_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16)) { + if (lanes <= 4) { os << vec << "." << access[i]; } else { os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } - } else if (t.is_bfloat16()) { - if (t.lanes() <= 4) { + } else if (t_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { + if (lanes <= 4) { os << vec << "." << access[i]; } else { os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } - } else if (t.lanes() > 4 && t.lanes() <= 8) { + } else if (lanes > 4 && lanes <= 8) { std::string type_name; - if (t.bits() == 16) { - if (t.is_int()) { + if (t.bits == 16) { + if (t_ty.MatchesCode(DLDataTypeCode::kDLInt)) { type_name = "short"; - } else if (t.is_uint()) { + } else if (t_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { type_name = "ushort"; } - } else if (t.bits() == 32) { - if (t.is_int()) { + } else if (t.bits == 32) { + if (t_ty.MatchesCode(DLDataTypeCode::kDLInt)) { type_name = "int"; - } else if (t.is_uint()) { + } else if (t_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { type_name = "uint"; - } else if (t.is_float()) { + } else if (t_ty.code() == DLDataTypeCode::kDLFloat) { type_name = "float"; } } TVM_FFI_ICHECK(!type_name.empty()); os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; - } else if (t.is_float4_e2m1fn()) { + } else if (t_ty.code() == DLDataTypeCode::kDLFloat4_e2m1fn) { os << "([](__nv_fp4_storage_t v) { __nv_fp4_e2m1 t; t.__x = v; return t; })((" << vec << ".__x >> " << i * 4 << ") & 0xF)"; } else { @@ -660,50 +698,53 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, } } -void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i, +void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DLDataType t, int i, const std::string& value) { + PrimType t_ty(t); + int lanes = t_ty.lanes(); this->PrintIndent(); static const char access[] = {'x', 'y', 'z', 'w'}; - TVM_FFI_ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); - if (t.bits() == 8 && (t.is_int() || t.is_uint())) { - if (t.lanes() == 2 || t.lanes() == 3) { - stream << vec << '.' << access[i % t.lanes()] << "=" + TVM_FFI_ICHECK(i >= 0 && i < (t.bits == 8 ? 16 : (t.bits == 16 || t.bits == 32) ? 8 : 4)); + if (t.bits == 8 && (t_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt))) { + if (lanes == 2 || lanes == 3) { + stream << vec << '.' << access[i % lanes] << "=" << "(" << value << ");\n"; } else { - std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); - std::string type_name = t.is_int() ? "signed char" : "unsigned char"; + std::string ac = lanes == 4 ? vec : (vec + "." + access[i / 4]); + std::string type_name = + t_ty.MatchesCode(DLDataTypeCode::kDLInt) ? "signed char" : "unsigned char"; stream << "reinterpret_cast<" << type_name << "*>(&(" << ac << "))[" << (i % 4) << "] = (" << type_name << ")(" << value << ");\n"; } - } else if (t.is_float16()) { - if (t.lanes() <= 4) { + } else if (t_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16)) { + if (lanes <= 4) { stream << vec << "." << access[i] << " = " << value << ";\n"; } else { stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " << value << ";\n"; } - } else if (t.is_bfloat16()) { - if (t.lanes() <= 4) { + } else if (t_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { + if (lanes <= 4) { stream << vec << "." << access[i] << " = " << value << ";\n"; } else { stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " << value << ";\n"; } - } else if (t.lanes() > 4 && t.lanes() <= 8) { + } else if (lanes > 4 && lanes <= 8) { std::string type_name; - if (t.bits() == 16) { - if (t.is_int()) { + if (t.bits == 16) { + if (t_ty.MatchesCode(DLDataTypeCode::kDLInt)) { type_name = "short"; - } else if (t.is_uint()) { + } else if (t_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { type_name = "ushort"; } - } else if (t.bits() == 32) { - if (t.is_int()) { + } else if (t.bits == 32) { + if (t_ty.MatchesCode(DLDataTypeCode::kDLInt)) { type_name = "int"; - } else if (t.is_uint()) { + } else if (t_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { type_name = "uint"; - } else if (t.is_float()) { + } else if (t_ty.code() == DLDataTypeCode::kDLFloat) { type_name = "float"; } } @@ -766,15 +807,19 @@ void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) } } -std::string CodeGenCUDA::CastFromTo(std::string value, DataType from, DataType target) { +std::string CodeGenCUDA::CastFromTo(std::string value, DLDataType from, DLDataType target) { if (from == target) return value; + PrimType from_ty(from); + PrimType target_ty(target); std::ostringstream os; os << "(("; this->PrintType(target, os); os << ")"; - if (from.is_float16() && (target.is_int() || target.is_uint()) && target.bits() == 8) { + if (from_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16) && + (target_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) && + target.bits == 8) { os << "("; - if (target.is_uint()) { + if (target_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { os << "u"; } os << "int)"; @@ -794,33 +839,22 @@ void CodeGenCUDA::AddUtilFunction(const std::string& func_name, const std::strin } void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { - DataType from_ty = op->value.dtype(); - DataType target_ty = op->dtype; + DLDataType from_dtype = op->value.ty()->dtype; + DLDataType target_dtype = op->ty()->dtype; + PrimType from_ty(from_dtype); + PrimType target_ty(target_dtype); TVM_FFI_ICHECK_EQ(target_ty.lanes(), from_ty.lanes()); // Emit simple C-style type conversion. - if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); - - if (target_ty.code() == DataType::kFloat8_e3m4 || target_ty.code() == DataType::kFloat8_e4m3 || - target_ty.code() == DataType::kFloat8_e4m3b11fnuz || - target_ty.code() == DataType::kFloat8_e4m3fn || - target_ty.code() == DataType::kFloat8_e4m3fnuz || - target_ty.code() == DataType::kFloat8_e5m2 || - target_ty.code() == DataType::kFloat8_e5m2fnuz || - target_ty.code() == DataType::kFloat8_e8m0fnu || - target_ty.code() == DataType::kFloat4_e2m1fn || - - from_ty.code() == DataType::kFloat8_e3m4 || from_ty.code() == DataType::kFloat8_e4m3 || - from_ty.code() == DataType::kFloat8_e4m3b11fnuz || - from_ty.code() == DataType::kFloat8_e4m3fn || from_ty.code() == DataType::kFloat8_e4m3fnuz || - from_ty.code() == DataType::kFloat8_e5m2 || from_ty.code() == DataType::kFloat8_e5m2fnuz || - from_ty.code() == DataType::kFloat8_e8m0fnu || from_ty.code() == DataType::kFloat4_e2m1fn) { + if (from_ty.IsScalar()) return CodeGenC::VisitExpr_(op, os); + + if (IsCUDAPackedFloat(target_ty.code()) || IsCUDAPackedFloat(from_ty.code())) { std::ostringstream val; - if (target_ty.code() == DataType::kBFloat && target_ty.lanes() == 2) { + if (target_ty.code() == DLDataTypeCode::kDLBfloat && target_ty.lanes() == 2) { val << "cast_to_nv_bfloat162(" << PrintExpr(op->value) << ")"; } else { val << "("; - PrintType(target_ty, val); + PrintType(target_dtype, val); val << ")(" << PrintExpr(op->value) << ")"; } os << val.str(); @@ -831,18 +865,18 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { // too compact to read. Emit this as vectorized unary ops. std::string sret = name_supply_->FreshName("_"); this->PrintIndent(); - this->PrintType(target_ty, stream); + this->PrintType(target_dtype, stream); stream << ' ' << sret << ";\n"; { - std::string src = SSAGetID(PrintExpr(op->value), from_ty); + std::string src = SSAGetID(PrintExpr(op->value), from_dtype); for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { std::ostringstream val; val << "("; - PrintType(target_ty.element_of(), val); + PrintType(DLDataType{target_dtype.code, target_dtype.bits, 1}, val); val << ")("; - PrintVecElemLoad(src, from_ty, i, val); + PrintVecElemLoad(src, from_dtype, i, val); val << ")"; - PrintVecElemStore(sret, target_ty, i, val.str()); + PrintVecElemStore(sret, target_dtype, i, val.str()); } } os << sret; @@ -851,8 +885,9 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { void CodeGenCUDA::PrintCallExtern(Type ret_type, ffi::String global_symbol, const ffi::Array& args, bool skip_first_arg, std::ostream& os) { // NOLINT(*) - DataType ret_dtype = GetRuntimeDataType(ret_type); - if (ret_dtype.is_fixed_length_vector()) { + DLDataType ret_dtype = GetRuntimeDataType(ret_type); + PrimType ret_ty(ret_dtype); + if (ret_ty.IsFixedLengthVector()) { // // Emit an unsupported vector call // @@ -881,17 +916,17 @@ void CodeGenCUDA::PrintCallExtern(Type ret_type, ffi::String global_symbol, std::vector sargs; size_t arg_begin = static_cast(skip_first_arg); for (size_t i = arg_begin; i < args.size(); ++i) { - std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype()); + std::string val = SSAGetID(PrintExpr(args[i]), args[i].ty()->dtype); sargs.push_back(std::move(val)); } // Emit a scalar call for each lane. - for (int i = 0; i < ret_dtype.lanes(); ++i) { + for (int i = 0; i < ret_ty.lanes(); ++i) { std::ostringstream scall; scall << global_symbol << "("; for (size_t j = 0; j < sargs.size(); ++j) { if (j > 0) scall << ", "; - PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall); + PrintVecElemLoad(sargs[j], args[arg_begin + j].ty()->dtype, i, scall); } scall << ")"; PrintVecElemStore(sret, ret_dtype, i, scall.str()); @@ -1196,7 +1231,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string local_ptr = this->PrintExpr(op->args[3]); std::string local_offset = this->PrintExpr(op->args[4]); std::string smem_ptr = this->PrintExpr(op->args[5]); - if (trans && op->dtype.bits() == 8) { + if (trans && op->ty()->dtype.bits == 8) { // ldmatrix can't transpose 8-bit elements (it assumes 16-bit), so // synthesize the equivalent manual gather loop. args[6] is the // shared-memory stride for this fallback. @@ -1317,39 +1352,46 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { << guard << ")\n"; stream << ");\n"; } else if (op->op.same_as(builtin::reinterpret())) { - DataType tgt_dtype = op->dtype; - DataType src_dtype = op->args[0]->dtype; + DLDataType tgt_dtype = op->ty()->dtype; + DLDataType src_dtype = op->args[0].ty()->dtype; + PrimType tgt_ty(tgt_dtype); + PrimType src_ty(src_dtype); PrimExpr value = op->args[0]; - if (src_dtype.is_handle() && tgt_dtype.is_scalar() && - (tgt_dtype.is_uint() || tgt_dtype.is_int()) && tgt_dtype.bits() == 64) { + if (src_ty.IsHandle() && tgt_ty.IsScalar() && + tgt_ty.MatchesCode(DLDataTypeCode::kDLUInt, DLDataTypeCode::kDLInt) && + tgt_dtype.bits == 64) { os << "reinterpret_cast<"; this->PrintType(tgt_dtype, os); os << ">(" << PrintExpr(value) << ")"; return; } - if (tgt_dtype.is_handle() && src_dtype.is_scalar() && - (src_dtype.is_uint() || src_dtype.is_int()) && src_dtype.bits() == 64) { + if (tgt_ty.IsHandle() && src_ty.IsScalar() && + src_ty.MatchesCode(DLDataTypeCode::kDLUInt, DLDataTypeCode::kDLInt) && + src_dtype.bits == 64) { os << "reinterpret_cast(" << PrintExpr(value) << ")"; return; } // Handle float4_e2m1fn reinterpret - if (!src_dtype.is_float4_e2m1fn() && !tgt_dtype.is_float4_e2m1fn()) { + if (!IsCUDAFloat4(src_ty.code()) && !IsCUDAFloat4(tgt_ty.code())) { return CodeGenC::VisitExpr_(op, os); } if (src_dtype == tgt_dtype || - tgt_dtype.lanes() * tgt_dtype.bits() == src_dtype.lanes() * src_dtype.bits()) { + tgt_ty.lanes() * tgt_dtype.bits == src_ty.lanes() * src_dtype.bits) { return CodeGenC::VisitExpr_(op, os); } - TVM_FFI_ICHECK_EQ(tgt_dtype.lanes(), src_dtype.lanes()) + TVM_FFI_ICHECK_EQ(tgt_ty.lanes(), src_ty.lanes()) << "E2M1 float4 reinterpret expects source and target to have the same number of lanes. " - << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype; - TVM_FFI_ICHECK_EQ(tgt_dtype.bytes(), src_dtype.bytes()) + << "Source dtype: " << ffi::DLDataTypeToString(src_dtype) + << ", Target dtype: " << ffi::DLDataTypeToString(tgt_dtype); + TVM_FFI_ICHECK_EQ((tgt_ty.lanes() * tgt_dtype.bits + 7) / 8, + (src_ty.lanes() * src_dtype.bits + 7) / 8) << "E2M1 float4 reinterpret expects source and target to have the same number of bytes. " - << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype; + << "Source dtype: " << ffi::DLDataTypeToString(src_dtype) + << ", Target dtype: " << ffi::DLDataTypeToString(tgt_dtype); - int lanes = tgt_dtype.lanes(); + int lanes = tgt_ty.lanes(); int ssa_scope = BeginScope(); if (lanes == 1) { @@ -1360,47 +1402,47 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintType(tgt_dtype, os); os << " *)(&(" << rhs << ")))"; } else if (lanes == 2) { - if (tgt_dtype.is_float4_e2m1fn()) { + if (IsCUDAFloat4(tgt_ty.code())) { // We view the source as an uint16, and then extract bits of two fp4 numbers, // and finally reinterpret the result as fp4x2. - value = tirx::Call(DataType::UInt(16), tirx::builtin::reinterpret(), {value}); - tirx::Var temp_var("temp_var", DataType::UInt(16)); + value = tirx::Call(PrimType::UInt(16), tirx::builtin::reinterpret(), {value}); + tirx::Var temp_var("temp_var", PrimType::UInt(16)); value = tirx::Let(temp_var, value, - tirx::Cast(DataType::UInt(8), - (temp_var & IntImm(DataType::UInt(16), 0xF)) | - ((temp_var >> 4) & IntImm(DataType::UInt(16), 0xF0)))); + tirx::Cast(PrimType::UInt(8), + (temp_var & IntImm(PrimType::UInt(16), 0xF)) | + ((temp_var >> 4) & IntImm(PrimType::UInt(16), 0xF0)))); } else { - value = tirx::Cast(DataType::UInt(16), - tirx::Call(DataType::UInt(8), tirx::builtin::reinterpret(), {value})); - tirx::Var temp_var("temp_var", DataType::UInt(16)); + value = tirx::Cast(PrimType::UInt(16), + tirx::Call(PrimType::UInt(8), tirx::builtin::reinterpret(), {value})); + tirx::Var temp_var("temp_var", PrimType::UInt(16)); value = tirx::Let(temp_var, value, - (temp_var & IntImm(DataType::UInt(16), 0xF)) | - ((temp_var & IntImm(DataType::UInt(16), 0xF0)) << 4)); + (temp_var & IntImm(PrimType::UInt(16), 0xF)) | + ((temp_var & IntImm(PrimType::UInt(16), 0xF0)) << 4)); } - os << PrintExpr(tirx::Call(tgt_dtype, tirx::builtin::reinterpret(), {value})); + os << PrintExpr(tirx::Call(PrimType(tgt_dtype), tirx::builtin::reinterpret(), {value})); } else if (lanes == 4) { - if (tgt_dtype.is_float4_e2m1fn()) { + if (IsCUDAFloat4(tgt_ty.code())) { // We view the source as an uint32, and then extract bits of four fp4 numbers, // and finally reinterpret the result as fp4x4. - value = tirx::Call(DataType::UInt(32), tirx::builtin::reinterpret(), {value}); - tirx::Var temp_var("temp_var", DataType::UInt(32)); + value = tirx::Call(PrimType::UInt(32), tirx::builtin::reinterpret(), {value}); + tirx::Var temp_var("temp_var", PrimType::UInt(32)); value = tirx::Let(temp_var, value, - tirx::Cast(DataType::UInt(16), - (temp_var & IntImm(DataType::UInt(32), 0xF)) | - ((temp_var >> 4) & IntImm(DataType::UInt(32), 0xF0)) | - ((temp_var >> 8) & IntImm(DataType::UInt(32), 0xF00)) | - ((temp_var >> 12) & IntImm(DataType::UInt(32), 0xF000)))); + tirx::Cast(PrimType::UInt(16), + (temp_var & IntImm(PrimType::UInt(32), 0xF)) | + ((temp_var >> 4) & IntImm(PrimType::UInt(32), 0xF0)) | + ((temp_var >> 8) & IntImm(PrimType::UInt(32), 0xF00)) | + ((temp_var >> 12) & IntImm(PrimType::UInt(32), 0xF000)))); } else { - value = tirx::Cast(DataType::UInt(32), - tirx::Call(DataType::UInt(16), tirx::builtin::reinterpret(), {value})); - tirx::Var temp_var("temp_var", DataType::UInt(32)); + value = tirx::Cast(PrimType::UInt(32), + tirx::Call(PrimType::UInt(16), tirx::builtin::reinterpret(), {value})); + tirx::Var temp_var("temp_var", PrimType::UInt(32)); value = tirx::Let(temp_var, value, - (temp_var & IntImm(DataType::UInt(32), 0xF)) | - ((temp_var & IntImm(DataType::UInt(32), 0xF0)) << 4) | - ((temp_var & IntImm(DataType::UInt(32), 0xF00)) << 8) | - ((temp_var & IntImm(DataType::UInt(32), 0xF000)) << 12)); + (temp_var & IntImm(PrimType::UInt(32), 0xF)) | + ((temp_var & IntImm(PrimType::UInt(32), 0xF0)) << 4) | + ((temp_var & IntImm(PrimType::UInt(32), 0xF00)) << 8) | + ((temp_var & IntImm(PrimType::UInt(32), 0xF000)) << 12)); } - os << PrintExpr(tirx::Call(tgt_dtype, tirx::builtin::reinterpret(), {value})); + os << PrintExpr(tirx::Call(PrimType(tgt_dtype), tirx::builtin::reinterpret(), {value})); } else { TVM_FFI_THROW(InternalError) << "Invalid number of lanes for float4_e2m1fn reinterpret: " << lanes; @@ -1411,7 +1453,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { const PrimExpr& arg = op->args[0]; const auto* var_node = arg.as(); - DataType dtype = op->dtype; + DLDataType dtype = op->ty()->dtype; + PrimType dtype_ty(dtype); bool is_string = op->args[2].as()->value; bool is_scalar = op->args[3].as()->value; int num_dims = op->args[4].as()->value; @@ -1432,22 +1475,23 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { if (is_scalar) { // Scalar printing logic std::string format_specifier; - bool is_float16 = dtype.is_float() && dtype.bits() == 16; - if (dtype.is_float()) + bool is_float16 = dtype_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16); + if (dtype_ty.code() == DLDataTypeCode::kDLFloat) format_specifier = "%f"; - else if (dtype.is_int()) + else if (dtype_ty.MatchesCode(DLDataTypeCode::kDLInt)) format_specifier = "%d"; - else if (dtype.is_uint()) + else if (dtype_ty.MatchesCode(DLDataTypeCode::kDLUInt)) format_specifier = "%u"; else - TVM_FFI_THROW(InternalError) << "Unsupported data type for scalar print: " << dtype; + TVM_FFI_THROW(InternalError) + << "Unsupported data type for scalar print: " << ffi::DLDataTypeToString(dtype); std::string print_arg = var_node ? ("*" + GetVarID(var_node)) : PrintExpr(arg); os << "// print_buffer starts (scalar)\n" << "if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {\n" - << " printf(\"Scalar (dtype: " << dtype << "): " << format_specifier << "\\n\\n\", " - << (is_float16 ? "static_cast(" : "") << print_arg << (is_float16 ? ")" : "") - << ");\n" + << " printf(\"Scalar (dtype: " << ffi::DLDataTypeToString(dtype) + << "): " << format_specifier << "\\n\\n\", " << (is_float16 ? "static_cast(" : "") + << print_arg << (is_float16 ? ")" : "") << ");\n" << "}\n" << "// print_buffer ends\n"; return; @@ -1460,19 +1504,20 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string format_specifier; bool is_float16 = false; - if (dtype.is_float()) { - if (dtype.bits() == 16) { + if (dtype_ty.code() == DLDataTypeCode::kDLFloat) { + if (dtype.bits == 16) { format_specifier = "%f"; is_float16 = true; } else { format_specifier = "%f"; } - } else if (dtype.is_int()) { + } else if (dtype_ty.MatchesCode(DLDataTypeCode::kDLInt)) { format_specifier = "%d"; - } else if (dtype.is_uint()) { + } else if (dtype_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { format_specifier = "%u"; } else { - TVM_FFI_THROW(InternalError) << "Unsupported data type for print: " << dtype; + TVM_FFI_THROW(InternalError) + << "Unsupported data type for print: " << ffi::DLDataTypeToString(dtype); } TVM_FFI_ICHECK(var_node) << "Formatted print is only supported for buffer variables."; @@ -1485,7 +1530,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { for (int i = 0; i < num_dims; ++i) { os << PrintExpr(shape[i]) << (i < num_dims - 1 ? "," : ""); } - os << "), dtype=" << dtype << "):\\n\");\n"; + os << "), dtype=" << ffi::DLDataTypeToString(dtype) << "):\\n\");\n"; std::vector loop_vars; for (int i = 0; i < num_dims; ++i) { @@ -1572,7 +1617,7 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { << "For CUDA, the index of an async queue must be 0."; this->VisitStmt(op->body); static const Op& ptx_cp_async_commit_group_op = Op::Get("tirx.ptx.cp_async_commit_group"); - auto commit_group = Call(DataType::Void(), ptx_cp_async_commit_group_op, {}); + auto commit_group = Call(PrimType::Void(), ptx_cp_async_commit_group_op, {}); this->PrintIndent(); this->VisitExpr(commit_group, this->stream); this->stream << ";\n"; @@ -1584,7 +1629,7 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { << "For CUDA, the index of an async queue must be 0."; auto wait_cnt = wait_attrs.second; static const Op& ptx_cp_async_wait_group_op = Op::Get("tirx.ptx.cp_async_wait_group"); - auto wait_group = Call(DataType::Void(), ptx_cp_async_wait_group_op, {wait_cnt}); + auto wait_group = Call(PrimType::Void(), ptx_cp_async_wait_group_op, {wait_cnt}); this->PrintIndent(); this->VisitExpr(wait_group, this->stream); this->stream << ";\n"; @@ -1614,19 +1659,23 @@ void CodeGenCUDA::VisitStmt_(const AllocBufferNode* op) { this->PrintIndent(); std::string scope = GetPtrStorageScope(op->buffer->data); const VarNode* buffer = op->buffer->data.as(); - DataType dtype = op->buffer->dtype; + DLDataType dtype = op->buffer->dtype->dtype; if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { - TVM_FFI_ICHECK(dtype == DataType::Float(16) || dtype == DataType::Int(8) || - dtype == DataType::UInt(8) || dtype == DataType::Int(4) || - dtype == DataType::UInt(4) || dtype == DataType::Int(1) || - dtype == DataType::BFloat(16)) + bool supported_wmma_input_dtype = + dtype == DLDataType{kDLFloat, 16, 1} || dtype == DLDataType{kDLInt, 8, 1} || + dtype == DLDataType{kDLUInt, 8, 1} || dtype == DLDataType{kDLInt, 4, 1} || + dtype == DLDataType{kDLUInt, 4, 1} || dtype == DLDataType{kDLInt, 1, 1} || + dtype == DLDataType{kDLBfloat, 16, 1}; + TVM_FFI_ICHECK(supported_wmma_input_dtype) << "Matrix_a and matrix_b only support half or char or unsigned char " << "or uint4 or int4 or int1 type for now"; } else { - TVM_FFI_ICHECK(dtype == DataType::Float(16) || dtype == DataType::Float(32) || - dtype == DataType::Int(32)) + bool supported_wmma_accumulator_dtype = dtype == DLDataType{kDLFloat, 16, 1} || + dtype == DLDataType{kDLFloat, 32, 1} || + dtype == DLDataType{kDLInt, 32, 1}; + TVM_FFI_ICHECK(supported_wmma_accumulator_dtype) << "Accumulator only support half, float and int type for now"; } PrintWmmaScope(scope, dtype, buffer, stream); @@ -1662,9 +1711,11 @@ void CodeGenCUDA::VisitStmt_(const AllocBufferNode* op) { if (scope.find("wmma.") == 0) { constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); } - if ((dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) && - scope == "shared") { - constant_size = constant_size / (32 / dtype.bits()); + bool is_packed_integer_dtype = dtype == DLDataType{kDLInt, 4, 1} || + dtype == DLDataType{kDLUInt, 4, 1} || + dtype == DLDataType{kDLInt, 1, 1}; + if (is_packed_integer_dtype && scope == "shared") { + constant_size = constant_size / (32 / dtype.bits); } stream << ' ' << vid << '[' << constant_size << "];\n"; } @@ -1693,9 +1744,10 @@ void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) { } void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { - int lanes = op->dtype.lanes(); + PrimType op_ty = op->ty(); + int lanes = op_ty.lanes(); if (lanes <= 4) { - PrintVecConstructor(op->dtype, os); + PrintVecConstructor(op->ty()->dtype, os); os << "("; for (int i = 0; i < lanes; i++) { os << "(" << PrintExpr(op->base) << ")" @@ -1710,16 +1762,16 @@ void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { // constructor argument layout does not match TIR vector lane layout. std::string sret = name_supply_->FreshName("_"); this->PrintIndent(); - this->PrintType(op->dtype, stream); + this->PrintType(op->ty()->dtype, stream); stream << ' ' << sret << ";\n"; int ssa_scope = BeginScope(); { - std::string vbase = SSAGetID(PrintExpr(op->base), op->base.dtype()); - std::string vstride = SSAGetID(PrintExpr(op->stride), op->stride.dtype()); + std::string vbase = SSAGetID(PrintExpr(op->base), op->base.ty()->dtype); + std::string vstride = SSAGetID(PrintExpr(op->stride), op->stride.ty()->dtype); for (int i = 0; i < lanes; ++i) { std::ostringstream value_temp; value_temp << "(" << vbase << ")+(" << vstride << "*" << i << ")"; - PrintVecElemStore(sret, op->dtype, i, value_temp.str()); + PrintVecElemStore(sret, op->ty()->dtype, i, value_temp.str()); } } EndScope(ssa_scope); @@ -1727,14 +1779,16 @@ void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { } void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) - int lanes = op->dtype.lanes(); - if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && lanes == 4) { + PrimType op_ty = op->ty(); + int lanes = op_ty.lanes(); + if ((op_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) && op_ty.bits() == 8 && + lanes == 4) { // make_int8x4 const int64_t* p = as_const_int(op->value); TVM_FFI_ICHECK(p); int64_t v = *p & 0xFF; v = (v << 24) | (v << 16) | (v << 8) | v; - if (op->dtype.is_uint()) { + if (op_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { os << "(uint)" << v; } else { os << "(int)" << v; @@ -1742,9 +1796,9 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO return; } - if (op->dtype.is_float16()) { + if (op_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16)) { std::string v = PrintExpr(op->value); - PrintVecConstructor(op->dtype, os); + PrintVecConstructor(op->ty()->dtype, os); os << '('; if (lanes <= 4) { for (int i = 0; i < lanes / 2; ++i) { @@ -1761,9 +1815,9 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO return; } - if (op->dtype.is_bfloat16()) { + if (op_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { std::string v = PrintExpr(op->value); - PrintVecConstructor(op->dtype, os); + PrintVecConstructor(op->ty()->dtype, os); os << '('; if (lanes > 4) { for (int i = 0; i < lanes / 2; ++i) { @@ -1780,12 +1834,11 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO return; } - if (op->dtype.is_float8() || op->dtype.is_float4()) { - int lanes = op->dtype.lanes(); + if (IsCUDAFloat8(op_ty.code()) || IsCUDAFloat4(op_ty.code())) { TVM_FFI_ICHECK(lanes == 1 || lanes == 2 || lanes == 4); std::string v = PrintExpr(op->value); // Implicit conversion from float back to fp8 - PrintType(op->dtype, os); + PrintType(op->ty()->dtype, os); os << "(make_float" << lanes << "("; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; @@ -1795,7 +1848,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO return; } - if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) { + if ((op_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) && op_ty.bits() == 4) { bool fail = false; const int64_t* p = as_const_int(op->value); TVM_FFI_ICHECK(p); @@ -1803,7 +1856,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO if (lanes == 4) { v = (v << 12) | (v << 8) | (v << 4) | v; - if (op->dtype.is_uint()) { + if (op_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { os << "(uint16_t)" << v; } else { os << "(int16_t)" << v; @@ -1811,17 +1864,17 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO } else { v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v; if (lanes == 8) { - if (op->dtype.is_uint()) { + if (op_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { os << "(uint)" << v; } else { os << "(int)" << v; } } else if (lanes == 16 || lanes == 32) { - PrintVecConstructor(op->dtype, os); + PrintVecConstructor(op->ty()->dtype, os); os << '('; for (int i = 0; i < lanes / 8; ++i) { if (i != 0) os << ", "; - if (op->dtype.is_uint()) { + if (op_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { os << "(uint)" << v; } else { os << "(int)" << v; @@ -1839,7 +1892,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO } std::string v = PrintExpr(op->value); - PrintVecConstructor(op->dtype, os); + PrintVecConstructor(op->ty()->dtype, os); os << '('; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; @@ -1849,47 +1902,49 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO } void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) { + PrimType op_ty = op->ty(); // Non-vector cases. - if (!op->dtype.is_fixed_length_vector()) { + if (!op_ty.IsFixedLengthVector()) { CodeGenC::VisitExpr_(op, os); return; } // Codegen vector condition case by serializing the select op. - TVM_FFI_ICHECK(op->false_value->dtype == op->dtype && op->true_value->dtype == op->dtype && - op->dtype.lanes() == op->condition.dtype().lanes()); + TVM_FFI_ICHECK(op->false_value.ty() == op_ty && op->true_value.ty() == op_ty && + op_ty.lanes() == op->condition.ty().lanes()); std::string r_var = name_supply_->FreshName("_"); this->PrintIndent(); - this->PrintType(op->dtype, stream); + this->PrintType(op->ty()->dtype, stream); stream << ' ' << r_var << ";\n"; { - std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype); - std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype); - std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype); + std::string c_var = SSAGetID(PrintExpr(op->condition), op->ty()->dtype); + std::string t_var = SSAGetID(PrintExpr(op->true_value), op->ty()->dtype); + std::string f_var = SSAGetID(PrintExpr(op->false_value), op->ty()->dtype); // The condition is stored as an ushort vector. - int lanes = op->dtype.lanes(); - DataType memory_ty(DataType::TypeCode::kUInt, 16, lanes); + int lanes = op_ty.lanes(); + DLDataType memory_dtype{kDLUInt, 16, static_cast(lanes)}; for (int i = 0; i < lanes; ++i) { std::ostringstream item; item << "(bool("; - PrintVecElemLoad(c_var, memory_ty, i, item); + PrintVecElemLoad(c_var, memory_dtype, i, item); item << ")?"; - PrintVecElemLoad(t_var, op->dtype, i, item); + PrintVecElemLoad(t_var, op->ty()->dtype, i, item); item << ':'; - PrintVecElemLoad(f_var, op->dtype, i, item); + PrintVecElemLoad(f_var, op->ty()->dtype, i, item); item << ')'; - PrintVecElemStore(r_var, op->dtype, i, item.str()); + PrintVecElemStore(r_var, op->ty()->dtype, i, item.str()); } } os << r_var; } inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*) + PrimType op_ty = op->ty(); // Type code is kBFloat - if (op->dtype.is_bfloat16()) { + if (op_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { os << "__float2bfloat16_rn"; os << '(' << std::hexfloat << op->value << 'f'; os << "/*" << std::scientific << op->value << "*/"; @@ -1897,15 +1952,15 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) return; } // Type code is kFloat8_e5m2 or kE4M4Float - if (op->dtype.is_float8() || op->dtype.is_float4()) { - p->PrintType(op->dtype, os); + if (IsCUDAFloat8(op_ty.code()) || IsCUDAFloat4(op_ty.code())) { + p->PrintType(op->ty()->dtype, os); os << '(' << std::hexfloat << op->value << 'f'; os << "/*" << std::scientific << op->value << "*/"; os << ')'; return; } // Type code is kFloat - switch (op->dtype.bits()) { + switch (op_ty.bits()) { case 64: { std::ostringstream temp; if (std::isinf(op->value)) { @@ -1945,13 +2000,14 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) } case 16: { os << "__float2half_rn" << '('; - FloatImm const_f32 = FloatImm(DataType::Float(32), op->value); + FloatImm const_f32 = FloatImm(PrimType::Float(32), op->value); PrintConst(const_f32.get(), os, p); os << ')'; break; } default: - TVM_FFI_THROW(InternalError) << "Bad bit-width for float: " << op->dtype << "\n"; + TVM_FFI_THROW(InternalError) + << "Bad bit-width for float: " << ffi::DLDataTypeToString(op->ty()->dtype) << "\n"; } } @@ -1959,25 +2015,27 @@ void CodeGenCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOL PrintConst(op, os, this); } -void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, +void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DLDataType t, const VarNode* variable, std::ostream& os) { + PrimType t_ty(t); std::stringstream type; PrintType(t, type); TVM_FFI_ICHECK(fragment_shapes.count(variable)) << "Cannot find shape of the wmma fragment " << variable->name_hint; std::string shape_str = fragment_shapes.at(variable); - if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) { + if ((t_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) && t.bits < 8 && + t_ty.lanes() == 1) { type.str(std::string()); - if (t.is_int()) { - if (t.bits() == 4) { + if (t_ty.MatchesCode(DLDataTypeCode::kDLInt)) { + if (t.bits == 4) { type << "nvcuda::wmma::experimental::precision::s4"; - } else if (t.bits() == 1) { + } else if (t.bits == 1) { type << "nvcuda::wmma::experimental::precision::b1"; } else { TVM_FFI_THROW(InternalError) << "Unhandled interger type for wmma fragment!"; } - } else if (t.is_uint()) { - if (t.bits() == 4) { + } else if (t_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { + if (t.bits == 4) { type << "nvcuda::wmma::experimental::precision::u4"; } else { TVM_FFI_THROW(InternalError) << "Unhandled interger type for wmma fragment!"; @@ -2029,20 +2087,25 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const BufferLoad // Cast away volatile qualifier for fp16 types. That is, only loads and // stores are volatile. The loaded objects are not marked as volatile. // - if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) { + PrimType op_ty = op->ty(); + if ((op_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16) || + op_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) && + IsVolatile(op->buffer->data.get())) { os << "("; - PrintType(op->dtype, os); + PrintType(op->ty()->dtype, os); os << ")(" << value << ")"; } else { os << value; } } -void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, +void CodeGenCUDA::PrintVecElemLoadExpr(DLDataType t, int i, const std::string& value, std::ostream& os) { - TVM_FFI_ICHECK_GT(t.lanes(), 1); - if (t.bits() == 8 && (t.is_int() || t.is_uint())) { - if (!(t.lanes() == 2 || t.lanes() == 3)) { + PrimType t_ty(t); + int lanes = t_ty.lanes(); + TVM_FFI_ICHECK_GT(lanes, 1); + if (t.bits == 8 && (t_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt))) { + if (!(lanes == 2 || lanes == 3)) { if (i != 0) { os << "|"; } @@ -2051,12 +2114,12 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val } } - if (t.is_float16()) { + if (t_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16)) { if (i == 0) { PrintVecConstructor(t, os); os << '('; } - if (i == t.lanes() - 1) { + if (i == lanes - 1) { os << value << ")"; } else { os << value << ","; @@ -2064,12 +2127,12 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val return; } - if (t.is_bfloat16()) { + if (t_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { if (i == 0) { PrintVecConstructor(t, os); os << '('; } - if (i == t.lanes() - 1) { + if (i == lanes - 1) { os << value << ")"; } else { os << value << ","; @@ -2082,7 +2145,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val os << "("; } os << value; - if (i != t.lanes() - 1) { + if (i != lanes - 1) { os << ","; } else { os << ")"; diff --git a/src/backend/cuda/codegen/codegen_cuda.h b/src/backend/cuda/codegen/codegen_cuda.h index 92ca3cab34a4..94f86614e45e 100644 --- a/src/backend/cuda/codegen/codegen_cuda.h +++ b/src/backend/cuda/codegen/codegen_cuda.h @@ -56,16 +56,17 @@ class CodeGenCUDA final : public CodeGenC { void VisitStmt_(const WhileNode* op) final; void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) - void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, - std::ostream& os) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - void PrintVecConstructor(DataType t, std::ostream& os) final; - void PrintVecElemLoad(const std::string& vec, DataType t, int i, + void PrintVecBinaryOp(const std::string& op, DLDataType t, PrimExpr lhs, PrimExpr rhs, + std::ostream& os) final; // NOLINT(*) + void PrintType(DLDataType t, std::ostream& os) final; // NOLINT(*) + void PrintVecConstructor(DLDataType t, std::ostream& os) final; + void PrintVecElemLoad(const std::string& vec, DLDataType t, int i, std::ostream& os) final; // NOLINT(*) - void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; + void PrintVecElemStore(const std::string& vec, DLDataType t, int i, + const std::string& value) final; void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) - void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final; - std::string CastFromTo(std::string value, DataType from, DataType target) final; + void PrintVecElemLoadExpr(DLDataType t, int i, const std::string& value, std::ostream& os) final; + std::string CastFromTo(std::string value, DLDataType from, DLDataType target) final; void AddUtilFunction(const std::string& name, const std::string& code); // overload visitor void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) @@ -129,7 +130,7 @@ class CodeGenCUDA final : public CodeGenC { std::unordered_map fragment_shapes; std::unordered_map fragment_layouts; friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p); - void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, + void PrintWmmaScope(const std::string& scope, DLDataType t, const VarNode* variable, std::ostream& os); int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size); }; diff --git a/src/backend/cuda/codegen/intrin_rule_cuda.cc b/src/backend/cuda/codegen/intrin_rule_cuda.cc index dc8d4a020e1e..ea2d0abfa80e 100644 --- a/src/backend/cuda/codegen/intrin_rule_cuda.cc +++ b/src/backend/cuda/codegen/intrin_rule_cuda.cc @@ -34,8 +34,8 @@ namespace intrin { using tirx::FLowerIntrinsic; struct CUDAMath { - std::string operator()(DataType t, std::string name) const { - if (t.is_float()) { + std::string operator()(PrimType t, std::string name) const { + if (t.code() == DLDataTypeCode::kDLFloat) { switch (t.bits()) { case 64: // Use nearbyint (ties-to-even) for round to match constant-folding semantics. @@ -56,7 +56,7 @@ struct CUDAMath { default: return ""; } - } else if (t.is_bfloat16()) { + } else if (t.code() == DLDataTypeCode::kDLBfloat && t.bits() == 16) { if (name == "fabs") { return "__habs"; } else if (name == "round") { @@ -64,7 +64,7 @@ struct CUDAMath { } else { return "h" + name; } - } else if (t.is_int() || t.is_uint()) { + } else if (t.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) { switch (t.bits()) { case 32: return "__" + name; @@ -79,8 +79,8 @@ struct CUDAMath { }; struct CUDAFastMath : public CUDAMath { - std::string operator()(DataType t, std::string name) const { - if (t.is_float() && t.bits() == 32) { + std::string operator()(PrimType t, std::string name) const { + if (t.code() == DLDataTypeCode::kDLFloat && t.bits() == 32) { return "__" + name + 'f'; } else { return CUDAMath::operator()(t, name); @@ -90,8 +90,8 @@ struct CUDAFastMath : public CUDAMath { }; struct CUDAFastMathTan : public CUDAMath { - std::string operator()(DataType t, std::string name) const { - if (t.is_float()) { + std::string operator()(PrimType t, std::string name) const { + if (t.code() == DLDataTypeCode::kDLFloat) { switch (t.bits()) { case 64: return name; @@ -110,8 +110,8 @@ struct CUDAFastMathTan : public CUDAMath { }; struct CUDAPopcount { - std::string operator()(DataType t, std::string name) const { - if (t.is_uint()) { + std::string operator()(PrimType t, std::string name) const { + if (t.MatchesCode(DLDataTypeCode::kDLUInt)) { switch (t.bits()) { case 32: return "__popc"; @@ -126,7 +126,7 @@ struct CUDAPopcount { }; struct CUDAWarpIntrinsic { - const Op operator()(DataType t, const Op& orig_op) const { + const Op operator()(PrimType t, const Op& orig_op) const { if (orig_op.same_as(builtin::tvm_warp_shuffle())) { static const Op& cuda_shfl_sync_op = Op::Get("tirx.cuda.__shfl_sync"); return cuda_shfl_sync_op; @@ -147,7 +147,7 @@ struct CUDAWarpIntrinsic { static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr& e) { const CallNode* call = e.as(); static const Op& cuda_active_mask_op = Op::Get("tirx.cuda.__activemask"); - return Call(call->dtype, cuda_active_mask_op, call->args); + return Call(e.ty(), cuda_active_mask_op, call->args); } template @@ -156,7 +156,7 @@ static PrimExpr DispatchCUDAShuffle(const PrimExpr& e) { TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size ffi::Array cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}}; - return Call(call->dtype, T()(call->dtype, call->op.as_or_throw()), cuda_args); + return Call(e.ty(), T()(e.ty(), call->op.as_or_throw()), cuda_args); } void RegisterCudaIntrinRules() { diff --git a/src/backend/cuda/codegen/llvm/codegen_nvptx.cc b/src/backend/cuda/codegen/llvm/codegen_nvptx.cc index e523e2b22aab..eb84f10fda10 100644 --- a/src/backend/cuda/codegen/llvm/codegen_nvptx.cc +++ b/src/backend/cuda/codegen/llvm/codegen_nvptx.cc @@ -87,7 +87,7 @@ class CodeGenNVPTX : public CodeGenLLVM { } auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data)); - DataType dtype = op->buffer->dtype; + PrimType dtype = op->buffer->dtype; if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { // Shared memory: address space == 3 @@ -230,7 +230,8 @@ class CodeGenNVPTX : public CodeGenLLVM { // corresponding nvvm intrinsic. Return true if the match is successful. static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) { // Only 32 bit data type is supported. - if (op->dtype.is_fixed_length_vector() || op->dtype.bits() != 32) { + PrimType op_ty = op->ty(); + if (op_ty.IsFixedLengthVector() || op_ty.bits() != 32) { return false; } @@ -253,7 +254,7 @@ static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) return false; } - *id = ids[offset + op->dtype.is_float()]; + *id = ids[offset + (op_ty.code() == DLDataTypeCode::kDLFloat)]; return true; } @@ -279,10 +280,11 @@ llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) { auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true); return builder_->CreateCall(val); } else if (op->op.same_as(builtin::atomic_add())) { - TVM_FFI_ICHECK(op->args[1]->dtype.bits() == 32) << "Only supports 32 bit atomic for now"; + PrimType value_ty = op->args[1].ty(); + TVM_FFI_ICHECK(value_ty.bits() == 32) << "Only supports 32 bit atomic for now"; llvm::Value* v0 = MakeValue(op->args[0]); llvm::Value* v1 = MakeValue(op->args[1]); - if (op->args[1]->dtype.is_float()) { + if (value_ty.code() == DLDataTypeCode::kDLFloat) { return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1, llvm::MaybeAlign(), llvm::AtomicOrdering::Monotonic); } diff --git a/src/backend/cuda/codegen/llvm/intrin_rule_nvptx.cc b/src/backend/cuda/codegen/llvm/intrin_rule_nvptx.cc index d8706a94b181..13d6f7d95a3b 100644 --- a/src/backend/cuda/codegen/llvm/intrin_rule_nvptx.cc +++ b/src/backend/cuda/codegen/llvm/intrin_rule_nvptx.cc @@ -38,7 +38,8 @@ inline PrimExpr DispatchPureExternLibDevice(const PrimExpr& e) { using namespace tirx; const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); - TVM_FFI_ICHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) + PrimType call_ty = call->ty(); + TVM_FFI_ICHECK(call_ty.bits() == 32 || call_ty.bits() == 64) << "Only support float32 or float64."; const OpNode* op = call->op.as(); @@ -48,13 +49,13 @@ inline PrimExpr DispatchPureExternLibDevice(const PrimExpr& e) { std::ostringstream intrinsic_name; intrinsic_name << "__nv_" << name.substr(5); - if (call->dtype.bits() == 32) intrinsic_name << "f"; + if (call_ty.bits() == 32) intrinsic_name << "f"; ffi::Array new_args = {StringImm(intrinsic_name.str())}; for (auto arg : call->args) { new_args.push_back(arg); } - return Call(call->dtype, builtin::call_pure_extern(), new_args); + return Call(call->ty(), builtin::call_pure_extern(), new_args); } namespace llvm { @@ -73,7 +74,7 @@ TVM_REGISTER_OP("tirx.round") const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); static const Op& nearbyint_op = Op::Get("tirx.nearbyint"); - auto new_call = Call(call->dtype, nearbyint_op, call->args); + auto new_call = Call(call->ty(), nearbyint_op, call->args); return DispatchPureExternLibDevice(new_call); }); diff --git a/src/backend/cuda/runtime/cuda_device_api.cc b/src/backend/cuda/runtime/cuda_device_api.cc index 68ae39de56bf..44d1acff4937 100644 --- a/src/backend/cuda/runtime/cuda_device_api.cc +++ b/src/backend/cuda/runtime/cuda_device_api.cc @@ -426,7 +426,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_ICHECK_GE(args.size(), 4) << "init_cuTensorMap expects at least 4 arguments"; size_t arg_cnt = 0; CUtensorMap* tensor_map = static_cast(args[arg_cnt++].cast()); - runtime::DataType tensor_dtype = args[arg_cnt++].cast(); + DLDataType tensor_dtype = args[arg_cnt++].cast(); int32_t raw_tensor_rank = args[arg_cnt++].cast(); TVM_FFI_ICHECK_GT(raw_tensor_rank, 0) << "tensorRank must be non-zero"; TVM_FFI_ICHECK_LE(raw_tensor_rank, 5) @@ -478,13 +478,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { auto l2_promotion_kind = static_cast(args[arg_cnt++].cast()); auto oob_fill_kind = static_cast(args[arg_cnt++].cast()); - TVM_FFI_ICHECK_EQ(tensor_dtype.lanes(), 1) + TVM_FFI_ICHECK_EQ(tensor_dtype.lanes, 1) << "Expect tensor_dtype to have lanes=1, but get " << tensor_dtype; + uint64_t tensor_dtype_bytes = (static_cast(tensor_dtype.bits) + 7) / 8; CUtensorMapDataType cu_dtype; - switch (tensor_dtype.code()) { - case DataType::kInt: + switch (tensor_dtype.code) { + case kDLInt: // int - switch (tensor_dtype.bits()) { + switch (tensor_dtype.bits) { case 8: cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; break; @@ -499,9 +500,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { << "Unsupported data type " << ffi::DLDataTypeToString(tensor_dtype); } break; - case DataType::kUInt: + case kDLUInt: // unsigned int - switch (tensor_dtype.bits()) { + switch (tensor_dtype.bits) { case 8: cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; break; @@ -519,9 +520,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { << "Unsupported data type " << ffi::DLDataTypeToString(tensor_dtype); } break; - case DataType::kFloat: + case kDLFloat: // float - switch (tensor_dtype.bits()) { + switch (tensor_dtype.bits) { case 16: cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; break; @@ -536,9 +537,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { << "Unsupported data type " << ffi::DLDataTypeToString(tensor_dtype); } break; - case DataType::kBFloat: + case kDLBfloat: // bfloat - switch (tensor_dtype.bits()) { + switch (tensor_dtype.bits) { case 16: cu_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; break; @@ -547,15 +548,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { << "Unsupported data type " << ffi::DLDataTypeToString(tensor_dtype); } break; - case DataType::kFloat8_e4m3fn: + case kDLFloat8_e4m3fn: // NV float8 e4m3 cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; break; - case DataType::kFloat8_e5m2: + case kDLFloat8_e5m2: // NV float8 e5m2 cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; break; - case DataType::kFloat4_e2m1fn: + case kDLFloat4_e2m1fn: #if (CUDA_VERSION >= 12080) // Packed FP4 in GMEM, unpacked into SMEM/TMEM-facing tiles. cu_dtype = CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B; @@ -674,7 +675,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { << "globalDim[0] must be a multiple of 2 for packed 16U4 align8 format"; } if (interleaved_kind == CU_TENSOR_MAP_INTERLEAVE_NONE && !is_packed_dtype) { - uint64_t inner_box_bytes = static_cast(box_dim[0]) * tensor_dtype.bytes(); + uint64_t inner_box_bytes = static_cast(box_dim[0]) * tensor_dtype_bytes; TVM_FFI_ICHECK_EQ(inner_box_bytes % 16, 0) << "boxDim[0] * elementSizeInBytes(tensorDataType) must be a multiple of 16 bytes"; } @@ -694,15 +695,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { if (interleaved_kind == CU_TENSOR_MAP_INTERLEAVE_NONE && !is_packed_dtype && swizzle_kind == CU_TENSOR_MAP_SWIZZLE_32B) { - TVM_FFI_ICHECK_LE(box_dim[0] * tensor_dtype.bytes(), 32) + TVM_FFI_ICHECK_LE(box_dim[0] * tensor_dtype_bytes, 32) << "CU_TENSOR_MAP_SWIZZLE_32B implies the bounding box inner dimension will be <= 32."; } else if (interleaved_kind == CU_TENSOR_MAP_INTERLEAVE_NONE && !is_packed_dtype && swizzle_kind == CU_TENSOR_MAP_SWIZZLE_64B) { - TVM_FFI_ICHECK_LE(box_dim[0] * tensor_dtype.bytes(), 64) + TVM_FFI_ICHECK_LE(box_dim[0] * tensor_dtype_bytes, 64) << "CU_TENSOR_MAP_SWIZZLE_64B implies the bounding box inner dimension will be <= 64."; } else if (interleaved_kind == CU_TENSOR_MAP_INTERLEAVE_NONE && !is_packed_dtype && is_128b_swizzle) { - TVM_FFI_ICHECK_LE(box_dim[0] * tensor_dtype.bytes(), 128) + TVM_FFI_ICHECK_LE(box_dim[0] * tensor_dtype_bytes, 128) << "CU_TENSOR_MAP_SWIZZLE_128B implies the bounding box inner dimension will be <= " "128."; } diff --git a/src/backend/hexagon/codegen/llvm/codegen_hexagon.cc b/src/backend/hexagon/codegen/llvm/codegen_hexagon.cc index 017796918444..17aba2d3fc40 100644 --- a/src/backend/hexagon/codegen/llvm/codegen_hexagon.cc +++ b/src/backend/hexagon/codegen/llvm/codegen_hexagon.cc @@ -66,6 +66,11 @@ namespace tvm { namespace codegen { +TVM_FFI_INLINE int GetVectorBytes(const PrimType& dtype) { + TVM_FFI_ICHECK(dtype.IsFixedLengthVector() || dtype.IsScalar()); + return static_cast(dtype.StorageBytes()); +} + // Hexagon code generation class CodeGenHexagon final : public CodeGenCPU { public: @@ -97,12 +102,12 @@ class CodeGenHexagon final : public CodeGenCPU { void CreatePrintf(const std::string& format, llvm::ArrayRef format_args) final; private: - TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype, - llvm::ArrayRef indices, DataType value_dtype) final; + TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, PrimType buffer_element_dtype, + llvm::ArrayRef indices, PrimType value_dtype) final; bool IsQHLFunction(const std::string& func); - llvm::Value* VectorLookupLoad(Buffer buffer, DataType buffer_type, ffi::Array indices); + llvm::Value* VectorLookupLoad(Buffer buffer, PrimType buffer_type, ffi::Array indices); llvm::Value* Intrinsic(llvm::Intrinsic::ID, llvm::ArrayRef args); std::vector fqhl_list_ = { "tvm_vect_qhmath_hvx_cos_ahf", "tvm_vect_qhmath_hvx_tanh_ahf", @@ -149,8 +154,9 @@ void CodeGenHexagon::InitTarget() { llvm::Value* CodeGenHexagon::CreateCallExternQHL(Type ret_type, ffi::String global_symbol, const ffi::Array& args, bool skip_first_arg) { - int num_lanes = args[1].dtype().lanes(); - int vector_length = native_vector_bits_ / args[1].dtype().bits(); + PrimType arg_ty = args[1].ty(); + int num_lanes = arg_ty.lanes(); + int vector_length = native_vector_bits_ / arg_ty.bits(); num_lanes = ((num_lanes + vector_length - 1) / vector_length) * vector_length; std::vector vect_split; for (int i = 0; i < num_lanes / vector_length; ++i) { @@ -181,8 +187,9 @@ bool CodeGenHexagon::IsQHLFunction(const std::string& func) { llvm::Value* CodeGenHexagon::CreateCallExtern(Type ret_type, ffi::String global_symbol, const ffi::Array& args, bool skip_first_arg) { - int num_lanes = args[1].dtype().lanes(); - int vector_length = native_vector_bits_ / args[1].dtype().bits(); + PrimType arg_ty = args[1].ty(); + int num_lanes = arg_ty.lanes(); + int vector_length = native_vector_bits_ / arg_ty.bits(); if (IsQHLFunction(global_symbol) && (num_lanes > vector_length)) return CreateCallExternQHL(ret_type, global_symbol, args, skip_first_arg); return CodeGenCPU::CreateCallExtern(ret_type, global_symbol, args, skip_first_arg); @@ -192,7 +199,7 @@ llvm::Value* CodeGenHexagon::VisitExpr_(const BufferLoadNode* op) { if (!op->buffer.same_as(op->buffer->data)) { // Check if we can generate a vector lookup. if (!op->indices[0].as()) { - if (auto* vlut = VectorLookupLoad(op->buffer, op->dtype, op->indices)) { + if (auto* vlut = VectorLookupLoad(op->buffer, PrimType(op->ty()->dtype), op->indices)) { return vlut; } } @@ -261,9 +268,9 @@ void CodeGenHexagon::CreatePrintf(const std::string& format, } CodeGenLLVM::TypedPointer CodeGenHexagon::CreateBufferPtr(llvm::Value* buffer_ptr, - DataType buffer_element_dtype, + PrimType buffer_element_dtype, llvm::ArrayRef indices, - DataType value_dtype) { + PrimType value_dtype) { // Flat indices get delegated to the LLVM codegen. if (indices.size() == 1) { return CodeGenCPU::CreateBufferPtr(buffer_ptr, buffer_element_dtype, indices, value_dtype); @@ -274,7 +281,7 @@ CodeGenLLVM::TypedPointer CodeGenHexagon::CreateBufferPtr(llvm::Value* buffer_pt << "-d buffer indices"; // Use the first index to identify the pointer. - DataType dtype_void_ptr = DataType::Handle(); + PrimType dtype_void_ptr = PrimType::Handle(); CodeGenLLVM::TypedPointer buffer_chunk_ptr_ptr = CodeGenCPU::CreateBufferPtr(buffer_ptr, dtype_void_ptr, {indices[0]}, dtype_void_ptr); llvm::Value* buffer_chunk_ptr = @@ -317,10 +324,11 @@ llvm::Value* CodeGenHexagon::Intrinsic(llvm::Intrinsic::ID IntID, return builder_->CreateCall(intf_callee, conv_args); } -llvm::Value* CodeGenHexagon::VectorLookupLoad(Buffer buffer, DataType buffer_type, +llvm::Value* CodeGenHexagon::VectorLookupLoad(Buffer buffer, PrimType buffer_type, ffi::Array indices) { PrimExpr index = indices[0]; - if (!index.dtype().is_fixed_length_vector()) { + PrimType index_ty = index.ty(); + if (!index_ty.IsFixedLengthVector()) { return nullptr; } @@ -329,16 +337,16 @@ llvm::Value* CodeGenHexagon::VectorLookupLoad(Buffer buffer, DataType buffer_typ int table_elem_count = arith::Analyzer()->Simplify(buffer->shape[0]).as()->value; if (table_elem_count <= 0 || table_elem_count > 256) return nullptr; - auto int32 = DataType::Int(32); + auto int32 = PrimType::Int(32); auto native_vector_bytes = native_vector_bits_ / 8; // Indexes - llvm::Value* trunc = MakeValue(Cast(index.dtype().with_bits(8), index)); + llvm::Value* trunc = MakeValue(Cast(index_ty.WithBits(8), index)); llvm::Value* index_pad = CreateVecPad(trunc, native_vector_bytes); // Values std::vector vloads; - DataType table_type = buffer_type.with_lanes(table_elem_count); + PrimType table_type = buffer_type.WithLanes(table_elem_count); auto table_all = MakeValue(BufferLoad(buffer, { @@ -347,7 +355,7 @@ llvm::Value* CodeGenHexagon::VectorLookupLoad(Buffer buffer, DataType buffer_typ // The number of value vectors should be a power of 2. int table_vec_count = llvm::PowerOf2Ceil(GetVectorBytes(table_type) / native_vector_bytes); - int table_vec_length = native_vector_bytes / buffer_type.bytes(); + int table_vec_length = native_vector_bytes / GetVectorBytes(buffer_type); for (int i = 0; i != table_vec_count; ++i) { // CreateVecSlice will generate undefs for elements outside the source vector. vloads.push_back(CreateVecSlice(table_all, i * table_vec_length, table_vec_length)); diff --git a/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc b/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc index 3e46e322a881..928df03f38aa 100644 --- a/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc +++ b/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc @@ -50,7 +50,7 @@ inline PrimExpr TVMExternCall(const tirx::CallNode* call, const std::string& fna for (PrimExpr arg : call->args) { new_args.push_back(arg); } - return tirx::Call(call->dtype, tirx::builtin::call_pure_extern(), new_args); + return tirx::Call(call->ty(), tirx::builtin::call_pure_extern(), new_args); } template @@ -72,14 +72,16 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { // Enable QHL library for FP16 data type const PrimExpr& x = call->args[0]; - if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { + PrimType x_ty = x.ty(); + if (x_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16) && + (x_ty.IsFixedLengthVector() || x_ty.IsScalableVector()) && useqhl) { return TVMExternCall(call, tvm_wrapper); } #endif - new_args.push_back(IntImm(DataType::UInt(32), id)); - new_args.push_back(IntImm(DataType::UInt(32), num_sign)); + new_args.push_back(IntImm(PrimType::UInt(32), id)); + new_args.push_back(IntImm(PrimType::UInt(32), num_sign)); new_args.insert(new_args.end(), call->args.begin(), call->args.end()); - return tirx::Call(call->dtype, tirx::builtin::call_llvm_pure_intrin(), new_args); + return tirx::Call(call->ty(), tirx::builtin::call_llvm_pure_intrin(), new_args); } void RegisterHexagonIntrinRules() { @@ -117,6 +119,7 @@ TVM_REGISTER_OP("tirx.tanh") const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; + PrimType x_ty = x.ty(); #if ENABLE_QHL // Check target for qfloat enablement @@ -130,14 +133,15 @@ TVM_REGISTER_OP("tirx.tanh") } // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { + if (x_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16) && + (x_ty.IsFixedLengthVector() || x_ty.IsScalableVector()) && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_tanh_ahf"); return TVMExternCall(call, tvm_wrapper); } #endif - PrimExpr one = tirx::MakeConst(x.dtype(), 1); - PrimExpr two = tirx::MakeConst(x.dtype(), 2); - PrimExpr neg_two = tirx::MakeConst(x.dtype(), -2); + PrimExpr one = tirx::MakeConst(x_ty, 1); + PrimExpr two = tirx::MakeConst(x_ty, 2); + PrimExpr neg_two = tirx::MakeConst(x_ty, -2); PrimExpr exp_neg2x = exp(neg_two * x); PrimExpr exp_pos2x = exp(two * x); @@ -145,7 +149,7 @@ TVM_REGISTER_OP("tirx.tanh") PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); // MakeConst can handle both vector and scalar types. - PrimExpr tanh_x = tirx::Select(x >= tirx::MakeConst(x.dtype(), 0), tanh_pos, tanh_neg); + PrimExpr tanh_x = tirx::Select(x >= tirx::MakeConst(x_ty, 0), tanh_pos, tanh_neg); return tanh_x; }); @@ -154,6 +158,7 @@ TVM_REGISTER_OP("tirx.tan") const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; + PrimType x_ty = x.ty(); #if ENABLE_QHL // Check target for qfloat enablement const auto f = tvm::ffi::Function::GetGlobal("target.TargetCurrent"); @@ -166,7 +171,8 @@ TVM_REGISTER_OP("tirx.tan") } // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { + if (x_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16) && + (x_ty.IsFixedLengthVector() || x_ty.IsScalableVector()) && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_tan_ahf"); return TVMExternCall(call, tvm_wrapper); } @@ -184,6 +190,7 @@ TVM_REGISTER_OP("tirx.sigmoid") const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; + PrimType x_ty = x.ty(); #if ENABLE_QHL // Check target for qfloat enablement const auto f = tvm::ffi::Function::GetGlobal("target.TargetCurrent"); @@ -195,21 +202,22 @@ TVM_REGISTER_OP("tirx.sigmoid") useqhl = tstring.find("+hvx-qfloat") != std::string::npos; } - PrimExpr MinBound = tirx::MakeConst(x.dtype(), -8); - PrimExpr MaxBound = tirx::MakeConst(x.dtype(), 8); + PrimExpr MinBound = tirx::MakeConst(x_ty, -8); + PrimExpr MaxBound = tirx::MakeConst(x_ty, 8); const PrimExpr v1 = tirx::Max(x, MinBound); const PrimExpr v2 = tirx::Min(v1, MaxBound); ffi::Array new_args = {v2}; - const tirx::Call new_call = tirx::Call(call->dtype, call->op, new_args); + const tirx::Call new_call = tirx::Call(call->ty(), call->op, new_args); // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { + if (x_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16) && + (x_ty.IsFixedLengthVector() || x_ty.IsScalableVector()) && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_sigmoid_ahf"); return TVMExternCall(new_call.get(), tvm_wrapper); } #endif - PrimExpr one = tirx::MakeConst(x.dtype(), 1); + PrimExpr one = tirx::MakeConst(x_ty, 1); return one / (one + exp(-x)); }); diff --git a/src/backend/hexagon/runtime/ops/conv2d_fp16_hvx.cc b/src/backend/hexagon/runtime/ops/conv2d_fp16_hvx.cc index d555fb77cfae..c063ae62b1bd 100644 --- a/src/backend/hexagon/runtime/ops/conv2d_fp16_hvx.cc +++ b/src/backend/hexagon/runtime/ops/conv2d_fp16_hvx.cc @@ -21,8 +21,8 @@ #include #include #include +#include #include -#include #include #include @@ -469,7 +469,7 @@ int conv2d_packed_fp16(void*, TVMFFIAny* args, int num_args, TVMFFIAny* out_val) // Prepare zero_block int64_t block_nbytes = 2048; void* zero_block = device_api->AllocDataSpace(conv_utils::hexagon_device, 1, &block_nbytes, - tvm::runtime::DataType::UInt(8), vtcm_scope); + DLDataType{kDLUInt, 8, 1}, vtcm_scope); memset(zero_block, 0, 2048); // FIXME: Setting bias to zero_block: this works for up to 256 output channels. diff --git a/src/backend/metal/codegen/codegen_metal.cc b/src/backend/metal/codegen/codegen_metal.cc index 3f483f79aaed..e6ef1647e5bf 100644 --- a/src/backend/metal/codegen/codegen_metal.cc +++ b/src/backend/metal/codegen/codegen_metal.cc @@ -46,7 +46,7 @@ void CodeGenMetal::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); // analyze the data; for (Var arg : f->params) { - if (arg.dtype().is_handle()) { + if (arg.ty().IsHandle()) { alloc_storage_scope_[arg.get()] = "global"; } } @@ -97,7 +97,7 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { } for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) { Var v = func->params[i]; - if (!v.dtype().is_handle()) break; + if (!v.ty().IsHandle()) break; this->stream << " "; std::string vid = AllocVarID(v.get()); auto it = alloc_storage_scope_.find(v.get()); @@ -126,24 +126,24 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { decl_stream << "struct " << arg_buf_type << " {\n"; for (size_t i = num_buffer; i < func->params.size(); ++i) { Var v = func->params[i]; - TVM_FFI_ICHECK(!v.dtype().is_handle()); + TVM_FFI_ICHECK(!v.ty().IsHandle()); std::string vid = AllocVarID(v.get()); std::ostringstream vref; - if (v.dtype().bits() == 32) { + if (v.ty().bits() == 32) { decl_stream << " "; - PrintType(v.dtype(), decl_stream); + PrintType(v.ty()->dtype, decl_stream); decl_stream << " " << vid << "[2];\n"; vref << varg << "." << vid << "[0]"; - } else if (v.dtype().bits() == 64) { + } else if (v.ty().bits() == 64) { decl_stream << " "; - PrintType(v.dtype(), decl_stream); + PrintType(v.ty()->dtype, decl_stream); decl_stream << " " << vid << ";\n"; vref << varg << "." << vid; } else { // For non 32bit type, ref through arg union. decl_stream << " __TVMArgUnion " << vid << ";\n"; vref << varg << "." << vid << ".v_"; - PrintType(v.dtype(), vref); + PrintType(v.ty()->dtype, vref); } var_idmap_[v.get()] = vref.str(); } @@ -165,10 +165,14 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { if (work_dim != 0) { // use ushort by default for now stream << " "; - PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); + PrintType(DLDataType{kDLUInt, static_cast(thread_index_bits_), + static_cast(work_dim)}, + stream); stream << " blockIdx [[threadgroup_position_in_grid]],\n"; stream << " "; - PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); + PrintType(DLDataType{kDLUInt, static_cast(thread_index_bits_), + static_cast(work_dim)}, + stream); stream << " threadIdx [[thread_position_in_threadgroup]]\n"; } thread_work_dim_ = work_dim; @@ -190,28 +194,29 @@ void CodeGenMetal::BindThreadIndex(const IterVar& iv) { if (thread_work_dim_ <= 1) { vname = vname.substr(0, iv->thread_tag.length() - 2); } - var_idmap_[iv->var.get()] = - CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype()); + var_idmap_[iv->var.get()] = CastFromTo( + vname, DLDataType{kDLUInt, static_cast(thread_index_bits_), 1}, iv->var.ty()->dtype); } -void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) +void CodeGenMetal::PrintType(DLDataType raw_t, std::ostream& os) { // NOLINT(*) + PrimType t(raw_t); int lanes = t.lanes(); - if (t.is_handle()) { + if (t.IsHandle()) { TVM_FFI_ICHECK_EQ(lanes, 1) << "do not yet support vector types"; os << "void*"; return; } - if (t.is_void()) { + if (t.IsVoid()) { os << "void"; return; } - if (t == DataType::Bool()) { + if (raw_t == DLDataType{kDLBool, 8, 1}) { os << "bool"; return; } bool fail = false; - if (t.is_float()) { + if (t.code() == DLDataTypeCode::kDLFloat) { // Need to care about sizes and alignment of half3/float3 because tirx representation might not // be aware of Metal half3/float3 details and can treat them as just three elements, // while sizes and alignmnents of half3/float3 are one element more (half3-8 bytes/ @@ -239,8 +244,8 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << lanes; return; } - } else if (t.is_uint() || t.is_int()) { - if (t.is_uint()) { + } else if (t.MatchesCode(DLDataTypeCode::kDLUInt, DLDataTypeCode::kDLInt)) { + if (t.MatchesCode(DLDataTypeCode::kDLUInt)) { os << 'u'; } switch (t.bits()) { @@ -268,11 +273,12 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << lanes; return; } - } else if (t.is_bfloat16()) { + } else if (t.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { os << "bfloat"; return; } - TVM_FFI_THROW(InternalError) << "Cannot convert type " << t << " to Metal type"; + TVM_FFI_THROW(InternalError) << "Cannot convert type " << ffi::DLDataTypeToString(raw_t) + << " to Metal type"; } void CodeGenMetal::PrintStorageSync(const CallNode* op) { @@ -288,12 +294,12 @@ void CodeGenMetal::PrintStorageSync(const CallNode* op) { } } -void CodeGenMetal::PrintVecElemLoad(const std::string& vec, DataType t, int i, +void CodeGenMetal::PrintVecElemLoad(const std::string& vec, DLDataType t, int i, std::ostream& os) { // NOLINT(*) os << vec << "[" << i << "]"; } -void CodeGenMetal::PrintVecElemStore(const std::string& vec, DataType t, int i, +void CodeGenMetal::PrintVecElemStore(const std::string& vec, DLDataType t, int i, const std::string& value) { this->PrintIndent(); stream << vec << "[" << i << "]" @@ -328,11 +334,14 @@ void CodeGenMetal::VisitStmt_(const AllocBufferNode* op) { auto scope = GetPtrStorageScope(op->buffer->data); alloc_storage_scope_[op->buffer->data.get()] = scope; - DataType dtype = op->buffer->dtype; + DLDataType dtype = op->buffer->dtype->dtype; if (scope == "metal.simdgroup") { - TVM_FFI_ICHECK(dtype == DataType::Float(16) || dtype == DataType::Float(32) || - dtype == DataType::BFloat(16)) - << "Only float16, float32, and bfloat16 are supported, but got " << dtype; + bool supported_simdgroup_dtype = dtype == DLDataType{kDLFloat, 16, 1} || + dtype == DLDataType{kDLFloat, 32, 1} || + dtype == DLDataType{kDLBfloat, 16, 1}; + TVM_FFI_ICHECK(supported_simdgroup_dtype) + << "Only float16, float32, and bfloat16 are supported, but got " + << ffi::DLDataTypeToString(dtype); TVM_FFI_ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got " << constant_size << " bytes\n"; @@ -360,8 +369,8 @@ void CodeGenMetal::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLI void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); - int lanes = op->dtype.lanes(); - PrintType(op->dtype, os); + int lanes = op->ty().lanes(); + PrintType(op->ty()->dtype, os); os << "("; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; @@ -422,7 +431,7 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT } else if (op->op.same_as(builtin::reinterpret())) { // generate as_type(ARG) os << "(as_type<"; - this->PrintType(op->dtype, os); + this->PrintType(op->ty()->dtype, os); os << ">("; this->PrintExpr(op->args[0], os); os << "))"; @@ -442,9 +451,9 @@ void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NO temp << "NAN"; } else { temp << std::scientific << op->value; - if (op->dtype.bits() == 32) + if (op->ty().bits() == 32) temp << 'f'; - else if (op->dtype.bits() == 16) + else if (op->ty().bits() == 16) temp << 'h'; } MarkConst(temp.str()); diff --git a/src/backend/metal/codegen/codegen_metal.h b/src/backend/metal/codegen/codegen_metal.h index b92608aecfa1..ffa9a321aa43 100644 --- a/src/backend/metal/codegen/codegen_metal.h +++ b/src/backend/metal/codegen/codegen_metal.h @@ -43,13 +43,14 @@ class CodeGenMetal final : public CodeGenC { void InitFuncState(const PrimFunc& f) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const CallNode* op) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintType(DLDataType t, std::ostream& os) final; // NOLINT(*) void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) // print load of single element - void PrintVecElemLoad(const std::string& vec, DataType t, int i, + void PrintVecElemLoad(const std::string& vec, DLDataType t, int i, std::ostream& os) final; // NOLINT(*) // print store of single element. - void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; + void PrintVecElemStore(const std::string& vec, DLDataType t, int i, + const std::string& value) final; // overload visitor void VisitStmt_(const AllocBufferNode* op) final; // NOLINT(*) void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*) diff --git a/src/backend/metal/codegen/intrin_rule_metal.cc b/src/backend/metal/codegen/intrin_rule_metal.cc index c807ac4c2e8a..999fe526f04e 100644 --- a/src/backend/metal/codegen/intrin_rule_metal.cc +++ b/src/backend/metal/codegen/intrin_rule_metal.cc @@ -31,7 +31,7 @@ namespace intrin { using tirx::FLowerIntrinsic; struct MetalWarpIntrinsic { - const Op operator()(DataType t, const Op& orig_op) const { + const Op operator()(PrimType t, const Op& orig_op) const { if (orig_op.same_as(builtin::tvm_warp_shuffle())) { static const Op& metal_simd_shuffle_op = Op::Get("tirx.metal.simd_shuffle"); return metal_simd_shuffle_op; @@ -52,7 +52,7 @@ static PrimExpr DispatchMetalShuffle(const PrimExpr& e) { TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size ffi::Array metal_args{{call->args[1], call->args[2]}}; - return Call(call->dtype, T()(call->dtype, call->op.as_or_throw()), metal_args); + return Call(e.ty(), T()(e.ty(), call->op.as_or_throw()), metal_args); } void RegisterMetalIntrinRules() { @@ -81,7 +81,7 @@ TVM_REGISTER_OP("tirx.round") for (auto arg : call->args) { new_args.push_back(arg); } - return tirx::Call(call->dtype, tirx::builtin::call_pure_extern(), new_args); + return tirx::Call(e.ty(), tirx::builtin::call_pure_extern(), new_args); }); TVM_REGISTER_OP("tirx.nearbyint") diff --git a/src/backend/opencl/codegen/codegen_opencl.cc b/src/backend/opencl/codegen/codegen_opencl.cc index 51719785195b..001d4a33b081 100644 --- a/src/backend/opencl/codegen/codegen_opencl.cc +++ b/src/backend/opencl/codegen/codegen_opencl.cc @@ -84,7 +84,7 @@ void CodeGenOpenCL::InitFuncState(const PrimFunc& f) { // Storage scope qualifiers for textures are inferred // and set prior to function codegen. continue; - } else if (arg.dtype().is_handle()) { + } else if (arg.ty().IsHandle()) { alloc_storage_scope_[arg.get()] = "global"; } } @@ -189,26 +189,27 @@ void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) { } else { os << "get_group_id(" << ts.dim_index << ")"; } - var_idmap_[iv->var.get()] = CastFromTo(os.str(), DataType::UInt(64), iv->var.dtype()); + var_idmap_[iv->var.get()] = CastFromTo(os.str(), DLDataType{kDLUInt, 64, 1}, iv->var.ty()->dtype); } -void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) +void CodeGenOpenCL::PrintType(DLDataType raw_t, std::ostream& os) { // NOLINT(*) + PrimType t(raw_t); int lanes = t.lanes(); - if (t.is_handle()) { + if (t.IsHandle()) { TVM_FFI_ICHECK_EQ(lanes, 1) << "do not yet support vector types"; os << "void*"; return; } - if (t.is_void()) { + if (t.IsVoid()) { os << "void"; return; } - if (t == DataType::Bool()) { + if (raw_t == DLDataType{kDLBool, 8, 1}) { os << "bool"; return; } bool fail = false; - if (t.is_float()) { + if (t.code() == DLDataTypeCode::kDLFloat) { switch (t.bits()) { case 16: os << "half"; @@ -230,14 +231,14 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << lanes; return; } - } else if (t.is_bool()) { + } else if (t.MatchesCode(DLDataTypeCode::kDLBool)) { os << "uint"; if (!fail && ((lanes >= 2 && lanes <= 4) || lanes == 8 || lanes == 16)) { os << lanes; return; } - } else if (t.is_uint() || t.is_int()) { - if (t.is_uint()) { + } else if (t.MatchesCode(DLDataTypeCode::kDLUInt, DLDataTypeCode::kDLInt)) { + if (t.MatchesCode(DLDataTypeCode::kDLUInt)) { os << 'u'; } switch (t.bits()) { @@ -266,7 +267,8 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) return; } } - TVM_FFI_THROW(InternalError) << "Cannot convert type " << t << " to OpenCL type"; + TVM_FFI_THROW(InternalError) << "Cannot convert type " << ffi::DLDataTypeToString(raw_t) + << " to OpenCL type"; } void CodeGenOpenCL::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) @@ -286,41 +288,44 @@ void CodeGenOpenCL::PrintType(const Type& type, std::ostream& os) { // NOLINT(* } } -void CodeGenOpenCL::PrintVecAddr(const BufferNode* buffer, DataType t, PrimExpr base, +void CodeGenOpenCL::PrintVecAddr(const BufferNode* buffer, DLDataType t, PrimExpr base, std::ostream& os) { // NOLINT(*) const VarNode* buffer_var = buffer->data.get(); - if (!HandleTypeMatch(buffer_var, t.element_of())) { + DLDataType elem_type{t.code, t.bits, 1}; + if (!HandleTypeMatch(buffer_var, elem_type)) { os << '('; auto it = alloc_storage_scope_.find(buffer_var); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, os); } - PrintType(t.element_of(), os); + PrintType(elem_type, os); os << "*)"; } os << GetVarID(buffer_var) << " + "; PrintExpr(base, os); } -std::string CodeGenOpenCL::GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) { +std::string CodeGenOpenCL::GetVecLoad(DLDataType t, const BufferNode* buffer, PrimExpr base) { std::ostringstream os; - os << "vload" << t.lanes() << "(0, "; + os << "vload" << PrimType(t).lanes() << "(0, "; PrintVecAddr(buffer, t, base, os); os << ")"; return os.str(); } -void CodeGenOpenCL::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base, +void CodeGenOpenCL::PrintVecStore(const BufferNode* buffer, DLDataType t, PrimExpr base, const std::string& value) { this->PrintIndent(); - stream << "vstore" << t.lanes() << "(" << value << ", 0, "; + stream << "vstore" << PrimType(t).lanes() << "(" << value << ", 0, "; PrintVecAddr(buffer, t, base, stream); stream << ");\n"; } -void CodeGenOpenCL::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, +void CodeGenOpenCL::PrintVecElemLoadExpr(DLDataType t, int i, const std::string& value, std::ostream& os) { // NOLINT(*) - TVM_FFI_ICHECK_GT(t.lanes(), 1); - if (t.bits() == 8 && (t.is_int() || t.is_uint())) { + PrimType t_ty(t); + int lanes = t_ty.lanes(); + TVM_FFI_ICHECK_GT(lanes, 1); + if (t.bits == 8 && (t_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt))) { if (i != 0) { os << "|"; } @@ -334,7 +339,7 @@ void CodeGenOpenCL::PrintVecElemLoadExpr(DataType t, int i, const std::string& v os << ")("; } os << value; - if (i != t.lanes() - 1) { + if (i != lanes - 1) { os << ","; } else { os << "))"; @@ -376,14 +381,14 @@ void CodeGenOpenCL::PrintRestrict(const Var& v, std::ostream& os) { } } -std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType target) { +std::string CodeGenOpenCL::CastFromTo(std::string value, DLDataType from, DLDataType target) { if (from == target) return value; return CastTo(value, target); } -std::string CodeGenOpenCL::CastTo(std::string value, DataType target) { +std::string CodeGenOpenCL::CastTo(std::string value, DLDataType target) { std::ostringstream os; - if (target == DataType::Bool()) { + if (target == DLDataType{kDLBool, 8, 1}) { os << "("; os << "("; this->PrintType(target, os); @@ -422,7 +427,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, os); } - this->PrintType(load->dtype.element_of(), os); + this->PrintType(DLDataType{load->ty()->dtype.code, load->ty()->dtype.bits, 1}, os); os << " *)" << this->GetVarID(load->buffer->data.get()) << " + "; this->PrintExpr(load->indices[0], os); os << ')'; @@ -434,13 +439,14 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { const int channel_size = op->args[4].as_or_throw()->value; TVM_FFI_ICHECK(channel_size == 64 || channel_size == 128) << "Unsupported Channel Size: " << channel_size; - DataType channel_type = runtime::GetChannelType(channel_size); + DLDataType channel_type = runtime::GetChannelType(channel_size); - DataType buffer_type = ptr_type->element_type.as()->dtype; + DLDataType buffer_type = ptr_type->element_type.as()->dtype; std::stringstream ss; this->PrintExpr(op->args[5], ss); std::string value; - value = this->SSAGetID(ss.str(), buffer_type.with_lanes(channel_size / buffer_type.bits())); + value = this->SSAGetID(ss.str(), + PrimType(buffer_type).WithLanes(channel_size / buffer_type.bits)->dtype); if (channel_size == 64) { os << "write_imageh("; } else if (channel_size == 128) { @@ -467,11 +473,11 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { enable_compliant_texture_reads_ = true; std::stringstream ss; const int channel_size = op->args[4].as_or_throw()->value; - const int data_lanes = channel_size / op->dtype.bits(); + const int data_lanes = channel_size / op->ty().bits(); TVM_FFI_ICHECK(channel_size == 64 || channel_size == 128) << "Unsupported Channel Size: " << channel_size; ss << "as_"; - this->PrintType(op->dtype.with_lanes(data_lanes), ss); + this->PrintType(op->ty().WithLanes(data_lanes)->dtype, ss); ss << "("; if (channel_size == 64) { ss << "READ_IMAGEH("; @@ -493,7 +499,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(IntImm::Int32(0), ss); ss << "))))"; - std::string rhs = SSAGetID(ss.str(), op->dtype.with_lanes(data_lanes)); + std::string rhs = SSAGetID(ss.str(), op->ty().WithLanes(data_lanes)->dtype); if (auto ramp = op->args.back().as()) { if (ramp->base.as() && *tirx::as_const_int(ramp->base) == 0 && *tirx::as_const_int(ramp->lanes) == data_lanes && @@ -501,10 +507,10 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { os << rhs; } else if (*tirx::as_const_int(ramp->stride) == 1) { os << "(*("; - this->PrintType(op->dtype.with_lanes(*tirx::as_const_int(ramp->lanes)), os); + this->PrintType(op->ty().WithLanes(*tirx::as_const_int(ramp->lanes))->dtype, os); os << "*)"; os << "(("; - this->PrintType(op->dtype.with_lanes(1), os); + this->PrintType(op->ty().WithLanes(1)->dtype, os); os << "*)&" << rhs << " + "; this->PrintExpr(ramp->base, os); os << "))"; @@ -513,7 +519,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { } } else { os << "(("; - this->PrintType(op->dtype.with_lanes(1), os); + this->PrintType(op->ty().WithLanes(1)->dtype, os); os << "*)&" << rhs << ")["; this->PrintExpr(op->args.back(), os); os << "]"; @@ -521,7 +527,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { auto func = op->args[0].as_or_throw(); // Enable atomics extension if used. - if (func->value == "atomic_add" && op->dtype.is_float()) { + if (func->value == "atomic_add" && op->ty().code() == DLDataTypeCode::kDLFloat) { enable_atomics_ = true; this->PrintCallExtern(GetType(ffi::GetRef(op)), "atomic_add_float_emu", op->args, true, os); @@ -540,9 +546,9 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); - int lanes = op->dtype.lanes(); + int lanes = op->ty().lanes(); os << "(("; - PrintType(op->dtype, os); + PrintType(op->ty()->dtype, os); os << ")("; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; @@ -553,9 +559,9 @@ void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // void CodeGenOpenCL::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) os << "(("; - PrintType(op->dtype, os); + PrintType(op->ty()->dtype, os); os << ")("; - int lanes = op->dtype.lanes(); + int lanes = op->ty().lanes(); for (int i = 0; i < lanes; i++) { os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i << ")"; @@ -579,18 +585,18 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // N template inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, CodeGenOpenCL* p) { - if (op->dtype.lanes() == 1) { + if (op->ty().lanes() == 1) { os << opstr << "(("; - p->PrintType(op->a->dtype, os); + p->PrintType(op->a.ty()->dtype, os); os << ")"; p->PrintExpr(op->a, os); os << ", ("; - p->PrintType(op->b->dtype, os); + p->PrintType(op->b.ty()->dtype, os); os << ")"; p->PrintExpr(op->b, os); os << ')'; } else { - p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os); + p->PrintVecBinaryOp(opstr, op->ty()->dtype, op->a, op->b, os); } } @@ -604,14 +610,16 @@ void CodeGenOpenCL::VisitExpr_(const MaxNode* op, std::ostream& os) { void CodeGenOpenCL::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*) std::string opstr; - if (op->dtype.is_int() || op->dtype.is_uint()) { + PrimType op_ty = op->ty(); + if (op_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) { opstr = "%"; } else { - TVM_FFI_ICHECK(op->dtype.is_float()) - << "Expected floating point or integer dtype in Mod, but got " << op->dtype; + TVM_FFI_ICHECK(op_ty.code() == DLDataTypeCode::kDLFloat) + << "Expected floating point or integer dtype in Mod, but got " + << ffi::DLDataTypeToString(op->ty()->dtype); opstr = "fmod"; } - if (op->dtype.lanes() == 1) { + if (op_ty.lanes() == 1) { if (isalpha(opstr.c_str()[0])) { os << opstr.c_str() << '('; this->PrintExpr(op->a, os); @@ -626,7 +634,7 @@ void CodeGenOpenCL::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT os << ')'; } } else { - this->PrintVecBinaryOp(opstr.c_str(), op->dtype, op->a, op->b, os); + this->PrintVecBinaryOp(opstr.c_str(), op->ty()->dtype, op->a, op->b, os); } } @@ -634,11 +642,11 @@ void CodeGenOpenCL::VisitExpr_(const AndNode* op, std::ostream& os) { std::ostringstream oss; os << "("; this->PrintExpr(op->a, oss); - os << CastTo(oss.str(), op->dtype); + os << CastTo(oss.str(), op->ty()->dtype); oss.str(""); os << " && "; this->PrintExpr(op->b, oss); - os << CastTo(oss.str(), op->dtype); + os << CastTo(oss.str(), op->ty()->dtype); os << ")"; } @@ -646,11 +654,11 @@ void CodeGenOpenCL::VisitExpr_(const OrNode* op, std::ostream& os) { std::ostringstream oss; os << "("; this->PrintExpr(op->a, oss); - os << CastTo(oss.str(), op->dtype); + os << CastTo(oss.str(), op->ty()->dtype); oss.str(""); os << " || "; this->PrintExpr(op->b, oss); - os << CastTo(oss.str(), op->dtype); + os << CastTo(oss.str(), op->ty()->dtype); os << ")"; } @@ -658,18 +666,19 @@ void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) { std::ostringstream oss; os << "select("; PrintExpr(op->false_value, oss); - os << CastFromTo(oss.str(), op->false_value.dtype(), op->dtype); + os << CastFromTo(oss.str(), op->false_value.ty()->dtype, op->ty()->dtype); oss.str(""); os << ", "; PrintExpr(op->true_value, oss); - os << CastFromTo(oss.str(), op->true_value.dtype(), op->dtype); + os << CastFromTo(oss.str(), op->true_value.ty()->dtype, op->ty()->dtype); oss.str(""); os << ", "; PrintExpr(op->condition, oss); - if (op->dtype.is_float()) { - os << CastTo(oss.str(), DataType::Int(op->dtype.bits(), op->dtype.lanes())); + if (op->ty().code() == DLDataTypeCode::kDLFloat) { + os << CastTo(oss.str(), DLDataType{kDLInt, static_cast(op->ty().bits()), + static_cast(op->ty().lanes())}); } else { - os << CastFromTo(oss.str(), op->condition.dtype(), op->dtype); + os << CastFromTo(oss.str(), op->condition.ty()->dtype, op->ty()->dtype); } os << ")"; } diff --git a/src/backend/opencl/codegen/codegen_opencl.h b/src/backend/opencl/codegen/codegen_opencl.h index d588a18c2029..47667e30663a 100644 --- a/src/backend/opencl/codegen/codegen_opencl.h +++ b/src/backend/opencl/codegen/codegen_opencl.h @@ -46,20 +46,20 @@ class CodeGenOpenCL final : public CodeGenC { void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const CallNode* op) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintType(DLDataType t, std::ostream& os) final; // NOLINT(*) void PrintType(const Type& type, std::ostream& os) final; // NOLINT(*) - std::string GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) final; - void PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base, + std::string GetVecLoad(DLDataType t, const BufferNode* buffer, PrimExpr base) final; + void PrintVecStore(const BufferNode* buffer, DLDataType t, PrimExpr base, const std::string& value) final; // NOLINT(*) - void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, + void PrintVecElemLoadExpr(DLDataType t, int i, const std::string& value, std::ostream& os) final; // NOLINT(*) // the address of load/store - void PrintVecAddr(const BufferNode* buffer, DataType t, PrimExpr base, - std::ostream& os); // NOLINT(*) - void PrintRestrict(const Var& v, std::ostream& os) final; // NOLINT(*) - std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) - std::string CastTo(std::string value, DataType target); // NOLINT(*) - void SetTextureScope(const std::unordered_map&); // NOLINT(*) + void PrintVecAddr(const BufferNode* buffer, DLDataType t, PrimExpr base, + std::ostream& os); // NOLINT(*) + void PrintRestrict(const Var& v, std::ostream& os) final; // NOLINT(*) + std::string CastFromTo(std::string value, DLDataType from, DLDataType target); // NOLINT(*) + std::string CastTo(std::string value, DLDataType target); // NOLINT(*) + void SetTextureScope(const std::unordered_map&); // NOLINT(*) // overload visitor void VisitStmt_(const AllocBufferNode* op) final; // NOLINT(*) diff --git a/src/backend/opencl/codegen/intrin_rule_opencl.cc b/src/backend/opencl/codegen/intrin_rule_opencl.cc index f0f58be84d10..669fd1863b39 100644 --- a/src/backend/opencl/codegen/intrin_rule_opencl.cc +++ b/src/backend/opencl/codegen/intrin_rule_opencl.cc @@ -42,7 +42,7 @@ static PrimExpr DispatchIntelShuffle(const PrimExpr& e) { << "Intel warp shuffle dose not support width != warp_size"; ffi::Array opencl_args{ {StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}}; - return Call(call->dtype, builtin::call_pure_extern(), opencl_args); + return Call(e.ty(), builtin::call_pure_extern(), opencl_args); } void RegisterOpenCLIntrinRules() { @@ -75,7 +75,7 @@ TVM_REGISTER_OP("tirx.round") for (auto arg : call->args) { new_args.push_back(arg); } - return tirx::Call(call->dtype, tirx::builtin::call_pure_extern(), new_args); + return tirx::Call(e.ty(), tirx::builtin::call_pure_extern(), new_args); }); TVM_REGISTER_OP("tirx.nearbyint") diff --git a/src/backend/opencl/runtime/opencl_common.h b/src/backend/opencl/runtime/opencl_common.h index 3b99fa166def..4fc7ce85e383 100644 --- a/src/backend/opencl/runtime/opencl_common.h +++ b/src/backend/opencl/runtime/opencl_common.h @@ -186,24 +186,25 @@ inline const char* CLGetErrorString(cl_int error) { } inline cl_channel_type DTypeToOpenCLChannelType(DLDataType data_type) { - DataType dtype(data_type); - dtype = dtype.with_lanes(1); + DLDataType dtype = data_type; + // OpenCL image channel type depends on the scalar element type, not vector lanes. + dtype.lanes = 1; - if (dtype == DataType::Float(32)) { + if (dtype == DLDataType{kDLFloat, 32, 1}) { return CL_FLOAT; - } else if (dtype == DataType::Float(16)) { + } else if (dtype == DLDataType{kDLFloat, 16, 1}) { return CL_HALF_FLOAT; - } else if (dtype == DataType::Int(8)) { + } else if (dtype == DLDataType{kDLInt, 8, 1}) { return CL_SIGNED_INT8; - } else if (dtype == DataType::Int(16)) { + } else if (dtype == DLDataType{kDLInt, 16, 1}) { return CL_SIGNED_INT16; - } else if (dtype == DataType::Int(32)) { + } else if (dtype == DLDataType{kDLInt, 32, 1}) { return CL_SIGNED_INT32; - } else if (dtype == DataType::UInt(8)) { + } else if (dtype == DLDataType{kDLUInt, 8, 1}) { return CL_UNSIGNED_INT8; - } else if (dtype == DataType::UInt(16)) { + } else if (dtype == DLDataType{kDLUInt, 16, 1}) { return CL_UNSIGNED_INT16; - } else if (dtype == DataType::UInt(32)) { + } else if (dtype == DLDataType{kDLUInt, 32, 1}) { return CL_UNSIGNED_INT32; } TVM_FFI_THROW(InternalError) << "data type is not supported in OpenCL runtime yet: " << dtype; diff --git a/src/backend/opencl/runtime/opencl_device_api.cc b/src/backend/opencl/runtime/opencl_device_api.cc index eeb8e95ad543..0b53a1915192 100644 --- a/src/backend/opencl/runtime/opencl_device_api.cc +++ b/src/backend/opencl/runtime/opencl_device_api.cc @@ -779,14 +779,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { int64_t height = shape[1]; int64_t depth = shape[2]; int64_t channel_size = args[7].cast(); - DataType channel_type = GetChannelType(channel_size); + DLDataType channel_type = GetChannelType(channel_size); Device dev; dev.device_type = static_cast(device_type); dev.device_id = device_id; DLDataType type_hint; - type_hint.code = channel_type.code(); - type_hint.bits = channel_type.bits(); - type_hint.lanes = channel_type.lanes(); + type_hint = channel_type; *rv = OpenCLWorkspace::Global()->AllocDataSpace( dev, static_cast(width), static_cast(height), diff --git a/src/backend/opencl/runtime/texture.h b/src/backend/opencl/runtime/texture.h index a8711805cbfa..3aa2d3681142 100644 --- a/src/backend/opencl/runtime/texture.h +++ b/src/backend/opencl/runtime/texture.h @@ -120,15 +120,13 @@ size_t GetTextureMemorySize(T shape, int bits, int lanes, std::string mem_scope, /*! * \brief Returns the standard channel datatype for any given type. * \param channel_size The Number of bits in a Channel - * \return DataType to be used in the codegen. + * \return DLDataType to be used in the codegen. */ -inline DataType GetChannelType(size_t channel_size) { - DataType channel_type; - +inline DLDataType GetChannelType(size_t channel_size) { if (channel_size == 128) - return DataType::Float(32, 4); + return DLDataType{kDLFloat, 32, 4}; else if (channel_size == 64) - return DataType::Float(16, 4); + return DLDataType{kDLFloat, 16, 4}; TVM_FFI_THROW(InternalError) << "Unsupported Channel Size: " << channel_size; } diff --git a/src/backend/rocm/codegen/llvm/codegen_amdgpu.cc b/src/backend/rocm/codegen/llvm/codegen_amdgpu.cc index 22ce75cddade..6f70343f46a4 100644 --- a/src/backend/rocm/codegen/llvm/codegen_amdgpu.cc +++ b/src/backend/rocm/codegen/llvm/codegen_amdgpu.cc @@ -100,7 +100,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::Value* buf = nullptr; StorageInfo& info = alloc_storage_info_[op->buffer->data.get()]; auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data)); - DataType dtype = op->buffer->dtype; + PrimType dtype = op->buffer->dtype; if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { LOG(WARNING) << "Dynamic shared memory support for rocm is experimental."; @@ -188,7 +188,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id); #endif llvm::Value* result = builder_->CreateCall(f, {}); - return this->CreateCast(DataType::Int(32), iv->var->dtype, result); + return this->CreateCast(PrimType::Int(32), iv->var.ty(), result); } llvm::Value* CreateStorageSync(const CallNode* op) final { @@ -220,10 +220,11 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::Value* CreateIntrinsic(const CallNode* op) final { if (op->op.same_as(builtin::atomic_add())) { - TVM_FFI_ICHECK(op->args[1]->dtype.bits() == 32) << "Only supports 32 bit atomic for now"; + PrimType value_ty = op->args[1].ty(); + TVM_FFI_ICHECK(value_ty.bits() == 32) << "Only supports 32 bit atomic for now"; llvm::Value* v0 = MakeValue(op->args[0]); llvm::Value* v1 = MakeValue(op->args[1]); - if (op->args[1]->dtype.is_float()) { + if (value_ty.MatchesCode(DLDataTypeCode::kDLFloat)) { return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1, llvm::MaybeAlign(), llvm::AtomicOrdering::Monotonic); } diff --git a/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc b/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc index 4859fd5f4a24..db0f113b9c8b 100644 --- a/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc +++ b/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc @@ -50,14 +50,14 @@ inline PrimExpr DispatchPureExternOCML(const PrimExpr& e) { TVM_FFI_ICHECK_EQ(name.substr(0, 5), "tirx."); std::ostringstream intrinsic_name; - intrinsic_name << "__ocml_" << name.substr(5) << "_f" << call->dtype.bits(); + intrinsic_name << "__ocml_" << name.substr(5) << "_f" << call->ty().bits(); ffi::Array new_args = {StringImm(intrinsic_name.str())}; for (auto arg : call->args) { new_args.push_back(arg); } - return Call(call->dtype, builtin::call_pure_extern(), new_args); + return Call(call->ty(), builtin::call_pure_extern(), new_args); } inline PrimExpr DispatchShuffle(const PrimExpr& e) { @@ -66,15 +66,17 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) { TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size PrimExpr var = call->args[1]; - TVM_FFI_ICHECK_EQ(var.dtype().bits(), 32); + PrimType var_ty = var.ty(); + TVM_FFI_ICHECK_EQ(var_ty.bits(), 32); // get own lane in self (__lane_id) PrimExpr minus_one = IntImm::Int32(-1); PrimExpr zero = IntImm::Int32(0); - PrimExpr lo = Call(DataType::Int(32), builtin::call_pure_extern(), + PrimType i32_ty = PrimType::Int(32); + PrimExpr lo = Call(i32_ty, builtin::call_pure_extern(), {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero}); - PrimExpr self = Call(DataType::Int(32), builtin::call_pure_extern(), - {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo}); + PrimExpr self = + Call(i32_ty, builtin::call_pure_extern(), {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo}); // compute lane to get from PrimExpr width = call->args[3]; @@ -93,12 +95,12 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) { index = Select((self & (width - 1)) + delta >= width, self, index); } // reinterprete var as int32 - bool is_int32 = var.dtype().is_int() && var.dtype().bits() == 32; - PrimExpr source = is_int32 ? var : reinterpret(DataType::Int(32), var); - PrimExpr res = Call(DataType::Int(32), builtin::call_pure_extern(), + bool is_int32 = var_ty.MatchesElementType(DLDataTypeCode::kDLInt, 32); + PrimExpr source = is_int32 ? var : reinterpret(PrimType::Int(32), var); + PrimExpr res = Call(i32_ty, builtin::call_pure_extern(), {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, source}); if (!is_int32) { - res = reinterpret(var.dtype(), res); + res = reinterpret(var_ty, res); } return res; } diff --git a/src/backend/trn/codegen/codegen_trn.cc b/src/backend/trn/codegen/codegen_trn.cc index eb9d7ca4b437..631df21f8b08 100644 --- a/src/backend/trn/codegen/codegen_trn.cc +++ b/src/backend/trn/codegen/codegen_trn.cc @@ -110,7 +110,7 @@ void CodeGenTrainium::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { size_t num_buffer = 0; for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) { Var v = func->params[i]; - if (!v.dtype().is_handle()) { + if (!v.ty().IsHandle()) { LOG(FATAL) << "Trainium codegen currently only support buffer arguments"; }; std::string vid = AllocVarID(v.get()); @@ -137,16 +137,17 @@ void CodeGenTrainium::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { this->EndScope(func_scope); } -void CodeGenTrainium::PrintType(DataType t, std::ostream& os) { // NOLINT(*) +void CodeGenTrainium::PrintType(DLDataType raw_t, std::ostream& os) { // NOLINT(*) + PrimType t(raw_t); int lanes = t.lanes(); TVM_FFI_ICHECK(lanes == 1) << "Trainium codegen does not support vector types"; - TVM_FFI_ICHECK(!t.is_handle()) << "Trainium codegen does not support handle type"; - TVM_FFI_ICHECK(!t.is_void()) << "Trainium codegen does not support void type"; - if (t == DataType::Bool()) { + TVM_FFI_ICHECK(!t.IsHandle()) << "Trainium codegen does not support handle type"; + TVM_FFI_ICHECK(!t.IsVoid()) << "Trainium codegen does not support void type"; + if (t.MatchesCode(DLDataTypeCode::kDLBool)) { os << "np.bool"; return; } - if (t.is_float()) { + if (t.code() == DLDataTypeCode::kDLFloat) { switch (t.bits()) { case 16: os << "np.float16"; @@ -160,13 +161,13 @@ void CodeGenTrainium::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } return; } - if (t.is_uint() || t.is_int()) { + if (t.MatchesCode(DLDataTypeCode::kDLUInt, DLDataTypeCode::kDLInt)) { if (t.bits() == 1) { os << "np.bool"; return; } os << "np."; - if (t.is_uint()) { + if (t.MatchesCode(DLDataTypeCode::kDLUInt)) { os << 'u'; } switch (t.bits()) { @@ -188,11 +189,11 @@ void CodeGenTrainium::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } return; } - if (t.is_bfloat16()) { + if (t.code() == DLDataTypeCode::kDLBfloat && t.bits() == 16) { os << "nl.bfloat16"; return; } - LOG(FATAL) << "Cannot convert type " << t << " to Trainium type"; + LOG(FATAL) << "Cannot convert type " << raw_t << " to Trainium type"; } std::string CodeGenTrainium::GetStorageScopeStr(const std::string& scope) { // NOLINT(*) @@ -215,7 +216,7 @@ void CodeGenTrainium::VisitStmt_(const AllocBufferNode* op) { this->PrintIndent(); auto scope = GetPtrStorageScope(op->buffer->data); std::ostringstream dtype_os; - PrintType(op->buffer->dtype, dtype_os); + PrintType(op->buffer->dtype->dtype, dtype_os); std::string dtype_str = dtype_os.str(); if (scope == "trn.psum") { stream << vid << " = nl.ndarray(shape=["; @@ -589,7 +590,7 @@ void CodeGenTrainium::VisitExpr_(const VarNode* op, std::ostream& os) { // NOLI } void CodeGenTrainium::VisitExpr_(const CastNode* op, std::ostream& os) { - ctx_.dst_dtype = op->dtype; + ctx_.dst_dtype = op->ty(); CodeGenTrainium::VisitExpr(op->value, os); } diff --git a/src/backend/trn/codegen/codegen_trn.h b/src/backend/trn/codegen/codegen_trn.h index 2c3b5fd37393..ec4eaad29cce 100644 --- a/src/backend/trn/codegen/codegen_trn.h +++ b/src/backend/trn/codegen/codegen_trn.h @@ -41,7 +41,7 @@ struct NKIInstructionCtx { bool is_matmul_input = false; int buffer_index = -1; int used_var_cnt = 0; - DataType dst_dtype; + PrimType dst_dtype = PrimType::Void(); PrimExpr mask; bool tensorizing = false; }; @@ -57,7 +57,7 @@ class CodeGenTrainium final : public CodeGenC { void InitFuncState(const PrimFunc& f) final; std::string GetStorageScopeStr(const std::string& scope); // NOLINT(*) void VisitExpr_(const VarNode* op, std::ostream& os) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintType(DLDataType t, std::ostream& os) final; // NOLINT(*) void VisitStmt_(const AllocBufferNode* op) final; // NOLINT(*) void VisitStmt_(const AttrStmtNode* op) final; // NOLINT(*) void VisitStmt_(const ForNode* op) final; // NOLINT(*) diff --git a/src/backend/trn/transform/lower_trainium_layout.cc b/src/backend/trn/transform/lower_trainium_layout.cc index ad4b206a48b2..fb1d92c5215d 100644 --- a/src/backend/trn/transform/lower_trainium_layout.cc +++ b/src/backend/trn/transform/lower_trainium_layout.cc @@ -176,8 +176,8 @@ class TrainiumLayoutApplier : public arith::IRMutatorWithAnalyzer { flattened = buf.GetFlattenedBuffer(); writer = flattened.CopyOnWrite(); } - if (flattened->dtype == DataType::Bool()) { - writer->dtype = DataType::Int(8); + if (flattened->dtype->dtype == DLDataType{kDLBool, 8, 1}) { + writer->dtype = PrimType::Int(8); } for (size_t i = 0; i < flattened->shape.size(); ++i) { writer->shape.Set(i, analyzer_->canonical_simplify(flattened->shape[i])); @@ -191,28 +191,30 @@ class TrainiumLayoutApplier : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore store = StmtExprMutator::VisitStmt_(op).as_or_throw(); - bool store_returns_bool = (op->value.dtype() == DataType::Bool()); + PrimType store_value_ty = op->value.ty(); + bool store_returns_bool = store_value_ty.MatchesCode(DLDataTypeCode::kDLBool); store = VisitBufferAccess(store); if (store_returns_bool) { - TVM_FFI_ICHECK_EQ(store->buffer->dtype, DataType::Int(8)) + TVM_FFI_ICHECK_EQ(store->buffer->dtype->dtype, (DLDataType{kDLInt, 8, 1})) << "Expected int8 backing array for boolean tensor"; auto writer = store.CopyOnWrite(); - writer->value = tvm::cast(DataType::Int(8), store->value); + writer->value = tvm::cast(PrimType::Int(8), store->value); return std::move(store); } return std::move(store); } PrimExpr VisitExpr_(const BufferLoadNode* op) final { - bool load_returns_bool = (op->dtype == DataType::Bool()); + PrimType load_ty = op->ty(); + bool load_returns_bool = load_ty.MatchesCode(DLDataTypeCode::kDLBool); BufferLoad load = StmtExprMutator::VisitExpr_(op).as_or_throw(); load = VisitBufferAccess(load); if (load_returns_bool) { - TVM_FFI_ICHECK_EQ(load->buffer->dtype, DataType::Int(8)) + TVM_FFI_ICHECK_EQ(load->buffer->dtype->dtype, (DLDataType{kDLInt, 8, 1})) << "Expected int8 backing array for boolean tensor"; - load.CopyOnWrite()->dtype = DataType::Int(8); - return tvm::cast(DataType::Bool(), load); + load.CopyOnWrite()->BaseExprNode::ty = PrimType::Int(8); + return tvm::cast(PrimType::Bool(), load); } else { return std::move(load); } diff --git a/src/backend/vulkan/codegen/codegen_spirv.cc b/src/backend/vulkan/codegen/codegen_spirv.cc index 5737c60da9dc..094e31370481 100644 --- a/src/backend/vulkan/codegen/codegen_spirv.cc +++ b/src/backend/vulkan/codegen/codegen_spirv.cc @@ -52,8 +52,8 @@ runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::s const uint32_t descriptor_set = 0; for (Var arg : f->params) { - DataType t = arg.dtype(); - if (t.is_handle()) { + PrimType t = PrimType(arg.ty()->dtype); + if (t.IsHandle()) { auto* ptr = arg->type_annotation.as(); TVM_FFI_ICHECK(ptr) << "All handles passed to the Vulkan codegen must have a type_annotation as a " @@ -64,11 +64,11 @@ runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::s << "All handles passed to the Vulkan codegen must have a type_annotation as a " "PointerType, " << "and must point to a PrimType"; - DataType value_storage_type = prim->dtype; - if (value_storage_type == DataType::Bool()) { + PrimType value_storage_type(prim->dtype); + if (value_storage_type == PrimType::Bool()) { // We need a physically addressable buffer type to support boolean tensors. // The loaded byte is cast to bool inside the LoadNode visitor below. - value_storage_type = boolean_storage_type_.with_lanes(value_storage_type.lanes()); + value_storage_type = boolean_storage_type_.WithLanes(value_storage_type.lanes()); } spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type), descriptor_set, i_buffer++); @@ -87,7 +87,7 @@ runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::s if (pod_args.size() != 0) { std::vector value_types; for (size_t i = 0; i < pod_args.size(); ++i) { - value_types.push_back(builder_->GetSType(pod_args[i].dtype())); + value_types.push_back(builder_->GetSType(PrimType(pod_args[i].ty()->dtype))); } if (pod_args.size() * sizeof(runtime::ArgUnion64) <= runtime::vulkan::kMaxPushConstantsBytes) { spirv::Value ptr = builder_->DeclarePushConstant(value_types); @@ -150,7 +150,7 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& ext } else { v = builder_->GetWorkgroupID(ts.dim_index); } - return builder_->Cast(builder_->GetSType(iv->var.dtype()), v); + return builder_->Cast(builder_->GetSType(PrimType(iv->var.ty()->dtype)), v); } spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) { @@ -179,7 +179,7 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) { TVM_FFI_THROW(InternalError) << "Do not support sync " << sync; } - auto type_int = builder_->GetSType(DataType::Int(32)); + auto type_int = builder_->GetSType(PrimType::Int(32)); builder_->MakeInst(spv::OpControlBarrier, builder_->IntImm(type_int, sync_scope), builder_->IntImm(type_int, sync_scope), builder_->IntImm(type_int, memory_semantics)); @@ -194,11 +194,11 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const VarNode* op) { } spirv::Value CodeGenSPIRV::VisitExpr_(const IntImmNode* op) { - return builder_->IntImm(builder_->GetSType(op->dtype), op->value); + return builder_->IntImm(builder_->GetSType(PrimType(op->ty()->dtype)), op->value); } spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImmNode* op) { - return builder_->FloatImm(builder_->GetSType(op->dtype), op->value); + return builder_->FloatImm(builder_->GetSType(PrimType(op->ty()->dtype)), op->value); } spirv::Value CodeGenSPIRV::VisitExpr_(const StringImmNode* op) { @@ -206,7 +206,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const StringImmNode* op) { } spirv::Value CodeGenSPIRV::VisitExpr_(const CastNode* op) { - return builder_->Cast(builder_->GetSType(op->dtype), MakeValue(op->value)); + return builder_->Cast(builder_->GetSType(PrimType(op->ty()->dtype)), MakeValue(op->value)); } spirv::Value CodeGenSPIRV::VisitExpr_(const AddNode* op) { @@ -308,7 +308,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { for (size_t i = 1; i < op->args.size(); ++i) { values.push_back(MakeValue(op->args[i])); } - return builder_->CallGLSL450(builder_->GetSType(op->dtype), inst_id, values); + return builder_->CallGLSL450(builder_->GetSType(PrimType(op->ty()->dtype)), inst_id, values); } else if (op->op.same_as(builtin::bitwise_and())) { TVM_FFI_ICHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); @@ -337,20 +337,20 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { TVM_FFI_ICHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); - if (op->args[0].dtype().is_int()) { + if (PrimType(op->args[0].ty()->dtype).MatchesCode(DLDataTypeCode::kDLInt)) { return builder_->MakeValue(spv::OpShiftRightArithmetic, a.stype, a, b); } else { return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b); } } else if (op->op.same_as(builtin::reinterpret())) { - return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype), + return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(PrimType(op->ty()->dtype)), MakeValue(op->args[0])); } else if (op->op.same_as(builtin::large_uint_imm())) { TVM_FFI_ICHECK_EQ(op->args.size(), 2U); uint64_t low = static_cast(op->args[0].as_or_throw()->value); uint64_t high = static_cast(op->args[1].as_or_throw()->value); uint64_t val = (high << 32U) | low; - return builder_->UIntImm(builder_->GetSType(op->dtype), val); + return builder_->UIntImm(builder_->GetSType(PrimType(op->ty()->dtype)), val); } else if (op->op.same_as(builtin::tvm_storage_sync())) { return this->CreateStorageSync(op); } else if (op->op.same_as(builtin::if_then_else())) { @@ -378,7 +378,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { phi.SetIncoming(1, else_value, else_value_label); return phi; } else if (op->op.same_as(builtin::popcount())) { - return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(op->dtype), + return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(PrimType(op->ty()->dtype)), MakeValue(op->args[0])); } else if (op->op.same_as(builtin::call_pure_extern())) { TVM_FFI_ICHECK_GE(op->args.size(), 1U); @@ -388,7 +388,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { for (size_t i = 1; i < op->args.size(); ++i) { values.push_back(MakeValue(op->args[i])); } - return builder_->CallKHRIntegerDotProduct(builder_->GetSType(op->dtype), values, op->dtype); + PrimType op_dtype(op->ty()->dtype); + return builder_->CallKHRIntegerDotProduct(builder_->GetSType(op_dtype), values, op_dtype); } else { TVM_FFI_THROW(InternalError) << "SPIR-V shader cannot make extern calls. Graph contains extern \"" @@ -412,8 +413,9 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { TVM_FFI_ICHECK_EQ(op->args.size(), 6U); const VarNode* buffer_node = op->args[0].as(); TVM_FFI_ICHECK(buffer_node && fragment_info_.count(buffer_node)); - DataType ele_dtype = GetElementDataType(buffer_node); - TVM_FFI_ICHECK(ele_dtype.is_float()) << "Only floating point fragment accumulator is supported"; + PrimType ele_dtype = GetElementDataType(buffer_node); + TVM_FFI_ICHECK(ele_dtype.MatchesCode(DLDataTypeCode::kDLFloat)) + << "Only floating point fragment accumulator is supported"; spirv::SType ele_stype = builder_->GetSType(ele_dtype); spirv::SType& fragment_type = fragment_info_[buffer_node].stype; double init = static_cast(op->args[5].as_or_throw()->value); @@ -435,7 +437,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { PrimExpr dst_index = op->args[4]; PrimExpr src_ptr_expr = op->args[5]; int stride = static_cast(op->args[6].as_or_throw()->value); - auto type_int = builder_->GetSType(DataType::Int(32)); + auto type_int = builder_->GetSType(PrimType::Int(32)); spirv::Value stride_val = builder_->IntImm(type_int, stride); std::string layout = (op->args[7].as())->value; spirv::SType dst_ptr_type = @@ -443,7 +445,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { spirv::Value dst_ptr = builder_->StructArrayAccess(dst_ptr_type, var_map_[buffer_node], MakeValue(dst_index)); spirv::Value src_ptr = VisitExpr(op->args[5]); - spirv::SType type_bool = builder_->GetSType(DataType::Bool()); + spirv::SType type_bool = builder_->GetSType(PrimType::Bool()); spirv::Value t_val = builder_->UIntImm(type_bool, 1); spirv::Value f_val = builder_->UIntImm(type_bool, 0); spirv::Value loaded = @@ -494,7 +496,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { PrimExpr index = op->args[4]; PrimExpr buffer_ptr = op->args[5]; int stride = static_cast(op->args[6].as_or_throw()->value); - auto type_int = builder_->GetSType(DataType::Int(32)); + auto type_int = builder_->GetSType(PrimType::Int(32)); spirv::Value stride_val = builder_->IntImm(type_int, stride); std::string layout = (op->args[7].as())->value; spirv::Value dst_ptr = VisitExpr(op->args[5]); @@ -505,7 +507,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index)); uint32_t mask = spv::MemoryAccessMaskNone; spirv::Value loaded = builder_->MakeValue(spv::OpLoad, fragment_type, ptr, mask); - spirv::SType type_bool = builder_->GetSType(DataType::Bool()); + spirv::SType type_bool = builder_->GetSType(PrimType::Bool()); spirv::Value t_val = builder_->UIntImm(type_bool, 1); spirv::Value f_val = builder_->UIntImm(type_bool, 0); builder_->MakeInst(spv::OpCooperativeMatrixStoreNV, dst_ptr, loaded, stride_val, @@ -516,7 +518,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { Var buffer_var = load->buffer->data; const VarNode* buffer_node = buffer_var.get(); PrimExpr index = load->indices[0]; - DataType ele_dtype = GetElementDataType(buffer_node); + PrimType ele_dtype = GetElementDataType(buffer_node); spirv::SType ele_stype = builder_->GetSType(ele_dtype); spirv::Value buffer_val = MakeValue(buffer_var); spirv::SType ptr_type = builder_->GetPointerType(ele_stype, buffer_val.stype.storage_class); @@ -532,11 +534,11 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) { std::vector values; spirv::Value base = MakeValue(op->base); - int lanes = op->dtype.lanes(); + int lanes = op->ty().lanes(); for (int i = 0; i < lanes; ++i) { spirv::Value v = base; if (i != 0) { - spirv::Value offset = MakeValue(MakeConst(op->stride.dtype(), i) * op->stride); + spirv::Value offset = MakeValue(MakeConst(op->stride.ty(), i) * op->stride); v = builder_->Add(v, offset); } values.push_back(v); @@ -547,7 +549,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) { std::vector values; spirv::Value v = MakeValue(op->value); - int lanes = op->dtype.lanes(); + int lanes = op->ty().lanes(); for (int i = 0; i < lanes; i++) { values.push_back(v); } @@ -560,15 +562,15 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) { Var buffer_var = op->buffer->data; PrimExpr prim_index = op->indices[0]; - DataType desired_read_type = op->dtype; - if (desired_read_type == DataType::Bool()) { - desired_read_type = boolean_storage_type_.with_lanes(desired_read_type.lanes()); + PrimType desired_read_type(op->ty()->dtype); + if (desired_read_type == PrimType::Bool()) { + desired_read_type = boolean_storage_type_.WithLanes(desired_read_type.lanes()); } auto it = storage_info_.find(buffer_var.get()); TVM_FFI_ICHECK(it != storage_info_.end()); StorageInfo& info = it->second; - info.CheckContentType(desired_read_type, prim_index.dtype().lanes()); + info.CheckContentType(desired_read_type, PrimType(prim_index.ty()->dtype).lanes()); spirv::SType content_type = builder_->GetSType(info.element_type); spirv::Value buffer = MakeValue(buffer_var); @@ -588,13 +590,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) { spirv::Value loaded = builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); // OpTypeBool have no physical address/storage. Here, cast from // the storage type to an OpTypeBool. - if (op->dtype == DataType::Bool()) { - auto spirv_bool = builder_->GetSType(DataType::Bool()); + if (PrimType(op->ty()->dtype) == PrimType::Bool()) { + auto spirv_bool = builder_->GetSType(PrimType::Bool()); loaded = builder_->Cast(spirv_bool, loaded); } return loaded; - } else if (desired_read_type.element_of() == info.element_type) { + } else if (desired_read_type.WithLanes(1) == info.element_type) { // Requested several elements returned as an array. Read out each // element and concatenate into the result. std::vector values; @@ -609,21 +611,22 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) { TVM_FFI_THROW(InternalError) << "Cannot perform buffer access of buffer variable '" << buffer_var->name_hint << "' with element type " << info.element_type << " using index of type " - << prim_index->dtype << " to produce output of type " << op->dtype; + << PrimType(prim_index.ty()->dtype) + << " to produce output of type " << PrimType(op->ty()->dtype); return spirv::Value(); } } void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function f) { if (const RampNode* ramp = e.as()) { - for (int i = 0; i < ramp->dtype.lanes(); ++i) { + for (int i = 0; i < ramp->ty().lanes(); ++i) { PrimExpr offset = ramp->base + ramp->stride * i; f(i, MakeValue(offset)); } } else { - spirv::SType etype = builder_->GetSType(e.dtype().element_of()); + spirv::SType etype = builder_->GetSType(PrimType(e.ty()->dtype).WithLanes(1)); spirv::Value value = MakeValue(e); - for (int i = 0; i < e.dtype().lanes(); ++i) { + for (int i = 0; i < PrimType(e.ty()->dtype).lanes(); ++i) { f(i, builder_->MakeValue(spv::OpCompositeExtract, etype, value, i)); } } @@ -635,7 +638,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const ShuffleNode* op) { << "of one vector with one index"; spirv::Value vector = MakeValue(op->vectors[0]); int index = op->indices[0].as_or_throw()->value; - spirv::SType etype = builder_->GetSType(op->dtype); + spirv::SType etype = builder_->GetSType(PrimType(op->ty()->dtype)); spirv::Value element = builder_->MakeValue(spv::OpCompositeExtract, etype, vector, index); return element; } @@ -649,7 +652,7 @@ void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) { auto it = storage_info_.find(buffer_var.get()); TVM_FFI_ICHECK(it != storage_info_.end()); StorageInfo& info = it->second; - info.CheckContentType(op->value.dtype(), prim_index.dtype().lanes()); + info.CheckContentType(PrimType(op->value.ty()->dtype), PrimType(prim_index.ty()->dtype).lanes()); spirv::SType content_type = builder_->GetSType(info.element_type); spirv::Value buffer = MakeValue(buffer_var); @@ -661,16 +664,16 @@ void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) { mask |= spv::MemoryAccessVolatileMask; } - if (op->value.dtype() == info.element_type) { + if (PrimType(op->value.ty()->dtype) == info.element_type) { // Requested store of a single value. This may be a scalar store // or a vectorized store, based on the array element type. - TVM_FFI_ICHECK_EQ(info.element_type, op->value.dtype()) + TVM_FFI_ICHECK_EQ(info.element_type, PrimType(op->value.ty()->dtype)) << "Vulkan only allow one type access to the same buffer"; spirv::Value index = MakeValue(prim_index); spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); builder_->MakeInst(spv::OpStore, ptr, value, mask); - } else if (op->value.dtype().element_of() == info.element_type) { + } else if (PrimType(op->value.ty()->dtype).WithLanes(1) == info.element_type) { // Requested store of several arbitrarily located values. Extract // each value from the composite, then assign to the buffer. auto f = [&](int i, spirv::Value index) { @@ -681,10 +684,10 @@ void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) { this->Scalarize(prim_index, f); } else { - TVM_FFI_THROW(InternalError) << "Cannot store value of type " << op->value.dtype() + TVM_FFI_THROW(InternalError) << "Cannot store value of type " << PrimType(op->value.ty()->dtype) << " into buffer variable '" << buffer_var->name_hint << "' with element type " << info.element_type - << " using index of type " << prim_index->dtype; + << " using index of type " << PrimType(prim_index.ty()->dtype); } } @@ -697,10 +700,11 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { // loop step spirv::Value step; if (op->HasTrivialStep()) { - step = op->loop_var.dtype().is_int() ? builder_->IntImm(init_value.stype, 1) - : builder_->UIntImm(init_value.stype, 1); + step = PrimType(op->loop_var.ty()->dtype).MatchesCode(DLDataTypeCode::kDLInt) + ? builder_->IntImm(init_value.stype, 1) + : builder_->UIntImm(init_value.stype, 1); } else { - step = MakeValue(tvm::cast(end->dtype, *op->step)); + step = MakeValue(tvm::cast(end.ty(), *op->step)); } // Must get init label after making value(to make sure they are correct) @@ -807,7 +811,7 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) { } void CodeGenSPIRV::VisitStmt_(const AllocBufferNode* op) { - TVM_FFI_ICHECK(!op->buffer->dtype.is_handle()); + TVM_FFI_ICHECK(!op->buffer->dtype.IsHandle()); const IntImmNode* dim_imm = op->buffer->shape[0].as(); TVM_FFI_ICHECK(dim_imm) << "Can only handle constant size stack allocation in GPU"; size_t constant_size = static_cast(dim_imm->value); @@ -848,7 +852,7 @@ void CodeGenSPIRV::VisitStmt_(const AllocBufferNode* op) { int32_t aligned_constant_size = ((constant_size + 3) & ~0x3); buf = builder_->Allocate(etype, static_cast(aligned_constant_size), storage_class); - size_t num_bytes = op->buffer->dtype.bytes() * op->buffer->dtype.lanes() * + size_t num_bytes = ((op->buffer->dtype.bits() + 7) / 8) * op->buffer->dtype.lanes() * static_cast(aligned_constant_size); shared_memory_bytes_used_ += num_bytes; } break; @@ -897,7 +901,7 @@ void CodeGenSPIRV::VisitStmt_(const AssertStmtNode* op) { void CodeGenSPIRV::VisitStmt_(const BindNode* op) { TVM_FFI_ICHECK(!var_map_.count(op->var.get())); - TVM_FFI_ICHECK(!op->var.dtype().is_handle()); + TVM_FFI_ICHECK(!PrimType(op->var.ty()->dtype).IsHandle()); var_map_[op->var.get()] = MakeValue(op->value); analyzer_->Bind(op->var, op->value); } @@ -910,18 +914,18 @@ void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) { void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } -spirv::SType CodeGenSPIRV::GetFragmentSType(const VarNode* buffer, const DataType& dtype) { +spirv::SType CodeGenSPIRV::GetFragmentSType(const VarNode* buffer, const PrimType& dtype) { TVM_FFI_ICHECK(fragment_info_.count(buffer)); const std::string& scope = fragment_info_[buffer].scope; const std::string& shape_str = fragment_info_.at(buffer).shape; std::pair dim = GetWmmaFragmentDimSize(shape_str, scope); int64_t size = dim.first * dim.second; - spirv::SType stype = builder_->GetSType(dtype.with_lanes(size), dim.first, dim.second); + spirv::SType stype = builder_->GetSType(dtype.WithLanes(size), dim.first, dim.second); fragment_info_[buffer].stype = stype; return stype; } -DataType CodeGenSPIRV::GetElementDataType(const VarNode* buffer) { +PrimType CodeGenSPIRV::GetElementDataType(const VarNode* buffer) { auto it = storage_info_.find(buffer); TVM_FFI_ICHECK(it != storage_info_.end()); return it->second.element_type; diff --git a/src/backend/vulkan/codegen/codegen_spirv.h b/src/backend/vulkan/codegen/codegen_spirv.h index 46fbcb696b6f..5ade6e383908 100644 --- a/src/backend/vulkan/codegen/codegen_spirv.h +++ b/src/backend/vulkan/codegen/codegen_spirv.h @@ -142,7 +142,7 @@ class CodeGenSPIRV : public ExprFunctor, * buffer variable (AllocBufferNode) or of the parameter (shader * arguments). */ - DataType element_type{DataType()}; + PrimType element_type{PrimType::Void()}; /* \brief Check that the access type matches the known type * @@ -156,10 +156,10 @@ class CodeGenSPIRV : public ExprFunctor, * product of the number of lanes of the buffer element type and * the number of lanes of the index. */ - void CheckContentType(DataType type, int index_lanes = 1) const { + void CheckContentType(PrimType type, int index_lanes = 1) const { TVM_FFI_ICHECK(element_type_known) << "Cannot check element type of buffer " << name_hint << " no previous element type defined"; - DataType expected_type = element_type.with_lanes(index_lanes * element_type.lanes()); + PrimType expected_type = element_type.WithLanes(index_lanes * element_type.lanes()); TVM_FFI_ICHECK_EQ(type, expected_type) << "Attempted to access buffer " << name_hint << " as element type " << type << " using an index of size " << index_lanes << " when the element type is " @@ -167,7 +167,7 @@ class CodeGenSPIRV : public ExprFunctor, } // Update content type if it hasn't been updated. - void SetContentType(DataType type, std::string name_hint) { + void SetContentType(PrimType type, std::string name_hint) { TVM_FFI_ICHECK(!element_type_known) << "Cannot set element type of buffer " << name_hint << " a second time."; this->element_type = type; @@ -191,8 +191,8 @@ class CodeGenSPIRV : public ExprFunctor, spirv::Value CreateStorageSync(const CallNode* op); void Scalarize(const PrimExpr& e, std::function f); - spirv::SType GetFragmentSType(const VarNode* buffer, const DataType& dtype); - DataType GetElementDataType(const VarNode* buffer); + spirv::SType GetFragmentSType(const VarNode* buffer, const PrimType& dtype); + PrimType GetElementDataType(const VarNode* buffer); // SPIRV-related capabilities of the target SPIRVSupport spirv_support_; @@ -213,7 +213,7 @@ class CodeGenSPIRV : public ExprFunctor, * integer type supported by the device, as not all Vulkan * implementations support int8. */ - DataType boolean_storage_type_{DataType::Int(8)}; + PrimType boolean_storage_type_{PrimType::Int(8)}; // the storage scope of allocation std::unordered_map storage_info_; diff --git a/src/backend/vulkan/codegen/intrin_rule_spirv.cc b/src/backend/vulkan/codegen/intrin_rule_spirv.cc index 14287562d9e4..6deb6e0a9b61 100644 --- a/src/backend/vulkan/codegen/intrin_rule_spirv.cc +++ b/src/backend/vulkan/codegen/intrin_rule_spirv.cc @@ -39,12 +39,12 @@ PrimExpr CallGLSLIntrin(PrimExpr e, const ffi::Array& args) { TVM_FFI_ICHECK(call != nullptr); ffi::Array cargs; // intrin id. - cargs.push_back(IntImm(DataType::UInt(32), id)); + cargs.push_back(IntImm(PrimType::UInt(32), id)); for (PrimExpr arg : args) { cargs.push_back(arg); } - return tirx::Call(call->dtype, tirx::builtin::call_spirv_pure_glsl450(), cargs); + return tirx::Call(call->ty(), tirx::builtin::call_spirv_pure_glsl450(), cargs); } template @@ -166,21 +166,22 @@ TVM_REGISTER_OP("tirx.clz") TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 1); PrimExpr arg = call->args[0]; + PrimType arg_ty = arg.ty(); PrimExpr msb; - if (arg.dtype().bits() == 64) { + if (arg_ty.bits() == 64) { // SPIR-V FindUMsb intrinsic only supports 32 bit input - auto int32 = DataType::Int(32); + auto int32 = PrimType::Int(32); PrimExpr arg_hi32 = tvm::tirx::Cast(int32, arg >> 32); PrimExpr arg_lo32 = tvm::tirx::Cast(int32, arg); PrimExpr msb_hi = CallGLSLIntrin(e, {arg_hi32}); PrimExpr msb_lo = CallGLSLIntrin(e, {arg_lo32}); msb = tvm::if_then_else(arg_hi32 == 0, msb_lo, msb_hi + 32); - } else if (arg.dtype().bits() == 32) { + } else if (arg_ty.bits() == 32) { msb = CallGLSLIntrin(e); } else { TVM_FFI_THROW(InternalError) << "SPIR-V clz only supports a 32 bit or 64 bit integer."; } - return PrimExpr(arg.dtype().bits() - 1) - msb; + return PrimExpr(arg_ty.bits() - 1) - msb; }); // clang-format on } diff --git a/src/backend/vulkan/codegen/ir_builder.cc b/src/backend/vulkan/codegen/ir_builder.cc index f912e482761c..e986454a7f75 100644 --- a/src/backend/vulkan/codegen/ir_builder.cc +++ b/src/backend/vulkan/codegen/ir_builder.cc @@ -74,10 +74,10 @@ void IRBuilder::InitHeader() { void IRBuilder::InitPreDefs() { ext_glsl450_ = ExtInstImport("GLSL.std.450"); - t_int32_ = DeclareType(DataType::Int(32)); - t_uint32_ = DeclareType(DataType::UInt(32)); - t_bool_ = DeclareType(DataType::Bool()); - t_fp32_ = DeclareType(DataType::Float(32)); + t_int32_ = DeclareType(PrimType::Int(32)); + t_uint32_ = DeclareType(PrimType::UInt(32)); + t_bool_ = DeclareType(PrimType::Bool()); + t_fp32_ = DeclareType(PrimType::Float(32)); const_i32_zero_ = IntImm(t_int32_, 0); // declare void, and void functions @@ -112,14 +112,14 @@ std::vector IRBuilder::Finalize() { return data; } -SType IRBuilder::GetSType(const DataType& dtype, uint32_t row, uint32_t col) { - if (dtype == DataType::Int(32)) { +SType IRBuilder::GetSType(const PrimType& dtype, uint32_t row, uint32_t col) { + if (dtype == PrimType::Int(32)) { return t_int32_; - } else if (dtype == DataType::Bool()) { + } else if (dtype == PrimType::Bool()) { return t_bool_; - } else if (dtype == DataType::Float(32)) { + } else if (dtype == PrimType::Float(32)) { return t_fp32_; - } else if (dtype == DataType::UInt(32)) { + } else if (dtype == PrimType::UInt(32)) { return t_uint32_; } uint64_t type_key; @@ -151,7 +151,7 @@ SType IRBuilder::GetPointerType(const SType& value_type, spv::StorageClass stora } SType t; t.id = id_counter_++; - t.type = DataType::Handle(); + t.type = PrimType::Handle(); t.element_type_id = value_type.id; t.storage_class = storage_class; ib_.Begin(spv::OpTypePointer).AddSeq(t, storage_class, value_type).Commit(&global_); @@ -169,11 +169,11 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems, SType arr_type; arr_type.id = id_counter_++; - arr_type.type = DataType::Handle(); + arr_type.type = PrimType::Handle(); arr_type.element_type_id = value_type.id; if (num_elems != 0) { - Value length = UIntImm(GetSType(DataType::UInt(32)), num_elems); + Value length = UIntImm(GetSType(PrimType::UInt(32)), num_elems); ib_.Begin(spv::OpTypeArray).AddSeq(arr_type, value_type, length).Commit(&global_); } else { ib_.Begin(spv::OpTypeRuntimeArray).AddSeq(arr_type, value_type).Commit(&global_); @@ -188,7 +188,7 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems, // declare struct of array SType struct_type; struct_type.id = id_counter_++; - struct_type.type = DataType::Handle(); + struct_type.type = PrimType::Handle(); struct_type.element_type_id = value_type.id; ib_.Begin(spv::OpTypeStruct).AddSeq(struct_type, arr_type).Commit(&global_); @@ -241,7 +241,7 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) { if (data == 0) return GetConst_(dtype, &data); else - return Cast(dtype, FloatImm(GetSType(DataType::Float(32)), value)); + return Cast(dtype, FloatImm(GetSType(PrimType::Float(32)), value)); } } @@ -270,7 +270,7 @@ Value IRBuilder::DeclareStorageVariable(const std::vector& value_types, spv::StorageClass storage_class, ValueKind kind) { SType struct_type; struct_type.id = id_counter_++; - struct_type.type = DataType::Handle(); + struct_type.type = PrimType::Handle(); ib_.Begin(spv::OpTypeStruct).Add(struct_type); for (const SType& vtype : value_types) { ib_.Add(vtype); @@ -282,7 +282,7 @@ Value IRBuilder::DeclareStorageVariable(const std::vector& value_types, ib_.Begin(spv::OpMemberDecorate) .AddSeq(struct_type, i, spv::DecorationOffset, offset) .Commit(&decorate_); - DataType t = value_types[i].type; + PrimType t = value_types[i].type; uint32_t nbits = t.bits() * t.lanes(); TVM_FFI_ICHECK_EQ(nbits % 8, 0); uint32_t bytes = (nbits / 8); @@ -394,13 +394,11 @@ Value IRBuilder::GetBuiltInValue(spv::BuiltIn built_in, uint32_t index, const st } } - DataType data_type; - DataType global_arr_type; + PrimType data_type = PrimType::Int(32); + PrimType global_arr_type = data_type.WithLanes(3); switch (built_in) { case spv::BuiltInLocalInvocationId: case spv::BuiltInWorkgroupId: - data_type = DataType::Int(32); - global_arr_type = data_type.with_lanes(3); break; default: @@ -468,7 +466,7 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { } TVM_FFI_ICHECK_LE(dtype.type.bits(), 64); Value ret = NewValue(dtype, kConstant); - if (dtype.type == DataType::Bool()) { + if (dtype.type == PrimType::Bool()) { // bool types. if (*pvalue) { ib_.Begin(spv::OpConstantTrue).AddSeq(dtype, ret); @@ -481,7 +479,7 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { uint64_t mask = 0xFFFFFFFFUL; ib_.Add(static_cast(pvalue[0] & mask)); if (dtype.type.bits() > 32) { - if (dtype.type.is_int()) { + if (dtype.type.MatchesCode(DLDataTypeCode::kDLInt)) { int64_t sign_mask = 0xFFFFFFFFL; const int64_t* sign_ptr = reinterpret_cast(pvalue); ib_.Add(static_cast((sign_ptr[0] >> 32L) & sign_mask)); @@ -495,20 +493,20 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { return ret; } -SType IRBuilder::DeclareType(const DataType& dtype, uint32_t row, uint32_t col) { +SType IRBuilder::DeclareType(const PrimType& dtype, uint32_t row, uint32_t col) { AddCapabilityFor(dtype); if (dtype.lanes() == 1) { SType t; t.id = id_counter_++; t.type = dtype; - if (dtype.is_bool()) { + if (dtype.MatchesCode(DLDataTypeCode::kDLBool)) { ib_.Begin(spv::OpTypeBool).Add(t).Commit(&global_); - } else if (dtype.is_int()) { + } else if (dtype.MatchesCode(DLDataTypeCode::kDLInt)) { ib_.Begin(spv::OpTypeInt).AddSeq(t, dtype.bits(), 1).Commit(&global_); - } else if (dtype.is_uint()) { + } else if (dtype.MatchesCode(DLDataTypeCode::kDLUInt)) { ib_.Begin(spv::OpTypeInt).AddSeq(t, dtype.bits(), 0).Commit(&global_); - } else if (dtype.is_float()) { + } else if (dtype.MatchesCode(DLDataTypeCode::kDLFloat)) { ib_.Begin(spv::OpTypeFloat).AddSeq(t, dtype.bits()).Commit(&global_); } else { TVM_FFI_THROW(InternalError) << "declare type do not support handle"; @@ -518,15 +516,15 @@ SType IRBuilder::DeclareType(const DataType& dtype, uint32_t row, uint32_t col) SType t; t.id = id_counter_++; t.type = dtype; - SType base_type = GetSType(dtype.element_of()); + SType base_type = GetSType(dtype.WithLanes(1)); if (row * col == 0) { TVM_FFI_ICHECK((row == 0) && (col == 0)); ib_.Begin(spv::OpTypeVector).AddSeq(t, base_type, dtype.lanes()).Commit(&global_); } else { - Value v_row = GetSpecConst(GetSType(DataType::UInt(32)), row); - Value v_col = GetSpecConst(GetSType(DataType::UInt(32)), col); - Value scope = UIntImm(GetSType(DataType::UInt(32)), spv::ScopeSubgroup); + Value v_row = GetSpecConst(GetSType(PrimType::UInt(32)), row); + Value v_col = GetSpecConst(GetSType(PrimType::UInt(32)), col); + Value scope = UIntImm(GetSType(PrimType::UInt(32)), spv::ScopeSubgroup); ib_.Begin(spv::OpTypeCooperativeMatrixNV) .AddSeq(t, base_type, scope, v_row, v_col) .Commit(&global_); @@ -535,9 +533,9 @@ SType IRBuilder::DeclareType(const DataType& dtype, uint32_t row, uint32_t col) } } -void IRBuilder::AddCapabilityFor(const DataType& dtype) { +void IRBuilder::AddCapabilityFor(const PrimType& dtype) { // Declare appropriate capabilities for int/float types - if (dtype.is_int() || dtype.is_uint()) { + if (dtype.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) { if (dtype.bits() == 8) { TVM_FFI_ICHECK(spirv_support_.supports_int8) << "Vulkan target does not support Int8 capability. " @@ -561,7 +559,7 @@ void IRBuilder::AddCapabilityFor(const DataType& dtype) { capabilities_used_.insert(spv::CapabilityInt64); } - } else if (dtype.is_float()) { + } else if (dtype.MatchesCode(DLDataTypeCode::kDLFloat)) { if (dtype.bits() == 16) { TVM_FFI_ICHECK(spirv_support_.supports_float16) << "Vulkan target does not support Float16 capability. " @@ -584,7 +582,7 @@ void IRBuilder::AddCapabilityFor(const DataType& dtype) { // future. Requiring StorageBuffer8BitAccess in order to declare an // Int8 prevents use of an 8-bit loop iterator on a device that // supports Int8 but doesn't support 8-bit buffer access. - if (dtype.bits() == 8 && !dtype.is_bool()) { + if (dtype.bits() == 8 && !dtype.MatchesCode(DLDataTypeCode::kDLBool)) { TVM_FFI_ICHECK(spirv_support_.supports_storage_buffer_8bit_access) << "Vulkan target does not support StorageBuffer8BitAccess. " << "If your device supports 8-bit buffer access, " @@ -642,7 +640,7 @@ Value IRBuilder::CallGLSL450(const SType& ret_type, uint32_t inst_id, } Value IRBuilder::CallKHRIntegerDotProduct(const SType& ret_type, const std::vector& args, - const DataType& dtype) { + const PrimType& dtype) { if (args.size() != 3) { TVM_FFI_THROW(InternalError) << "Unresolved arguments in SPIRV_KHR_integer_dot_product"; } @@ -653,9 +651,9 @@ Value IRBuilder::CallKHRIntegerDotProduct(const SType& ret_type, const std::vect << "If your device supports integer dot product operations, " << "please either add -mattr=+dotprod to the target, " << "or query all device parameters by adding -from_device=0."; - if (dtype.is_int()) { + if (dtype.MatchesCode(DLDataTypeCode::kDLInt)) { ib_.Begin(spv::OpSDotAccSatKHR).AddSeq(ret_type, val); - } else if (dtype.is_uint()) { + } else if (dtype.MatchesCode(DLDataTypeCode::kDLUInt)) { ib_.Begin(spv::OpUDotAccSatKHR).AddSeq(ret_type, val); } else { TVM_FFI_THROW(InternalError) << "Unsupported type"; @@ -674,15 +672,15 @@ Value IRBuilder::CallKHRIntegerDotProduct(const SType& ret_type, const std::vect Value IRBuilder::Concat(const std::vector& vec) { bool is_const = vec[0].flag == kConstant; - DataType etype = vec[0].stype.type; + PrimType etype = vec[0].stype.type; int lanes = etype.lanes(); for (size_t i = 1; i < vec.size(); ++i) { - TVM_FFI_ICHECK_EQ(etype, vec[i].stype.type.element_of()) + TVM_FFI_ICHECK_EQ(etype, vec[i].stype.type.WithLanes(1)) << "Cannot concat vector of different element type"; lanes += vec[i].stype.type.lanes(); is_const = is_const && (vec[i].flag == kConstant); } - Value ret = NewValue(GetSType(etype.with_lanes(lanes)), kNormal); + Value ret = NewValue(GetSType(etype.WithLanes(lanes)), kNormal); if (is_const && vec.size() == static_cast(lanes)) { ib_.Begin(spv::OpConstantComposite); ib_.AddSeq(ret.stype, ret); @@ -704,53 +702,56 @@ Value IRBuilder::Concat(const std::vector& vec) { Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { TVM_FFI_ICHECK_NE(value.stype.id, 0U); if (value.stype.id == dst_type.id) return value; - const tvm::DataType& from = value.stype.type; - const tvm::DataType& to = dst_type.type; + const tvm::PrimType& from = value.stype.type; + const tvm::PrimType& to = dst_type.type; TVM_FFI_ICHECK_EQ(from.lanes(), to.lanes()); - if (from == DataType::Bool()) { - if (to.is_int()) { + if (from == PrimType::Bool()) { + if (to.MatchesCode(DLDataTypeCode::kDLInt)) { return Select(value, IntImm(dst_type, 1), IntImm(dst_type, 0)); - } else if (to.is_uint()) { + } else if (to.MatchesCode(DLDataTypeCode::kDLUInt)) { return Select(value, UIntImm(dst_type, 1), UIntImm(dst_type, 0)); - } else if (to.is_float()) { + } else if (to.MatchesCode(DLDataTypeCode::kDLFloat)) { return MakeValue(spv::OpConvertUToF, dst_type, Select(value, UIntImm(t_uint32_, 1), UIntImm(t_uint32_, 0))); } else { TVM_FFI_THROW(InternalError) << "cannot cast from " << from << " to " << to; return Value(); } - } else if (to == DataType::Bool()) { - if (from.is_int()) { + } else if (to == PrimType::Bool()) { + if (from.MatchesCode(DLDataTypeCode::kDLInt)) { return NE(value, IntImm(value.stype, 0)); - } else if (to.is_uint()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLUInt)) { return NE(value, UIntImm(value.stype, 0)); } else { TVM_FFI_THROW(InternalError) << "cannot cast from " << from << " to " << to; return Value(); } - } else if (from.is_int() && to.is_int()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLInt) && to.MatchesCode(DLDataTypeCode::kDLInt)) { return MakeValue(spv::OpSConvert, dst_type, value); - } else if (from.is_uint() && to.is_uint()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLUInt) && to.MatchesCode(DLDataTypeCode::kDLUInt)) { return MakeValue(spv::OpUConvert, dst_type, value); - } else if (from.is_uint() && to.is_int()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLUInt) && to.MatchesCode(DLDataTypeCode::kDLInt)) { if (from.bits() != to.bits()) { - value = MakeValue(spv::OpUConvert, GetSType(from.with_bits(to.bits())), value); + value = MakeValue(spv::OpUConvert, GetSType(from.WithBits(to.bits())), value); } return MakeValue(spv::OpBitcast, dst_type, value); - } else if (from.is_int() && to.is_uint()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLInt) && to.MatchesCode(DLDataTypeCode::kDLUInt)) { if (from.bits() != to.bits()) { - value = MakeValue(spv::OpSConvert, GetSType(from.with_bits(to.bits())), value); + value = MakeValue(spv::OpSConvert, GetSType(from.WithBits(to.bits())), value); } return MakeValue(spv::OpBitcast, dst_type, value); - } else if (from.is_float() && to.is_int()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLFloat) && to.MatchesCode(DLDataTypeCode::kDLInt)) { return MakeValue(spv::OpConvertFToS, dst_type, value); - } else if (from.is_float() && to.is_uint()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLFloat) && + to.MatchesCode(DLDataTypeCode::kDLUInt)) { return MakeValue(spv::OpConvertFToU, dst_type, value); - } else if (from.is_int() && to.is_float()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLInt) && to.MatchesCode(DLDataTypeCode::kDLFloat)) { return MakeValue(spv::OpConvertSToF, dst_type, value); - } else if (from.is_uint() && to.is_float()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLUInt) && + to.MatchesCode(DLDataTypeCode::kDLFloat)) { return MakeValue(spv::OpConvertUToF, dst_type, value); - } else if (from.is_float() && to.is_float()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLFloat) && + to.MatchesCode(DLDataTypeCode::kDLFloat)) { return MakeValue(spv::OpFConvert, dst_type, value); } else { TVM_FFI_THROW(InternalError) << "do not support type cast from " << from << " to " << to; @@ -782,28 +783,28 @@ Value IRBuilder::GetSpecConst(const SType& dtype, uint64_t value) { return ret; } -#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ - if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ - return MakeValue(spv::OpI##_Op, a.stype, a, b); \ - } else { \ - TVM_FFI_ICHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpF##_Op, a.stype, a, b); \ - } \ +#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ + if (a.stype.type.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) { \ + return MakeValue(spv::OpI##_Op, a.stype, a, b); \ + } else { \ + TVM_FFI_ICHECK(a.stype.type.MatchesCode(DLDataTypeCode::kDLFloat)); \ + return MakeValue(spv::OpF##_Op, a.stype, a, b); \ + } \ } -#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ - if (a.stype.type.is_int()) { \ - return MakeValue(spv::OpS##_Op, a.stype, a, b); \ - } else if (a.stype.type.is_uint()) { \ - return MakeValue(spv::OpU##_Op, a.stype, a, b); \ - } else { \ - TVM_FFI_ICHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpF##_Op, a.stype, a, b); \ - } \ +#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ + if (a.stype.type.MatchesCode(DLDataTypeCode::kDLInt)) { \ + return MakeValue(spv::OpS##_Op, a.stype, a, b); \ + } else if (a.stype.type.MatchesCode(DLDataTypeCode::kDLUInt)) { \ + return MakeValue(spv::OpU##_Op, a.stype, a, b); \ + } else { \ + TVM_FFI_ICHECK(a.stype.type.MatchesCode(DLDataTypeCode::kDLFloat)); \ + return MakeValue(spv::OpF##_Op, a.stype, a, b); \ + } \ } DEFINE_BUILDER_BINARY_USIGN_OP(Add, Add); @@ -813,29 +814,29 @@ DEFINE_BUILDER_BINARY_SIGN_OP(Div, Div); Value IRBuilder::Mod(Value a, Value b) { TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); - if (a.stype.type.is_int()) { + if (a.stype.type.MatchesCode(DLDataTypeCode::kDLInt)) { return MakeValue(spv::OpSRem, a.stype, a, b); - } else if (a.stype.type.is_uint()) { + } else if (a.stype.type.MatchesCode(DLDataTypeCode::kDLUInt)) { return MakeValue(spv::OpUMod, a.stype, a, b); } else { - TVM_FFI_ICHECK(a.stype.type.is_float()); + TVM_FFI_ICHECK(a.stype.type.MatchesCode(DLDataTypeCode::kDLFloat)); return MakeValue(spv::OpFRem, a.stype, a, b); } } -#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ - TVM_FFI_ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ - const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \ - if (a.stype.type.is_int()) { \ - return MakeValue(spv::OpS##_Op, bool_type, a, b); \ - } else if (a.stype.type.is_uint()) { \ - return MakeValue(spv::OpU##_Op, bool_type, a, b); \ - } else { \ - TVM_FFI_ICHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ - } \ +#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ + TVM_FFI_ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ + const auto& bool_type = this->GetSType(PrimType::Bool().WithLanes(a.stype.type.lanes())); \ + if (a.stype.type.MatchesCode(DLDataTypeCode::kDLInt)) { \ + return MakeValue(spv::OpS##_Op, bool_type, a, b); \ + } else if (a.stype.type.MatchesCode(DLDataTypeCode::kDLUInt)) { \ + return MakeValue(spv::OpU##_Op, bool_type, a, b); \ + } else { \ + TVM_FFI_ICHECK(a.stype.type.MatchesCode(DLDataTypeCode::kDLFloat)); \ + return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ + } \ } DEFINE_BUILDER_CMP_OP(LT, LessThan); @@ -843,17 +844,17 @@ DEFINE_BUILDER_CMP_OP(LE, LessThanEqual); DEFINE_BUILDER_CMP_OP(GT, GreaterThan); DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual); -#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ - TVM_FFI_ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ - const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \ - if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ - return MakeValue(spv::OpI##_Op, bool_type, a, b); \ - } else { \ - TVM_FFI_ICHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ - } \ +#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ + TVM_FFI_ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ + const auto& bool_type = this->GetSType(PrimType::Bool().WithLanes(a.stype.type.lanes())); \ + if (a.stype.type.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) { \ + return MakeValue(spv::OpI##_Op, bool_type, a, b); \ + } else { \ + TVM_FFI_ICHECK(a.stype.type.MatchesCode(DLDataTypeCode::kDLFloat)); \ + return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ + } \ } DEFINE_BUILDER_CMP_UOP(EQ, Equal); @@ -861,7 +862,7 @@ DEFINE_BUILDER_CMP_UOP(NE, NotEqual); Value IRBuilder::Select(Value cond, Value a, Value b) { TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); - TVM_FFI_ICHECK_EQ(cond.stype.type.element_of(), DataType::Bool()); + TVM_FFI_ICHECK_EQ(cond.stype.type.WithLanes(1), PrimType::Bool()); return MakeValue(spv::OpSelect, a.stype, cond, a, b); } diff --git a/src/backend/vulkan/codegen/ir_builder.h b/src/backend/vulkan/codegen/ir_builder.h index 3cca1b4cfe33..7e8844682c4e 100644 --- a/src/backend/vulkan/codegen/ir_builder.h +++ b/src/backend/vulkan/codegen/ir_builder.h @@ -50,7 +50,7 @@ struct SType { /*! \brief The Id to represent type */ uint32_t id{0}; /*! \brief corresponding TVM type */ - tvm::DataType type; + tvm::PrimType type{tvm::PrimType::Void()}; /*! \brief content type id if it is a pointer/struct-array class */ uint32_t element_type_id{0}; /*! \brief The storage class, if it is a pointer */ @@ -430,7 +430,7 @@ class IRBuilder { * \return The result value. */ Value CallKHRIntegerDotProduct(const SType& ret_type, const std::vector& args, - const DataType& dtype); + const PrimType& dtype); /*! * \brief Build vector by concatenating components @@ -444,7 +444,7 @@ class IRBuilder { * \param dtype The data type. * \return The corresponding spirv type. */ - SType GetSType(const tvm::DataType& dtype, uint32_t row = 0, uint32_t col = 0); + SType GetSType(const tvm::PrimType& dtype, uint32_t row = 0, uint32_t col = 0); /*! * \brief Get the pointer type that points to value_type * \param value_type. @@ -656,11 +656,11 @@ class IRBuilder { Value GetConst_(const SType& dtype, const uint64_t* pvalue); // declare type - SType DeclareType(const DataType& dtype, uint32_t row = 0, uint32_t col = 0); + SType DeclareType(const PrimType& dtype, uint32_t row = 0, uint32_t col = 0); // Declare the appropriate SPIR-V capabilities and extensions to use // this data type. - void AddCapabilityFor(const DataType& dtype); + void AddCapabilityFor(const PrimType& dtype); /*! \brief SPIRV-related capabilities of the target * diff --git a/src/backend/webgpu/codegen/codegen_webgpu.cc b/src/backend/webgpu/codegen/codegen_webgpu.cc index 440f1f04b95e..7129aa23d2ee 100644 --- a/src/backend/webgpu/codegen/codegen_webgpu.cc +++ b/src/backend/webgpu/codegen/codegen_webgpu.cc @@ -68,7 +68,7 @@ class WebGPUWorkgroupInfoCollector : public StmtExprVisitor { void VisitExpr_(const VarNode* op) final { StmtExprVisitor::VisitExpr_(op); Var buffer_var = ffi::GetRef(op); - if (buffer_var.dtype().is_handle()) { + if (buffer_var.ty().IsHandle()) { info_.write_access_set.insert(buffer_var); } } @@ -119,7 +119,7 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); // analyze the data; for (Var arg : f->params) { - if (arg.dtype().is_handle()) { + if (arg.ty().IsHandle()) { alloc_storage_scope_[arg.get()] = "global"; } } @@ -174,10 +174,10 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re os_param_access << "paramWriteAccess:["; // setup buffer argumemts for (Var arg : f->params) { - DataType t = arg.dtype(); - func_arg_types.push_back(t); + PrimType t = arg.ty(); + func_arg_types.push_back(t->dtype); - if (t.is_handle()) { + if (t.IsHandle()) { auto* ptr = arg->type_annotation.as(); TVM_FFI_ICHECK(ptr) << "All handles passed to the CodeGenWebGPU must have a type_annotation as a " @@ -188,11 +188,11 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re << "All handles passed to the CodeGenWebGPU must have a type_annotation as a " "PointerType, " << "and must point to a PrimType"; - DataType value_storage_type = prim->dtype; - if (value_storage_type == DataType::Bool()) { + PrimType value_storage_type(prim->dtype); + if (value_storage_type.MatchesCode(DLDataTypeCode::kDLBool)) { // We need a physically addressable buffer type to support boolean tensors. // The loaded byte is cast to bool inside the LoadNode visitor below. - value_storage_type = boolean_storage_type_.with_lanes(value_storage_type.lanes()); + value_storage_type = boolean_storage_type_.WithLanes(value_storage_type.lanes()); } std::string vid = AllocVarID(arg.get()); std::string access_mode; @@ -209,7 +209,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re // add extra access mode info to launch params this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") " << "var " << vid << " : array<"; - this->PrintType(value_storage_type, this->decl_stream); + this->PrintType(value_storage_type->dtype, this->decl_stream); this->decl_stream << ">;\n"; } else { pod_args.push_back(arg); @@ -228,17 +228,17 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re for (size_t i = 0; i < pod_args.size(); ++i) { Var v = pod_args[i]; - TVM_FFI_ICHECK(!v.dtype().is_handle()); + TVM_FFI_ICHECK(!v.ty().IsHandle()); std::string vid = AllocVarID(v.get()); - if (v.dtype() == DataType::Int(32)) { + if (v.ty() == PrimType::Int(32)) { this->decl_stream << " " << vid << ": i32"; - } else if (v.dtype() == DataType::UInt(32)) { + } else if (v.ty() == PrimType::UInt(32)) { this->decl_stream << " " << vid << ": u32"; - } else if (v.dtype() == DataType::Float(32)) { + } else if (v.ty() == PrimType::Float(32)) { this->decl_stream << " " << vid << ": f32"; } else { - TVM_FFI_THROW(InternalError) << "Do not support pod argument type " << v.dtype(); + TVM_FFI_THROW(InternalError) << "Do not support pod argument type " << v.ty()->dtype; } this->decl_stream << ",\n"; // value ref @@ -289,13 +289,13 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re void CodeGenWebGPU::BindThreadIndex(const IterVar& iv) { TVM_FFI_ICHECK(!var_idmap_.count(iv->var.get())); std::ostringstream os; - PrintType(iv->var.dtype(), os); + PrintType(iv->var.ty()->dtype, os); if (iv->thread_tag == "blockIdx.x") { // WebGPU have restriction to limit the maximum size of blockId.x to be 65535 // We allow runtime to spread the load out to blockIdx.z so it can be a large number. os << "(blockIdx.z * gridDim.x + blockIdx.x)"; std::string tidx = os.str(); - std::string aggregated_bidx = SSAGetID(os.str(), iv->var.dtype()); + std::string aggregated_bidx = SSAGetID(os.str(), iv->var.ty()->dtype); var_idmap_[iv->var.get()] = aggregated_bidx; } else { os << "(" << iv->thread_tag << ")"; @@ -305,16 +305,17 @@ void CodeGenWebGPU::BindThreadIndex(const IterVar& iv) { } } -void CodeGenWebGPU::PrintType(DataType t, std::ostream& os) { // NOLINT(*) +void CodeGenWebGPU::PrintType(DLDataType raw_t, std::ostream& os) { // NOLINT(*) + PrimType t(raw_t); int lanes = t.lanes(); - if (t.is_handle()) { + if (t.IsHandle()) { TVM_FFI_THROW(InternalError) << "Cannot print handle type in WebGPU"; } - if (t.is_void()) { + if (t.IsVoid()) { os << "void"; return; } - if (t == DataType::Bool()) { + if (raw_t == DLDataType{kDLBool, 8, 1}) { os << "bool"; return; } @@ -323,28 +324,29 @@ void CodeGenWebGPU::PrintType(DataType t, std::ostream& os) { // NOLINT(*) TVM_FFI_ICHECK(lanes >= 2 && lanes <= 4) << "CodeGenWebGPU: only allows vector with lanes in {2, 3, 4}"; // Currently WebGPU doesn't support `i8` and an `int8x4` is represented as a `u32`. - if (t.is_int() && t.bits() == 8 && lanes == 4) { + if (t.MatchesCode(DLDataTypeCode::kDLInt) && t.bits() == 8 && lanes == 4) { os << "u32"; return; } os << "vec" << lanes << "<"; } - if (t.is_float()) { + if (t.code() == DLDataTypeCode::kDLFloat) { TVM_FFI_ICHECK(t.bits() == 16 || t.bits() == 32) << "CodeGenWebGPU: only support f16 or f32"; if (t.bits() == 16) { // Using f16 requires enable directive enable_fp16_ = true; } os << "f" << t.bits(); - } else if (t.is_uint()) { + } else if (t.MatchesCode(DLDataTypeCode::kDLUInt)) { TVM_FFI_ICHECK(t.bits() != 64) << "CodeGenWebGPU: do not support u64"; os << "u" << t.bits(); - } else if (t.is_int()) { + } else if (t.MatchesCode(DLDataTypeCode::kDLInt)) { TVM_FFI_ICHECK(t.bits() != 64) << "CodeGenWebGPU: do not support i64"; os << "i" << t.bits(); } else { - TVM_FFI_THROW(InternalError) << "CodeGenWebGPU: Cannot convert type " << t << " to WebGPU type"; + TVM_FFI_THROW(InternalError) << "CodeGenWebGPU: Cannot convert type " + << ffi::DLDataTypeToString(raw_t) << " to WebGPU type"; } if (lanes != 1) { os << ">"; @@ -365,18 +367,18 @@ void CodeGenWebGPU::PrintStorageSync(const CallNode* op) { } void CodeGenWebGPU::PrintSSAAssign(const std::string& target, const std::string& src, - DataType type) { + PrimType type) { stream << "let " << target << " : "; - PrintType(type, stream); + PrintType(type->dtype, stream); stream << " = " << src << ";\n"; } -void CodeGenWebGPU::PrintVecElemLoad(const std::string& vec, DataType t, int i, +void CodeGenWebGPU::PrintVecElemLoad(const std::string& vec, DLDataType t, int i, std::ostream& os) { // NOLINT(*) os << vec << "[" << i << "]"; } -void CodeGenWebGPU::PrintVecElemStore(const std::string& vec, DataType t, int i, +void CodeGenWebGPU::PrintVecElemStore(const std::string& vec, DLDataType t, int i, const std::string& value) { this->PrintIndent(); stream << vec << "[" << i << "] = " << value << ";\n"; @@ -384,8 +386,8 @@ void CodeGenWebGPU::PrintVecElemStore(const std::string& vec, DataType t, int i, void CodeGenWebGPU::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); - int lanes = op->dtype.lanes(); - PrintType(op->dtype, os); + int lanes = op->ty().lanes(); + PrintType(op->ty()->dtype, os); os << "("; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; @@ -395,14 +397,14 @@ void CodeGenWebGPU::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // } PrimExpr CodeGenWebGPU::EnforceU32(PrimExpr value) { - return cast(DataType::UInt(32, value.dtype().lanes()), value); + return cast(PrimType::UInt(32, value.ty().lanes()), value); } void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (op->op.same_as(builtin::reinterpret())) { // generate bitcast(ARG) os << "bitcast<"; - this->PrintType(op->dtype, os); + this->PrintType(op->ty()->dtype, os); os << ">("; this->PrintExpr(op->args[0], os); os << ")"; @@ -426,7 +428,7 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN std::string cond = PrintExpr(op->args[0]); this->PrintIndent(); this->stream << "var " << result << " : "; - PrintType(op->dtype, this->stream); + PrintType(op->ty()->dtype, this->stream); this->stream << ";\n"; this->PrintIndent(); this->stream << "if (" << cond << ") {\n"; @@ -459,7 +461,7 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN } void CodeGenWebGPU::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*) - PrintType(op->dtype, os); + PrintType(op->ty()->dtype, os); os << "(" << PrintExpr(op->value) << ")"; } @@ -478,7 +480,7 @@ void CodeGenWebGPU::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT PrintIndent(); std::string value = PrintExpr(op->value); this->stream << "let " << AllocVarID(op->var.get()) << " : "; - PrintType(op->var.dtype(), this->stream); + PrintType(op->var.ty()->dtype, this->stream); this->stream << " = " << value << ";\n"; } os << PrintExpr(op->body); @@ -490,18 +492,18 @@ void CodeGenWebGPU::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT } void CodeGenWebGPU::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) - if (op->dtype.bits() == 32) { + if (op->ty().bits() == 32) { std::ostringstream temp; - if (op->dtype.is_int()) { + if (op->ty().MatchesCode(DLDataTypeCode::kDLInt)) { temp << op->value << "i"; } else { - TVM_FFI_ICHECK(op->dtype.is_uint()); + TVM_FFI_ICHECK(op->ty().MatchesCode(DLDataTypeCode::kDLUInt)); temp << op->value << "u"; } this->MarkConst(temp.str()); os << temp.str(); } else { - this->PrintType(op->dtype, os); + this->PrintType(op->ty()->dtype, os); os << "(" << op->value << ")"; } } @@ -509,14 +511,14 @@ void CodeGenWebGPU::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOL void CodeGenWebGPU::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) std::ostringstream temp; temp << std::scientific << op->value; - if (op->dtype.bits() == 32) { + if (op->ty().bits() == 32) { temp << 'f'; - } else if (op->dtype.bits() == 16) { + } else if (op->ty().bits() == 16) { // Using f16 requires enable directive enable_fp16_ = true; temp << 'h'; } else { - TVM_FFI_THROW(InternalError) << "Unsupported floating point bits " << op->dtype.bits(); + TVM_FFI_THROW(InternalError) << "Unsupported floating point bits " << op->ty().bits(); } MarkConst(temp.str()); os << temp.str(); @@ -530,39 +532,42 @@ void CodeGenWebGPU::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer load is not supported."; - DataType value_dtype = op->dtype; + DLDataType value_dtype = op->ty()->dtype; + PrimType value_ty(value_dtype); PrimExpr index = op->indices[0]; Var buffer_var = op->buffer->data; - DataType element_dtype = op->buffer->dtype; + DLDataType element_dtype = op->buffer->dtype->dtype; + PrimType element_ty(element_dtype); - int lanes = op->dtype.lanes(); + int lanes = value_ty.lanes(); std::string buffer_vid = GetVarID(buffer_var.get()); - if (value_dtype.lanes() == element_dtype.lanes()) { + if (value_ty.lanes() == element_ty.lanes()) { // Direct buffer loading // Special handle bool loading - if (value_dtype == DataType::Bool()) { + if (value_dtype == DLDataType{kDLBool, 8, 1}) { this->PrintType(value_dtype, os); os << "("; } else { TVM_FFI_ICHECK(value_dtype == element_dtype); } - TVM_FFI_ICHECK_EQ(index.dtype().lanes(), 1); + TVM_FFI_ICHECK_EQ(index.ty().lanes(), 1); os << buffer_vid << "[" << this->PrintExpr(index) << "]"; // Special handle bool loading - if (value_dtype == DataType::Bool()) { + if (value_dtype == DLDataType{kDLBool, 8, 1}) { os << ")"; } } else { // Vector load from scalar buffer - TVM_FFI_ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array"; - TVM_FFI_ICHECK(value_dtype.element_of() == element_dtype) + TVM_FFI_ICHECK_EQ(element_ty.lanes(), 1) << "Can only vector load scalar array"; + DLDataType value_element_dtype{value_dtype.code, value_dtype.bits, 1}; + TVM_FFI_ICHECK(value_element_dtype == element_dtype) << "WebGPU vector loading requires base type to match"; arith::PVar base; - if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) { + if (arith::ramp(base, 1, value_ty.lanes()).Match(index)) { // vec3(buf[base + 0], buf[base + 1], buf[base + 2]); - std::string base_vid = SSAGetID(PrintExpr(base.Eval()), base.Eval().dtype()); - PrintType(element_dtype.with_lanes(value_dtype.lanes()), os); + std::string base_vid = SSAGetID(PrintExpr(base.Eval()), base.Eval().ty()->dtype); + PrintType(element_ty.WithLanes(value_ty.lanes())->dtype, os); os << "("; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; @@ -571,8 +576,8 @@ void CodeGenWebGPU::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // os << ")"; } else { // vec3(buf[index[0]], buf[index[1]], buf[index[2]]); - std::string index_vid = SSAGetID(PrintExpr(index), index.dtype()); - PrintType(element_dtype.with_lanes(value_dtype.lanes()), os); + std::string index_vid = SSAGetID(PrintExpr(index), index.ty()->dtype); + PrintType(element_ty.WithLanes(value_ty.lanes())->dtype, os); os << "("; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; @@ -593,7 +598,7 @@ void CodeGenWebGPU::VisitStmt_(const BindNode* op) { PrintIndent(); std::string value = PrintExpr(op->value); this->stream << "let " << AllocVarID(op->var.get()) << " : "; - PrintType(op->var.dtype(), this->stream); + PrintType(op->var.ty()->dtype, this->stream); this->stream << " = " << value << ";\n"; } } @@ -602,14 +607,16 @@ void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer store is not supported."; - DataType value_dtype = op->value.dtype(); - DataType element_dtype = op->buffer->dtype; + DLDataType value_dtype = op->value.ty()->dtype; + PrimType value_ty(value_dtype); + DLDataType element_dtype = op->buffer->dtype->dtype; + PrimType element_ty(element_dtype); PrimExpr index = op->indices[0]; Var buffer_var = op->buffer->data; std::string buffer_vid = GetVarID(buffer_var.get()); - if (value_dtype.lanes() == element_dtype.lanes()) { + if (value_ty.lanes() == element_ty.lanes()) { // must execute print expr first // so we won't have recursive append to stream std::string index_vid = PrintExpr(index); @@ -618,7 +625,7 @@ void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { this->PrintIndent(); stream << buffer_vid << "[" << index_vid << "] = "; // special explicit conversion of bool - if (value_dtype == DataType::Bool()) { + if (value_dtype == DLDataType{kDLBool, 8, 1}) { PrintType(element_dtype, stream); stream << "("; } else { @@ -626,22 +633,23 @@ void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { } stream << value_vid; // Special handle bool store - if (value_dtype == DataType::Bool()) { + if (value_dtype == DLDataType{kDLBool, 8, 1}) { stream << ")"; } stream << ";\n"; } else { // Vector store into scalar buffer - TVM_FFI_ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array"; - TVM_FFI_ICHECK(value_dtype.element_of() == element_dtype) + TVM_FFI_ICHECK_EQ(element_ty.lanes(), 1) << "Can only vector load scalar array"; + DLDataType value_element_dtype{value_dtype.code, value_dtype.bits, 1}; + TVM_FFI_ICHECK(value_element_dtype == element_dtype) << "WebGPU vector stire requires base type to match"; std::string value_vid = PrintExpr(op->value); arith::PVar base; - if (arith::ramp(base, 1, value_dtype.lanes()).Match(index)) { + if (arith::ramp(base, 1, value_ty.lanes()).Match(index)) { // buf[base + 0] = value[0] // buf[base + 1] = value[1] - std::string base_vid = SSAGetID(PrintExpr(base.Eval()), base.Eval().dtype()); - for (int i = 0; i < value_dtype.lanes(); ++i) { + std::string base_vid = SSAGetID(PrintExpr(base.Eval()), base.Eval().ty()->dtype); + for (int i = 0; i < value_ty.lanes(); ++i) { this->PrintIndent(); stream << buffer_vid << "[" << base_vid << " + " << i << "] = " << value_vid << "[" << i << "];\n"; @@ -649,8 +657,8 @@ void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { } else { // buf[index[0]] = value[0] // buf[index[1]] = value[1] - std::string index_vid = SSAGetID(PrintExpr(index), index.dtype()); - for (int i = 0; i < value_dtype.lanes(); ++i) { + std::string index_vid = SSAGetID(PrintExpr(index), index.ty()->dtype); + for (int i = 0; i < value_ty.lanes(); ++i) { this->PrintIndent(); stream << buffer_vid << "[" << index_vid << "[" << i << "]] = " << value_vid << "[" << i << "];\n"; @@ -673,12 +681,12 @@ void CodeGenWebGPU::VisitStmt_(const AllocBufferNode* op) { if (storage_scope.rank == runtime::StorageRank::kShared) { this->decl_stream << "var " << vid << " : array<"; - PrintType(op->buffer->dtype, this->decl_stream); + PrintType(op->buffer->dtype->dtype, this->decl_stream); this->decl_stream << ", " << constant_size << ">;\n"; } else if (storage_scope.rank == runtime::StorageRank::kLocal) { this->PrintIndent(); this->stream << "var " << vid << " : array<"; - PrintType(op->buffer->dtype, this->stream); + PrintType(op->buffer->dtype->dtype, this->stream); this->stream << ", " << constant_size << ">;\n"; } else { TVM_FFI_THROW(InternalError) << "WebGPU: Do not support storage scope: " @@ -694,7 +702,7 @@ void CodeGenWebGPU::VisitStmt_(const ForNode* op) { std::string vid = AllocVarID(op->loop_var.get()); PrintIndent(); stream << "for (var " << vid << " : "; - PrintType(op->loop_var.dtype(), stream); + PrintType(op->loop_var.ty()->dtype, stream); stream << " = " << begin_str << "; " << vid << " < " << end_str << "; " << vid; if (step_str.empty()) { stream << "++"; diff --git a/src/backend/webgpu/codegen/codegen_webgpu.h b/src/backend/webgpu/codegen/codegen_webgpu.h index 4c873ac3db18..c2179c5c48aa 100644 --- a/src/backend/webgpu/codegen/codegen_webgpu.h +++ b/src/backend/webgpu/codegen/codegen_webgpu.h @@ -51,16 +51,17 @@ class CodeGenWebGPU final : public CodeGenC { using CodeGenC::AddFunction; runtime::FunctionInfo AddFunction(const PrimFunc& f, bool skip_readonly_decl); // NOLINT(*) void InitFuncState(const PrimFunc& f) final; - void PrintStorageSync(const CallNode* op) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) + void PrintStorageSync(const CallNode* op) final; // NOLINT(*) + void PrintType(DLDataType t, std::ostream& os) final; // NOLINT(*) + void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) // assignment printing - void PrintSSAAssign(const std::string& target, const std::string& src, DataType type) final; + void PrintSSAAssign(const std::string& target, const std::string& src, PrimType type) final; // overload printing vector element load/store - void PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream& os) final; - void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; + void PrintVecElemLoad(const std::string& vec, DLDataType t, int i, std::ostream& os) final; + void PrintVecElemStore(const std::string& vec, DLDataType t, int i, + const std::string& value) final; // overload visitor void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) @@ -90,7 +91,7 @@ class CodeGenWebGPU final : public CodeGenC { /*! * \brief Storage type of bool values. */ - DataType boolean_storage_type_{DataType::Int(8)}; + PrimType boolean_storage_type_{PrimType::Int(8)}; // whether enable fp16 bool enable_fp16_{false}; diff --git a/src/backend/webgpu/codegen/intrin_rule_webgpu.cc b/src/backend/webgpu/codegen/intrin_rule_webgpu.cc index 1c172fcd141b..7992fa9915c0 100644 --- a/src/backend/webgpu/codegen/intrin_rule_webgpu.cc +++ b/src/backend/webgpu/codegen/intrin_rule_webgpu.cc @@ -34,7 +34,7 @@ using tirx::FLowerIntrinsic; // warp-level primitives. Follows implementation in intrin_rule_metal.cc struct WebGPUWarpIntrinsic { - const Op operator()(DataType t, const Op& orig_op) const { + const Op operator()(PrimType t, const Op& orig_op) const { if (orig_op.same_as(builtin::tvm_warp_shuffle())) { static const Op& webgpu_subgroup_shuffle_op = Op::Get("tirx.webgpu.subgroup_shuffle"); return webgpu_subgroup_shuffle_op; @@ -55,9 +55,9 @@ static PrimExpr DispatchWebGPUShuffle(const PrimExpr& e) { const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size - PrimExpr lane_or_delta = Cast(DataType::UInt(32, call->args[2].dtype().lanes()), call->args[2]); + PrimExpr lane_or_delta = Cast(PrimType::UInt(32, call->args[2].ty().lanes()), call->args[2]); ffi::Array webgpu_args{{call->args[1], lane_or_delta}}; - return Call(call->dtype, T()(call->dtype, call->op.as_or_throw()), webgpu_args); + return Call(e.ty(), T()(e.ty(), call->op.as_or_throw()), webgpu_args); } void RegisterWebGPUIntrinRules() { @@ -69,7 +69,7 @@ void RegisterWebGPUIntrinRules() { // See full list of builtin: https://www.w3.org/TR/WGSL/#builtin-functions struct ReturnAbs { - std::string operator()(DataType t, std::string name) const { return "abs"; } + std::string operator()(PrimType t, std::string name) const { return "abs"; } }; TVM_REGISTER_OP("tirx.fabs") @@ -124,7 +124,7 @@ TVM_REGISTER_OP("tirx.pow") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); struct ReturnRound { - std::string operator()(DataType t, std::string name) const { return "round"; } + std::string operator()(PrimType t, std::string name) const { return "round"; } }; // WGSL round() uses ties-to-even (banker's rounding), matching IEEE 754 and ONNX Round spec. diff --git a/src/ir/expr.cc b/src/ir/expr.cc index ef6ea0ed6dca..f73cd6ae3913 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -48,33 +49,39 @@ TVM_FFI_STATIC_INIT_BLOCK() { PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm::Int32(value)) {} -PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} +PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(PrimType::Float(32), value)) {} PrimExpr PrimExpr::ConvertFallbackValue(ffi::String value) { return tirx::StringImm(value); } -IntImm::IntImm(DataType dtype, int64_t value, Span span) { - TVM_FFI_CHECK(dtype.is_scalar(), ValueError) - << "IntImm can only take scalar, but " << dtype << " was supplied."; - TVM_FFI_CHECK(dtype.is_int() || dtype.is_uint() || dtype.is_bool(), ValueError) - << "IntImm supports only int or uint or bool type, but " << dtype << " was supplied."; - if (dtype.is_uint()) { +IntImm::IntImm(PrimType value_ty, int64_t value, Span span) { + DLDataType runtime_dtype = value_ty->dtype; + DLDataTypeCode code = value_ty.code(); + int32_t bits = value_ty.bits(); + TVM_FFI_CHECK(!value_ty.IsScalableVector() && !value_ty.IsFixedLengthVector(), ValueError) + << "IntImm can only take scalar, but " << runtime_dtype << " was supplied."; + TVM_FFI_CHECK(value_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt, + DLDataTypeCode::kDLBool), + ValueError) + << "IntImm supports only int or uint or bool type, but " << runtime_dtype << " was supplied."; + if (code == DLDataTypeCode::kDLUInt) { TVM_FFI_CHECK_GE(value, 0U, ValueError) - << "Literal value " << value << " is negative for unsigned integer type " << dtype; - if (dtype.bits() < 64) { - TVM_FFI_CHECK_LT(value, 1LL << dtype.bits(), ValueError) - << "Literal value " << value << " exceeds maximum of " << dtype; + << "Literal value " << value << " is negative for unsigned integer type " << runtime_dtype; + if (bits < 64) { + TVM_FFI_CHECK_LT(value, 1LL << bits, ValueError) + << "Literal value " << value << " exceeds maximum of " << runtime_dtype; } - } else if (dtype.bits() == 1 || dtype.is_bool()) { + } else if (bits == 1 || code == DLDataTypeCode::kDLBool) { // int(1) - TVM_FFI_CHECK(value == 0 || value == 1, ValueError) << value << " exceeds range of " << dtype; - } else if (dtype.bits() < 64) { - TVM_FFI_CHECK_GE(value, -(1LL << (dtype.bits() - 1)), ValueError) - << "Literal value " << value << " exceeds minimum of " << dtype; - TVM_FFI_CHECK_LT(value, 1LL << (dtype.bits() - 1), ValueError) - << "Literal value " << value << " exceeds maximum of " << dtype; + TVM_FFI_CHECK(value == 0 || value == 1, ValueError) + << value << " exceeds range of " << runtime_dtype; + } else if (bits < 64) { + TVM_FFI_CHECK_GE(value, -(1LL << (bits - 1)), ValueError) + << "Literal value " << value << " exceeds minimum of " << runtime_dtype; + TVM_FFI_CHECK_LT(value, 1LL << (bits - 1), ValueError) + << "Literal value " << value << " exceeds maximum of " << runtime_dtype; } ffi::ObjectPtr node = ffi::make_object(); - node->dtype = dtype; + node->BaseExprNode::ty = std::move(value_ty); node->value = value; node->span = span; data_ = std::move(node); @@ -82,103 +89,118 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ir.IntImm", [](DataType dtype, int64_t value, Span span) { - return IntImm(dtype, value, span); + refl::GlobalDef().def("ir.IntImm", [](DLDataType dtype, int64_t value, Span span) { + return IntImm(PrimType(dtype), value, span); }); } -FloatImm::FloatImm(DataType dtype, double value, Span span) { - TVM_FFI_CHECK_EQ(dtype.lanes(), 1, ValueError) << "FloatImm can only take scalar."; +FloatImm::FloatImm(PrimType value_ty, double value, Span span) { + DLDataType runtime_dtype = value_ty->dtype; + DLDataTypeCode code = value_ty.code(); + int32_t bits = value_ty.bits(); + TVM_FFI_CHECK(!value_ty.IsScalableVector() && !value_ty.IsFixedLengthVector(), ValueError) + << "FloatImm can only take scalar."; - TVM_FFI_CHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float6() || - dtype.is_float4() || dtype.code() >= DataType::kCustomBegin, - ValueError) - << "FloatImm supports only float, but " << dtype << " was supplied."; + TVM_FFI_CHECK( + value_ty.MatchesCode(DLDataTypeCode::kDLFloat, DLDataTypeCode::kDLFloat8_e3m4, + DLDataTypeCode::kDLFloat8_e4m3, DLDataTypeCode::kDLFloat8_e4m3b11fnuz, + DLDataTypeCode::kDLFloat8_e4m3fn, DLDataTypeCode::kDLFloat8_e4m3fnuz, + DLDataTypeCode::kDLFloat8_e5m2, DLDataTypeCode::kDLFloat8_e5m2fnuz, + DLDataTypeCode::kDLFloat8_e8m0fnu, DLDataTypeCode::kDLFloat6_e2m3fn, + DLDataTypeCode::kDLFloat6_e3m2fn) || + value_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16) || + value_ty.MatchesElementType(DLDataTypeCode::kDLFloat4_e2m1fn, 4) || + static_cast(code) >= static_cast(ffi::DLExtDataTypeCode::kDLExtCustomBegin), + ValueError) + << "FloatImm supports only float, but " << runtime_dtype << " was supplied."; // check range for float32 and float16 since they have specified range. if (!std::isinf(value) && !std::isnan(value)) { - if (dtype.bits() == 32) { + if (bits == 32) { TVM_FFI_CHECK_GE(value, std::numeric_limits::lowest(), ValueError) - << "Literal value " << value << " exceeds minimum of " << dtype; + << "Literal value " << value << " exceeds minimum of " << runtime_dtype; TVM_FFI_CHECK_LE(value, std::numeric_limits::max(), ValueError) - << "Literal value " << value << " exceeds maximum of " << dtype; - } else if (dtype.is_float16()) { + << "Literal value " << value << " exceeds maximum of " << runtime_dtype; + } else if (value_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16)) { TVM_FFI_CHECK_GE(value, -support::kMaxFloat16, ValueError) - << "Literal value " << value << " exceeds minimum of " << dtype; + << "Literal value " << value << " exceeds minimum of " << runtime_dtype; TVM_FFI_CHECK_LE(value, support::kMaxFloat16, ValueError) - << "Literal value " << value << " exceeds maximum of " << dtype; - } else if (dtype.is_bfloat16()) { + << "Literal value " << value << " exceeds maximum of " << runtime_dtype; + } else if (value_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { TVM_FFI_CHECK_GE(value, -support::kMaxBFloat16, ValueError) - << "Literal value " << value << " exceeds minimum of " << dtype; + << "Literal value " << value << " exceeds minimum of " << runtime_dtype; TVM_FFI_CHECK_LE(value, support::kMaxBFloat16, ValueError) - << "Literal value " << value << " exceeds maximum of " << dtype; - } else if (dtype.is_float8_e3m4() || dtype.is_float8_e4m3() || dtype.is_float8_e4m3b11fnuz() || - dtype.is_float8_e4m3fn() || dtype.is_float8_e4m3fnuz() || dtype.is_float8_e5m2() || - dtype.is_float8_e5m2fnuz() || dtype.is_float8_e8m0fnu()) { + << "Literal value " << value << " exceeds maximum of " << runtime_dtype; + } else if (value_ty.MatchesCode( + DLDataTypeCode::kDLFloat8_e3m4, DLDataTypeCode::kDLFloat8_e4m3, + DLDataTypeCode::kDLFloat8_e4m3b11fnuz, DLDataTypeCode::kDLFloat8_e4m3fn, + DLDataTypeCode::kDLFloat8_e4m3fnuz, DLDataTypeCode::kDLFloat8_e5m2, + DLDataTypeCode::kDLFloat8_e5m2fnuz, DLDataTypeCode::kDLFloat8_e8m0fnu)) { double bound = 0.0; bool nonneg = false; - switch (dtype.code()) { - case DataType::TypeCode::kFloat8_e3m4: + switch (code) { + case DLDataTypeCode::kDLFloat8_e3m4: bound = support::kMaxE3M4; break; - case DataType::TypeCode::kFloat8_e4m3: + case DLDataTypeCode::kDLFloat8_e4m3: bound = support::kMaxE4M3; break; - case DataType::TypeCode::kFloat8_e4m3b11fnuz: + case DLDataTypeCode::kDLFloat8_e4m3b11fnuz: bound = support::kMaxE4M3B11FNUZ; nonneg = true; break; - case DataType::TypeCode::kFloat8_e4m3fn: + case DLDataTypeCode::kDLFloat8_e4m3fn: bound = support::kMaxE4M3FN; break; - case DataType::TypeCode::kFloat8_e4m3fnuz: + case DLDataTypeCode::kDLFloat8_e4m3fnuz: bound = support::kMaxE4M3FNUZ; nonneg = true; break; - case DataType::TypeCode::kFloat8_e5m2: + case DLDataTypeCode::kDLFloat8_e5m2: bound = support::kMaxE5M2; break; - case DataType::TypeCode::kFloat8_e5m2fnuz: + case DLDataTypeCode::kDLFloat8_e5m2fnuz: bound = support::kMaxE5M2FNUZ; nonneg = true; break; - case DataType::TypeCode::kFloat8_e8m0fnu: + case DLDataTypeCode::kDLFloat8_e8m0fnu: bound = support::kMaxE8M0FNU; nonneg = true; break; default: - TVM_FFI_THROW(InternalError) << "Unhandled float8 type: " << dtype; + TVM_FFI_THROW(InternalError) << "Unhandled float8 type: " << runtime_dtype; } if (nonneg) { TVM_FFI_CHECK_GE(value, 0, ValueError) - << "Literal value " << value << " below zero for unsigned " << dtype; + << "Literal value " << value << " below zero for unsigned " << runtime_dtype; } else { TVM_FFI_CHECK_GE(value, -bound, ValueError) - << "Literal value " << value << " below minimum of " << dtype; + << "Literal value " << value << " below minimum of " << runtime_dtype; } TVM_FFI_CHECK_LE(value, bound, ValueError) - << "Literal value " << value << " exceeds maximum of " << dtype; + << "Literal value " << value << " exceeds maximum of " << runtime_dtype; - } else if (dtype.is_float6_e2m3fn() || dtype.is_float6_e3m2fn()) { - double bound = (dtype.code() == DataType::TypeCode::kFloat6_e2m3fn) ? support::kMaxE2M3FN - : support::kMaxE3M2FN; + } else if (value_ty.MatchesCode(DLDataTypeCode::kDLFloat6_e2m3fn, + DLDataTypeCode::kDLFloat6_e3m2fn)) { + double bound = + (code == DLDataTypeCode::kDLFloat6_e2m3fn) ? support::kMaxE2M3FN : support::kMaxE3M2FN; TVM_FFI_CHECK_GE(value, -bound, ValueError) - << "Literal value " << value << " below minimum of " << dtype; + << "Literal value " << value << " below minimum of " << runtime_dtype; TVM_FFI_CHECK_LE(value, bound, ValueError) - << "Literal value " << value << " exceeds maximum of " << dtype; + << "Literal value " << value << " exceeds maximum of " << runtime_dtype; - } else if (dtype.is_float4_e2m1fn()) { + } else if (code == DLDataTypeCode::kDLFloat4_e2m1fn) { double bound = support::kMaxE2M1FN; TVM_FFI_CHECK_GE(value, -bound, ValueError) - << "Literal value " << value << " below minimum of " << dtype; + << "Literal value " << value << " below minimum of " << runtime_dtype; TVM_FFI_CHECK_LE(value, bound, ValueError) - << "Literal value " << value << " exceeds maximum of " << dtype; + << "Literal value " << value << " exceeds maximum of " << runtime_dtype; } } ffi::ObjectPtr node = ffi::make_object(); - node->dtype = dtype; + node->BaseExprNode::ty = std::move(value_ty); node->value = value; node->span = span; data_ = std::move(node); @@ -186,8 +208,8 @@ FloatImm::FloatImm(DataType dtype, double value, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ir.FloatImm", [](DataType dtype, double value, Span span) { - return FloatImm(dtype, value, span); + refl::GlobalDef().def("ir.FloatImm", [](DLDataType dtype, double value, Span span) { + return FloatImm(PrimType(dtype), value, span); }); } @@ -206,7 +228,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { if (end.defined()) { return Range(begin, end.value(), span); } else { - return Range(IntImm(begin->dtype, 0), begin, span); + return Range(IntImm(begin.ty(), 0), begin, span); } }); } diff --git a/src/ir/type.cc b/src/ir/type.cc index d6d059dba079..20bbe9c0e58a 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -21,30 +21,133 @@ * \file src/ir/type.cc * \brief Common type system AST nodes throughout the IR. */ +#include #include #include #include + +#include +#include + namespace tvm { +namespace { + +DLDataType ScalableVectorDType(DLDataTypeCode code, int bits, int lanes) { + TVM_FFI_ICHECK_GT(lanes, 1) << "Invalid value for vscale factor " << lanes; + TVM_FFI_ICHECK_LT(lanes, 32768); + return DLDataType{static_cast(code), static_cast(bits), + static_cast(-lanes)}; +} + +uint32_t PackDataTypeKey(DLDataType dtype) { + return (static_cast(dtype.code) << 24) | (static_cast(dtype.bits) << 16) | + static_cast(dtype.lanes); +} + +int64_t PrimTypeAnyHash(const ffi::Any& src) { + return static_cast(PackDataTypeKey(src.cast()->dtype)); +} + +bool PrimTypeAnyEqual(const ffi::Any& lhs, const ffi::Any& rhs) { + return lhs.cast()->dtype == rhs.cast()->dtype; +} + +ffi::ObjectPtr GetCachedPrimTypeNode(DLDataType dtype) { + thread_local std::unordered_map> cache; + uint32_t key = PackDataTypeKey(dtype); + auto it = cache.find(key); + if (it != cache.end()) { + return it->second; + } + + ffi::ObjectPtr node = ffi::make_object(); + node->dtype = dtype; + return cache.emplace(key, std::move(node)).first->second; +} + +} // namespace + TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; TypeNode::RegisterReflection(); PrimTypeNode::RegisterReflection(); + refl::TypeAttrDef() + .attr(refl::type_attr::kAnyHash, reinterpret_cast(&PrimTypeAnyHash)) + .attr(refl::type_attr::kAnyEqual, reinterpret_cast(&PrimTypeAnyEqual)); PointerTypeNode::RegisterReflection(); TupleTypeNode::RegisterReflection(); FuncTypeNode::RegisterReflection(); TensorMapTypeNode::RegisterReflection(); } -PrimType::PrimType(runtime::DataType dtype, Span span) { - ffi::ObjectPtr n = ffi::make_object(); - n->dtype = dtype; - n->span = std::move(span); - data_ = std::move(n); +PrimType::PrimType(DLDataType dtype) { data_ = GetCachedPrimTypeNode(dtype); } + +PrimType::PrimType(DLDataTypeCode code, int bits, int lanes) + : PrimType(DLDataType{static_cast(code), static_cast(bits), + static_cast(lanes)}) {} + +PrimType PrimType::Int(int bits, int lanes) { + if (lanes == 1) { + if (bits == 32) { + static const PrimType i32_ty(DLDataType{kDLInt, 32, 1}); + return i32_ty; + } + if (bits == 64) { + static const PrimType i64_ty(DLDataType{kDLInt, 64, 1}); + return i64_ty; + } + } + return PrimType(DLDataType{kDLInt, static_cast(bits), static_cast(lanes)}); +} + +PrimType PrimType::UInt(int bits, int lanes) { + return PrimType(DLDataType{kDLUInt, static_cast(bits), static_cast(lanes)}); +} + +PrimType PrimType::Float(int bits, int lanes) { + if (bits == 32 && lanes == 1) { + static const PrimType f32_ty(DLDataType{kDLFloat, 32, 1}); + return f32_ty; + } + return PrimType(DLDataType{kDLFloat, static_cast(bits), static_cast(lanes)}); +} + +PrimType PrimType::BFloat(int bits, int lanes) { + return PrimType(DLDataType{kDLBfloat, static_cast(bits), static_cast(lanes)}); +} + +PrimType PrimType::Bool(int lanes) { + if (lanes == 1) { + static const PrimType bool_ty(DLDataType{kDLBool, 8, 1}); + return bool_ty; + } + return PrimType(DLDataType{kDLBool, 8, static_cast(lanes)}); +} + +PrimType PrimType::Handle(int bits, int lanes) { + return PrimType( + DLDataType{kDLOpaqueHandle, static_cast(bits), static_cast(lanes)}); +} + +PrimType PrimType::Void() { return PrimType(DLDataType{kDLOpaqueHandle, 0, 0}); } + +PrimType PrimType::ScalableVector(DLDataTypeCode code, int bits, int lanes) { + return PrimType(ScalableVectorDType(code, bits, lanes)); +} + +size_t PrimType::StorageBytes() const { + int16_t encoded_lanes = static_cast(get()->dtype.lanes); + if (TVM_FFI_PREDICT_FALSE(encoded_lanes < 0)) { + TVM_FFI_THROW(InternalError) + << "Cannot compute compile-time storage bytes for non-fixed vector type " << get()->dtype; + } + return ffi::GetDataSize(1, get()->dtype); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ir.PrimType", [](runtime::DataType dtype) { return PrimType(dtype); }); + refl::GlobalDef().def("ir.PrimType", [](DLDataType dtype) { return PrimType(dtype); }); } PointerType::PointerType(Type element_type, ffi::String storage_scope) { diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index 369f5793d9b5..0bfb48cca94c 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -478,8 +478,8 @@ bool HasReshapePattern(const PrimFunc& func) { } if (nontrivial_indices.defined()) { - DataType dtype = - !block->iter_vars.empty() ? block->iter_vars[0]->var->dtype : DataType::Int(64); + PrimType dtype = + !block->iter_vars.empty() ? block->iter_vars[0]->var.ty() : PrimType::Int(64); tirx::Var fused_var("fused", dtype); ffi::Map inverse_indices_map; PrimExpr stride = IntImm(dtype, /*value=*/1); @@ -494,7 +494,8 @@ bool HasReshapePattern(const PrimFunc& func) { ffi::Array simplify_res = arith::IterMapSimplify( /*indices=*/{flattened_idx}, - /*input_iters=*/{{fused_var, Range(IntImm(dtype, /*value=*/0), stride)}}, + /*input_iters=*/ + ffi::Map{{fused_var, Range(IntImm(dtype, /*value=*/0), stride)}}, /*input_pred=*/IntImm::Bool(true), /*check_level=*/arith::IterMapLevel::Surjective, /*analyzer=*/this->ana_, diff --git a/src/relax/analysis/type_analysis.cc b/src/relax/analysis/type_analysis.cc index 33070051ae63..34f5a4de6216 100644 --- a/src/relax/analysis/type_analysis.cc +++ b/src/relax/analysis/type_analysis.cc @@ -43,7 +43,7 @@ class StaticTypeDeriver : public TypeFunctor { public: Type VisitType_(const ObjectTypeNode* op) final { return ObjectType(op->span); } - Type VisitType_(const PrimTypeNode* op) final { return PrimType(op->dtype, op->span); } + Type VisitType_(const PrimTypeNode* op) final { return tvm::PrimType(op->dtype); } Type VisitType_(const ShapeTypeNode* op) final { return ShapeType(op->ndim, op->span); } @@ -86,7 +86,9 @@ Type TypeFromStaticType(const Type& type) { if (type.as()) { return ObjectType(type->span); } else if (const PrimTypeNode* prim_type = type.as()) { - return PrimType(prim_type->dtype, prim_type->span); + return tvm::PrimType(prim_type->dtype); + } else if (const tvm::PrimTypeNode* prim_type = type.as()) { + return tvm::PrimType(prim_type->dtype); } else if (const ShapeTypeNode* shape_type = type.as()) { return ShapeType(shape_type->ndim, type->span); } else if (const TensorTypeNode* tensor_type = type.as()) { @@ -221,9 +223,9 @@ class WellDefinedEraser : public TypeMutator, public ExprMutatorBase, public tir if (ret.defined()) { PrimExpr value = ret.value(); if (value->IsInstance()) { - return tvm::cast(DataType::Int(64), value); + return tvm::cast(PrimType::Int(64), value); } - TVM_FFI_ICHECK(value.dtype() == DataType::Int(64)) + TVM_FFI_ICHECK(value.ty().MatchesElementType(DLDataTypeCode::kDLInt, 64)) << "Can only provide i64 expressions in shape"; return value; } else { @@ -1015,7 +1017,9 @@ class TypeLCAFinder : public TypeFunctor { if (rhs == nullptr) return ObjectType(lhs->span); // find the target dtype, ndim, and vdevice. - DataType dtype = lhs->dtype == rhs->dtype ? lhs->dtype : DataType::Void(); + PrimType dtype = lhs->dtype->dtype == rhs->dtype->dtype + ? PrimType(lhs->dtype->dtype) + : PrimType(DLDataType{kDLOpaqueHandle, 0, 0}); int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownNDim; VDevice vdev = VDevice(); if (lhs->vdevice.defined() && rhs->vdevice.defined() && @@ -1028,7 +1032,7 @@ class TypeLCAFinder : public TypeFunctor { !CanProveShapeEqual(lhs->shape.value(), rhs->shape.value(), ffi::GetRef(analyzer_))) { // reuse lhs when possible - if (!lhs->shape.defined() && lhs->dtype == dtype && lhs->ndim == ndim && + if (!lhs->shape.defined() && lhs->dtype->dtype == dtype->dtype && lhs->ndim == ndim && (!lhs->vdevice.defined() || vdev.defined())) { return ffi::GetRef(lhs); } else { @@ -1036,7 +1040,7 @@ class TypeLCAFinder : public TypeFunctor { } } // symbolic shape and vdevice match but dtype mismatch - if (lhs->dtype != dtype || (lhs->vdevice.defined() && !vdev.defined())) { + if (lhs->dtype->dtype != dtype->dtype || (lhs->vdevice.defined() && !vdev.defined())) { return TensorType(lhs->shape.value(), dtype, vdev, lhs->span); } else { return ffi::GetRef(lhs); diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 5c3547249c5e..52e974be75f0 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -457,9 +457,9 @@ class WellFormedChecker : public relax::ExprVisitor, for (PrimExpr expr : op->values) { // check if the symbolic vars in the expr are defined, e.g, 2 * m tirx::ExprVisitor::VisitExpr(expr); - if (!expr.dtype().is_int()) { + if (expr.ty().code() != DLDataTypeCode::kDLInt) { TVM_FFI_VISIT_THROW(TypeError, expr) - << "Shape expressions must be of integer type, but got " << expr.dtype(); + << "Shape expressions must be of integer type, but got " << expr.ty()->dtype; } } CheckType(op); diff --git a/src/relax/backend/contrib/codegen_c/codegen_c.h b/src/relax/backend/contrib/codegen_c/codegen_c.h index 1a5fb1dd801e..0c36b04812c8 100644 --- a/src/relax/backend/contrib/codegen_c/codegen_c.h +++ b/src/relax/backend/contrib/codegen_c/codegen_c.h @@ -347,19 +347,20 @@ class CodegenCBase { */ std::string GetDtypeString(const TensorTypeNode* tensor_ty) { std::string dtype; - if (runtime::TypeMatch(tensor_ty->dtype, kDLFloat, 32)) { + DLDataType raw_dtype = tensor_ty->dtype->dtype; + if (raw_dtype == DLDataType{kDLFloat, 32, 1}) { dtype = "float"; - } else if (runtime::TypeMatch(tensor_ty->dtype, kDLFloat, 16)) { + } else if (raw_dtype == DLDataType{kDLFloat, 16, 1}) { dtype = "half"; - } else if (runtime::TypeMatch(tensor_ty->dtype, kDLBfloat, 16)) { + } else if (raw_dtype == DLDataType{kDLBfloat, 16, 1}) { dtype = "bfloat"; - } else if (runtime::TypeMatch(tensor_ty->dtype, kDLInt, 32)) { + } else if (raw_dtype == DLDataType{kDLInt, 32, 1}) { dtype = "int"; - } else if (runtime::TypeMatch(tensor_ty->dtype, kDLInt, 64)) { + } else if (raw_dtype == DLDataType{kDLInt, 64, 1}) { dtype = "int64_t"; - } else if (runtime::TypeMatch(tensor_ty->dtype, kDLInt, 8)) { + } else if (raw_dtype == DLDataType{kDLInt, 8, 1}) { dtype = "int8_t"; - } else if (runtime::TypeMatch(tensor_ty->dtype, kDLUInt, 8)) { + } else if (raw_dtype == DLDataType{kDLUInt, 8, 1}) { dtype = "uint8_t"; } else { TVM_FFI_THROW(InternalError) << "Unsupported dtype " << tensor_ty->dtype; diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index edebb7593fca..03133599a58a 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -89,8 +89,8 @@ class OpAttrExtractor { } } - void Visit(const char* key, DataType* value) { - if (!value->is_void()) { + void Visit(const char* key, DLDataType* value) { + if (!(value->code == kDLOpaqueHandle && value->bits == 0 && value->lanes == 0)) { SetNodeAttr(key, ffi::String(ffi::DLDataTypeToString(*value))); } else { SetNodeAttr(key, ffi::String("")); @@ -201,7 +201,7 @@ class OpAttrExtractor { break; } case ffi::TypeIndex::kTVMFFIDataType: { - DataType value(field_value.cast()); + DLDataType value = field_value.cast(); this->Visit(field_info->name.data, &value); break; } @@ -282,7 +282,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { ShapeExpr output_shape = tensor_ty->shape.value().as_or_throw(); ret.push_back(JSONGraphNodeEntry(node_id, i)); shape.emplace_back(GetIntShape(output_shape->values)); - dtype.emplace_back(DType2String(tensor_ty->dtype)); + dtype.emplace_back(DType2String(tensor_ty->dtype->dtype)); } node->SetNumOutput(tuple_ty->fields.size()); } else { @@ -292,7 +292,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { ShapeExpr output_shape = tensor_ty->shape.value().as_or_throw(); shape.emplace_back(GetIntShape(output_shape->values)); - dtype.emplace_back(DType2String(tensor_ty->dtype)); + dtype.emplace_back(DType2String(tensor_ty->dtype->dtype)); ret.push_back(JSONGraphNodeEntry(node_id, 0)); } node->SetShape(shape); diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc index 5284de94f622..f2999b172136 100644 --- a/src/relax/backend/contrib/cublas/codegen.cc +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -86,11 +86,11 @@ class CublasJSONSerializer : public JSONSerializer { const auto* const_expr = dequantize_call->args[1].as(); auto ty = const_expr->ty.as_or_throw(); float alpha = 1.0; - if (ty->dtype == DataType::Float(16)) { + if (ty->dtype == PrimType::Float(16)) { alpha = __extendXfYf2__( static_cast(const_expr->data->data)[0]); } else { - TVM_FFI_ICHECK(ty->dtype == DataType::Float(32)); + TVM_FFI_ICHECK(ty->dtype == PrimType::Float(32)); alpha = static_cast(const_expr->data->data)[0]; } diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc index 03621c400551..dfe4b24e4f12 100644 --- a/src/relax/backend/contrib/cutlass/codegen.cc +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -167,9 +167,9 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, for (const auto& arg : ext_func_args_) { auto ty = GetType(arg); if (const auto* tensor_ty = ty.as()) { - arg_types.emplace_back(backend::DType2String(tensor_ty->dtype)); + arg_types.emplace_back(backend::DType2String(tensor_ty->dtype->dtype)); } else if (const auto* shape_ty = ty.as()) { - arg_types.emplace_back(backend::DType2String(shape_ty->values.value()[0]->dtype)); + arg_types.emplace_back(backend::DType2String(shape_ty->values.value()[0].ty()->dtype)); } else { TVM_FFI_THROW(InternalError) << "Unimplemented"; } @@ -302,7 +302,7 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, std::vector out_types; if (const auto* tensor_ty = ty.as()) { - out_types.emplace_back(backend::DType2String(tensor_ty->dtype)); + out_types.emplace_back(backend::DType2String(tensor_ty->dtype->dtype)); } else { TVM_FFI_THROW(InternalError) << "Unimplemented ty type: " << ty; } diff --git a/src/relax/backend/contrib/utils.h b/src/relax/backend/contrib/utils.h index 93916bf23236..6147a6eb2199 100644 --- a/src/relax/backend/contrib/utils.h +++ b/src/relax/backend/contrib/utils.h @@ -59,9 +59,7 @@ inline std::vector GetIntShape(const ffi::Array& shape) { * \param typ * \return std::string string format of type */ -inline std::string DType2String(const tvm::DataType dtype) { - return tvm::ffi::DLDataTypeToString(dtype); -} +inline std::string DType2String(DLDataType dtype) { return tvm::ffi::DLDataTypeToString(dtype); } /*! * \brief Check if a call node is calling an op with the given name diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index c1e9af85511c..3e2ac365d4fb 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -88,19 +88,19 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { PrimExpr RegListGet(int64_t slot) const { // use 128 bits to represent any - return tirx::Call(DataType::Handle(), tirx::builtin::anylist_getitem(), + return tirx::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(), {reg_anylist_handle_, ConstInt32(slot)}); } PrimExpr ConstListGet(int64_t slot) const { // use 128 bits to represent any - return tirx::Call(DataType::Handle(), tirx::builtin::anylist_getitem(), + return tirx::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(), {const_anylist_handle_, ConstInt32(slot)}); } PrimExpr FuncListGet(int64_t slot) const { // use 128 bits to represent any - return tirx::Call(DataType::Handle(), tirx::builtin::anylist_getitem(), + return tirx::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(), {func_anylist_handle_, ConstInt32(slot)}); } @@ -121,11 +121,11 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { all_args.push_back(arg); } if (dst_anylist_slot >= 0) { - this->EmitStmt(tirx::Evaluate( - tirx::Call(DataType::Int(32), tirx::builtin::anylist_setitem_call_packed(), all_args))); + this->EmitStmt(tirx::Evaluate(tirx::Call( + tvm::PrimType::Int(32), tirx::builtin::anylist_setitem_call_packed(), all_args))); } else { this->EmitStmt(tirx::Evaluate( - tirx::Call(DataType::Int(32), tirx::builtin::tvm_call_packed(), all_args))); + tirx::Call(tvm::PrimType::Int(32), tirx::builtin::tvm_call_packed(), all_args))); } } @@ -143,11 +143,11 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { all_args.push_back(arg); } if (dst_anylist_slot >= 0) { - this->EmitStmt(tirx::Evaluate( - tirx::Call(DataType::Int(32), tirx::builtin::anylist_setitem_call_cpacked(), all_args))); + this->EmitStmt(tirx::Evaluate(tirx::Call( + tvm::PrimType::Int(32), tirx::builtin::anylist_setitem_call_cpacked(), all_args))); } else { this->EmitStmt(tirx::Evaluate( - tirx::Call(DataType::Int(32), tirx::builtin::tvm_call_cpacked(), all_args))); + tirx::Call(tvm::PrimType::Int(32), tirx::builtin::tvm_call_cpacked(), all_args))); } } @@ -160,10 +160,10 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { stmt_stack_ = {}; registers_num_ = 0; var_map_.clear(); - ctx_ptr_ = tirx::Var("ctx_ptr", DataType::Handle()); - reg_anylist_handle_ = tirx::Var("r", DataType::Handle()); - func_anylist_handle_ = tirx::Var("f", DataType::Handle()); - const_anylist_handle_ = tirx::Var("c", DataType::Handle()); + ctx_ptr_ = tirx::Var("ctx_ptr", PrimType::Handle()); + reg_anylist_handle_ = tirx::Var("r", PrimType::Handle()); + func_anylist_handle_ = tirx::Var("f", PrimType::Handle()); + const_anylist_handle_ = tirx::Var("c", PrimType::Handle()); ffi::Array param_names; for (Var param : func->params) { @@ -231,7 +231,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { Call call = ffi::GetRef(call_node); if (call_node->op == null_value_op_) { - return tirx::Call(DataType::Handle(), tirx::builtin::reinterpret(), {IntImm::Int64(0)}); + return tirx::Call(tvm::PrimType::Handle(), tirx::builtin::reinterpret(), {IntImm::Int64(0)}); } int64_t dst_reg = HasVoidType(call) ? -1 : NewRegister(); if (call->op.as()) { @@ -264,7 +264,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { size_t merge_register = NewRegister(); PrimExpr cond_value = this->VisitExpr(op->cond).value(); - cond_value = tirx::Call(DataType::Bool(), tirx::builtin::tvm_call_packed(), + cond_value = tirx::Call(tvm::PrimType::Bool(), tirx::builtin::tvm_call_packed(), {tirx::StringImm("vm.builtin.read_if_cond"), cond_value}); tirx::Stmt true_branch = WithNewScope([&]() { @@ -438,7 +438,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { TVM_FFI_ICHECK(tir_call->args[0].same_as(reg_anylist_handle_)); const auto* p_dst_reg = tir_call->args[1].as(); TVM_FFI_ICHECK(p_dst_reg != nullptr); - TVM_FFI_ICHECK(p_dst_reg->dtype == DataType::Int(32)); + TVM_FFI_ICHECK(p_dst_reg->ty().MatchesElementType(DLDataTypeCode::kDLInt, 32)); int64_t dst_reg = p_dst_reg->value; this->EmitCallPacked("vm.builtin.null_value", {}, dst_reg); diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index 344fc6a67e65..4a32efd81e5a 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -21,6 +21,7 @@ * \brief Lowers most builtin functions and packed calls. */ #include +#include #include #include #include @@ -29,7 +30,6 @@ #include #include #include -#include #include namespace tvm { @@ -85,7 +85,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { Expr MakeMemAllocStorage(const Call& call) { PrimValue runtime_device_index = call->args[1].as_or_throw(); StringImm storage_scope = call->args[2].as_or_throw(); - DataTypeImm output_dtype = DataTypeImm(DataType::UInt(8)); + DataTypeImm output_dtype = DataTypeImm((DLDataType{kDLUInt, 8, 1})); return Call(vm_alloc_storage_op_, {call->args[0], runtime_device_index, output_dtype, storage_scope}, Attrs()); } diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 3d895349bbc3..6784489c5b32 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -229,7 +229,7 @@ class VMShapeLowerMutator slot_map_.clear(); current_gvar_ = gvar; PrimExprSlotCollector::Collect(func, &slot_vec_, &slot_map_); - heap_size_ = IntImm(ShapeDType(), static_cast(slot_vec_.size())); + heap_size_ = IntImm(tvm::PrimType(ShapeDType()), static_cast(slot_vec_.size())); VarBinding shape_heap_binding = this->AllocShapeHeapBinding(heap_size_); shape_heap_ = shape_heap_binding->var; @@ -298,7 +298,7 @@ class VMShapeLowerMutator //------------------------------------------------------- // PrimExpr slot handling //------------------------------------------------------- - static DataType ShapeDType() { return DataType::Int(64); } + static DLDataType ShapeDType() { return DLDataType{kDLInt, 64, 1}; } /*! \brief populate additional information in the slot. */ void PopulateSlotInfo() { @@ -329,7 +329,7 @@ class VMShapeLowerMutator VarBinding AllocShapeHeapBinding(IntImm heap_size) { if (heap_size->value > 0) { - TensorType heap_ty(ShapeDType(), 1); + TensorType heap_ty(PrimType(ShapeDType()), 1); Var var("shape_heap", heap_ty); // set up the builtin func. Call call(call_builtin_with_ctx_op_, @@ -566,7 +566,7 @@ class VMShapeLowerMutator if (to_compute.size() == 0) return 0; TVM_FFI_ICHECK_GT(heap_size_->value, 0); // construct a PrimFunc that compute the shape. - tirx::Var heap("heap", DataType::Handle()); + tirx::Var heap("heap", PrimType::Handle()); ffi::Array buffer_shape{heap_size_}; tirx::Buffer buffer = tirx::decl_buffer(buffer_shape, ShapeDType(), "H", "global"); ffi::Map buffer_map; @@ -575,7 +575,8 @@ class VMShapeLowerMutator auto var_map = [&](const tirx::Var& var) -> ffi::Optional { auto it = slot_map_.find(var); TVM_FFI_ICHECK(it != slot_map_.end()); - return tirx::BufferLoad(buffer, {IntImm(ShapeDType(), it->second->index)}); + return tirx::BufferLoad( + buffer, ffi::Array{IntImm(tvm::PrimType(ShapeDType()), it->second->index)}); }; ffi::Array seq; @@ -583,7 +584,8 @@ class VMShapeLowerMutator TVM_FFI_ICHECK(!slot->value_computed); slot->value_computed = true; PrimExpr value = tirx::Substitute(slot->expr, var_map); - seq.push_back(tirx::BufferStore(buffer, value, {IntImm(ShapeDType(), slot->index)})); + seq.push_back( + tirx::BufferStore(buffer, value, {IntImm(tvm::PrimType(ShapeDType()), slot->index)})); } tirx::Stmt body = tirx::SeqStmt::Flatten(seq); @@ -678,10 +680,11 @@ class VMShapeLowerMutator // if we only check dynamic shapes, and the shape is static, we can skip. return; } - if (always_check || !IsBaseOf(TensorType(op->dtype, op->ndim), GetType(value))) { + if (always_check || !IsBaseOf(TensorType(PrimType(op->dtype), op->ndim), GetType(value))) { // check_tensor_info(value, ndim, dtype, err_ctx) Call call(builtin_check_tensor_info_, - {value, PrimValue::Int64(op->ndim), DataTypeImm(op->dtype), GetErrContext(err_ctx)}, + {value, PrimValue::Int64(op->ndim), DataTypeImm(op->dtype->dtype), + GetErrContext(err_ctx)}, Attrs(), {void_ty_}); builder_->Emit(call, "_"); } diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index 7b14a1f7e7e9..10fd67de1740 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -736,7 +736,7 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { return ExternFuncPattern(func->global_symbol); } else if (auto prim = expr.as()) { - return TypePattern(WildcardPattern(), PrimType(prim->value.dtype())); + return TypePattern(WildcardPattern(), PrimType(prim->value.ty())); } else { TVM_FFI_THROW(TypeError) << "Cannot convert Relax expression of type " << expr->GetTypeKey() diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 08689bd10f0b..f75c540a96cd 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -573,8 +573,7 @@ bool DFPatternMatcher::VisitDFPattern_(const DataTypePatternNode* op, const Expr // no need to jump, as var.dtype == value.dtype auto expr_ty = expr.as()->ty; if (const TensorTypeNode* tensor_ty = expr_ty.as()) { - return (ffi::StructuralEqual()(op->dtype, tensor_ty->dtype)) && - VisitDFPattern(op->pattern, expr); + return op->dtype == tensor_ty->dtype->dtype && VisitDFPattern(op->pattern, expr); } return false; } diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index 5cb5352ec6c2..6302ee85049a 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -369,15 +369,15 @@ RELAX_PATTERN_PRINTER_DEF(SameShapeConstraintNode, [](auto p, auto node) { p->stream << ")"; }); -DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) { +DataTypePattern::DataTypePattern(DFPattern pattern, DLDataType dtype) { ffi::ObjectPtr n = ffi::make_object(); n->pattern = std::move(pattern); - n->dtype = std::move(dtype); + n->dtype = dtype; data_ = std::move(n); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.dpl.DataTypePattern", [](DFPattern pattern, DataType dtype) { + refl::GlobalDef().def("relax.dpl.DataTypePattern", [](DFPattern pattern, DLDataType dtype) { return DataTypePattern(pattern, dtype); }); } @@ -474,11 +474,11 @@ AttrPattern DFPattern::HasAttr(const ffi::Map& attrs) const { return AttrPattern(*this, DictAttrs(attrs)); } TypePattern DFPattern::HasType(const Type& ty) const { return TypePattern(*this, ty); } -DataTypePattern DFPattern::HasDtype(const DataType& dtype) const { +DataTypePattern DFPattern::HasDtype(DLDataType dtype) const { return DataTypePattern(*this, dtype); } DataTypePattern DFPattern::HasDtype(const std::string& dtype) const { - return HasDtype(DataType(ffi::StringToDLDataType(dtype))); + return HasDtype(ffi::StringToDLDataType(dtype)); } ShapePattern DFPattern::HasShape(const ffi::Array& shape) const { return ShapePattern(*this, shape); diff --git a/src/relax/ir/dependent_type.cc b/src/relax/ir/dependent_type.cc index 6a2034ccc2a8..d95ebb1534e7 100644 --- a/src/relax/ir/dependent_type.cc +++ b/src/relax/ir/dependent_type.cc @@ -54,9 +54,9 @@ ShapeType::ShapeType(ffi::Array values, Span span) { n->ndim = static_cast(values.size()); n->values = values.Map([](PrimExpr value) { if (value->IsInstance()) { - return tvm::cast(DataType::Int(64), value); + return tvm::cast(PrimType::Int(64), value); } - TVM_FFI_ICHECK(value.dtype() == DataType::Int(64)) + TVM_FFI_ICHECK(value.ty().MatchesElementType(DLDataTypeCode::kDLInt, 64)) << "the value in ShapeType can only have dtype of int64"; return value; }); @@ -86,7 +86,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } // Tensor -TensorType::TensorType(Expr shape, DataType dtype, ffi::Optional vdevice, Span span) { +TensorType::TensorType(Expr shape, PrimType dtype, ffi::Optional vdevice, Span span) { ffi::ObjectPtr n = ffi::make_object(); // assign ndim before move TVM_FFI_ICHECK(shape.defined()) << "Must provide a shape in this constructor"; @@ -103,7 +103,7 @@ TensorType::TensorType(Expr shape, DataType dtype, ffi::Optional vdevic data_ = std::move(n); } -TensorType::TensorType(DataType dtype, int ndim, ffi::Optional vdevice, Span span) { +TensorType::TensorType(PrimType dtype, int ndim, ffi::Optional vdevice, Span span) { ffi::ObjectPtr n = ffi::make_object(); TVM_FFI_ICHECK(ndim >= -1) << "ndim of TensorType must be >= -1, but got " << ndim; n->ndim = ndim; @@ -116,13 +116,14 @@ TensorType::TensorType(DataType dtype, int ndim, ffi::Optional vdevice, TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "relax.TensorType", [](ffi::Optional shape, ffi::Optional dtype, int ndim, + "relax.TensorType", [](ffi::Optional shape, ffi::Optional dtype, int ndim, VDevice vdevice, Span span) { + PrimType resolved_dtype = dtype.value_or(PrimType(DLDataType{kDLOpaqueHandle, 0, 0})); if (shape.defined()) { TVM_FFI_CHECK_EQ(ndim, kUnknownNDim, ValueError) << "Cannot both specify shape and ndim"; - return TensorType(shape.value(), dtype.value_or(DataType::Void()), vdevice, span); + return TensorType(shape.value(), resolved_dtype, vdevice, span); } else { - return TensorType(dtype.value_or(DataType::Void()), ndim, vdevice, span); + return TensorType(resolved_dtype, ndim, vdevice, span); } }); } diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc index 304911c1dca2..68e48eaf93b6 100644 --- a/src/relax/ir/emit_te.cc +++ b/src/relax/ir/emit_te.cc @@ -42,7 +42,7 @@ te::Tensor TETensor(Expr value, ffi::Map tir_var_map, std:: // checked-type might not be properly set. In this case we set the shape and dtype of the returned // TE tensor. if (const auto* constant = value.as()) { - n->dtype = DataType(constant->data->dtype); + n->dtype = PrimType(constant->data->dtype); int ndim = constant->data->ndim; ffi::Shape shape_tuple = constant->data.Shape(); diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 11e80135500a..b4c4486f0dd4 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -257,9 +257,9 @@ ShapeExpr::ShapeExpr(ffi::Array values, Span span) { n->values = values.Map([](PrimExpr value) { if (value->IsInstance()) { - return tvm::cast(DataType::Int(64), value); + return tvm::cast(PrimType::Int(64), value); } - TVM_FFI_ICHECK(value.dtype() == DataType::Int(64)) + TVM_FFI_ICHECK(value.ty().MatchesElementType(DLDataTypeCode::kDLInt, 64)) << "the value in ShapeType can only have dtype of int64"; return value; }); @@ -350,7 +350,7 @@ Constant::Constant(runtime::Tensor data, ffi::Optional ty_annotation, Span if (ty_annotation.defined()) { n->ty = ty_annotation.value(); } else { - TensorType tinfo(ShapeExpr(values), n->data.DataType(), VDevice(), span); + TensorType tinfo(ShapeExpr(values), PrimType(n->data.DataType()), VDevice(), span); n->ty = tinfo; } @@ -366,7 +366,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { PrimValue::PrimValue(PrimExpr value, Span span) { ffi::ObjectPtr n = ffi::make_object(); - n->ty = PrimType(value.dtype()); + n->ty = PrimType(value.ty()); n->value = std::move(value); n->span = std::move(span); data_ = std::move(n); @@ -396,9 +396,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](ffi::String value, Span span) { return StringImm(value, span); }); } -DataTypeImm::DataTypeImm(DataType value, Span span) { +DataTypeImm::DataTypeImm(DLDataType value, Span span) { ffi::ObjectPtr n = ffi::make_object(); - n->value = std::move(value); + n->value = value; n->span = std::move(span); n->ty = ObjectType(); data_ = std::move(n); @@ -407,7 +407,7 @@ DataTypeImm::DataTypeImm(DataType value, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.DataTypeImm", - [](DataType value, Span span) { return DataTypeImm(value, span); }); + [](DLDataType value, Span span) { return DataTypeImm(value, span); }); } MatchCast::MatchCast(Var var, Expr value, Type ty, Span span) { diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index dd67f65dea09..15b8064d2b6f 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -85,7 +85,7 @@ Type InferTypeAllGather(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); int num_workers = attrs->num_workers; - DataType output_dtype = input_ty->dtype; + PrimType output_dtype = input_ty->dtype; auto input_shape = input_ty->GetShape(); if (!input_shape.defined()) { return input_ty; @@ -143,7 +143,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { Type InferTypeScatter(const Call& call, const BlockBuilder& ctx) { TensorType input_ty = GetUnaryInputTensorType(call, ctx); - DataType output_dtype = input_ty->dtype; + PrimType output_dtype = input_ty->dtype; const auto* attrs = call->attrs.as(); int num_workers = attrs->num_workers; diff --git a/src/relax/op/distributed/binary.cc b/src/relax/op/distributed/binary.cc index 766d60edb86f..daaacff4121b 100644 --- a/src/relax/op/distributed/binary.cc +++ b/src/relax/op/distributed/binary.cc @@ -31,7 +31,7 @@ Type InferDistTypeBroadcastCMP(const Call& call, const BlockBuilder& ctx) { return InferDistTypeBroadcast( call, ctx, [](const Call& call, const BlockBuilder& ctx, const TensorType& x1_ty, - const TensorType& x2_ty) { return DataType::Bool(); }); + const TensorType& x2_ty) { return DLDataType{kDLBool, 8, 1}; }); } /***************** Arithmetic operators *****************/ diff --git a/src/relax/op/distributed/binary.h b/src/relax/op/distributed/binary.h index 5fd39b50f364..a6d3fd9ba124 100644 --- a/src/relax/op/distributed/binary.h +++ b/src/relax/op/distributed/binary.h @@ -41,8 +41,8 @@ Type InferDistTypeBroadcast(const Call& call, const BlockBuilder& ctx, FType f_c TensorType x1_ty = input_dtensor_tys[0]->tensor_ty; TensorType x2_ty = input_dtensor_tys[1]->tensor_ty; - // DateType - DataType output_dtype = f_compute_out_dtype(call, ctx, x1_ty, x2_ty); + // Dtype + PrimType output_dtype(f_compute_out_dtype(call, ctx, x1_ty, x2_ty)); // ndims TVM_FFI_ICHECK(!x1_ty->IsUnknownNdim() && !x2_ty->IsUnknownNdim()) diff --git a/src/relax/op/distributed/distributed.cc b/src/relax/op/distributed/distributed.cc index b009630070cd..ff5bc986c0c7 100644 --- a/src/relax/op/distributed/distributed.cc +++ b/src/relax/op/distributed/distributed.cc @@ -154,7 +154,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { Type InferTypeRtoS(const Call& call, const BlockBuilder& ctx) { TensorType input_ty = GetUnaryInputTensorType(call, ctx); - DataType output_dtype = input_ty->dtype; + PrimType output_dtype = input_ty->dtype; const auto* attrs = call->attrs.as(); int num_workers = attrs->num_workers; diff --git a/src/relax/op/distributed/linear_algebra.cc b/src/relax/op/distributed/linear_algebra.cc index 80fccbe115a9..b498f1a4a953 100644 --- a/src/relax/op/distributed/linear_algebra.cc +++ b/src/relax/op/distributed/linear_algebra.cc @@ -32,9 +32,9 @@ Type InferDistTypeMatmul(const Call& call, const BlockBuilder& ctx) { TensorType x2_ty = input_dtensor_tys[1]->tensor_ty; const auto* attrs = call->attrs.as(); - DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, x1_ty, x2_ty) - : attrs->out_dtype; + PrimType out_dtype = PrimType(attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? InferBinaryArithOpOutDtype(call, ctx, x1_ty, x2_ty) + : attrs->out_dtype); if (x1_ty->IsUnknownNdim() || x2_ty->IsUnknownNdim()) { TVM_FFI_VISIT_THROW(ValueError, call) diff --git a/src/relax/op/distributed/nn.cc b/src/relax/op/distributed/nn.cc index 1339a18e72d0..fcdc37c54046 100644 --- a/src/relax/op/distributed/nn.cc +++ b/src/relax/op/distributed/nn.cc @@ -33,7 +33,9 @@ Type InferDistTypeSoftmax(const Call& call, const BlockBuilder& ctx) { if (input_tensor_ty->IsUnknownNdim()) { TVM_FFI_VISIT_THROW(ValueError, call) << "Input of distributed operator must have known ndim"; } - if (!input_tensor_ty->IsUnknownDtype() && !input_tensor_ty->dtype.is_float()) { + PrimType input_dtype = input_tensor_ty->dtype; + // Softmax validation preserves the old float-kind check; lanes do not affect this policy. + if (!input_tensor_ty->IsUnknownDtype() && !input_dtype.MatchesCode(DLDataTypeCode::kDLFloat)) { TVM_FFI_VISIT_THROW(TypeError, call) << "Softmax requires the input tensor to have float " "dtype. However, the given input dtype is " << input_tensor_ty->dtype; diff --git a/src/relax/op/distributed/unary.cc b/src/relax/op/distributed/unary.cc index 4356b403c6d9..8e4ccce23a9c 100644 --- a/src/relax/op/distributed/unary.cc +++ b/src/relax/op/distributed/unary.cc @@ -25,7 +25,7 @@ namespace distributed { Type InferDistTypeUnaryCheck(const Call& call, const BlockBuilder& ctx) { return InferDistTypeUnary(call, ctx, - [](const TensorType& input_ty) { return DataType::Bool(); }); + [](const TensorType& input_ty) { return PrimType::Bool(); }); } RELAX_REGISTER_UNARY_ARITH_DIST_INFER_TYPE(abs, /*require_float_dtype=*/false); diff --git a/src/relax/op/distributed/unary.h b/src/relax/op/distributed/unary.h index 92c719ad0b98..58e0a41e27cb 100644 --- a/src/relax/op/distributed/unary.h +++ b/src/relax/op/distributed/unary.h @@ -40,15 +40,22 @@ Type InferDistTypeUnary(const Call& call, const BlockBuilder& ctx, FType f_compu distributed::DTensorType input_dtensor_ty = input_dtensor_tys[0]; TensorType input_tensor_ty = input_dtensor_ty->tensor_ty; + PrimType input_dtype = input_tensor_ty->dtype; + // Unary op validation preserves the old float-kind check; lanes do not affect this policy. if (require_float_dtype && !input_tensor_ty->IsUnknownDtype() && - !input_tensor_ty->dtype.is_float()) { + !input_dtype.MatchesCode(DLDataTypeCode::kDLFloat)) { TVM_FFI_VISIT_THROW(TypeError, call) << call->op << " requires the input tensor to have float dtype. However, the given input dtype is " << input_tensor_ty->dtype; } auto output_ty = ffi::make_object(*input_tensor_ty.get()); - output_ty->dtype = f_compute_out_dtype(input_tensor_ty); + auto computed_dtype = f_compute_out_dtype(input_tensor_ty); + if constexpr (std::is_same_v, PrimType>) { + output_ty->dtype = computed_dtype; + } else { + output_ty->dtype = PrimType(computed_dtype); + } TensorType out_tensor_ty(output_ty); return distributed::DTensorType(out_tensor_ty, input_dtensor_ty->device_mesh, input_dtensor_ty->placement); diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index b92167e031f1..82b12c0fe26f 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -41,7 +41,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { Resize3DAttrs::RegisterReflection(); } Expr resize2d(Expr data, Expr size, ffi::Array roi, ffi::String layout, ffi::String method, ffi::String coordinate_transformation_mode, ffi::String rounding_method, double cubic_alpha, int cubic_exclude, - double extrapolation_value, ffi::Optional out_dtype) { + double extrapolation_value, ffi::Optional out_dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->roi = std::move(roi); attrs->layout = std::move(layout); @@ -51,7 +51,7 @@ Expr resize2d(Expr data, Expr size, ffi::Array roi, ffi::String layout attrs->cubic_alpha = cubic_alpha; attrs->cubic_exclude = cubic_exclude; attrs->extrapolation_value = extrapolation_value; - attrs->out_dtype = out_dtype.value_or(DataType::Void()); + attrs->out_dtype = out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); static const Op& op = Op::Get("relax.image.resize2d"); return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {}); @@ -93,7 +93,9 @@ Type InferTypeResize2D(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCHW", // /*tensor_name=*/"data"); - DataType out_dtype = attrs->out_dtype.is_void() ? data_ty->dtype : attrs->out_dtype; + PrimType out_dtype = attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? data_ty->dtype + : PrimType(attrs->out_dtype); ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, ffi::GetRef(data_ty), data_layout); @@ -155,7 +157,7 @@ TVM_REGISTER_OP("relax.image.resize2d") Expr resize3d(Expr data, Expr size, ffi::Array roi, ffi::String layout, ffi::String method, ffi::String coordinate_transformation_mode, ffi::String rounding_method, double cubic_alpha, int cubic_exclude, - double extrapolation_value, ffi::Optional out_dtype) { + double extrapolation_value, ffi::Optional out_dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->roi = std::move(roi); attrs->layout = std::move(layout); @@ -165,7 +167,7 @@ Expr resize3d(Expr data, Expr size, ffi::Array roi, ffi::String layout attrs->cubic_alpha = cubic_alpha; attrs->cubic_exclude = cubic_exclude; attrs->extrapolation_value = extrapolation_value; - attrs->out_dtype = out_dtype.value_or(DataType::Void()); + attrs->out_dtype = out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); static const Op& op = Op::Get("relax.image.resize3d"); return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {}); @@ -207,7 +209,9 @@ Type InferTypeResize3D(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCDHW", // /*tensor_name=*/"data"); - DataType out_dtype = attrs->out_dtype.is_void() ? data_ty->dtype : attrs->out_dtype; + PrimType out_dtype = attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? data_ty->dtype + : PrimType(attrs->out_dtype); ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, ffi::GetRef(data_ty), data_layout); @@ -315,7 +319,7 @@ Type InferTypeGridSample(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/is_ncdhw ? "NCDHW" : "NCHW", /*tensor_name=*/"data"); - DataType out_dtype = data_ty->dtype; + PrimType out_dtype = data_ty->dtype; ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, ffi::GetRef(data_ty), data_layout); @@ -422,7 +426,7 @@ Type InferTypeAffineGrid(const Call& call, const BlockBuilder& ctx) { } } - DataType out_dtype = data_ty->dtype; + PrimType out_dtype = data_ty->dtype; if (data_shape == nullptr || size_value == nullptr) { return TensorType(out_dtype, /*ndim=*/4, data_ty->vdevice); diff --git a/src/relax/op/image/resize.h b/src/relax/op/image/resize.h index 382a3a162be2..1aaed69f9146 100644 --- a/src/relax/op/image/resize.h +++ b/src/relax/op/image/resize.h @@ -36,13 +36,13 @@ namespace relax { Expr resize2d(Expr data, Expr size, ffi::Array roi, ffi::String layout, ffi::String method, ffi::String coordinate_transformation_mode, ffi::String rounding_method, double cubic_alpha, int cubic_exclude, - double extrapolation_value, ffi::Optional out_dtype); + double extrapolation_value, ffi::Optional out_dtype); /*! \brief Image resize3d operator. */ Expr resize3d(Expr data, Expr size, ffi::Array roi, ffi::String layout, ffi::String method, ffi::String coordinate_transformation_mode, ffi::String rounding_method, double cubic_alpha, int cubic_exclude, - double extrapolation_value, ffi::Optional out_dtype); + double extrapolation_value, ffi::Optional out_dtype); /*! \brief Image grid_sample operator. */ Expr grid_sample(Expr data, Expr grid, ffi::String method, ffi::String layout, diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index 25ad9aa66d8e..828eba4950f0 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -87,7 +87,7 @@ Type InferTypeView(const Call& call, const BlockBuilder& ctx) { } }(); - auto view_dtype = [&]() -> std::optional { + auto view_dtype = [&]() -> std::optional { Type ty = GetType(arg_dtype); if (HasVoidType(arg_dtype)) { @@ -116,7 +116,7 @@ Type InferTypeView(const Call& call, const BlockBuilder& ctx) { } else if (ty.as()) { // The view changes the datatype, but we don't know what it is // being changed into. - return DataType::Void(); + return DLDataType{kDLOpaqueHandle, 0, 0}; } else { TVM_FFI_THROW(TypeError) << "Operator " << call->op << " expects the dtype argument to be a relax::DataTypeImm, " @@ -131,7 +131,7 @@ Type InferTypeView(const Call& call, const BlockBuilder& ctx) { // No byte offset is specified, so no change is applied. return IntImm::Int64(0); } else if (auto prim_ty = ty.as()) { - TVM_FFI_CHECK_EQ(prim_ty->dtype, DataType::Int(64), TypeError) + TVM_FFI_CHECK_EQ(prim_ty->dtype, (DLDataType{kDLInt, 64, 1}), TypeError) << "Operator " << call->op << " expects the relative_byte_offset to be a 64-bit integer, but received " << arg_relative_byte_offset << ", which has type " << ty; @@ -167,17 +167,16 @@ Type InferTypeView(const Call& call, const BlockBuilder& ctx) { output_ndim = data_ty->ndim; } - DataType output_dtype = view_dtype.value_or(data_ty->dtype); + DLDataType output_raw_dtype = view_dtype.value_or(data_ty->dtype->dtype); + PrimType output_dtype(output_raw_dtype); - // Helper function, returns the number of bytes per vectorized - // element. Cannot use `DataType::bytes`, as it returns the - // number of bytes per scalar element. - auto get_size_bytes = [](const DataType& dtype) -> ffi::Optional { - if (dtype.is_void()) { + // Helper function returns the number of bytes per vectorized element. + auto get_size_bytes = [](DLDataType dtype) -> ffi::Optional { + PrimType ty(dtype); + if (ty.IsVoid() || ty.IsScalableVector()) { return std::nullopt; } else { - auto size_bits = dtype.bits() * dtype.lanes(); - return IntImm::Int64((size_bits + 7) / 8); + return IntImm::Int64(static_cast(ty.StorageBytes())); } }; @@ -199,8 +198,8 @@ Type InferTypeView(const Call& call, const BlockBuilder& ctx) { ffi::Optional input_nelements = get_num_elements(input_shape); ffi::Optional output_nelements = get_num_elements(output_shape); - ffi::Optional input_element_size = get_size_bytes(data_ty->dtype); - ffi::Optional output_element_size = get_size_bytes(output_dtype); + ffi::Optional input_element_size = get_size_bytes(data_ty->dtype->dtype); + ffi::Optional output_element_size = get_size_bytes(output_raw_dtype); if (input_nelements && output_nelements && input_element_size && output_element_size && view_relative_byte_offset) { @@ -329,8 +328,9 @@ Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { } if (HasVoidType(dtype)) { - auto data_dtype = data->ty.as().value()->dtype; - TVM_FFI_ICHECK(!data_dtype.is_void()) + DLDataType data_dtype = data->ty.as().value()->dtype->dtype; + TVM_FFI_ICHECK(!(((data_dtype).code == kDLOpaqueHandle) && ((data_dtype).bits == 0) && + ((data_dtype).lanes == 0))) << "Legalization of " << call->op << " requires that either the output dtype be explicitly specified, " << "or the input dtype is known. " diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index 83080537c1d0..62e7d2959346 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -143,7 +143,7 @@ Type InferTypeAttention(const Call& call, const BlockBuilder& ctx) { return TensorType(ShapeExpr(output_shape), q_ty->dtype, q_ty->vdevice); } -Call InferMixedPrecisionAttention(const Call& call, const DataType& out_dtype) { +Call InferMixedPrecisionAttention(const Call& call, DLDataType out_dtype) { return attention(call->args[0], call->args[1], call->args[2], std::nullopt, std::nullopt, std::nullopt, std::nullopt) .as_or_throw(); diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 1fa9b9b1ae94..90d58a9e662d 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -47,7 +47,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { Expr conv1d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype) { + ffi::Optional out_dtype) { padding = GetCompletePadding1D(std::move(padding)); TVM_FFI_ICHECK_GT(groups, 0) @@ -62,7 +62,8 @@ Expr conv1d(Expr data, Expr weight, ffi::Array strides, ffi::Array(std::move(data), std::move(weight), std::move(strides), std::move(padding), std::move(dilation), groups, data_layout, std::move(kernel_layout), out_layout.value_or(data_layout), - out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv1d"); + out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})), + /*op_name=*/"relax.nn.conv1d"); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -91,9 +92,9 @@ Type InferTypeConv1d(const Call& call, const BlockBuilder& ctx) { ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); - DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) - : attrs->out_dtype; + PrimType out_dtype = PrimType(attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) + : attrs->out_dtype); ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { return TensorType(out_dtype, out_layout.ndim(), vdevice); @@ -186,7 +187,7 @@ InferLayoutOutput InferLayoutConv1d( return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); } -Call InferMixedPrecisionConv1d(const Call& call, const DataType& out_dtype) { +Call InferMixedPrecisionConv1d(const Call& call, DLDataType out_dtype) { const auto* conv1d_attrs = call->attrs.as(); return conv1d(call->args[0], call->args[1], conv1d_attrs->strides, conv1d_attrs->padding, conv1d_attrs->dilation, conv1d_attrs->groups, conv1d_attrs->data_layout, @@ -210,7 +211,7 @@ TVM_REGISTER_OP("relax.nn.conv1d") Expr conv2d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype) { + ffi::Optional out_dtype) { padding = GetCompletePadding2D(std::move(padding)); if (strides.size() == 1) { strides.push_back(strides[0]); @@ -231,7 +232,8 @@ Expr conv2d(Expr data, Expr weight, ffi::Array strides, ffi::Array(std::move(data), std::move(weight), std::move(strides), std::move(padding), std::move(dilation), groups, data_layout, std::move(kernel_layout), out_layout.value_or(data_layout), - out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv2d"); + out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})), + /*op_name=*/"relax.nn.conv2d"); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -260,9 +262,9 @@ Type InferTypeConv2d(const Call& call, const BlockBuilder& ctx) { ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); - DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) - : attrs->out_dtype; + PrimType out_dtype = PrimType(attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) + : attrs->out_dtype); ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { return TensorType(out_dtype, out_layout.ndim(), vdevice); @@ -336,9 +338,10 @@ InferLayoutOutput InferLayoutConv2d( SLayout desired_data_layout = (*it).second[0]; SLayout desired_weight_layout = (*it).second[1]; SLayout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; - tirx::SLayout input_layout(attrs->data_layout, DataType::Int(64)); - tirx::SLayout kernel_layout(attrs->kernel_layout, DataType::Int(64)); - tirx::SLayout out_layout(attrs->out_layout, DataType::Int(64)); + tvm::PrimType i64_ty = tvm::PrimType::Int(64); + tirx::SLayout input_layout(attrs->data_layout, i64_ty); + tirx::SLayout kernel_layout(attrs->kernel_layout, i64_ty); + tirx::SLayout out_layout(attrs->out_layout, i64_ty); if ((desired_data_layout.ndim() == input_layout.ndim()) && (desired_weight_layout.ndim() == kernel_layout.ndim()) && @@ -396,7 +399,7 @@ InferLayoutOutput InferLayoutConv2d( return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); } -Call InferMixedPrecisionConv2d(const Call& call, const DataType& out_dtype) { +Call InferMixedPrecisionConv2d(const Call& call, DLDataType out_dtype) { const auto* conv2d_attrs = call->attrs.as(); return conv2d(call->args[0], call->args[1], conv2d_attrs->strides, conv2d_attrs->padding, conv2d_attrs->dilation, conv2d_attrs->groups, conv2d_attrs->data_layout, @@ -420,7 +423,7 @@ TVM_REGISTER_OP("relax.nn.conv2d") Expr conv3d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype) { + ffi::Optional out_dtype) { padding = GetCompletePadding3D(std::move(padding)); if (strides.size() == 1) { strides.push_back(strides[0]); @@ -443,7 +446,8 @@ Expr conv3d(Expr data, Expr weight, ffi::Array strides, ffi::Array(std::move(data), std::move(weight), std::move(strides), std::move(padding), std::move(dilation), groups, data_layout, std::move(kernel_layout), out_layout.value_or(data_layout), - out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv3d"); + out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})), + /*op_name=*/"relax.nn.conv3d"); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -472,9 +476,9 @@ Type InferTypeConv3d(const Call& call, const BlockBuilder& ctx) { ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); - DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) - : attrs->out_dtype; + PrimType out_dtype = PrimType(attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) + : attrs->out_dtype); ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { return TensorType(out_dtype, out_layout.ndim(), vdevice); @@ -581,7 +585,7 @@ InferLayoutOutput InferLayoutConv3d( return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); } -Call InferMixedPrecisionConv3d(const Call& call, const DataType& out_dtype) { +Call InferMixedPrecisionConv3d(const Call& call, DLDataType out_dtype) { const auto* conv3d_attrs = call->attrs.as(); return conv3d(call->args[0], call->args[1], conv3d_attrs->strides, conv3d_attrs->padding, conv3d_attrs->dilation, conv3d_attrs->groups, conv3d_attrs->data_layout, @@ -604,7 +608,7 @@ Expr conv1d_transpose(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array output_padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype) { + ffi::Optional out_dtype) { padding = GetCompletePadding1D(std::move(padding)); TVM_FFI_ICHECK_GT(groups, 0) @@ -630,7 +634,7 @@ Expr conv1d_transpose(Expr data, Expr weight, ffi::Array strides, attrs->data_layout = data_layout; attrs->kernel_layout = std::move(kernel_layout); attrs->out_layout = out_layout.value_or(data_layout); - attrs->out_dtype = std::move(out_dtype.value_or(DataType::Void())); + attrs->out_dtype = out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); const Op& op = Op::Get("relax.nn.conv1d_transpose"); return Call(op, {data, weight}, Attrs(attrs), {}); } @@ -660,9 +664,9 @@ Type InferTypeConv1dTranspose(const Call& call, const BlockBuilder& ctx) { ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); - DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) - : attrs->out_dtype; + PrimType out_dtype = PrimType(attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) + : attrs->out_dtype); ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { return TensorType(out_dtype, out_layout.ndim(), vdevice); @@ -758,7 +762,7 @@ InferLayoutOutput InferLayoutConv1dTranspose( return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); } -Call InferMixedPrecisionConv1dTranspose(const Call& call, const DataType& out_dtype) { +Call InferMixedPrecisionConv1dTranspose(const Call& call, DLDataType out_dtype) { const auto* conv1d_transpose_attrs = call->attrs.as(); return conv1d_transpose(call->args[0], call->args[1], conv1d_transpose_attrs->strides, conv1d_transpose_attrs->padding, conv1d_transpose_attrs->output_padding, @@ -786,7 +790,7 @@ Expr conv2d_transpose(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array output_padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype) { + ffi::Optional out_dtype) { padding = GetCompletePadding2D(std::move(padding)); if (output_padding.size() == 1) { output_padding.push_back(output_padding[0]); @@ -821,7 +825,7 @@ Expr conv2d_transpose(Expr data, Expr weight, ffi::Array strides, attrs->data_layout = data_layout; attrs->kernel_layout = std::move(kernel_layout); attrs->out_layout = out_layout.value_or(data_layout); - attrs->out_dtype = std::move(out_dtype.value_or(DataType::Void())); + attrs->out_dtype = out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); const Op& op = Op::Get("relax.nn.conv2d_transpose"); return Call(op, {data, weight}, Attrs(attrs), {}); } @@ -852,9 +856,9 @@ Type InferTypeConv2dTranspose(const Call& call, const BlockBuilder& ctx) { ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); - DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) - : attrs->out_dtype; + PrimType out_dtype = PrimType(attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) + : attrs->out_dtype); ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { return TensorType(out_dtype, out_layout.ndim(), vdevice); @@ -987,7 +991,7 @@ InferLayoutOutput InferLayoutConv2dTranspose( return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); } -Call InferMixedPrecisionConv2dTranspose(const Call& call, const DataType& out_dtype) { +Call InferMixedPrecisionConv2dTranspose(const Call& call, DLDataType out_dtype) { const auto* conv2d_transpose_attrs = call->attrs.as(); return conv2d_transpose(call->args[0], call->args[1], conv2d_transpose_attrs->strides, conv2d_transpose_attrs->padding, conv2d_transpose_attrs->output_padding, @@ -1015,7 +1019,7 @@ Expr conv3d_transpose(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array output_padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype) { + ffi::Optional out_dtype) { padding = GetCompletePadding3D(std::move(padding)); if (output_padding.size() == 1) { output_padding.push_back(output_padding[0]); @@ -1053,7 +1057,7 @@ Expr conv3d_transpose(Expr data, Expr weight, ffi::Array strides, attrs->data_layout = data_layout; attrs->kernel_layout = std::move(kernel_layout); attrs->out_layout = out_layout.value_or(data_layout); - attrs->out_dtype = std::move(out_dtype.value_or(DataType::Void())); + attrs->out_dtype = out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); const Op& op = Op::Get("relax.nn.conv3d_transpose"); return Call(op, {data, weight}, Attrs(attrs), {}); } @@ -1084,9 +1088,9 @@ Type InferTypeConv3dTranspose(const Call& call, const BlockBuilder& ctx) { ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); - DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) - : attrs->out_dtype; + PrimType out_dtype = PrimType(attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) + : attrs->out_dtype); ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { return TensorType(out_dtype, out_layout.ndim(), vdevice); @@ -1227,7 +1231,7 @@ InferLayoutOutput InferLayoutConv3dTranspose( return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); } -Call InferMixedPrecisionConv3dTranspose(const Call& call, const DataType& out_dtype) { +Call InferMixedPrecisionConv3dTranspose(const Call& call, DLDataType out_dtype) { const auto* conv3d_transpose_attrs = call->attrs.as(); return conv3d_transpose(call->args[0], call->args[1], conv3d_transpose_attrs->strides, conv3d_transpose_attrs->padding, conv3d_transpose_attrs->output_padding, diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h index b08eb8a83ff8..b33a19f07057 100644 --- a/src/relax/op/nn/convolution.h +++ b/src/relax/op/nn/convolution.h @@ -39,7 +39,7 @@ template inline Expr MakeConv(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::String out_layout, - DataType out_dtype, std::string op_name) { + DLDataType out_dtype, std::string op_name) { auto attrs = ffi::make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -48,7 +48,7 @@ inline Expr MakeConv(Expr data, Expr weight, ffi::Array strides, attrs->data_layout = std::move(data_layout); attrs->kernel_layout = std::move(kernel_layout); attrs->out_layout = std::move(out_layout); - attrs->out_dtype = std::move(out_dtype); + attrs->out_dtype = out_dtype; const Op& op = Op::Get(op_name); return Call(op, {data, weight}, Attrs(attrs), {}); } @@ -57,19 +57,19 @@ inline Expr MakeConv(Expr data, Expr weight, ffi::Array strides, Expr conv1d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype); + ffi::Optional out_dtype); /*! \brief 2D convolution */ Expr conv2d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype); + ffi::Optional out_dtype); /*! \brief 3D convolution */ Expr conv3d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype); + ffi::Optional out_dtype); /*! * \brief One dimensional transposed convolution operator. @@ -81,7 +81,7 @@ Expr conv1d_transpose(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array output_padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype); + ffi::Optional out_dtype); /*! * \brief Two dimensional transposed convolution operator. @@ -93,7 +93,7 @@ Expr conv2d_transpose(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array output_padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype); + ffi::Optional out_dtype); /*! * \brief Three dimensional transposed convolution operator. @@ -105,7 +105,7 @@ Expr conv3d_transpose(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array output_padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype); + ffi::Optional out_dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index b24f81c72d49..5deb6db937bb 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -122,7 +122,9 @@ Type InferTypePRelu(const Call& call, const BlockBuilder& ctx) { if (data_ty->IsUnknownNdim()) { return data_ty; } - if (!data_ty->IsUnknownDtype() && !data_ty->dtype.is_float()) { + PrimType data_dtype = data_ty->dtype; + // PRelu preserves the old float-kind check; vector lanes are irrelevant to this check. + if (!data_ty->IsUnknownDtype() && !data_dtype.MatchesCode(DLDataTypeCode::kDLFloat)) { TVM_FFI_VISIT_THROW(TypeError, call) << "Prelu requires the input tensor to have float " "dtype. However, the given input dtype is " << data_ty->dtype; @@ -186,10 +188,14 @@ Type InferTypeSoftmax(const Call& call, const BlockBuilder& ctx) { if (data_ty->IsUnknownNdim()) { return data_ty; } - if (!data_ty->IsUnknownDtype() && !data_ty->dtype.is_float() && !data_ty->dtype.is_bfloat()) { - TVM_FFI_VISIT_THROW(TypeError, call) << "Softmax requires the input tensor to have float " - "dtype. However, the given input dtype is " - << data_ty->dtype; + if (!data_ty->IsUnknownDtype()) { + PrimType data_dtype = data_ty->dtype; + // Softmax only requires a floating element kind; lane encoding is irrelevant to the check. + if (!data_dtype.MatchesCode(kDLFloat, kDLBfloat)) { + TVM_FFI_VISIT_THROW(TypeError, call) << "Softmax requires the input tensor to have float " + "dtype. However, the given input dtype is " + << data_ty->dtype; + } } const auto* attrs = call->attrs.as(); NormalizeAxis(call, ctx, data_ty->ndim, attrs->axis); @@ -380,10 +386,14 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, axes_non_neg = NormalizeAxes(call, ctx, data_ty->ndim, axes); } int n_axis = axes.size(); - if (!data_ty->IsUnknownDtype() && (!data_ty->dtype.is_float() && !data_ty->dtype.is_bfloat())) { - TVM_FFI_VISIT_THROW(TypeError, call) - << op << " requires the input data to have float dtype. However, the given data dtype is " - << data_ty->dtype; + if (!data_ty->IsUnknownDtype()) { + PrimType data_dtype = data_ty->dtype; + // Norm ops only require a floating element kind; lane encoding is irrelevant to the check. + if (!data_dtype.MatchesCode(kDLFloat, kDLBfloat)) { + TVM_FFI_VISIT_THROW(TypeError, call) + << op << " requires the input data to have float dtype. However, the given data dtype is " + << data_ty->dtype; + } } for (int i = 1; i < n_input; ++i) { if (input_ty[i]->dtype != data_ty->dtype) { @@ -462,7 +472,7 @@ Type InferTypeBatchNorm(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_ty, {attrs->axis}); - DataType dtype = input_ty[0]->dtype; + PrimType dtype = input_ty[0]->dtype; if (unknown_shape) { auto vdev = input_ty[0]->vdevice; return TupleType({TensorType(dtype, input_ty[0]->ndim, vdev), @@ -620,7 +630,9 @@ Type InferTypeGroupNorm(const Call& call, const BlockBuilder& ctx) { << channel_axis << ", axes: " << attrs->axes; } } - if (!data_ty->IsUnknownDtype() && !data_ty->dtype.is_float()) { + PrimType data_dtype = data_ty->dtype; + // GroupNorm preserves the old float-kind check; vector lanes are irrelevant to this check. + if (!data_ty->IsUnknownDtype() && !data_dtype.MatchesCode(DLDataTypeCode::kDLFloat)) { TVM_FFI_VISIT_THROW(TypeError, call) << op << " expects that data must be float, but got " << data_ty->dtype; } @@ -890,7 +902,7 @@ Type InferTypeCrossEntropy(const Call& call, const BlockBuilder& ctx) { TensorType label_ty = input_ty[1]; // infer dtype - DataType dtype = InferBinaryArithOpOutDtype(call, ctx, pred_ty, label_ty); + PrimType dtype(InferBinaryArithOpOutDtype(call, ctx, pred_ty, label_ty)); // infer vdevice ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, pred_ty, label_ty); @@ -1002,23 +1014,26 @@ Type InferTypeNLLLoss(const Call& call, const BlockBuilder& ctx) { } // infer dtype, vdevice - DataType output_dtype; - ffi::Optional vdevice; - if (wgt_ty != nullptr) { - output_dtype = InferBinaryArithOpOutDtype(call, ctx, ffi::GetRef(pred_ty), - ffi::GetRef(wgt_ty)); - vdevice = InferBinaryArithOpOutVDevice(call, ctx, ffi::GetRef(pred_ty), - ffi::GetRef(wgt_ty)); - } else { - output_dtype = pred_ty->dtype; - vdevice = pred_ty->vdevice; - } + PrimType output_dtype = + wgt_ty != nullptr + ? PrimType(InferBinaryArithOpOutDtype(call, ctx, ffi::GetRef(pred_ty), + ffi::GetRef(wgt_ty))) + : pred_ty->dtype; + ffi::Optional vdevice = + wgt_ty != nullptr ? InferBinaryArithOpOutVDevice(call, ctx, ffi::GetRef(pred_ty), + ffi::GetRef(wgt_ty)) + : pred_ty->vdevice; // the type of targets must be int/uint. - if (!tgt_ty->IsUnknownDtype() && !tgt_ty->dtype.is_int() && !tgt_ty->dtype.is_uint()) { - TVM_FFI_VISIT_THROW(TypeError, call) - << "NLLLoss expects the dtype of targets to be int/uint. However, the dtype of targets is " - << tgt_ty->dtype; + if (!tgt_ty->IsUnknownDtype()) { + PrimType target_dtype = tgt_ty->dtype; + // NLLLoss only needs the target element kind; vector lanes do not affect target indexing. + if (!target_dtype.MatchesCode(DLDataTypeCode::kDLInt) && + !target_dtype.MatchesCode(DLDataTypeCode::kDLUInt)) { + TVM_FFI_VISIT_THROW(TypeError, call) << "NLLLoss expects the dtype of targets to be " + "int/uint. However, the dtype of targets is " + << tgt_ty->dtype; + } } // infer ndim diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 856cd75c5902..84f994bc612f 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -275,7 +275,8 @@ InferLayoutOutput InferLayoutPool2d( ffi::ObjectPtr new_attrs = ffi::make_object(*attrs); if (layout->layout.ndim() != layout->layout.ndim_primal()) { - tirx::SLayout in_layout(attrs->layout, DataType::Int(64)); + tvm::PrimType i64_ty = tvm::PrimType::Int(64); + tirx::SLayout in_layout(attrs->layout, i64_ty); auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout); auto data_si = GetType(call->args[0]); TensorType data_ty = data_si.as().value(); @@ -675,7 +676,8 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool2D( LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ffi::ObjectPtr new_attrs = ffi::make_object(*attrs); if (layout->layout.ndim() != layout->layout.ndim_primal()) { - tirx::SLayout in_layout(attrs->layout, DataType::Int(64)); + tvm::PrimType i64_ty = tvm::PrimType::Int(64); + tirx::SLayout in_layout(attrs->layout, i64_ty); auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout); auto data_si = GetType(call->args[0]); TensorType data_ty = data_si.as().value(); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 9c58ab769950..16e5d5f20d0e 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -409,9 +409,9 @@ static ffi::Optional InferCallTIROutputTypeFromArguments( TVM_FFI_ICHECK(packed_tuple_ty); PrimType dummy_arg_ty = [&]() { if (packed_tuple_ty->values) { - return PrimType(packed_tuple_ty->values.value()[i].dtype()); + return PrimType(packed_tuple_ty->values.value()[i].ty()); } else { - return PrimType(DataType::Int(64)); + return PrimType::Int(64); } }(); dummy_args.push_back(Var("dummy_trailing_arg", dummy_arg_ty)); @@ -1119,7 +1119,7 @@ Type InferTypeSize(const Call& call, const BlockBuilder& ctx) { auto* tensor_ty = GetType(call->args[0]).as(); TVM_FFI_ICHECK(tensor_ty) << "size expects a tensor input, but received " << arg_ty << "; use MatchCast if necessary"; - return TensorType(ShapeExpr(ffi::Array{}), DataType::Int(64)); + return TensorType(ShapeExpr(ffi::Array{}), PrimType::Int(64)); } TVM_REGISTER_OP("relax.size") @@ -1182,7 +1182,7 @@ Type ReturnShapeToTensorType(const Call& call, const BlockBuilder& ctx) { const auto* ty = GetTypeAs(call->args[0]); TVM_FFI_ICHECK(ty); int32_t ndim = ty->ndim; - return TensorType(ShapeExpr({PrimExpr(ndim)}), DataType::Int(64)); + return TensorType(ShapeExpr({PrimExpr(ndim)}), PrimType::Int(64)); } TVM_REGISTER_OP("relax.shape_to_tensor") @@ -1209,10 +1209,10 @@ Type InferTypeAllocateTensor(const Call& call, const BlockBuilder& ctx) { << "must be ShapeExpr, but got " << call->args[0]->GetTypeKey(); TVM_FFI_ICHECK(call->args[1].as()) << "must be DataTypeImm, but got " << call->args[1]->GetTypeKey(); - DataType out_dtype; + PrimType out_dtype = PrimType::Void(); if (const auto* dtype_node = call->args[1].as()) { const DataTypeImm dtype_imm = ffi::GetRef(dtype_node); - out_dtype = dtype_imm->value; + out_dtype = PrimType(dtype_imm->value); } int64_t vdevice_index = -1; if (auto* prim_value_node = call->args[2].as()) { @@ -1284,10 +1284,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { Type InferTypeMemAllocTensor(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK(GetTypeAs(call->args[2])) << "must be a Expr of ShapeType, but got " << call->args[1]->GetTypeKey(); - DataType out_dtype; + PrimType out_dtype = PrimType::Void(); if (const auto* dtype_node = call->args[3].as()) { const DataTypeImm dtype_imm = ffi::GetRef(dtype_node); - out_dtype = dtype_imm->value; + out_dtype = PrimType(dtype_imm->value); } if (call->args.size() == 5) { @@ -1408,10 +1408,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { // vm alloc_tensor Type InferTypeVMAllocTensor(const Call& call, const BlockBuilder& ctx) { - DataType out_dtype; + PrimType out_dtype = PrimType::Void(); if (const auto* dtype_node = call->args[3].as()) { const DataTypeImm dtype_imm = ffi::GetRef(dtype_node); - out_dtype = dtype_imm->value; + out_dtype = PrimType(dtype_imm->value); } int64_t vdevice_index = -1; if (auto* prim_value_node = call->args[4].as()) { diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index cb0d6034e2d1..a19f59d4d56a 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -33,6 +33,7 @@ #include #include +#include #include #include @@ -184,14 +185,12 @@ std::tuple GetArgType(const Call& call, const BlockBuilder& ctx) { tvm::ffi::reflection::GlobalDef().def("relax.op." OpRegName, OpName); \ } -/************ Utilities ************/ - /*! * \brief Infer the type for unary elementwise ops. * \param call The context Call to the operator. * \param ctx The error reporting context. * \param f_compute_out_dtype The function to compute the output dtype, with - * signature DataType f_compute_out_dtype(const TensorType& input_ty). + * signature DLDataType or PrimType f_compute_out_dtype(const TensorType& input_ty). * \tparam require_float_dtype whether this op requires the input dtype to be float * \tparam Ftype the type of f_compute_out_dtype * \return The inferred type. @@ -199,15 +198,21 @@ std::tuple GetArgType(const Call& call, const BlockBuilder& ctx) { template inline Type InferTypeUnary(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { TensorType input_ty = GetUnaryInputTensorType(call, ctx); + DLDataType input_dtype = input_ty->dtype->dtype; if (require_float_dtype && !input_ty->IsUnknownDtype() && - (!input_ty->dtype.is_float() && !input_ty->dtype.is_bfloat())) { + (input_dtype.code != kDLFloat && input_dtype.code != kDLBfloat)) { TVM_FFI_VISIT_THROW(TypeError, call) << call->op << " requires the input tensor to have float dtype. However, the given input dtype is " << input_ty->dtype; } auto output_ty = ffi::make_object(*input_ty.get()); - output_ty->dtype = f_compute_out_dtype(input_ty); + auto computed_dtype = f_compute_out_dtype(input_ty); + if constexpr (std::is_same_v, PrimType>) { + output_ty->dtype = computed_dtype; + } else { + output_ty->dtype = PrimType(computed_dtype); + } if (call->ty_args.size() > 0) { auto defined_ty = call->ty_args[0].as(); TVM_FFI_ICHECK(defined_ty); @@ -274,9 +279,9 @@ InferLayoutOutput InferLayoutUnaryEwise( * \return The inferred element dtype. * \throw Throw exception if the Type doesn't have an element type. */ -inline std::optional GetElementDType(const Type& ty) { +inline std::optional GetElementDType(const Type& ty) { if (const auto* prim = ty.as()) { - return prim->dtype; + return ffi::GetRef(prim); } else if (const auto* tensor = ty.as()) { return tensor->dtype; } else { @@ -296,8 +301,8 @@ inline std::optional GetElementDType(const Type& ty) { * \return The inferred output dtype. * \throw Throw exception if the dtype of two input TensorType don’t match */ -inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& ctx, - const Type& lhs_ty, const Type& rhs_ty) { +inline DLDataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& ctx, + const Type& lhs_ty, const Type& rhs_ty) { auto opt_lhs_dtype = GetElementDType(lhs_ty); if (!opt_lhs_dtype) { TVM_FFI_VISIT_THROW(TypeError, call) @@ -318,15 +323,17 @@ inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& } auto rhs_dtype = opt_rhs_dtype.value(); - if (lhs_dtype.is_void() || rhs_dtype.is_void()) { - return DataType::Void(); - } else if (lhs_dtype != rhs_dtype && !lhs_dtype.is_bool() && !rhs_dtype.is_bool()) { + if (lhs_dtype.IsVoid() || rhs_dtype.IsVoid()) { + return DLDataType{kDLOpaqueHandle, 0, 0}; + } else if (lhs_dtype->dtype != rhs_dtype->dtype && + !lhs_dtype.MatchesCode(DLDataTypeCode::kDLBool) && + !rhs_dtype.MatchesCode(DLDataTypeCode::kDLBool)) { TVM_FFI_VISIT_THROW(TypeError, call) << "Binary operators must have the same datatype for both operands. " << "However, " << call << " uses datatype " << lhs_dtype << " on the LHS (Type of " << lhs_ty << "), and datatype " << rhs_dtype << " on the RHS (Type of " << rhs_ty << ")."; } - return lhs_dtype; + return lhs_dtype->dtype; } /*! @@ -469,7 +476,7 @@ bool IsIdentityPermutation(const std::vector& permutation); */ inline ffi::Array ConvertIntImmToInt64(const ffi::Array& int_imms) { return int_imms.Map( - [](const IntImm& i) { return cast(DataType::Int(64), i).as_or_throw(); }); + [](const IntImm& i) { return cast(PrimType::Int(64), i).as_or_throw(); }); } /************ Utilities for NN operators ************/ @@ -560,8 +567,9 @@ inline ffi::Array GetCompletePadding3D(ffi::Array padding) { inline std::pair CheckTensorLayout( const Call& call, const BlockBuilder& ctx, const ffi::String& tensor_layout, const ffi::String& tgt_layout, const ffi::String& tensor_name) { - tirx::SLayout _tensor_layout(tensor_layout, DataType::Int(64)); - tirx::SBijectiveLayout tensor2tgt(_tensor_layout, tirx::SLayout(tgt_layout, DataType::Int(64))); + tvm::PrimType i64_ty = tvm::PrimType::Int(64); + tirx::SLayout _tensor_layout(tensor_layout, i64_ty); + tirx::SBijectiveLayout tensor2tgt(_tensor_layout, tirx::SLayout(tgt_layout, i64_ty)); if (!tensor2tgt.defined()) { TVM_FFI_VISIT_THROW(ValueError, call) << call->op << " requires the given " << tensor_name << " layout to be convertible from " diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index 84c411238473..cbc786de0f8e 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -51,11 +51,11 @@ Type InferTypeBroadcast(const Call& call, const BlockBuilder& ctx, FType f_compu << "Arguments to binary operators must be either R.Tensor or R.Prim types, " << "but expression " << call << " has RHS " << call->args[1] << ", which has Type " << rhs_ty; - // DateType - DataType output_dtype = f_compute_out_dtype(call, ctx, lhs_ty, rhs_ty); + // Dtype + PrimType output_dtype(f_compute_out_dtype(call, ctx, lhs_ty, rhs_ty)); if (lhs_ty.as() && rhs_ty.as()) { - return PrimType(output_dtype); + return output_dtype; } // VDevice @@ -136,7 +136,7 @@ Type InferTypeBroadcastArith(const Call& call, const BlockBuilder& ctx) { Type InferTypeBroadcastCMP(const Call& call, const BlockBuilder& ctx) { return InferTypeBroadcast(call, ctx, [](const Call& call, const BlockBuilder& ctx, const Type& lhs_ty, - const Type& rhs_ty) { return DataType::Bool(); }); + const Type& rhs_ty) { return DLDataType{kDLBool, 8, 1}; }); } InferLayoutOutput InferLayoutBinaryEwise( diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index e7a972896569..fbe3a0b0c534 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -46,7 +46,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { /* relax.full */ Expr full(ffi::Variant> shape, Expr fill_value, - ffi::Optional dtype) { + ffi::Optional dtype) { Expr shape_in_expr{nullptr}; if (const auto* expr = shape.as()) { shape_in_expr = ffi::GetRef(expr); @@ -59,7 +59,7 @@ Expr full(ffi::Variant> shape, Expr fill_value, } ffi::ObjectPtr attrs = ffi::make_object(); - attrs->dtype = dtype.value_or(DataType::Void()); + attrs->dtype = dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); static const Op& op = Op::Get("relax.full"); return Call(op, {std::move(shape_in_expr), std::move(fill_value)}, Attrs(attrs), {}); @@ -88,7 +88,8 @@ Type InferTypeFull(const Call& call, const BlockBuilder& ctx) { } const auto* attrs = call->attrs.as(); - DataType out_dtype = attrs->dtype.is_void() ? fill_value_ty->dtype : attrs->dtype; + PrimType out_dtype = attrs->dtype == DLDataType{kDLOpaqueHandle, 0, 0} ? fill_value_ty->dtype + : PrimType(attrs->dtype); return TensorType(/*shape=*/call->args[0], out_dtype, fill_value_ty->vdevice); } @@ -104,9 +105,9 @@ TVM_REGISTER_OP("relax.full") .set_attr("FPurity", true); /* relax.full_like */ -Expr full_like(Expr x, Expr fill_value, ffi::Optional dtype) { +Expr full_like(Expr x, Expr fill_value, ffi::Optional dtype) { ffi::ObjectPtr attrs = ffi::make_object(); - attrs->dtype = dtype.value_or(DataType::Void()); + attrs->dtype = dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); static const Op& op = Op::Get("relax.full_like"); return Call(op, {std::move(x), std::move(fill_value)}, Attrs(attrs), {}); } @@ -127,11 +128,11 @@ Type InferTypeFullLike(const Call& call, const BlockBuilder& ctx) { } const auto* attrs = call->attrs.as(); - if (attrs->dtype.is_void()) { + if (attrs->dtype == DLDataType{kDLOpaqueHandle, 0, 0}) { return data_ty; } else { auto output_ty = ffi::make_object(*data_ty.get()); - output_ty->dtype = attrs->dtype; + output_ty->dtype = PrimType(attrs->dtype); return TensorType(output_ty); } } @@ -158,25 +159,26 @@ Type InferTypeOnesZeros(const Call& call, const BlockBuilder& ctx) { << call->args[0]->ty->GetTypeKey(); } const auto* attrs = call->attrs.as(); - return TensorType(/*shape=*/call->args[0], attrs->dtype); + return TensorType(/*shape=*/call->args[0], PrimType(attrs->dtype)); } // Structure info inference for ones_like and zeros_like Type InferTypeOnesLikeZerosLike(const Call& call, const BlockBuilder& ctx) { TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - if (attrs->dtype.is_void()) { + if (attrs->dtype == DLDataType{kDLOpaqueHandle, 0, 0}) { return data_ty; } else { auto output_ty = ffi::make_object(*data_ty.get()); - output_ty->dtype = attrs->dtype; + output_ty->dtype = PrimType(attrs->dtype); return TensorType(output_ty); } } /* relax.ones & relax.ones_like */ -Expr ones(Expr shape, DataType dtype) { - TVM_FFI_ICHECK(!dtype.is_void()) << "Ones op expects the input dtype not to be void"; +Expr ones(Expr shape, DLDataType dtype) { + TVM_FFI_ICHECK((dtype != DLDataType{kDLOpaqueHandle, 0, 0})) + << "Ones op expects the input dtype not to be void"; ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; @@ -184,9 +186,9 @@ Expr ones(Expr shape, DataType dtype) { return Call(op, {std::move(shape)}, Attrs(attrs), {}); } -Expr ones_like(Expr x, ffi::Optional dtype) { +Expr ones_like(Expr x, ffi::Optional dtype) { ffi::ObjectPtr attrs = ffi::make_object(); - attrs->dtype = dtype.value_or(DataType::Void()); + attrs->dtype = dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); static const Op& op = Op::Get("relax.ones_like"); return Call(op, {std::move(x)}, Attrs(attrs), {}); } @@ -212,8 +214,9 @@ TVM_REGISTER_OP("relax.ones_like") .set_attr("FPurity", true); /* relax.zeros & relax.zeros_like */ -Expr zeros(Expr shape, DataType dtype) { - TVM_FFI_ICHECK(!dtype.is_void()) << "Zeros op expects the input dtype not to be void"; +Expr zeros(Expr shape, DLDataType dtype) { + TVM_FFI_ICHECK((dtype != DLDataType{kDLOpaqueHandle, 0, 0})) + << "Zeros op expects the input dtype not to be void"; ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; @@ -221,9 +224,9 @@ Expr zeros(Expr shape, DataType dtype) { return Call(op, {std::move(shape)}, Attrs(attrs), {}); } -Expr zeros_like(Expr x, ffi::Optional dtype) { +Expr zeros_like(Expr x, ffi::Optional dtype) { ffi::ObjectPtr attrs = ffi::make_object(); - attrs->dtype = dtype.value_or(DataType::Void()); + attrs->dtype = dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); static const Op& op = Op::Get("relax.zeros_like"); return Call(op, {std::move(x)}, Attrs(attrs), {}); } @@ -249,16 +252,16 @@ TVM_REGISTER_OP("relax.zeros_like") .set_attr("FPurity", true); /* relax.eye & relax.eye_like */ -Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype) { +Expr eye(PrimValue n, PrimValue m, PrimValue k, DLDataType dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.eye"); return Call(op, {std::move(n), std::move(m), std::move(k)}, Attrs(attrs), {}); } -Expr eye_like(Expr x, PrimValue k, ffi::Optional dtype) { +Expr eye_like(Expr x, PrimValue k, ffi::Optional dtype) { ffi::ObjectPtr attrs = ffi::make_object(); - attrs->dtype = dtype.value_or(DataType::Void()); + attrs->dtype = dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); static const Op& op = Op::Get("relax.eye_like"); return Call(op, {std::move(x), std::move(k)}, Attrs(attrs), {}); } @@ -285,8 +288,8 @@ Type InferTypeEye(const Call& call, const BlockBuilder& ctx) { PrimExpr n = get_prim_value(call->args[0], "n"); PrimExpr m = get_prim_value(call->args[1], "m"); - DataType dtype = call->attrs.as()->dtype; - return TensorType(ShapeExpr({n, m}), dtype); + DLDataType dtype = call->attrs.as()->dtype; + return TensorType(ShapeExpr({n, m}), PrimType(dtype)); } Type InferTypeEyeLike(const Call& call, const BlockBuilder& ctx) { @@ -309,7 +312,8 @@ Type InferTypeEyeLike(const Call& call, const BlockBuilder& ctx) { } const auto* attrs = call->attrs.as(); - DataType out_dtype = attrs->dtype.is_void() ? x_ty->dtype : attrs->dtype; + PrimType out_dtype = + attrs->dtype == DLDataType{kDLOpaqueHandle, 0, 0} ? x_ty->dtype : PrimType(attrs->dtype); return TensorType(x_ty->shape.value(), out_dtype, x_ty->vdevice); } @@ -333,7 +337,7 @@ TVM_REGISTER_OP("relax.eye_like") .set_attr("FPurity", true); /* relax.arange */ -Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype) { +Expr arange(PrimValue start, PrimValue stop, PrimValue step, DLDataType dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.arange"); @@ -362,17 +366,18 @@ Type InferTypeArange(const Call& call, const BlockBuilder& ctx) { PrimExpr start = get_prim_value(call->args[0], "start"); PrimExpr end = get_prim_value(call->args[1], "end"); PrimExpr step = get_prim_value(call->args[2], "step"); - DataType dtype = call->attrs.as()->dtype; + DLDataType dtype = call->attrs.as()->dtype; PrimExpr num_elem; - if (start.dtype().is_int() && end.dtype().is_int() && step.dtype().is_int()) { + if (start.ty().code() == DLDataTypeCode::kDLInt && end.ty().code() == DLDataTypeCode::kDLInt && + step.ty().code() == DLDataTypeCode::kDLInt) { num_elem = tvm::floordiv((end - start + step - 1), step); } else { - num_elem = tvm::cast(tvm::DataType::Int(64), - tvm::ceil(tvm::cast(tvm::DataType::Float(32), end - start) / step)); + num_elem = tvm::cast(tvm::PrimType::Int(64), + tvm::ceil(tvm::cast(tvm::PrimType::Float(32), end - start) / step)); } arith::Analyzer analyzer; num_elem = analyzer->Simplify(num_elem); - return TensorType(ShapeExpr({num_elem}), dtype); + return TensorType(ShapeExpr({num_elem}), PrimType(dtype)); } TVM_REGISTER_OP("relax.arange") @@ -387,7 +392,7 @@ TVM_REGISTER_OP("relax.arange") /* relax.hamming_window */ Expr hamming_window(PrimValue window_size, PrimValue periodic, PrimValue alpha, PrimValue beta, - DataType dtype) { + DLDataType dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.hamming_window"); @@ -401,8 +406,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { } Type InferTypeHammingWindow(const Call& call, const BlockBuilder& ctx) { - DataType dtype = call->attrs.as()->dtype; - if (dtype.is_int() || dtype.is_uint() || dtype.is_uint()) { + DLDataType dtype = call->attrs.as()->dtype; + if (dtype.code == DLDataTypeCode::kDLInt || dtype.code == DLDataTypeCode::kDLUInt) { TVM_FFI_VISIT_THROW(TypeError, call) << "Hamming Window expects the datatype to be float but got " << dtype; } @@ -422,7 +427,7 @@ Type InferTypeHammingWindow(const Call& call, const BlockBuilder& ctx) { << window_size; } window_size = analyzer->Simplify(window_size); - return TensorType(ShapeExpr({window_size}), dtype); + return TensorType(ShapeExpr({window_size}), PrimType(dtype)); } TVM_REGISTER_OP("relax.hamming_window") diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h index 284448111739..497a535a4d0f 100644 --- a/src/relax/op/tensor/create.h +++ b/src/relax/op/tensor/create.h @@ -42,7 +42,7 @@ namespace relax { * \return The result tensor. */ Expr full(ffi::Variant> shape, Expr fill_value, - ffi::Optional dtype); + ffi::Optional dtype); /*! * \brief Construct a tensor such that @@ -55,7 +55,7 @@ Expr full(ffi::Variant> shape, Expr fill_value, * void, the input tensor's dtype will be used. * \return The result tensor. */ -Expr full_like(Expr x, Expr fill_value, ffi::Optional dtype); +Expr full_like(Expr x, Expr fill_value, ffi::Optional dtype); /*! * \brief Construct a tensor of all ones, with the input shape and dtype. @@ -63,7 +63,7 @@ Expr full_like(Expr x, Expr fill_value, ffi::Optional dtype); * \param dtype The data type of the created tensor. * \return The result tensor. */ -Expr ones(Expr shape, DataType dtype); +Expr ones(Expr shape, DLDataType dtype); /*! * \brief Construct a tensor with all ones, with shape of the input tensor shape. @@ -73,7 +73,7 @@ Expr ones(Expr shape, DataType dtype); * void, the input tensor's dtype will be used. * \return The result tensor. */ -Expr ones_like(Expr x, ffi::Optional dtype); +Expr ones_like(Expr x, ffi::Optional dtype); /*! * \brief Construct a tensor of all zeros, with the input shape and dtype. @@ -81,7 +81,7 @@ Expr ones_like(Expr x, ffi::Optional dtype); * \param dtype The data type of the created tensor. * \return The result tensor. */ -Expr zeros(Expr shape, DataType dtype); +Expr zeros(Expr shape, DLDataType dtype); /*! * \brief Construct a tensor with all zeros, with shape of the input tensor shape. @@ -91,7 +91,7 @@ Expr zeros(Expr shape, DataType dtype); * void, the input tensor's dtype will be used. * \return The result tensor. */ -Expr zeros_like(Expr x, ffi::Optional dtype); +Expr zeros_like(Expr x, ffi::Optional dtype); /*! * \brief Construct a 2-D tensor with ones on the diagonal and zeros elsewhere. @@ -102,7 +102,7 @@ Expr zeros_like(Expr x, ffi::Optional dtype); * \param dtype The data type of the created tensor. * \return The result tensor. */ -Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype); +Expr eye(PrimValue n, PrimValue m, PrimValue k, DLDataType dtype); /*! * \brief Construct a tensor with ones on the diagonal and zeros elsewhere, @@ -115,10 +115,10 @@ Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype); * void, the input tensor's dtype will be used. * \return The result tensor. */ -Expr eye_like(Expr x, PrimValue k, ffi::Optional dtype); +Expr eye_like(Expr x, PrimValue k, ffi::Optional dtype); /*! \brief Construct a tensor with evenly spaced elements. */ -Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype); +Expr arange(PrimValue start, PrimValue stop, PrimValue step, DLDataType dtype); /*! * \brief Hamming window function. @@ -131,7 +131,7 @@ Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype); * \return The result tensor. */ Expr hamming_window(PrimValue window_size, PrimValue periodic, PrimValue alpha, PrimValue beta, - DataType dtype); + DLDataType dtype); /*! \brief Return the lower triangular part of a matrix or a batch of matrices. */ Expr tril(Expr x, Expr k); diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc index 907dffb0b3f3..ec1043a025e1 100644 --- a/src/relax/op/tensor/datatype.cc +++ b/src/relax/op/tensor/datatype.cc @@ -38,7 +38,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { /* relax.astype */ -Expr astype(Expr x, DataType dtype) { +Expr astype(Expr x, DLDataType dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; @@ -55,7 +55,7 @@ Type InferTypeAstype(const Call& call, const BlockBuilder& ctx) { TensorType ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); ffi::ObjectPtr new_ty = ffi::make_object(*ty.get()); - new_ty->dtype = attrs->dtype; + new_ty->dtype = PrimType(attrs->dtype); return TensorType(new_ty); } @@ -70,7 +70,7 @@ TVM_REGISTER_OP("relax.astype") /* relax.wrap_param */ -Expr MakeWrapParam(Expr data, DataType dtype) { +Expr MakeWrapParam(Expr data, DLDataType dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; @@ -87,7 +87,7 @@ Type InferTypeWrapParam(const Call& call, const BlockBuilder& ctx) { TensorType ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); ffi::ObjectPtr new_ty = ffi::make_object(*ty.get()); - new_ty->dtype = attrs->dtype; + new_ty->dtype = PrimType(attrs->dtype); return TensorType(new_ty); } diff --git a/src/relax/op/tensor/datatype.h b/src/relax/op/tensor/datatype.h index b612c45fc941..db2ee396c0d6 100644 --- a/src/relax/op/tensor/datatype.h +++ b/src/relax/op/tensor/datatype.h @@ -37,7 +37,7 @@ namespace relax { * \param dtype The target data type * \return The casted result. */ -Expr astype(Expr x, DataType dtype); +Expr astype(Expr x, DLDataType dtype); /*! * \brief A wrapper to wrap the input const tensor to the given data type. @@ -45,7 +45,7 @@ Expr astype(Expr x, DataType dtype); * \param dtype The target data type * \return The wrapped result. */ -Expr wrap_param(Expr x, DataType dtype); +Expr wrap_param(Expr x, DLDataType dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index 515f37126183..5321798b8e48 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -72,7 +72,7 @@ Type InferTypeTake(const Call& call, const BlockBuilder& ctx) { if (auto tensor_ty = ty.as()) { return tensor_ty.value(); } else if (auto prim_ty = ty.as()) { - return TensorType(ShapeExpr(ffi::Array{}), prim_ty->dtype); + return TensorType(ShapeExpr(ffi::Array{}), ffi::GetRef(prim_ty)); } else { TVM_FFI_VISIT_THROW(TypeError, call) << "Operator " << call->op << " requires the indices argument to be " @@ -84,11 +84,14 @@ Type InferTypeTake(const Call& call, const BlockBuilder& ctx) { if (indices_ty->IsUnknownDtype()) { LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; - } else if (!(indices_ty->dtype.is_int() || indices_ty->dtype.is_uint())) { - TVM_FFI_VISIT_THROW(TypeError, call) - << "Take op requires the input indices to have integer dtype. However, the " - "given indices dtype is " - << indices_ty->dtype; + } else { + PrimType indices_dtype = indices_ty->dtype; + if (!indices_dtype.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) { + TVM_FFI_VISIT_THROW(TypeError, call) + << "Take op requires the input indices to have integer dtype. However, the " + "given indices dtype is " + << indices_ty->dtype; + } } const auto* attrs = call->attrs.as(); @@ -309,7 +312,7 @@ Type InferTypeStridedSlice(const Call& call, const BlockBuilder& ctx) { } }(); - TVM_FFI_ICHECK(IsBaseOf(relax::TensorType(DataType::Void(), kUnknownNDim), GetType(data))) + TVM_FFI_ICHECK(IsBaseOf(relax::TensorType(PrimType::Void(), kUnknownNDim), GetType(data))) << "Operator " << call->op << " requires the first argument to be a tensor. " << "However, in expression " << call << ", the first argument " << data << " has type " << GetType(data); @@ -325,9 +328,8 @@ Type InferTypeStridedSlice(const Call& call, const BlockBuilder& ctx) { const auto* tuple = ty.as(); if (!tuple) return false; - return std::all_of(tuple->fields.begin(), tuple->fields.end(), [](const Type& field) { - return IsBaseOf(tvm::PrimType(DataType::Int(64)), field); - }); + return std::all_of(tuple->fields.begin(), tuple->fields.end(), + [](const Type& field) { return IsBaseOf(tvm::PrimType::Int(64), field); }); }; auto check_tuple = [&](const char* name, Expr expr) { auto ty = GetType(expr); @@ -347,7 +349,7 @@ Type InferTypeStridedSlice(const Call& call, const BlockBuilder& ctx) { const auto* data_ty = data->ty.as(); - DataType dtype = DataType::Void(); + PrimType dtype(DLDataType{kDLOpaqueHandle, 0, 0}); ffi::Optional vdevice = std::nullopt; int ndim = kUnknownNDim; if (data_ty) { @@ -545,7 +547,7 @@ Type InferTypeDynStridedSlice(const Call& call, const BlockBuilder& ctx) { LOG(WARNING) << "Dynamic strided slice assumes " << name << " to be int64 when it is not specified."; } else { - TVM_FFI_ICHECK(ty->dtype == DataType::Int(64)) + TVM_FFI_ICHECK(ty->dtype == PrimType::Int(64)) << "Dynamic strided_slice expects the input " << name << "values to be all int64. However, " << name << " has dtype " << ty->dtype << "."; } diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc index bf57670e7f2a..97955eb62455 100644 --- a/src/relax/op/tensor/inspect.cc +++ b/src/relax/op/tensor/inspect.cc @@ -88,24 +88,21 @@ std::tuple> GetTensorArgInfoWithIndex(const C return {ffi::GetRef(tensor_ty), int_imm_axis}; } -DataType GetTensorDataType(const Call& call) { return GetTensorArgInfo(call)->dtype; } +tirx::PrimFunc GetDLTensorField(tirx::builtin::TVMStructFieldKind field, PrimType field_ty) { + tirx::Var dlpack_handle("dlpack_handle", PrimType::Handle()); -tirx::PrimFunc GetDLTensorField(tirx::builtin::TVMStructFieldKind field, DataType field_dtype) { - tirx::Var dlpack_handle("dlpack_handle", DataType::Handle()); - - tirx::Var value("value", field_dtype); + tirx::Var value("value", field_ty); tirx::Stmt body = tirx::SeqStmt( - {tirx::Bind(value, tirx::Call(field_dtype, tirx::builtin::tvm_struct_get(), + {tirx::Bind(value, tirx::Call(field_ty, tirx::builtin::tvm_struct_get(), {dlpack_handle, IntImm::Int32(0), IntImm::Int32(field)})), tirx::Evaluate(tvm::ret(value))}); DictAttrs attrs({{"tirx.is_scheduled", true}, {"tirx.is_host_func", true}}); - tirx::PrimFunc func(ffi::Array{dlpack_handle}, body, tvm::PrimType(field_dtype), {}, - attrs); + tirx::PrimFunc func(ffi::Array{dlpack_handle}, body, field_ty, {}, attrs); - FuncType ty({TensorType(DataType::Void(), kUnknownNDim)}, PrimType(field_dtype)); + FuncType ty({TensorType(PrimType::Void(), kUnknownNDim)}, field_ty); func->ty = ty; return func; @@ -120,23 +117,14 @@ Expr tensor_dtype_code(Expr expr) { return Call(op, {expr}); } -Type InferTypeTensorDtypeCode(const Call& call, const BlockBuilder&) { - auto dlpack_type = DataType::UInt(8); - - DataType dtype = GetTensorDataType(call); - if (dtype.is_void()) { - return PrimType(dlpack_type); - } else { - return PrimType(dlpack_type); - } -} +Type InferTypeTensorDtypeCode(const Call& call, const BlockBuilder&) { return PrimType::UInt(8); } Expr LegalizeTensorDtypeCode(const BlockBuilder& bb, const Call& call) { - auto field_dtype = call->ty.as_or_throw()->dtype; + PrimType field_ty = call->ty.as_or_throw(); Expr arg = call->args[0]; tirx::PrimFunc getter = - GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeCode, field_dtype); + GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeCode, field_ty); GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_code"); return Call(gvar_getter, {arg}); @@ -158,23 +146,14 @@ Expr tensor_dtype_bits(Expr expr) { return Call(op, {expr}); } -Type InferTypeTensorDtypeBits(const Call& call, const BlockBuilder&) { - auto dlpack_type = DataType::UInt(8); - - DataType dtype = GetTensorDataType(call); - if (dtype.is_void()) { - return PrimType(dlpack_type); - } else { - return PrimType(dlpack_type); - } -} +Type InferTypeTensorDtypeBits(const Call& call, const BlockBuilder&) { return PrimType::UInt(8); } Expr LegalizeTensorDtypeBits(const BlockBuilder& bb, const Call& call) { - auto field_dtype = call->ty.as_or_throw()->dtype; + PrimType field_ty = call->ty.as_or_throw(); Expr arg = call->args[0]; tirx::PrimFunc getter = - GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeBits, field_dtype); + GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeBits, field_ty); GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_bits"); return Call(gvar_getter, {arg}); @@ -196,23 +175,14 @@ Expr tensor_dtype_lanes(Expr expr) { return Call(op, {expr}); } -Type InferTypeTensorDtypeLanes(const Call& call, const BlockBuilder&) { - auto dlpack_type = DataType::UInt(16); - - DataType dtype = GetTensorDataType(call); - if (dtype.is_void()) { - return PrimType(dlpack_type); - } else { - return PrimType(dlpack_type); - } -} +Type InferTypeTensorDtypeLanes(const Call& call, const BlockBuilder&) { return PrimType::UInt(16); } Expr LegalizeTensorDtypeLanes(const BlockBuilder& bb, const Call& call) { - auto field_dtype = call->ty.as_or_throw()->dtype; + PrimType field_ty = call->ty.as_or_throw(); Expr arg = call->args[0]; tirx::PrimFunc getter = - GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeLanes, field_dtype); + GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeLanes, field_ty); GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_lanes"); return Call(gvar_getter, {arg}); @@ -234,23 +204,14 @@ Expr tensor_ndim(Expr expr) { return Call(op, {expr}); } -Type InferTypeTensorNDim(const Call& call, const BlockBuilder&) { - auto dlpack_type = DataType::Int(32); - - auto ty = GetTensorArgInfo(call); - if (ty->IsUnknownNdim()) { - return PrimType(dlpack_type); - } else { - return PrimType(dlpack_type); - } -} +Type InferTypeTensorNDim(const Call& call, const BlockBuilder&) { return PrimType::Int(32); } Expr LegalizeTensorNDim(const BlockBuilder& bb, const Call& call) { - auto field_dtype = call->ty.as_or_throw()->dtype; + PrimType field_ty = call->ty.as_or_throw(); Expr arg = call->args[0]; tirx::PrimFunc getter = - GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorNDim, field_dtype); + GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorNDim, field_ty); GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_ndim"); return Call(gvar_getter, {arg}); @@ -273,45 +234,45 @@ Expr tensor_shape_i(Expr expr) { } Type InferTypeTensorShape(const Call& call, const BlockBuilder&) { - auto dlpack_type = DataType::Int(64); + auto dlpack_type = PrimType::Int(64); auto [tensor_ty, int_imm_axis] = GetTensorArgInfoWithIndex(call); auto tensor_shape = tensor_ty->GetShape(); if (int_imm_axis && tensor_shape.defined()) { - return PrimType(tensor_shape.value()[int_imm_axis.value()].dtype()); + return tensor_shape.value()[int_imm_axis.value()].ty(); } else { - return PrimType(dlpack_type); + return dlpack_type; } } Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) { - auto field_dtype = call->ty.as_or_throw()->dtype; + PrimType field_ty = call->ty.as_or_throw(); tirx::PrimFunc getter = [&]() -> tirx::PrimFunc { - tirx::Var dlpack_handle("dlpack_handle", DataType::Handle()); - tirx::Var axis("axis", DataType::Int(64)); + tirx::Var dlpack_handle("dlpack_handle", PrimType::Handle()); + tirx::Var axis("axis", PrimType::Int(64)); - tirx::Var ndim("ndim", DataType::Int(32)); + tirx::Var ndim("ndim", PrimType::Int(32)); - tirx::Buffer shape_buffer = tirx::decl_buffer({ndim}, field_dtype, "shape"); + tirx::Buffer shape_buffer = tirx::decl_buffer({ndim}, field_ty, "shape"); - tirx::Var extent("extent", field_dtype); + tirx::Var extent("extent", field_ty); tirx::Stmt body = tirx::SeqStmt( {tirx::AssertStmt(0 <= axis, tirx::StringImm("RuntimeError"), {tirx::StringImm("Specified axis may not be negative")}), tirx::Bind(ndim, - tirx::Call(ndim->dtype, tirx::builtin::tvm_struct_get(), + tirx::Call(ndim.ty(), tirx::builtin::tvm_struct_get(), {dlpack_handle, IntImm::Int32(0), IntImm::Int32(tirx::builtin::TVMStructFieldKind::kDLTensorNDim)})), tirx::AssertStmt( - axis < tvm::cast(axis->dtype, ndim), tirx::StringImm("RuntimeError"), + axis < tvm::cast(axis.ty(), ndim), tirx::StringImm("RuntimeError"), {tirx::StringImm( "Specified axis may not be larger than the tensor's dimensionality")}), tirx::Bind(shape_buffer->data, - tirx::Call(DataType::Handle(), tirx::builtin::tvm_struct_get(), + tirx::Call(tvm::PrimType::Handle(), tirx::builtin::tvm_struct_get(), {dlpack_handle, IntImm::Int32(0), IntImm::Int32(tirx::builtin::TVMStructFieldKind::kDLTensorShape)})), tirx::DeclBuffer(shape_buffer), tirx::Bind(extent, tirx::BufferLoad(shape_buffer, {axis})), @@ -319,10 +280,9 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) { DictAttrs attrs({{"tirx.is_scheduled", true}, {"tirx.is_host_func", true}}); - tirx::PrimFunc func({dlpack_handle, axis}, body, tvm::PrimType(field_dtype), {}, attrs); + tirx::PrimFunc func({dlpack_handle, axis}, body, field_ty, {}, attrs); - FuncType ty({TensorType(DataType::Void(), kUnknownNDim), PrimType(axis->dtype)}, - PrimType(field_dtype)); + FuncType ty({TensorType(PrimType::Void(), kUnknownNDim), axis.ty()}, field_ty); func->ty = ty; return func; }(); @@ -349,7 +309,7 @@ Expr tensor_stride_i(Expr expr) { } Type InferTypeTensorStride(const Call& call, const BlockBuilder&) { - auto dlpack_type = DataType::Int(64); + auto dlpack_type = PrimType::Int(64); auto [tensor_ty, int_imm_axis] = GetTensorArgInfoWithIndex(call); @@ -373,9 +333,9 @@ Type InferTypeTensorStride(const Call& call, const BlockBuilder&) { for (size_t axis = int_imm_axis.value() + 1; axis < tensor_shape.size(); axis++) { stride = stride * tensor_shape[axis]; } - return PrimType(stride.dtype()); + return stride.ty(); } else { - return PrimType(dlpack_type); + return dlpack_type; } } @@ -396,7 +356,7 @@ Expr tensor_byte_offset(Expr expr) { } Type InferTypeTensorByteOffset(const Call& call, const BlockBuilder&) { - auto dlpack_type = DataType::UInt(64); + auto dlpack_type = PrimType::UInt(64); auto tensor_ty = GetTensorArgInfo(call); @@ -405,9 +365,9 @@ Type InferTypeTensorByteOffset(const Call& call, const BlockBuilder&) { // Relax implicitly requires that the byte offset is zero for any // legalizable tensor. See InferTypeTensorStride for full // explanation. - return PrimType(dlpack_type); + return dlpack_type; } else { - return PrimType(dlpack_type); + return dlpack_type; } } @@ -427,7 +387,7 @@ Expr tensor_elem_offset(Expr expr) { } Type InferTypeTensorElemOffset(const Call& call, const BlockBuilder&) { - auto dlpack_type = DataType::UInt(64); + auto dlpack_type = PrimType::UInt(64); auto tensor_ty = GetTensorArgInfo(call); @@ -436,9 +396,9 @@ Type InferTypeTensorElemOffset(const Call& call, const BlockBuilder&) { // Relax implicitly requires that the element offset is zero for // any legalizable tensor. See InferTypeTensorStride for // full explanation. - return PrimType(dlpack_type); + return dlpack_type; } else { - return PrimType(dlpack_type); + return dlpack_type; } } diff --git a/src/relax/op/tensor/inspect.h b/src/relax/op/tensor/inspect.h index 3f820ab58a83..92cc4c256c79 100644 --- a/src/relax/op/tensor/inspect.h +++ b/src/relax/op/tensor/inspect.h @@ -36,7 +36,7 @@ namespace inspect { * `TensorType`. * * \returns The uint8_t value of the type_code, with - * `PrimType(DataType::UInt(8))` + * `PrimType::UInt(8)` */ Expr tensor_dtype_code(Expr expr); @@ -46,7 +46,7 @@ Expr tensor_dtype_code(Expr expr); * `TensorType`. * * \returns The uint8_t value of the number of bits, with - * `PrimType(DataType::UInt(8))`. For vectorized types, returns + * `PrimType::UInt(8)`. For vectorized types, returns * the bit width of the underlying scalar type (e.g. 32 for * "float32x4", not 128). */ @@ -58,7 +58,7 @@ Expr tensor_dtype_bits(Expr expr); * `TensorType`. * * \returns The uint16_t value of the number of lanes, with - * `PrimType(DataType::UInt(16))` + * `PrimType::UInt(16)` */ Expr tensor_dtype_lanes(Expr expr); @@ -68,7 +68,7 @@ Expr tensor_dtype_lanes(Expr expr); * `TensorType`. * * \returns The int32_t value of the dimensionality, with - * `PrimType(DataType::Int(32))`. + * `PrimType::Int(32)`. */ Expr tensor_ndim(Expr expr); @@ -81,7 +81,7 @@ Expr tensor_ndim(Expr expr); * axis < tensor_ndim(expr)`, or else the results are undefined. * * \returns The int64_t extent of the specified tensor axis, with - * `PrimType(DataType::Int(64))`. + * `PrimType::Int(64)`. */ Expr tensor_shape_i(Expr expr, Expr axis); @@ -98,7 +98,7 @@ Expr tensor_shape_i(Expr expr, Expr axis); * axis < tensor_ndim(expr)`, or else the results are undefined. * * \returns The int64_t extent of the specified tensor axis, with - * `PrimType(DataType::Int(64))`. + * `PrimType::Int(64)`. */ Expr tensor_stride_i(Expr expr, Expr axis); @@ -107,7 +107,7 @@ Expr tensor_stride_i(Expr expr, Expr axis); * \param expr The relax expression to be inspected. Must have * `TensorType`. * - * \returns The uint64_t byte offset, with `PrimType(DataType::UInt(64))`. + * \returns The uint64_t byte offset, with `PrimType::UInt(64)`. */ Expr tensor_byte_offset(Expr expr); @@ -120,7 +120,7 @@ Expr tensor_byte_offset(Expr expr); * \param expr The relax expression to be inspected. Must have * `TensorType`. * - * \returns The uint64_t element offset, with `PrimType(DataType::UInt(64))`. + * \returns The uint64_t element offset, with `PrimType::UInt(64)`. */ Expr tensor_elem_offset(Expr expr); diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc index a1693c6563f2..6ea68b422378 100644 --- a/src/relax/op/tensor/linear_algebra.cc +++ b/src/relax/op/tensor/linear_algebra.cc @@ -42,9 +42,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { /* relax.matmul */ -Expr matmul(Expr x1, Expr x2, ffi::Optional out_dtype) { +Expr matmul(Expr x1, Expr x2, ffi::Optional out_dtype) { ffi::ObjectPtr attrs = ffi::make_object(); - attrs->out_dtype = out_dtype.value_or(DataType::Void()); + attrs->out_dtype = out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); static const Op& op = Op::Get("relax.matmul"); return Call(op, {std::move(x1), std::move(x2)}, Attrs(attrs), {}); @@ -74,9 +74,9 @@ Type InferTypeMatmul(const Call& call, const BlockBuilder& ctx) { } const auto* attrs = call->attrs.as(); - DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, x1_ty, x2_ty) - : attrs->out_dtype; + PrimType out_dtype = PrimType(attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? InferBinaryArithOpOutDtype(call, ctx, x1_ty, x2_ty) + : attrs->out_dtype); if (x1_ty->IsUnknownNdim() || x2_ty->IsUnknownNdim()) { if (vdev.defined()) { @@ -158,7 +158,7 @@ Type InferTypeMatmul(const Call& call, const BlockBuilder& ctx) { return TensorType(ShapeExpr(output_shape), out_dtype); } -Call InferMixedPrecisionMatmul(const Call& call, const DataType& out_dtype) { +Call InferMixedPrecisionMatmul(const Call& call, DLDataType out_dtype) { return matmul(call->args[0], call->args[1], out_dtype).as_or_throw(); } @@ -218,17 +218,17 @@ Type InferTypeEinsum(const Call& call, const BlockBuilder& ctx) { ffi::String subscripts = attrs->subscripts; - DataType operand_dtype = operands_tensor_ty[0]->dtype; + PrimType operand_ty = operands_tensor_ty[0]->dtype; std::vector> input_shapes; input_shapes.reserve(operands_tensor_ty.size()); for (TensorType tensor_ty : operands_tensor_ty) { // Check the input tuple consists of tensors with same dtype - if (tensor_ty->dtype != operand_dtype) { + if (tensor_ty->dtype != operand_ty) { TVM_FFI_VISIT_THROW(TypeError, call) << "Einsum expects all input tensors to have the same dtype. However, the " "input contains tensors with dtype " - << operand_dtype << " and " << tensor_ty->dtype; + << operand_ty << " and " << tensor_ty->dtype; } // Get input shapes @@ -237,18 +237,18 @@ Type InferTypeEinsum(const Call& call, const BlockBuilder& ctx) { input_shapes.push_back(shape_expr->values); } else { if (!vdevice_unknown) { - return TensorType(operand_dtype, tensor_ty->ndim, vdev); + return TensorType(operand_ty, tensor_ty->ndim, vdev); } - return TensorType(operand_dtype, tensor_ty->ndim); + return TensorType(operand_ty, tensor_ty->ndim); } } // Calculate output shape using InferEinsumShape in topi ffi::Array oshape = topi::InferEinsumShape(subscripts, input_shapes); if (!vdevice_unknown) { - return TensorType(ShapeExpr(oshape), operand_dtype, vdev); + return TensorType(ShapeExpr(oshape), operand_ty, vdev); } - return TensorType(ShapeExpr(oshape), operand_dtype); + return TensorType(ShapeExpr(oshape), operand_ty); } TVM_REGISTER_OP("relax.einsum") diff --git a/src/relax/op/tensor/linear_algebra.h b/src/relax/op/tensor/linear_algebra.h index ddfceae4dc35..481193f897b8 100644 --- a/src/relax/op/tensor/linear_algebra.h +++ b/src/relax/op/tensor/linear_algebra.h @@ -41,7 +41,7 @@ namespace relax { * When it is not specified, the output dtype will be the same as input dtype. * \return The computed result. */ -Expr matmul(Expr x1, Expr x2, ffi::Optional out_dtype); +Expr matmul(Expr x1, Expr x2, ffi::Optional out_dtype); /*! * \brief Einstein summation on the operands. diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index caa730091383..f0c7947b5ba2 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -35,7 +35,7 @@ #include #include -#include "tvm/runtime/data_type.h" +#include "tvm/ffi/dtype.h" namespace tvm { namespace relax { @@ -219,7 +219,7 @@ Type InferTypeConcat(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); int output_ndim = attrs->axis.has_value() ? kUnknownNDim : 1; - DataType output_dtype = DataType::Void(); + PrimType output_dtype = PrimType::Void(); ffi::Optional vdev = std::nullopt; bool shape_unknown = false; bool is_void_dtype = false; @@ -229,9 +229,9 @@ Type InferTypeConcat(const Call& call, const BlockBuilder& ctx) { for (TensorType ty : tensor_ty) { // Update the output dtype. - if (ty->dtype.is_void()) { + if (ty->IsUnknownDtype()) { is_void_dtype = true; - } else if (output_dtype.is_void()) { + } else if (output_dtype.IsVoid()) { output_dtype = ty->dtype; } else if (ty->dtype != output_dtype) { TVM_FFI_VISIT_THROW(TypeError, call) @@ -285,7 +285,7 @@ Type InferTypeConcat(const Call& call, const BlockBuilder& ctx) { } if (is_void_dtype) { - output_dtype = DataType::Void(); + output_dtype = PrimType::Void(); } if (vdevice_unknown) { vdev = std::nullopt; @@ -573,14 +573,16 @@ Type InferTypeIndexTensor(const Call& call, const BlockBuilder& ctx) { << "index_tensor expects a non‑empty tuple of index tensors"; } - DataType output_dtype = data_ty->dtype; + PrimType output_dtype = data_ty->dtype; int n_indices = static_cast(indices_ty.size()); ffi::Optional vdev = data_ty->vdevice; // Indices must be integers for (int i = 0; i < n_indices; ++i) { const auto& s = indices_ty[i]; - if (!s->IsUnknownDtype() && !s->dtype.is_int()) { + PrimType index_dtype = s->dtype; + // Indexing only requires integer element kind; vector lanes do not affect shape inference. + if (!s->IsUnknownDtype() && index_dtype.code() != DLDataTypeCode::kDLInt) { TVM_FFI_VISIT_THROW(TypeError, call) << "index_tensor requires every index tensor to have an integer dtype; " << "index " << i << " has dtype " << s->dtype; @@ -725,9 +727,10 @@ Type InferTypeLayoutTransform(const Call& call, const BlockBuilder& ctx) { // Check pad_value has same dtype as input. if (optional_pad_value.defined()) { PrimExpr padded_value = optional_pad_value.value()->value; - if (padded_value->dtype != data_ty->dtype) { + PrimType padded_dtype = padded_value.ty(); + if (padded_dtype != data_ty->dtype) { TVM_FFI_VISIT_THROW(TypeError, call) - << "layout_transform pad_value dtype (" << padded_value->dtype << ") and input dtype (" + << "layout_transform pad_value dtype (" << padded_dtype << ") and input dtype (" << data_ty->dtype << ") must be the same"; } } @@ -916,9 +919,10 @@ Expr ConvertNewShapeToExpr(const Expr& data, "Array of PrimExprs. However, the given new shape is " << shape; PrimExpr len = ffi::GetRef(_len); - TVM_FFI_ICHECK(len->dtype.is_int()) << "Reshape requires the new shape values to be all " - "integers. However, the give new shape is " - << shape; + TVM_FFI_ICHECK(len.ty().code() == DLDataTypeCode::kDLInt) + << "Reshape requires the new shape values to be all " + "integers. However, the give new shape is " + << shape; const auto* int_len = len.as(); if (int_len != nullptr && int_len->value == 0) { // Note that this dimension should be copied from the original shape. @@ -1108,7 +1112,7 @@ Type InferTypeSplit(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK_NE(axis, -1); - IntImm zero(DataType::Int(64), /*value=*/0); + IntImm zero(tvm::PrimType::Int(64), /*value=*/0); std::vector output_ty; for (size_t i = 0; i < p_indices.size() + 1; i++) { @@ -1489,7 +1493,7 @@ Type InferTypeStack(const Call& call, const BlockBuilder& ctx) { // Default axis is 0 if not specified int output_ndim = tensor_ty[0]->ndim + 1; // Stack adds one dimension - DataType output_dtype = DataType::Void(); + PrimType output_dtype = PrimType::Void(); ffi::Optional vdev = std::nullopt; bool shape_unknown = false; bool is_void_dtype = false; @@ -1499,9 +1503,9 @@ Type InferTypeStack(const Call& call, const BlockBuilder& ctx) { for (TensorType ty : tensor_ty) { // Check dtype consistency - if (ty->dtype.is_void()) { + if (ty->IsUnknownDtype()) { is_void_dtype = true; - } else if (output_dtype.is_void()) { + } else if (output_dtype.IsVoid()) { output_dtype = ty->dtype; } else if (ty->dtype != output_dtype) { TVM_FFI_VISIT_THROW(TypeError, call) @@ -1542,7 +1546,7 @@ Type InferTypeStack(const Call& call, const BlockBuilder& ctx) { } } - if (is_void_dtype) output_dtype = DataType::Void(); + if (is_void_dtype) output_dtype = PrimType::Void(); if (vdevice_unknown) vdev = std::nullopt; // Normalize axis (default to 0 if not specified) @@ -1650,7 +1654,7 @@ Type InferTypeCollapseSumLike(const Call& call, const BlockBuilder& ctx) { TensorType data_ty = input_ty[0]; TensorType collapse_target_ty = input_ty[1]; - DataType output_dtype = data_ty->dtype; + PrimType output_dtype = data_ty->dtype; ffi::Optional> data_shape_value; if (data_ty->shape.defined()) { @@ -1711,7 +1715,7 @@ Type InferTypeCollapseSumTo(const Call& call, const BlockBuilder& ctx) { << call->args[1]->ty->GetTypeKey(); } - DataType output_dtype = data_ty->dtype; + PrimType output_dtype = data_ty->dtype; ffi::Optional> data_shape_value; if (data_ty->shape.defined()) { @@ -2099,14 +2103,15 @@ Type InferTypeReverseSequence(const Call& call, const BlockBuilder& ctx) { << "ReverseSequence requires seq_lengths to be 1-D. However, seq_lengths has ndim " << seq_lengths_ty->ndim; } - if (!seq_lengths_ty->dtype.is_void() && !seq_lengths_ty->dtype.is_int()) { + PrimType seq_lengths_dtype = seq_lengths_ty->dtype; + if (!seq_lengths_ty->IsUnknownDtype() && !seq_lengths_dtype.MatchesCode(DLDataTypeCode::kDLInt)) { TVM_FFI_VISIT_THROW(ValueError, call) << "ReverseSequence requires seq_lengths to have dtype int32 or int64. However, " "seq_lengths has dtype " << seq_lengths_ty->dtype; } - if (seq_lengths_ty->dtype.is_int() && seq_lengths_ty->dtype.bits() != 32 && - seq_lengths_ty->dtype.bits() != 64) { + if (seq_lengths_dtype.MatchesCode(DLDataTypeCode::kDLInt) && + seq_lengths_dtype->dtype.bits != 32 && seq_lengths_dtype->dtype.bits != 64) { TVM_FFI_VISIT_THROW(ValueError, call) << "ReverseSequence requires seq_lengths to have dtype int32 or int64. However, " "seq_lengths has dtype " @@ -2192,7 +2197,9 @@ Type InferTypeGatherElements(const Call& call, const BlockBuilder& ctx) { << call->args[1]->ty->GetTypeKey(); } - if (!indices_ty->IsUnknownDtype() && !indices_ty->dtype.is_int()) { + PrimType indices_dtype = indices_ty->dtype; + // Gather indices only require integer element kind; vector lanes do not affect shape inference. + if (!indices_ty->IsUnknownDtype() && indices_dtype.code() != DLDataTypeCode::kDLInt) { TVM_FFI_VISIT_THROW(TypeError, call) << "GatherElements requires the input indices to have int64 dtype. However, the " << "given indices dtype is " << indices_ty->dtype; @@ -2295,7 +2302,7 @@ Type InferTypeGatherND(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK_GE(attrs->batch_dims, 0); int batch_dims = static_cast(attrs->batch_dims); int input_dims = data_ty->ndim; - if (!indices_ty->IsUnknownDtype() && indices_ty->dtype != DataType::Int(64)) { + if (!indices_ty->IsUnknownDtype() && indices_ty->dtype != PrimType::Int(64)) { TVM_FFI_VISIT_THROW(TypeError, call) << "GatherND requires the input indices to have int64 dtype. However, the " << "given indices dtype is " << indices_ty->dtype; @@ -2430,10 +2437,14 @@ Type InferTypeIndexPut(const Call& call, const BlockBuilder& ctx) { if (tensor_ty->IsUnknownDtype()) { LOG(WARNING) << "Data type of index tensor " << i << " has not been specified. Assume it has an integer type."; - } else if (!(tensor_ty->dtype.is_int() || tensor_ty->dtype.is_uint())) { - TVM_FFI_VISIT_THROW(TypeError, call) - << "IndexPut requires each index tensor to have integer dtype. " - << "However, index tensor " << i << " has dtype=" << tensor_ty->dtype; + } else { + PrimType index_dtype = tensor_ty->dtype; + if (!index_dtype.MatchesCode(DLDataTypeCode::kDLInt) && + !index_dtype.MatchesCode(DLDataTypeCode::kDLUInt)) { + TVM_FFI_VISIT_THROW(TypeError, call) + << "IndexPut requires each index tensor to have integer dtype. " + << "However, index tensor " << i << " has dtype=" << tensor_ty->dtype; + } } } @@ -2531,7 +2542,7 @@ Type InferTypeMeshgrid(const Call& call, const BlockBuilder& ctx) { } std::vector lengths; - DataType common_dtype = DataType::Void(); + PrimType common_dtype = PrimType::Void(); bool shape_unknown = false; ffi::Optional vdev = std::nullopt; bool vdevice_unknown = false; @@ -2545,9 +2556,9 @@ Type InferTypeMeshgrid(const Call& call, const BlockBuilder& ctx) { << i; } - if (ty->dtype.is_void()) { + if (ty->IsUnknownDtype()) { continue; - } else if (common_dtype.is_void()) { + } else if (common_dtype.IsVoid()) { common_dtype = ty->dtype; } else if (ty->dtype != common_dtype) { TVM_FFI_VISIT_THROW(TypeError, call) @@ -2683,11 +2694,15 @@ Type InferTypeScatterElements(const Call& call, const BlockBuilder& ctx) { if (indices_ty->IsUnknownDtype()) { LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; - } else if (!(indices_ty->dtype.is_int() || indices_ty->dtype.is_uint())) { - TVM_FFI_VISIT_THROW(TypeError, call) - << "ScatterElements op requires the input indices to have integer dtype. However, the " - "given indices dtype is " - << indices_ty->dtype; + } else { + PrimType indices_dtype = indices_ty->dtype; + if (!indices_dtype.MatchesCode(DLDataTypeCode::kDLInt) && + !indices_dtype.MatchesCode(DLDataTypeCode::kDLUInt)) { + TVM_FFI_VISIT_THROW(TypeError, call) + << "ScatterElements op requires the input indices to have integer dtype. However, the " + "given indices dtype is " + << indices_ty->dtype; + } } const auto* indices_shape = indices_ty->shape.as(); @@ -2803,11 +2818,15 @@ Type InferTypeScatterND(const Call& call, const BlockBuilder& ctx) { if (indices_ty->IsUnknownDtype()) { LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; - } else if (!(indices_ty->dtype.is_int() || indices_ty->dtype.is_uint())) { - TVM_FFI_VISIT_THROW(TypeError, call) - << "ScatterND op requires the input indices to have integer dtype. However, " - "the given indices dtype is " - << indices_ty->dtype; + } else { + PrimType indices_dtype = indices_ty->dtype; + if (!indices_dtype.MatchesCode(DLDataTypeCode::kDLInt) && + !indices_dtype.MatchesCode(DLDataTypeCode::kDLUInt)) { + TVM_FFI_VISIT_THROW(TypeError, call) + << "ScatterND op requires the input indices to have integer dtype. However, " + "the given indices dtype is " + << indices_ty->dtype; + } } const auto* data_shape = data_ty->shape.as(); @@ -3003,10 +3022,11 @@ Type InferTypeSliceScatter(const Call& call, const BlockBuilder& ctx) { << ") to be a PrimValue, but got " << arg_expr->GetTypeKey(); } const PrimExpr& prim_expr = prim_value_node->value; - if (!prim_expr.dtype().is_int() && !prim_expr.dtype().is_uint()) { + tvm::PrimType prim_ty = prim_expr.ty(); + if (prim_ty.code() != DLDataTypeCode::kDLInt && prim_ty.code() != DLDataTypeCode::kDLUInt) { TVM_FFI_VISIT_THROW(TypeError, call) << "SliceScatter expects `" << key << "` (" << prim_expr - << ") to be an integer PrimValue, but got dtype " << prim_expr.dtype(); + << ") to be an integer PrimValue, but got dtype " << prim_ty; } return prim_expr; }; @@ -3085,8 +3105,8 @@ Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, i attrs->axis = axis; // Check if on_value and off_value have the same dtype - DataType on_dtype = on_value->value->dtype; - DataType off_dtype = off_value->value->dtype; + PrimType on_dtype = on_value->value.ty(); + PrimType off_dtype = off_value->value.ty(); TVM_FFI_ICHECK(on_dtype == off_dtype) << "one_hot: on_value and off_value must have the same dtype, " << "but got " << on_dtype << " and " << off_dtype; @@ -3108,19 +3128,25 @@ Type InferTypeOneHot(const Call& call, const BlockBuilder& ctx) { PrimValue on_value = call->args[1].as_or_throw(); PrimValue off_value = call->args[2].as_or_throw(); // Check if on_value and off_value have the same dtype - TVM_FFI_ICHECK(on_value->value->dtype == off_value->value->dtype) + PrimType on_dtype = on_value->value.ty(); + PrimType off_dtype = off_value->value.ty(); + TVM_FFI_ICHECK(on_dtype == off_dtype) << "one_hot: on_value and off_value must have the same dtype, " - << "but got " << on_value->value->dtype << " and " << off_value->value->dtype; - DataType dtype = on_value->value->dtype; + << "but got " << on_dtype << " and " << off_dtype; + PrimType dtype = on_dtype; // Check if indices has an integer dtype if (indices_ty->IsUnknownDtype()) { LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; - } else if (!(indices_ty->dtype.is_int() || indices_ty->dtype.is_uint())) { - TVM_FFI_VISIT_THROW(TypeError, call) - << "one_hot op requires the input indices to have integer dtype. However, the " - "given indices dtype is " - << indices_ty->dtype; + } else { + PrimType indices_dtype = indices_ty->dtype; + if (!indices_dtype.MatchesCode(DLDataTypeCode::kDLInt) && + !indices_dtype.MatchesCode(DLDataTypeCode::kDLUInt)) { + TVM_FFI_VISIT_THROW(TypeError, call) + << "one_hot op requires the input indices to have integer dtype. However, the " + "given indices dtype is " + << indices_ty->dtype; + } } // Check if indices has unknown dimension if (indices_ty->IsUnknownNdim()) { diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc index 974d70e7300a..8940594abc51 100644 --- a/src/relax/op/tensor/qdq.cc +++ b/src/relax/op/tensor/qdq.cc @@ -39,7 +39,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { QuantizeAttrs::RegisterReflection(); } /* relax.quantize */ -Expr quantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_dtype) { +Expr quantize(Expr data, Expr scale, Expr zero_point, int axis, DLDataType out_dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->axis = axis; attrs->out_dtype = out_dtype; @@ -54,9 +54,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { Type InferTypeQuantize(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); - if (attrs->out_dtype != DataType::Int(8) && attrs->out_dtype != DataType::UInt(8) && - attrs->out_dtype != DataType::Int(16) && attrs->out_dtype != DataType::UInt(16) && - attrs->out_dtype != DataType::Float8E4M3FN() && attrs->out_dtype != DataType::Float8E5M2()) { + if (attrs->out_dtype != DLDataType{kDLInt, 8, 1} && + attrs->out_dtype != DLDataType{kDLUInt, 8, 1} && + attrs->out_dtype != DLDataType{kDLInt, 16, 1} && + attrs->out_dtype != DLDataType{kDLUInt, 16, 1} && + attrs->out_dtype != DLDataType{static_cast(kDLFloat8_e4m3fn), + static_cast(8), static_cast(1)} && + attrs->out_dtype != DLDataType{static_cast(kDLFloat8_e5m2), static_cast(8), + static_cast(1)}) { TVM_FFI_VISIT_THROW(TypeError, call) << "Unsupported output datatype attribute for operation: '" << attrs->out_dtype; } @@ -64,24 +69,27 @@ Type InferTypeQuantize(const Call& call, const BlockBuilder& ctx) { TensorType input_ty = GetInputTensorType(call, ctx)[0]; TensorType scale_ty = GetInputTensorType(call, ctx)[1]; TensorType zp_ty = GetInputTensorType(call, ctx)[2]; + PrimType input_dtype = input_ty->dtype; + PrimType scale_dtype = scale_ty->dtype; + PrimType zp_dtype = zp_ty->dtype; // Check input datatype: - if (input_ty->dtype != DataType::Float(16) && input_ty->dtype != DataType::Float(32)) { + if (input_dtype != PrimType::Float(16) && input_dtype != PrimType::Float(32)) { TVM_FFI_VISIT_THROW(TypeError, call) << "Unsupported input datatype for operation: " << input_ty->dtype; } // Check datatype of scale param: - if (scale_ty->dtype != DataType::Float(32) && scale_ty->dtype != DataType::Float(16)) { + if (scale_dtype != PrimType::Float(32) && scale_dtype != PrimType::Float(16)) { TVM_FFI_VISIT_THROW(TypeError, call) << "scale param datatype should be one of [float16, float32], but got " << scale_ty->dtype; } // Check datatype of zero_point param: - if (zp_ty->dtype != DataType::Int(8) && zp_ty->dtype != DataType::UInt(8) && - zp_ty->dtype != DataType::Int(16) && zp_ty->dtype != DataType::UInt(16) && - zp_ty->dtype != DataType::Int(32) && zp_ty->dtype != DataType::UInt(32) && - zp_ty->dtype != DataType::Float(16)) { + if (zp_dtype != PrimType::Int(8) && zp_dtype != PrimType::UInt(8) && + zp_dtype != PrimType::Int(16) && zp_dtype != PrimType::UInt(16) && + zp_dtype != PrimType::Int(32) && zp_dtype != PrimType::UInt(32) && + zp_dtype != PrimType::Float(16)) { TVM_FFI_VISIT_THROW(TypeError, call) << "zero_point param datatype should be one of " << "['int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 'float16'], " @@ -124,7 +132,7 @@ Type InferTypeQuantize(const Call& call, const BlockBuilder& ctx) { if (!is_scalar_or_singleton_vector(zp_ty)) check_param_size(zp_ty, input_ty, "zero_point"); auto output_ty = ffi::make_object(*input_ty.get()); - output_ty->dtype = attrs->out_dtype; + output_ty->dtype = PrimType(attrs->out_dtype); return TensorType(output_ty); } @@ -139,7 +147,7 @@ TVM_REGISTER_OP("relax.quantize") /* relax.dequantize */ -Expr dequantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_dtype) { +Expr dequantize(Expr data, Expr scale, Expr zero_point, int axis, DLDataType out_dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->axis = axis; attrs->out_dtype = out_dtype; @@ -154,7 +162,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { Type InferTypeDequantize(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); - if (attrs->out_dtype != DataType::Float(16) && attrs->out_dtype != DataType::Float(32)) { + if (attrs->out_dtype != DLDataType{kDLFloat, 16, 1} && + attrs->out_dtype != DLDataType{kDLFloat, 32, 1}) { TVM_FFI_VISIT_THROW(TypeError, call) << "Unsupported output datatype attribute for operation: " << attrs->out_dtype; } @@ -162,28 +171,34 @@ Type InferTypeDequantize(const Call& call, const BlockBuilder& ctx) { TensorType input_ty = GetInputTensorType(call, ctx)[0]; TensorType scale_ty = GetInputTensorType(call, ctx)[1]; TensorType zp_ty = GetInputTensorType(call, ctx)[2]; + PrimType input_dtype = input_ty->dtype; + PrimType scale_dtype = scale_ty->dtype; + PrimType zp_dtype = zp_ty->dtype; // Check input datatype: - if (input_ty->dtype != DataType::Int(8) && input_ty->dtype != DataType::UInt(8) && - input_ty->dtype != DataType::Int(16) && input_ty->dtype != DataType::UInt(16) && - input_ty->dtype != DataType::Int(32) && input_ty->dtype != DataType::Float8E4M3FN() && - input_ty->dtype != DataType::Float8E5M2() && input_ty->dtype != DataType::Float(16) && - input_ty->dtype != DataType::Float(32)) { + if (input_dtype != PrimType::Int(8) && input_dtype != PrimType::UInt(8) && + input_dtype != PrimType::Int(16) && input_dtype != PrimType::UInt(16) && + input_dtype != PrimType::Int(32) && + input_dtype != PrimType(DLDataType{static_cast(kDLFloat8_e4m3fn), + static_cast(8), static_cast(1)}) && + input_dtype != PrimType(DLDataType{static_cast(kDLFloat8_e5m2), + static_cast(8), static_cast(1)}) && + input_dtype != PrimType::Float(16) && input_dtype != PrimType::Float(32)) { TVM_FFI_VISIT_THROW(TypeError, call) << "Unsupported input datatype for operation: " << attrs->out_dtype; } // Check datatype of scale param: - if (scale_ty->dtype != DataType::Float(32) && scale_ty->dtype != DataType::Float(16)) { + if (scale_dtype != PrimType::Float(32) && scale_dtype != PrimType::Float(16)) { TVM_FFI_VISIT_THROW(TypeError, call) << "scale param datatype should be one of [float16, float32], but got " << scale_ty->dtype; } // Check datatype of zero_point param: - if (zp_ty->dtype != DataType::Int(8) && zp_ty->dtype != DataType::UInt(8) && - zp_ty->dtype != DataType::Int(16) && zp_ty->dtype != DataType::UInt(16) && - zp_ty->dtype != DataType::Int(32) && zp_ty->dtype != DataType::UInt(32) && - zp_ty->dtype != DataType::Float(16)) { + if (zp_dtype != PrimType::Int(8) && zp_dtype != PrimType::UInt(8) && + zp_dtype != PrimType::Int(16) && zp_dtype != PrimType::UInt(16) && + zp_dtype != PrimType::Int(32) && zp_dtype != PrimType::UInt(32) && + zp_dtype != PrimType::Float(16)) { TVM_FFI_VISIT_THROW(TypeError, call) << "zero_point param datatype should be one of " << "['int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 'float16'], " @@ -226,7 +241,7 @@ Type InferTypeDequantize(const Call& call, const BlockBuilder& ctx) { if (!is_scalar_or_singleton_vector(zp_ty)) check_param_size(zp_ty, input_ty, "zero_point"); auto output_ty = ffi::make_object(*input_ty.get()); - output_ty->dtype = attrs->out_dtype; + output_ty->dtype = PrimType(attrs->out_dtype); return TensorType(output_ty); } diff --git a/src/relax/op/tensor/qdq.h b/src/relax/op/tensor/qdq.h index 9d13dcde277f..bdb31f87e61e 100644 --- a/src/relax/op/tensor/qdq.h +++ b/src/relax/op/tensor/qdq.h @@ -40,7 +40,7 @@ namespace relax { * \param out_dtype The data type of the output tensor. * \return The computed result. */ -Expr quantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_dtype); +Expr quantize(Expr data, Expr scale, Expr zero_point, int axis, DLDataType out_dtype); /*! * \brief Dequantize op. @@ -53,7 +53,7 @@ Expr quantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_dty * \param out_dtype The data type of the output tensor. * \return The computed result. */ -Expr dequantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_dtype); +Expr dequantize(Expr data, Expr scale, Expr zero_point, int axis, DLDataType out_dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/sampling.cc b/src/relax/op/tensor/sampling.cc index 27f9241e2c29..196e6f887649 100644 --- a/src/relax/op/tensor/sampling.cc +++ b/src/relax/op/tensor/sampling.cc @@ -37,7 +37,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { MultinomialFromUniformAttrs::RegisterReflection(); /* relax.multinomial_from_uniform */ -Expr multinomial_from_uniform(Expr prob, Expr uniform_sample, Expr sample_indices, DataType dtype) { +Expr multinomial_from_uniform(Expr prob, Expr uniform_sample, Expr sample_indices, + DLDataType dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; @@ -59,19 +60,24 @@ Type InferTypeMultinomialFromUniform(const Call& call, const BlockBuilder& ctx) TensorType sample_indices_ty = GetInputTensorType(call, 2, ctx); const auto* attrs = call->attrs.as(); - if (!prob_ty->dtype.is_float()) { + // Only the element kind matters here; shape inference does not depend on vector lanes. + if (prob_ty->dtype.code() != DLDataTypeCode::kDLFloat && + prob_ty->dtype.code() != DLDataTypeCode::kDLBfloat) { TVM_FFI_VISIT_THROW(TypeError, call) << "Multinomial_from_uniform op requires the input prob to have float dtype. " "However, the given prob dtype is " << prob_ty->dtype; } - if (!uniform_sample_ty->dtype.is_float()) { + // Only the element kind matters here; shape inference does not depend on vector lanes. + if (uniform_sample_ty->dtype.code() != DLDataTypeCode::kDLFloat && + uniform_sample_ty->dtype.code() != DLDataTypeCode::kDLBfloat) { TVM_FFI_VISIT_THROW(TypeError, call) << "Multinomial_from_uniform op requires the input uniform_sample to have float " "dtype. However, the given uniform_sample dtype is " << uniform_sample_ty->dtype; } - if (!sample_indices_ty->dtype.is_int()) { + // Only the element kind matters here; shape inference does not depend on vector lanes. + if (sample_indices_ty->dtype.code() != DLDataTypeCode::kDLInt) { TVM_FFI_VISIT_THROW(TypeError, call) << "Multinomial from uniform op requires the input sample_indices to have int " "dtype. However, the given sample_indices dtype is " @@ -79,7 +85,7 @@ Type InferTypeMultinomialFromUniform(const Call& call, const BlockBuilder& ctx) } if (prob_ty->IsUnknownNdim() || uniform_sample_ty->IsUnknownNdim() || sample_indices_ty->IsUnknownNdim()) { - return TensorType(attrs->dtype, kUnknownNDim, prob_ty->vdevice); + return TensorType(PrimType(attrs->dtype), kUnknownNDim, prob_ty->vdevice); } if (prob_ty->ndim != 2) { TVM_FFI_VISIT_THROW(ValueError, call) @@ -109,7 +115,7 @@ Type InferTypeMultinomialFromUniform(const Call& call, const BlockBuilder& ctx) // The output shape is expected to be `(n, 1)` if (prob_shape == nullptr || uniform_sample_shape == nullptr || sample_indices_shape == nullptr) { - return TensorType(attrs->dtype, 2, prob_ty->vdevice); + return TensorType(PrimType(attrs->dtype), 2, prob_ty->vdevice); } PrimExpr batch = prob_shape->values[0]; @@ -132,7 +138,7 @@ Type InferTypeMultinomialFromUniform(const Call& call, const BlockBuilder& ctx) << uniform_sample_ty->shape << " and the given sample_indices tensor has shape " << sample_indices_ty->shape; } - return TensorType(ShapeExpr({n, 1}), attrs->dtype, prob_ty->vdevice); + return TensorType(ShapeExpr({n, 1}), PrimType(attrs->dtype), prob_ty->vdevice); } TVM_REGISTER_OP("relax.multinomial_from_uniform") diff --git a/src/relax/op/tensor/sampling.h b/src/relax/op/tensor/sampling.h index d13aa835d68d..077ef4313669 100644 --- a/src/relax/op/tensor/sampling.h +++ b/src/relax/op/tensor/sampling.h @@ -49,7 +49,8 @@ namespace relax { * \param dtype The data type of the output tensor. * \return The sampled result. */ -Expr multinomial_from_uniform(Expr prob, Expr uniform_sample, Expr sample_indices, DataType dtype); +Expr multinomial_from_uniform(Expr prob, Expr uniform_sample, Expr sample_indices, + DLDataType dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index d80f484ebcf5..c5021f6f5aef 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -64,10 +64,9 @@ Type InferTypeBucketize(const Call& call, const BlockBuilder& ctx) { } auto attrs = call->attrs.as(); - DataType out_dtype; - out_dtype = DataType::Int(64); + PrimType out_dtype = PrimType::Int(64); if (attrs->out_int32) { - out_dtype = DataType::Int(32); + out_dtype = PrimType::Int(32); } const auto* data_shape = input_tensor_info->shape.as(); @@ -119,13 +118,15 @@ Type InferTypeWhere(const Call& call, const BlockBuilder& ctx) { } } - if (!cond_ty->dtype.is_bool()) { + PrimType cond_dtype = cond_ty->dtype; + // Where condition validation only checks the boolean element kind; lanes are irrelevant here. + if (!cond_dtype.MatchesCode(DLDataTypeCode::kDLBool)) { TVM_FFI_VISIT_THROW(TypeError, call) << "Where requires the input condition tensor to have boolean dtype. However, " "the given condition dtype is " << cond_ty->dtype; } - DataType output_dtype = InferBinaryArithOpOutDtype(call, ctx, x1_ty, x2_ty); + PrimType output_dtype(InferBinaryArithOpOutDtype(call, ctx, x1_ty, x2_ty)); int output_ndim; if (cond_ty->IsUnknownNdim() || x1_ty->IsUnknownNdim() || x2_ty->IsUnknownNdim()) { @@ -209,7 +210,7 @@ Type InferTypeArgmaxArgmin(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK_GE(out_ndim, 0); } - DataType out_dtype = DataType::Int(64); + PrimType out_dtype = PrimType::Int(64); // The inference rule for reduction operator output shapes: // - axes is None, keepdims is false -> return the zero-rank shape; // - axes is None, keepdims is true -> return the shape whose ndim is the same as input and every @@ -230,7 +231,7 @@ Type InferTypeArgmaxArgmin(const Call& call, const BlockBuilder& ctx) { } if (data_ty->ndim > 0) { - out_dtype = data_shape->values[0]->dtype; + out_dtype = data_shape->values[0].ty(); } ffi::Array out_shape; diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index 57999a3356b7..a92cbee4a001 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -106,9 +106,9 @@ Type InferTypeUnique(const Call& call, const BlockBuilder& ctx) { if (f_convert_to_int64(return_index->value)) { if (data_ty->ndim == 0) { output_ty.push_back( - TensorType(ShapeExpr({IntImm::Int64(/*value=*/1)}), DataType::Int(64), data_ty->vdevice)); + TensorType(ShapeExpr({IntImm::Int64(/*value=*/1)}), PrimType::Int(64), data_ty->vdevice)); } else { - output_ty.push_back(TensorType(DataType::Int(64), /*ndim=*/1, data_ty->vdevice)); + output_ty.push_back(TensorType(PrimType::Int(64), /*ndim=*/1, data_ty->vdevice)); } } @@ -116,9 +116,9 @@ Type InferTypeUnique(const Call& call, const BlockBuilder& ctx) { if (f_convert_to_int64(return_inverse->value)) { if (data_ty->ndim == 0) { output_ty.push_back( - TensorType(ShapeExpr({IntImm::Int64(/*value=*/1)}), DataType::Int(64), data_ty->vdevice)); + TensorType(ShapeExpr({IntImm::Int64(/*value=*/1)}), PrimType::Int(64), data_ty->vdevice)); } else { - output_ty.push_back(TensorType(DataType::Int(64), /*ndim=*/1, data_ty->vdevice)); + output_ty.push_back(TensorType(PrimType::Int(64), /*ndim=*/1, data_ty->vdevice)); } } @@ -126,9 +126,9 @@ Type InferTypeUnique(const Call& call, const BlockBuilder& ctx) { if (f_convert_to_int64(return_counts->value)) { if (data_ty->ndim == 0) { output_ty.push_back( - TensorType(ShapeExpr({IntImm::Int64(/*value=*/1)}), DataType::Int(64), data_ty->vdevice)); + TensorType(ShapeExpr({IntImm::Int64(/*value=*/1)}), PrimType::Int(64), data_ty->vdevice)); } else { - output_ty.push_back(TensorType(DataType::Int(64), /*ndim=*/1, data_ty->vdevice)); + output_ty.push_back(TensorType(PrimType::Int(64), /*ndim=*/1, data_ty->vdevice)); } } @@ -175,7 +175,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { Type InferTypeNonzero(const Call& call, const BlockBuilder& ctx) { TensorType data_ty = GetInputTensorType(call, 0, ctx); - return TensorType(DataType::Int(64), 2, data_ty->vdevice); + return TensorType(PrimType::Int(64), 2, data_ty->vdevice); } TVM_REGISTER_OP("relax.nonzero") diff --git a/src/relax/op/tensor/sorting.cc b/src/relax/op/tensor/sorting.cc index 2d014cded4ec..c470fa0d4f6e 100644 --- a/src/relax/op/tensor/sorting.cc +++ b/src/relax/op/tensor/sorting.cc @@ -66,7 +66,7 @@ TVM_REGISTER_OP("relax.sort") /* relax.argsort */ -Expr argsort(Expr data, int axis, bool descending, DataType dtype) { +Expr argsort(Expr data, int axis, bool descending, DLDataType dtype) { auto attrs = ffi::make_object(); attrs->axis = std::move(axis); attrs->descending = std::move(descending); @@ -84,7 +84,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { Type InferTypeArgsort(const Call& call, const BlockBuilder& ctx) { TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - DataType out_type = attrs->dtype.is_void() ? data_ty->dtype : attrs->dtype; + PrimType out_type = + attrs->dtype == DLDataType{kDLOpaqueHandle, 0, 0} ? data_ty->dtype : PrimType(attrs->dtype); if (data_ty->shape.defined()) { return TensorType(data_ty->shape.value(), out_type, data_ty->vdevice); } @@ -100,7 +101,7 @@ TVM_REGISTER_OP("relax.argsort") /* relax.topk */ -Expr topk(Expr data, int k, int axis, ffi::String ret_type, bool largest, DataType dtype) { +Expr topk(Expr data, int k, int axis, ffi::String ret_type, bool largest, DLDataType dtype) { auto attrs = ffi::make_object(); attrs->k = std::move(k); attrs->axis = std::move(axis); @@ -121,7 +122,8 @@ Type InferTypeTopK(const Call& call, const BlockBuilder& ctx) { TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* data_shape = data_ty->shape.as(); const auto* attrs = call->attrs.as(); - DataType indices_type = attrs->dtype.is_void() ? data_ty->dtype : attrs->dtype; + PrimType indices_type = + attrs->dtype == DLDataType{kDLOpaqueHandle, 0, 0} ? data_ty->dtype : PrimType(attrs->dtype); int ndim = data_ty->ndim; int k = attrs->k; ffi::String ret_type = attrs->ret_type; diff --git a/src/relax/op/tensor/sorting.h b/src/relax/op/tensor/sorting.h index a4154ce416ad..8a2ec98388df 100644 --- a/src/relax/op/tensor/sorting.h +++ b/src/relax/op/tensor/sorting.h @@ -51,7 +51,7 @@ Expr sort(Expr data, int axis, bool descending); * \param dtype The data type of the output indices. * \return The computed result. */ -Expr argsort(Expr data, int axis, bool descending, DataType dtype); +Expr argsort(Expr data, int axis, bool descending, DLDataType dtype); /*! * \brief Get the top k elements in an input tensor along the given axis. @@ -63,7 +63,7 @@ Expr argsort(Expr data, int axis, bool descending, DataType dtype); * \param dtype The data type of the indices output. * \return The computed result. */ -Expr topk(Expr data, int k, int axis, ffi::String ret_type, bool largest, DataType dtype); +Expr topk(Expr data, int k, int axis, ffi::String ret_type, bool largest, DLDataType dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc index 9fe68afe2901..15bbd701e67f 100644 --- a/src/relax/op/tensor/statistical.cc +++ b/src/relax/op/tensor/statistical.cc @@ -155,7 +155,8 @@ Type InferTypeScan(const Call& call, const BlockBuilder& ctx) { TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - DataType out_type = attrs->dtype.is_void() ? data_ty->dtype : attrs->dtype; + PrimType out_type = + attrs->dtype == DLDataType{kDLOpaqueHandle, 0, 0} ? data_ty->dtype : PrimType(attrs->dtype); if (!attrs->axis.has_value()) { // flattened @@ -216,7 +217,7 @@ Type InferTypeStatisticalExtension(const Call& call, const BlockBuilder& ctx) { return TensorType(ShapeExpr(ffi::Array()), data_ty->dtype, data_ty->vdevice); } return TupleType({TensorType(data_ty->dtype, out_ndim, data_ty->vdevice), - TensorType(DataType::Int(64), out_ndim, data_ty->vdevice)}); + TensorType(PrimType::Int(64), out_ndim, data_ty->vdevice)}); } ffi::Array out_shape; @@ -234,15 +235,15 @@ Type InferTypeStatisticalExtension(const Call& call, const BlockBuilder& ctx) { return TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice); else return TupleType({TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice), - TensorType(ShapeExpr(out_shape), DataType::Int(64), data_ty->vdevice)}); + TensorType(ShapeExpr(out_shape), PrimType::Int(64), data_ty->vdevice)}); } /* relax.cumprod */ -Expr cumprod(Expr data, ffi::Optional axis, ffi::Optional dtype, +Expr cumprod(Expr data, ffi::Optional axis, ffi::Optional dtype, bool exclusive) { auto attrs = ffi::make_object(); attrs->axis = std::move(axis); - attrs->dtype = std::move(dtype.value_or(DataType::Void())); + attrs->dtype = dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); attrs->exclusive = exclusive; static const Op& op = Op::Get("relax.cumprod"); @@ -262,10 +263,11 @@ TVM_REGISTER_OP("relax.cumprod") .set_attr("FPurity", true); /* relax.cumsum */ -Expr cumsum(Expr data, ffi::Optional axis, ffi::Optional dtype, bool exclusive) { +Expr cumsum(Expr data, ffi::Optional axis, ffi::Optional dtype, + bool exclusive) { auto attrs = ffi::make_object(); attrs->axis = std::move(axis); - attrs->dtype = std::move(dtype.value_or(DataType::Void())); + attrs->dtype = dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); attrs->exclusive = exclusive; static const Op& op = Op::Get("relax.cumsum"); diff --git a/src/relax/op/tensor/statistical.h b/src/relax/op/tensor/statistical.h index 2d80790926ed..3ab998110603 100644 --- a/src/relax/op/tensor/statistical.h +++ b/src/relax/op/tensor/statistical.h @@ -99,7 +99,7 @@ Expr sum(Expr x, ffi::Optional> axis, bool keepdims); * result. */ Expr cumprod(Expr data, ffi::Optional axis = std::nullopt, - ffi::Optional dtype = std::nullopt, bool exclusive = false); + ffi::Optional dtype = std::nullopt, bool exclusive = false); /*! * \brief Numpy style cumsum op. Return the cumulative inclusive sum of the elements along @@ -114,7 +114,7 @@ Expr cumprod(Expr data, ffi::Optional axis = std::nullopt, * \return The computed result. */ Expr cumsum(Expr data, ffi::Optional axis = std::nullopt, - ffi::Optional dtype = std::nullopt, bool exclusive = false); + ffi::Optional dtype = std::nullopt, bool exclusive = false); /*! \brief Computes the variance of tensor elements over given axes. */ Expr variance(Expr x, ffi::Optional> axis, bool keepdims); diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc index 6daacfe16578..1e21e7dbdcc7 100644 --- a/src/relax/op/tensor/ternary.cc +++ b/src/relax/op/tensor/ternary.cc @@ -57,9 +57,9 @@ Type InferTypeEwiseFMA(const Call& call, const BlockBuilder& ctx) { } } - DataType output_dtype; + PrimType output_dtype = PrimType::Void(); if (t1->IsUnknownDtype() || t2->IsUnknownDtype() || t3->IsUnknownDtype()) { - output_dtype = DataType::Void(); + output_dtype = PrimType::Void(); } else if (t1->dtype != t2->dtype || t2->dtype != t3->dtype) { TVM_FFI_VISIT_THROW(TypeError, call) << "Data types " << t1->dtype << ", " << t2->dtype << ", and " << t3->dtype << " must be equal for EwiseFMA"; diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index 598ec78aacda..bd15223df878 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -33,7 +33,7 @@ namespace relax { Type InferTypeUnaryCheck(const Call& call, const BlockBuilder& ctx) { return InferTypeUnary(call, ctx, - [](const TensorType& input_ty) { return DataType::Bool(); }); + [](const TensorType& input_ty) { return PrimType::Bool(); }); } /***************** Arithmetic operators *****************/ diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc index bde579f0ed5a..6f289d6b8755 100644 --- a/src/relax/op/vision/nms.cc +++ b/src/relax/op/vision/nms.cc @@ -84,8 +84,8 @@ Type InferTypeAllClassNMS(const Call& call, const BlockBuilder& ctx) { ShapeExpr oshape(oshape_values); tvm::ffi::Array counts_values = {1}; ShapeExpr counts_shape(counts_values); - tvm::ffi::Array fields = {TensorType(oshape, DataType::Int(64), vdev), - TensorType(counts_shape, DataType::Int(64), vdev)}; + tvm::ffi::Array fields = {TensorType(oshape, PrimType::Int(64), vdev), + TensorType(counts_shape, PrimType::Int(64), vdev)}; return TupleType(fields); } @@ -96,9 +96,9 @@ Type InferTypeAllClassNMS(const Call& call, const BlockBuilder& ctx) { ShapeExpr scores_shape(scores_values); tvm::ffi::Array counts_values = {batch}; ShapeExpr counts_shape(counts_values); - tvm::ffi::Array fields = {TensorType(indices_shape, DataType::Int(64), vdev), - TensorType(scores_shape, DataType::Float(32), vdev), - TensorType(counts_shape, DataType::Int(64), vdev)}; + tvm::ffi::Array fields = {TensorType(indices_shape, PrimType::Int(64), vdev), + TensorType(scores_shape, PrimType::Float(32), vdev), + TensorType(counts_shape, PrimType::Int(64), vdev)}; return TupleType(fields); } @@ -153,9 +153,9 @@ Type InferTypeGetValidCounts(const Call& call, const BlockBuilder& ctx) { auto vdev = data_ty->vdevice; const auto* data_shape = data_ty->shape.as(); if (data_shape == nullptr) { - tvm::ffi::Array fields = {TensorType(DataType::Int(32), /*ndim=*/1, vdev), + tvm::ffi::Array fields = {TensorType(PrimType::Int(32), /*ndim=*/1, vdev), TensorType(data_ty->dtype, /*ndim=*/3, vdev), - TensorType(DataType::Int(32), /*ndim=*/2, vdev)}; + TensorType(PrimType::Int(32), /*ndim=*/2, vdev)}; return TupleType(fields); } @@ -177,9 +177,9 @@ Type InferTypeGetValidCounts(const Call& call, const BlockBuilder& ctx) { } tvm::ffi::Array fields = { - TensorType(ShapeExpr({batch}), DataType::Int(32), vdev), + TensorType(ShapeExpr({batch}), PrimType::Int(32), vdev), TensorType(ShapeExpr({batch, num_anchors, elem_length}), data_ty->dtype, vdev), - TensorType(ShapeExpr({batch, num_anchors}), DataType::Int(32), vdev)}; + TensorType(ShapeExpr({batch, num_anchors}), PrimType::Int(32), vdev)}; return TupleType(fields); } @@ -251,12 +251,12 @@ Type InferTypeNMS(const Call& call, const BlockBuilder& ctx) { TVM_FFI_VISIT_THROW(ValueError, call) << "non_max_suppression expects indices to be 2-D, got ndim " << indices_ty->ndim; } - if (!valid_count_ty->IsUnknownDtype() && valid_count_ty->dtype != DataType::Int(32)) { + if (!valid_count_ty->IsUnknownDtype() && valid_count_ty->dtype != PrimType::Int(32)) { TVM_FFI_VISIT_THROW(TypeError, call) << "non_max_suppression expects valid_count to have dtype int32, got " << valid_count_ty->dtype; } - if (!indices_ty->IsUnknownDtype() && indices_ty->dtype != DataType::Int(32)) { + if (!indices_ty->IsUnknownDtype() && indices_ty->dtype != PrimType::Int(32)) { TVM_FFI_VISIT_THROW(TypeError, call) << "non_max_suppression expects indices to have dtype int32, got " << indices_ty->dtype; } @@ -319,30 +319,30 @@ Type InferTypeNMS(const Call& call, const BlockBuilder& ctx) { // valid_box_count[batch, 1]) if (data_shape == nullptr) { tvm::ffi::Array fields = {TensorType(data_ty->dtype, /*ndim=*/3, vdev), - TensorType(DataType::Int(32), /*ndim=*/2, vdev), - TensorType(DataType::Int(32), /*ndim=*/2, vdev)}; + TensorType(PrimType::Int(32), /*ndim=*/2, vdev), + TensorType(PrimType::Int(32), /*ndim=*/2, vdev)}; return TupleType(fields); } auto batch = data_shape->values[0]; auto num_anchors = data_shape->values[1]; tvm::ffi::Array fields = { TensorType(ffi::GetRef(data_shape), data_ty->dtype, vdev), - TensorType(ShapeExpr({batch, num_anchors}), DataType::Int(32), vdev), - TensorType(ShapeExpr({batch, IntImm::Int64(1)}), DataType::Int(32), vdev)}; + TensorType(ShapeExpr({batch, num_anchors}), PrimType::Int(32), vdev), + TensorType(ShapeExpr({batch, IntImm::Int64(1)}), PrimType::Int(32), vdev)}; return TupleType(fields); } // Hard NMS returns (box_indices[batch, num_anchors], valid_box_count[batch, 1]) if (data_shape == nullptr) { - tvm::ffi::Array fields = {TensorType(DataType::Int(32), /*ndim=*/2, vdev), - TensorType(DataType::Int(32), /*ndim=*/2, vdev)}; + tvm::ffi::Array fields = {TensorType(PrimType::Int(32), /*ndim=*/2, vdev), + TensorType(PrimType::Int(32), /*ndim=*/2, vdev)}; return TupleType(fields); } auto batch = data_shape->values[0]; auto num_anchors = data_shape->values[1]; tvm::ffi::Array fields = { - TensorType(ShapeExpr({batch, num_anchors}), DataType::Int(32), vdev), - TensorType(ShapeExpr({batch, IntImm::Int64(1)}), DataType::Int(32), vdev)}; + TensorType(ShapeExpr({batch, num_anchors}), PrimType::Int(32), vdev), + TensorType(ShapeExpr({batch, IntImm::Int64(1)}), PrimType::Int(32), vdev)}; return TupleType(fields); } diff --git a/src/relax/script/printer/dependent_type.cc b/src/relax/script/printer/dependent_type.cc index a37c21406fac..e3a14c0cdafe 100644 --- a/src/relax/script/printer/dependent_type.cc +++ b/src/relax/script/printer/dependent_type.cc @@ -100,7 +100,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } if (!n->IsUnknownDtype()) { kwargs_keys.push_back("dtype"); - kwargs_values.push_back(LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))); + kwargs_values.push_back(LiteralDoc::DataType(n->dtype->dtype, n_p->Attr("dtype"))); } if (!n->shape.defined() && !n->IsUnknownNdim()) { kwargs_keys.push_back("ndim"); diff --git a/src/relax/script/printer/distributed.cc b/src/relax/script/printer/distributed.cc index f05ec8fe714a..97d800d5d139 100644 --- a/src/relax/script/printer/distributed.cc +++ b/src/relax/script/printer/distributed.cc @@ -61,11 +61,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } if (!n->tensor_ty->IsUnknownDtype()) { if (!require_kwargs) { - args.push_back(LiteralDoc::DataType(n->tensor_ty->dtype, n_p->Attr("dtype"))); + args.push_back(LiteralDoc::DataType(n->tensor_ty->dtype->dtype, n_p->Attr("dtype"))); } else { kwargs_keys.push_back("dtype"); kwargs_values.push_back( - LiteralDoc::DataType(n->tensor_ty->dtype, n_p->Attr("dtype"))); + LiteralDoc::DataType(n->tensor_ty->dtype->dtype, n_p->Attr("dtype"))); } } else { require_kwargs = true; diff --git a/src/relax/script/printer/expr.cc b/src/relax/script/printer/expr.cc index dfce2b40b1f9..7b2f39ecf335 100644 --- a/src/relax/script/printer/expr.cc +++ b/src/relax/script/printer/expr.cc @@ -81,21 +81,21 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); ffi::Optional SpecialScalar(const runtime::Tensor& n, const AccessPath& p) { - DataType dtype = n.DataType(); + DLDataType dtype = n.DataType(); const void* data = n->data; if (n->ndim != 0 || n->device.device_type != kDLCPU) { return std::nullopt; } - if (dtype == DataType::Int(8)) { + if (dtype == DLDataType{kDLInt, 8, 1}) { return LiteralDoc::Int(*reinterpret_cast(data), p); - } else if (dtype == DataType::Int(16)) { + } else if (dtype == DLDataType{kDLInt, 16, 1}) { return LiteralDoc::Int(*reinterpret_cast(data), p); - } else if (dtype == DataType::Int(32)) { + } else if (dtype == DLDataType{kDLInt, 32, 1}) { return LiteralDoc::Int(*reinterpret_cast(data), p); - } else if (dtype == DataType::Int(64)) { + } else if (dtype == DLDataType{kDLInt, 64, 1}) { return LiteralDoc::Int(*reinterpret_cast(data), p); - } else if (dtype == DataType::Float(16)) { + } else if (dtype == DLDataType{kDLFloat, 16, 1}) { // From IEEE-754 float16 definition // // Ref: https://en.wikipedia.org/wiki/Half-precision_floating-point_format @@ -122,11 +122,11 @@ ffi::Optional SpecialScalar(const runtime::Tensor& n, const AccessPath& } return LiteralDoc::Float(value, p); - } else if (dtype == DataType::Float(32)) { + } else if (dtype == DLDataType{kDLFloat, 32, 1}) { return LiteralDoc::Float(*reinterpret_cast(data), p); - } else if (dtype == DataType::Float(64)) { + } else if (dtype == DLDataType{kDLFloat, 64, 1}) { return LiteralDoc::Float(*reinterpret_cast(data), p); - } else if (dtype == DataType::Bool()) { + } else if (dtype == DLDataType{kDLBool, 8, 1}) { return LiteralDoc::Boolean(*reinterpret_cast(data), p); } else { return std::nullopt; diff --git a/src/relax/script/printer/tir.cc b/src/relax/script/printer/tir.cc index e0742f8edd44..06bce7c1ff8c 100644 --- a/src/relax/script/printer/tir.cc +++ b/src/relax/script/printer/tir.cc @@ -43,9 +43,10 @@ RelaxFrameNode* GetRelaxFrame(IRDocsifier d) { } Doc PrintTIRVar(tirx::Var n, AccessPath n_p, IRDocsifier d) { - TVM_FFI_CHECK(n->dtype.is_scalar(), TypeError) + PrimType n_ty = n.ty(); + TVM_FFI_CHECK(!n_ty.IsScalableVector() && !n_ty.IsFixedLengthVector(), TypeError) << "Relax only uses scalar TIR variables," - << "but received TIR variable " << n << " with dtype " << n->dtype; + << "but received TIR variable " << n << " with dtype " << n_ty->dtype; if (!d->IsVarDefined(n)) { RelaxFrameNode* f = GetRelaxFrame(d); @@ -77,7 +78,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "relax", [](tvm::IntImm n, AccessPath n_p, IRDocsifier d) -> Doc { // // TODO(@junrushao): support non-int64 cases - if (n->dtype.is_bool()) { + if (n->ty().MatchesElementType(DLDataTypeCode::kDLBool, 8)) { return LiteralDoc::Boolean(n->value, n_p); } else { return LiteralDoc::Int(n->value, n_p); diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 4cf8831514dc..2d6e6fcc5e33 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -208,22 +208,24 @@ std::tuple)>> // If two of the three are compile-time, group those two values // together, to allow them to be lifted out and pre-computed. if (is_compile_time(expr_a) && is_compile_time(expr_b)) { - return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, DataType::Void()); + return matmul(matmul(expr_a, expr_b, (DLDataType{kDLOpaqueHandle, 0, 0})), expr_c, + (DLDataType{kDLOpaqueHandle, 0, 0})); } else if (is_compile_time(expr_b) && is_compile_time(expr_c)) { - return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), DataType::Void()); + return matmul(expr_a, matmul(expr_b, expr_c, (DLDataType{kDLOpaqueHandle, 0, 0})), + (DLDataType{kDLOpaqueHandle, 0, 0})); } // Otherwise, select the order that reduces the total number of // operations required, assuming a naive matmul (see below). if (shape_a.size() == 1) { - shape_a = {IntImm(shape_a[0].dtype(), 1), shape_a[0]}; + shape_a = {IntImm(shape_a[0].ty(), 1), shape_a[0]}; } if (shape_b.size() == 1) { if (matches.count(pat_matmul_on_lhs)) { - shape_b = {shape_b[0], IntImm(shape_b[0].dtype(), 1)}; + shape_b = {shape_b[0], IntImm(shape_b[0].ty(), 1)}; } else if (matches.count(pat_matmul_on_rhs)) { - shape_b = {IntImm(shape_b[0].dtype(), 1), shape_b[0]}; + shape_b = {IntImm(shape_b[0].ty(), 1), shape_b[0]}; } else { TVM_FFI_THROW(InternalError) << "OrPattern " << pat << " matched, but neither " << pat_matmul_on_lhs << " nor " @@ -231,7 +233,7 @@ std::tuple)>> } } if (shape_c.size() == 1) { - shape_c = {shape_c[0], IntImm(shape_c[0].dtype(), 1)}; + shape_c = {shape_c[0], IntImm(shape_c[0].ty(), 1)}; } PrimExpr size_N = shape_a[shape_a.size() - 2]; // row of A @@ -285,9 +287,11 @@ std::tuple)>> size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0); if (analyzer->CanProve(ops_with_lhs_first < ops_with_rhs_first)) { - return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, DataType::Void()); + return matmul(matmul(expr_a, expr_b, (DLDataType{kDLOpaqueHandle, 0, 0})), expr_c, + (DLDataType{kDLOpaqueHandle, 0, 0})); } else if (analyzer->CanProve(ops_with_rhs_first < ops_with_lhs_first)) { - return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), DataType::Void()); + return matmul(expr_a, matmul(expr_b, expr_c, (DLDataType{kDLOpaqueHandle, 0, 0})), + (DLDataType{kDLOpaqueHandle, 0, 0})); } // If we cannot determine which order is best, keep the existing order. diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index 4dfc84b822da..a593cb7ffee7 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -61,7 +61,7 @@ class ExternFunctionRewriter : ExprMutator { // Append the workspace parameter to this function. ffi::Array new_params = func_node->params; - auto ty = TensorType(ShapeExpr({IntImm::Int32(max_workspace_size_)}), DataType::UInt(8)); + auto ty = TensorType(ShapeExpr({IntImm::Int32(max_workspace_size_)}), PrimType::UInt(8)); Var workspace_param(name_sup_->FreshName("workspace"), ty); if (func_node->GetAttr(attr::kCodegen)) { @@ -149,7 +149,7 @@ class WorkspaceProvider : ExprMutator { builder_->BeginDataflowBlock(); if (!workspace_var_main_.defined()) { auto shape = ShapeExpr({IntImm::Int32(max_workspace_size_)}); - auto ty = DataTypeImm(DataType::UInt(8)); + auto ty = DataTypeImm((DLDataType{kDLUInt, 8, 1})); auto workspace = MakeAllocTensor(shape, ty, PrimValue::Int64(0)); workspace_var_main_ = builder_->Emit(workspace, "workspace_main"); } diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index a938b946d20c..7a3b5743f423 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -45,7 +45,7 @@ static constexpr const char* kOperatorName = "operator_name"; /*! \brief Construct ranges from shape dimensions */ static ffi::Array ConstructRangeFromShape(const ffi::Array& shape) { - return shape.Map([](const PrimExpr& dim) { return Range(IntImm(dim.dtype(), 0), dim); }); + return shape.Map([](const PrimExpr& dim) { return Range(IntImm(dim.ty(), 0), dim); }); } static ffi::Array GetShapeFromTensorType(const TensorType& tensor_ty) { @@ -206,7 +206,7 @@ class AlterOpImplMutator : public ExprMutator { * \brief Adds the \p remove_pad op to the module if it has not already been added before. * \returns The global var associated with the remove_pad PrimFunc. */ - GlobalVar GetOrCreateRemovePadOp(const ffi::Array& old_shape, const DataType& dtype) { + GlobalVar GetOrCreateRemovePadOp(const ffi::Array& old_shape, DLDataType dtype) { int t_shape = old_shape.size(); if (remove_pad_map_.count(t_shape) != 0) { return remove_pad_map_[t_shape]; @@ -214,8 +214,8 @@ class AlterOpImplMutator : public ExprMutator { // Create dynamic shapes for input and output tensors ffi::Array dyn_padded_shape, dyn_old_shape; for (int i = 0; i < t_shape; i++) { - tirx::Var var1("p" + std::to_string(i), old_shape[i].dtype()); - tirx::Var var2("i" + std::to_string(i), old_shape[i].dtype()); + tirx::Var var1("p" + std::to_string(i), old_shape[i].ty()); + tirx::Var var2("i" + std::to_string(i), old_shape[i].ty()); dyn_padded_shape.push_back(var1); dyn_old_shape.push_back(var2); } @@ -264,7 +264,7 @@ class AlterOpImplMutator : public ExprMutator { TransformLayout(expr, inverse_index_map, axis_separator, input_axis_separator)); const auto& tensor_ty = padded_expr->ty.as_or_throw(); - GlobalVar gv_remove_pad = GetOrCreateRemovePadOp(old_shape, tensor_ty->dtype); + GlobalVar gv_remove_pad = GetOrCreateRemovePadOp(old_shape, tensor_ty->dtype->dtype); return Call(call_tir_op_, {gv_remove_pad, Tuple({padded_expr})}, {}, {old_tensor_ty}); } } diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index 61fee5be7f8d..5a1bbcaa0040 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -90,12 +90,12 @@ class CallTIRMutator : public ExprMutator { } if (!is_inplace) { - outs.push_back(builder_->Emit( - Call(alloc_tensor_op, - {output_ty->shape.value().as_or_throw(), - DataTypeImm(output_ty->dtype), PrimValue::Int64(dev_index), StringImm(scope)}, - Attrs(), {output_ty}), - "alloc")); + outs.push_back(builder_->Emit(Call(alloc_tensor_op, + {output_ty->shape.value().as_or_throw(), + DataTypeImm(output_ty->dtype->dtype), + PrimValue::Int64(dev_index), StringImm(scope)}, + Attrs(), {output_ty}), + "alloc")); } else { // if there is only one output, it must be an in-place argument, but check anyway TVM_FFI_ICHECK(inplace_attrs->inplace_indices[0] != -1) @@ -129,8 +129,8 @@ class CallTIRMutator : public ExprMutator { outs.push_back( builder_->Emit(Call(alloc_tensor_op, {field_tensor->shape.value().as_or_throw(), - DataTypeImm(field_tensor->dtype), PrimValue::Int64(dev_index), - StringImm(scope)}, + DataTypeImm(field_tensor->dtype->dtype), + PrimValue::Int64(dev_index), StringImm(scope)}, Attrs(), {field_tensor}), "alloc")); } else { diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index 1319356ee169..128202063695 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -202,7 +202,7 @@ ffi::TypedFunction(ffi::Map, ffi::Mapdtype; + DLDataType out_dtype = GetTensorType(matchings[patterns.matmul[indices[0]]])->dtype->dtype; auto matmul_combined = matmul(lhs, concat_rhs, out_dtype); if (branch_info.bias_dim) { diff --git a/src/relax/transform/compute_prim_value.cc b/src/relax/transform/compute_prim_value.cc index 4ad34d04367d..4c937fe135dc 100644 --- a/src/relax/transform/compute_prim_value.cc +++ b/src/relax/transform/compute_prim_value.cc @@ -43,11 +43,12 @@ class PrimValueComputeInjector : public ExprMutator { return node; } - auto ret_dtype = node->value->dtype; + tvm::PrimType ret_ty = node->value.ty(); auto param_vars = tirx::UndefinedVars(node->value); - tirx::Stmt body = tirx::Evaluate(tirx::Call(ret_dtype, tirx::builtin::ret(), {node->value})); + tirx::Stmt body = + tirx::Evaluate(tirx::Call(node->value.ty(), tirx::builtin::ret(), {node->value})); - tirx::PrimFunc func(param_vars, body, tvm::PrimType(ret_dtype), {}, + tirx::PrimFunc func(param_vars, body, ret_ty, {}, DictAttrs({{tirx::attr::kIsHostFunc, true}, {tvm::attr::kSTir, true}})); func = s_tir::RenewDefs(func); diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index ed2a9b1c8a8a..bd4631bb4cf8 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -102,7 +102,7 @@ class LayoutConvertMutator : public ExprMutator { ffi::Array initial_indices_expr; initial_indices.reserve(ndim); for (int i = 0; i < ndim; ++i) { - auto var = tvm::tirx::Var("i" + std::to_string(i), DataType::Int(32)); + auto var = tvm::tirx::Var("i" + std::to_string(i), PrimType::Int(32)); initial_indices.push_back(var); initial_indices_expr.push_back(var); } diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index fcedd3119599..289c1c3c3b40 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -383,7 +383,7 @@ std::unordered_set GatherCandidat const Type& result_ty) { if (auto* tensor_info = result_ty.as()) { // don't consider void dtype (don't know the size at compile time) - if (tensor_info->dtype.is_void()) { + if (tensor_info->dtype.IsVoid()) { return {}; } // don't consider cases where we don't know the shape at compile time diff --git a/src/relax/transform/decompose_ops.cc b/src/relax/transform/decompose_ops.cc index 494e4a67a4a4..156d3c278c46 100644 --- a/src/relax/transform/decompose_ops.cc +++ b/src/relax/transform/decompose_ops.cc @@ -66,7 +66,7 @@ Tuple DecomposeBatchNorm(const Call& call) { Expr moving_var = ExpandToMatchInput(call->args[4], ty->ndim, {attrs->axis}); // output = (x - mean) / sqrt(var + epsilon) * gamma + beta - Expr epsilon = MakeConstantScalar(attrs->epsilon, ty->dtype); + Expr epsilon = MakeConstantScalar(attrs->epsilon, ty->dtype->dtype); Expr sqrt_var = sqrt(add(moving_var, epsilon)); Expr out = divide(subtract(data, moving_mean), sqrt_var); @@ -103,8 +103,8 @@ Expr MutateBatchNormForTraining(Call call) { Expr data_mean = mean(data, reduce_axes, false); Expr data_var = variance(data, reduce_axes, false); - Expr momentum = MakeConstantScalar(attrs->momentum, ty->dtype); - Expr one_minus_mom = MakeConstantScalar(1 - attrs->momentum, ty->dtype); + Expr momentum = MakeConstantScalar(attrs->momentum, ty->dtype->dtype); + Expr one_minus_mom = MakeConstantScalar(1 - attrs->momentum, ty->dtype->dtype); Expr new_moving_mean = add(multiply(one_minus_mom, moving_mean), multiply(momentum, data_mean)); Expr new_moving_var = add(multiply(one_minus_mom, moving_var), multiply(momentum, data_var)); @@ -128,7 +128,7 @@ Expr DecomposeLayerNorm(const Call& call) { Expr data_var = variance(data, attrs->axes, true); // output = (x - mean) / sqrt(var + epsilon) * gamma + beta - Expr epsilon = MakeConstantScalar(attrs->epsilon, ty->dtype); + Expr epsilon = MakeConstantScalar(attrs->epsilon, ty->dtype->dtype); Expr sqrt_var = sqrt(add(data_var, epsilon)); Expr out = divide(subtract(data, data_mean), sqrt_var); @@ -159,7 +159,7 @@ Expr TensorToShape(const Call& call_node, const BlockBuilder& builder) { // ffi::Array), we define symbolic variables and returns them as a ShapeExpr. ffi::Array shape_var; for (int i = 0; i < ty->ndim; i++) { - shape_var.push_back(tirx::Var("x", DataType::Int(64))); + shape_var.push_back(tirx::Var("x", PrimType::Int(64))); } // bind symbolic variables to the shape tuple relax::Var var("y", ShapeType(shape_var)); diff --git a/src/relax/transform/expand_matmul_of_sum.cc b/src/relax/transform/expand_matmul_of_sum.cc index 1e768478fd95..9bf5fbd53b2d 100644 --- a/src/relax/transform/expand_matmul_of_sum.cc +++ b/src/relax/transform/expand_matmul_of_sum.cc @@ -88,7 +88,8 @@ std::tuple)>> rhs_b = permute_dims(rhs_b, axes); } - return add(matmul(lhs, rhs_a, DataType::Void()), matmul(lhs, rhs_b, DataType::Void())); + return add(matmul(lhs, rhs_a, (DLDataType{kDLOpaqueHandle, 0, 0})), + matmul(lhs, rhs_b, (DLDataType{kDLOpaqueHandle, 0, 0}))); }; return {pat_matmul, rewriter}; diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index d615c014709b..7c92ae49c578 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -197,7 +197,7 @@ class ConstantFolder : public ExprMutator { // Returns std::nullopt on failure. ffi::Optional ConstEvaluateCallTIR(tirx::PrimFunc tir_func, ffi::Array arr_args, ffi::Shape shape, - DataType ret_type) { + DLDataType ret_type) { // obtain function from the cache. ffi::Optional func = GetCachedBuild(tir_func); if (!func) return std::nullopt; @@ -243,7 +243,8 @@ class ConstantFolder : public ExprMutator { if (!shape) return std::nullopt; auto tensor_ty = tuple_ty->fields[i].as_or_throw(); if (tensor_ty->IsUnknownDtype()) return std::nullopt; - ret_tensors.push_back(runtime::Tensor::Empty(shape.value(), tensor_ty->dtype, cpu_dev)); + ret_tensors.push_back( + runtime::Tensor::Empty(shape.value(), tensor_ty->dtype->dtype, cpu_dev)); } // Pack input args + all output tensors. @@ -288,7 +289,8 @@ class ConstantFolder : public ExprMutator { ffi::Optional shape = MatchConstShape(call->ty_args[0]); if (shape) { TensorType ret_ty = call->ty.as_or_throw(); - return ConstEvaluateCallTIR(func.value(), arr_args.value(), shape.value(), ret_ty->dtype) + return ConstEvaluateCallTIR(func.value(), arr_args.value(), shape.value(), + ret_ty->dtype->dtype) .value_or({}); } return {}; @@ -391,7 +393,7 @@ class ConstantFolder : public ExprMutator { for (size_t i = 0; i < values.size(); i++) { PrimExpr val = values[i]; arr.push_back(val.as()->value); - is_known &= (val.dtype() == DataType::Int(64)); + is_known &= val.ty().MatchesElementType(DLDataTypeCode::kDLInt, 64); } if (is_known) { const auto func = tvm::ffi::Function::GetGlobalRequired("relax.run.shape_to_tensor"); diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index d5e656d15256..00c1029a98d1 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -60,10 +60,10 @@ class SymbolicMatcher : ExprFunctordtype + << " cannot match to argument " << other << " with dtype " << other.ty()->dtype; } else { ExprFunctor::VisitExpr(node, other); } @@ -120,9 +120,10 @@ class SymbolicMatcher : ExprFunctor(); if (!rhs) { - TVM_FFI_THROW(InternalError) << "Parameter expression " << ffi::GetRef(op) - << " expected an cast to " << op->dtype << " as the argument, " - << "but was provided with the argument " << other; + TVM_FFI_THROW(InternalError) + << "Parameter expression " << ffi::GetRef(op) << " expected an cast to " + << op->ty()->dtype << " as the argument, " + << "but was provided with the argument " << other; } VisitExpr(op->value, rhs->value); } @@ -132,10 +133,11 @@ class SymbolicMatcher : ExprFunctordtype.code() != rhs->dtype.code()) { + } else if (op->ty().code() != rhs.ty().code()) { TVM_FFI_THROW(InternalError) - << "Parameter expression " << ffi::GetRef(op) << " with dtype " << op->dtype - << " cannot match to argument " << rhs << " with dtype " << rhs.dtype(); + << "Parameter expression " << ffi::GetRef(op) << " with dtype " + << op->ty()->dtype << " cannot match to argument " << rhs << " with dtype " + << rhs.ty()->dtype; } else if (auto it = var_remap_->find(lhs); it != var_remap_->end()) { VisitExpr((*it).second, rhs); } else { @@ -592,7 +594,7 @@ class FusedTIRConstructor : public ExprVisitor { // printed, it's more readable when done explicitly. Since // Buffer is used more than param it gets the name with better // readability. - tirx::Var param = tirx::Var("p_" + buffer->name, tvm::PrimType(DataType::Handle())); + tirx::Var param = tirx::Var("p_" + buffer->name, tvm::PrimType::Handle()); func_info_.params.push_back(param); func_info_.buffer_map.Set(param, buffer); } @@ -636,8 +638,7 @@ class FusedTIRConstructor : public ExprVisitor { continue; } - tirx::Var param = - tirx::Var("p_output" + std::to_string(out_idx), tvm::PrimType(DataType::Handle())); + tirx::Var param = tirx::Var("p_output" + std::to_string(out_idx), tvm::PrimType::Handle()); out_idx++; func_info_.buffer_map.Set(param, buffers[i]); func_info_.params.push_back(param); @@ -855,9 +856,10 @@ class FusedTIRConstructor : public ExprVisitor { for (int64_t idx : output_indices) { int i = static_cast(idx); const tirx::Var& param = func->params[static_cast(i)]; - if (param->dtype.is_int() || param->dtype.is_uint()) { + tvm::PrimType param_ty = param.ty(); + if (param_ty.code() == DLDataTypeCode::kDLInt || param_ty.code() == DLDataTypeCode::kDLUInt) { if (symbolic_var_index == -1) symbolic_var_index = i; - } else if (param->dtype.is_handle()) { + } else if (param_ty.IsHandle()) { TVM_FFI_ICHECK(symbolic_var_index == -1) << "The scalar input should be at the ending of the " "parameter list."; @@ -865,7 +867,7 @@ class FusedTIRConstructor : public ExprVisitor { } else { TVM_FFI_THROW(InternalError) << "The params of PrimFunc are expected to be Buffer handle or scalar, but got: " - << param->dtype; + << param_ty->dtype; } } @@ -967,7 +969,7 @@ class FusedTIRConstructor : public ExprVisitor { // Case 1. The relax param is a Tensor, we directly create a tirx var and buffer const auto* shape_expr = tensor->shape.as(); TVM_FFI_ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a known shape."; - DataType dtype = tensor->dtype; + DLDataType dtype = tensor->dtype->dtype; tirx::Buffer buffer; if (tir_buffer_param.defined()) { buffer = tirx::decl_buffer(shape_expr->values, dtype, name_hint, @@ -980,7 +982,7 @@ class FusedTIRConstructor : public ExprVisitor { } else if (const auto* prim_value = ty.as()) { // Case 2. The relax param is a scalar, we directly create a tirx var - out->push_back(tirx::Var(name_hint, prim_value->dtype)); + out->push_back(tirx::Var(name_hint, tvm::PrimType(prim_value->dtype))); } else if (const auto* shape_expr = ty.as()) { // Case 3. The relax param is a tuple of scalars, each represented as a tirx var @@ -1257,7 +1259,7 @@ class TIRFuseMutator : public ExprMutator { if (const auto* literal = arg.as()) { tir_vars.push_back(literal->value); } else if (const auto* var = arg.as()) { - tir_vars.push_back(tirx::Var(var->name_hint(), prim_value->dtype)); + tir_vars.push_back(tirx::Var(var->name_hint(), tvm::PrimType(prim_value->dtype))); } else { TVM_FFI_THROW(TypeError) << "FuseTIR expects scalar arguments to be PrimValue or Var, " << "but received " << arg; diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index df22650e036d..992103de7d91 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -304,7 +304,7 @@ class BackwardBindingGenerator : private ExprVisitor { // Initialize the adjoint of target_var as ones op. We have already checked the target. auto* target_ty = GetTypeAs(target_var); - generator.UpdateAdjoint(target_var, ones(target_ty->shape.value(), target_ty->dtype)); + generator.UpdateAdjoint(target_var, ones(target_ty->shape.value(), target_ty->dtype->dtype)); // Do reverse-mode ad, so visit bindings backwards for (auto it = forward_block->bindings.rbegin(); it != forward_block->bindings.rend(); ++it) { @@ -546,7 +546,7 @@ class BackwardBindingGenerator : private ExprVisitor { auto* tensor_ty = ty.as(); TVM_FFI_ICHECK(tensor_ty) << "The leaf of adjoint should be a Tensor."; TVM_FFI_ICHECK(tensor_ty->shape.defined()) << "Missing shape when building zeros tuple."; - const Expr& init = zeros(tensor_ty->shape.value(), tensor_ty->dtype); + const Expr& init = zeros(tensor_ty->shape.value(), tensor_ty->dtype->dtype); return init; }); return AdjointMsgToExpr(msg); @@ -707,7 +707,8 @@ class GradientMutator : private ExprMutator { static bool IsFloatTensorType(const Type& ty) { auto* tensor_ty = ty.as(); - return tensor_ty && tensor_ty->dtype.is_float(); + // Gradient eligibility preserves the old float-kind check; lanes do not affect this policy. + return tensor_ty && tensor_ty->dtype.MatchesCode(DLDataTypeCode::kDLFloat); } // When the return value is a Var, it is the target; diff --git a/src/relax/transform/infer_amp_utils.cc b/src/relax/transform/infer_amp_utils.cc index 41c6cfe5ae42..4952aeea8fa2 100644 --- a/src/relax/transform/infer_amp_utils.cc +++ b/src/relax/transform/infer_amp_utils.cc @@ -22,19 +22,19 @@ namespace tvm { namespace relax { -NType NTypeFrom(const Type& ty, DataType dtype) { +NType NTypeFrom(const Type& ty, DLDataType dtype) { auto fmapleaf = [&](const Type& ty) -> NType { const auto* tensor = ty.as(); TVM_FFI_ICHECK(tensor) << "Expected TensorType, but got " << ty; - if (dtype == DataType::Void()) - return NType(DLDataTypeToString(tensor->dtype)); + if (dtype == DLDataType{kDLOpaqueHandle, 0, 0}) + return NType(DLDataTypeToString(tensor->dtype->dtype)); else return NType(DLDataTypeToString(dtype)); }; return MapToNestedMsg(ty, fmapleaf); } -NType NTypeFrom(const Expr& expr, DataType dtype) { return NTypeFrom(GetType(expr), dtype); } +NType NTypeFrom(const Expr& expr, DLDataType dtype) { return NTypeFrom(GetType(expr), dtype); } NType NTypeMerge(const NType& a, const NType& b) { auto fcombine = [&](const ffi::String& a_str, const ffi::String& b_str) -> ffi::String { @@ -44,20 +44,20 @@ NType NTypeMerge(const NType& a, const NType& b) { return a_str; } - DataType a = DataType(ffi::StringToDLDataType(a_str)); - DataType b = DataType(ffi::StringToDLDataType(b_str)); - TVM_FFI_ICHECK_EQ(a.code(), b.code()); - TVM_FFI_ICHECK_EQ(a.lanes(), b.lanes()); - return a.bits() > b.bits() ? a_str : b_str; + DLDataType a = ffi::StringToDLDataType(a_str); + DLDataType b = ffi::StringToDLDataType(b_str); + TVM_FFI_ICHECK_EQ(a.code, b.code); + TVM_FFI_ICHECK_EQ(a.lanes, b.lanes); + return a.bits > b.bits ? a_str : b_str; }; return CombineNestedMsg(a, b, fcombine); } -ffi::Array InferMixedPrecisionFollow(const Call& call, const DataType& out_dtype) { +ffi::Array InferMixedPrecisionFollow(const Call& call, DLDataType out_dtype) { return {IntImm::Int32(MixedPrecisionPolicyKind::kFollow), call}; } -ffi::Array InferMixedPrecisionNever(const Call& call, const DataType& out_dtype) { +ffi::Array InferMixedPrecisionNever(const Call& call, DLDataType out_dtype) { return {IntImm::Int32(MixedPrecisionPolicyKind::kNever), call}; } diff --git a/src/relax/transform/infer_amp_utils.h b/src/relax/transform/infer_amp_utils.h index faa33edd4a18..7f9f884a29d0 100644 --- a/src/relax/transform/infer_amp_utils.h +++ b/src/relax/transform/infer_amp_utils.h @@ -58,10 +58,10 @@ struct NTypeEqual { }; // Construct a NType from an Type -NType NTypeFrom(const Type& ty, DataType dtype = DataType::Void()); +NType NTypeFrom(const Type& ty, DLDataType dtype = DLDataType{kDLOpaqueHandle, 0, 0}); // Construct a NType from an Expr -NType NTypeFrom(const Expr& expr, DataType dtype = DataType::Void()); +NType NTypeFrom(const Expr& expr, DLDataType dtype = DLDataType{kDLOpaqueHandle, 0, 0}); // Merge two messages, we keep the higher precision type for each leaf tensor NType NTypeMerge(const NType& a, const NType& b); @@ -70,12 +70,11 @@ NType NTypeMerge(const NType& a, const NType& b); using VarDTypeMap = std::unordered_map; // Call is a call node, out_dtype is the expected output_dtype -using FInferMixedPrecision = - ffi::TypedFunction; +using FInferMixedPrecision = ffi::TypedFunction; -ffi::Array InferMixedPrecisionFollow(const Call& call, const DataType& out_dtype); +ffi::Array InferMixedPrecisionFollow(const Call& call, DLDataType out_dtype); -ffi::Array InferMixedPrecisionNever(const Call& call, const DataType& out_dtype); +ffi::Array InferMixedPrecisionNever(const Call& call, DLDataType out_dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index 7c42928d7d87..b800199610b8 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -65,8 +65,7 @@ class LazyInputMutator : public ExprMutator { param_lookup.insert({func->params[i], i - num_input_params}); } - Var fget_param("fget_param", - FuncType({PrimType(DataType::Int(64)), ObjectType()}, ObjectType())); + Var fget_param("fget_param", FuncType({PrimType::Int(64), ObjectType()}, ObjectType())); ffi::Array new_params(func->params.begin(), func->params.begin() + num_input_params); new_params.push_back(fget_param); @@ -145,7 +144,7 @@ class LazyOutputMutator : public ExprMutator { define_lookup(0, func_body->body); } - Var fset_output("fset_output", FuncType({PrimType(DataType::Int(64)), ObjectType()}, + Var fset_output("fset_output", FuncType({PrimType::Int(64), ObjectType()}, TupleType(ffi::Array{}), /* purity = */ false)); plan_ = FunctionPlan{std::move(output_lookup), fset_output}; diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 00bd8e859ac3..2c518cfbbeae 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -282,7 +282,7 @@ class LegalizeMutator : public ExprMutator { // This fallback would only be applicable for cases where // both the dtype and the dimensionality are known. While // Relax can express a tensor with unknown dtype and - // dimensionality as `TensorType(DataType::Void(), + // dimensionality as `TensorType(DLDataType{kDLOpaqueHandle, 0, 0}, // kUnknownNDim)`, TIR cannot express unknown dtype or // unknown dimensionality. return false; diff --git a/src/relax/transform/lower_alloc_tensor.cc b/src/relax/transform/lower_alloc_tensor.cc index 66c2c95b89c2..52bca3e707eb 100644 --- a/src/relax/transform/lower_alloc_tensor.cc +++ b/src/relax/transform/lower_alloc_tensor.cc @@ -72,7 +72,10 @@ class Mutator : public ExprMutator { }(); PrimExpr nbytes = [&]() -> PrimExpr { - PrimExpr nbytes = IntImm::Int64(dtype->value.bytes()); + PrimType dtype_ty(dtype->value); + TVM_FFI_ICHECK(!dtype_ty.IsScalableVector()) + << "Cannot statically compute allocation size for scalable vector dtype " << dtype_ty; + PrimExpr nbytes = IntImm::Int64(static_cast(dtype_ty.StorageBytes())); for (const auto& dim : shape) { nbytes *= dim; } @@ -112,7 +115,7 @@ class Mutator : public ExprMutator { auto offset = PrimValue::Int64(0); Expr storage = relax::Call(mem_alloc_storage_op, {size, runtime_device_index, storage_scope, - DataTypeImm(DataType::UInt(8))}); + DataTypeImm((DLDataType{kDLUInt, 8, 1}))}); storage = builder_->Emit(storage, "storage"); Expr tensor = relax::Call(mem_alloc_tensor_op, {storage, offset, shape_arg, dtype, op->args[2]}); diff --git a/src/relax/transform/remove_unused_outputs.cc b/src/relax/transform/remove_unused_outputs.cc index 995fe019be04..f8a9e8cde70b 100644 --- a/src/relax/transform/remove_unused_outputs.cc +++ b/src/relax/transform/remove_unused_outputs.cc @@ -289,7 +289,7 @@ Pass RemoveUnusedOutputs() { // into the old tuple, but it's simpler to just let // CanonicalizeBindings and DCE handle it. new_results.push_back( - relax::PrimValue(FloatImm(DataType::Float(64), std::nan("")))); + relax::PrimValue(FloatImm(tvm::PrimType::Float(64), std::nan("")))); } } diff --git a/src/relax/transform/remove_unused_parameters.cc b/src/relax/transform/remove_unused_parameters.cc index ebe9fa000f77..4f28f9d13132 100644 --- a/src/relax/transform/remove_unused_parameters.cc +++ b/src/relax/transform/remove_unused_parameters.cc @@ -100,7 +100,7 @@ std::optional AnalyzeCallee(Function func) { } for (const auto& tir_var : free_tir_vars) { - Var relax_var("param_" + tir_var->name_hint, PrimType(tir_var.dtype())); + Var relax_var("param_" + tir_var->name_hint, PrimType(tir_var.ty())); params.push_back(relax_var); } diff --git a/src/relax/transform/reorder_take_after_matmul.cc b/src/relax/transform/reorder_take_after_matmul.cc index bd36c5cb89c5..7fd0fb7eecaa 100644 --- a/src/relax/transform/reorder_take_after_matmul.cc +++ b/src/relax/transform/reorder_take_after_matmul.cc @@ -92,7 +92,7 @@ std::tuple)>> // indices.shape = [outfeatures] // out_table.shape = [*batch, table_size] - auto out_table = matmul(lhs, weights, DataType::Void()); + auto out_table = matmul(lhs, weights, (DLDataType{kDLOpaqueHandle, 0, 0})); // new_output.shape = [*batch, outfeatures] auto new_output = take(out_table, indices, matmul_ty->ndim - 1); @@ -116,7 +116,7 @@ std::tuple)>> auto fused_weight = reshape(reordered_weight, ShapeExpr({weight_shape[1], weight_shape[0] * weight_shape[2]})); // fused_output.shape = [batch1, batch2, table_size * outfeatures] - auto fused_output = matmul(lhs, fused_weight, DataType::Void()); + auto fused_output = matmul(lhs, fused_weight, (DLDataType{kDLOpaqueHandle, 0, 0})); // indexed_output.shape = [batch1, batch2, table_size, outfeatures] auto indexed_output = reshape( fused_output, ShapeExpr({lhs_shape[0], lhs_shape[1], weight_shape[0], weight_shape[2]})); diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc index 4d15c0fd88f5..19e0dfdf8f00 100644 --- a/src/relax/transform/split_call_tir_by_pattern.cc +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -129,7 +129,7 @@ class ForMatcher : public TensorizeComparator { if (match) { evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); evaluated_symbols.back()[ffi::GetRef(operand_a)] = - MakeConstScalar(rhs_ptr->b.dtype(), 1); + MakeConstScalar(rhs_ptr->b.ty(), 1); return true; } } @@ -142,7 +142,7 @@ class ForMatcher : public TensorizeComparator { if (match) { evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); evaluated_symbols.back()[ffi::GetRef(operand_b)] = - MakeConstScalar(rhs_ptr->a.dtype(), 1); + MakeConstScalar(rhs_ptr->a.ty(), 1); return true; } } @@ -160,7 +160,7 @@ class ForMatcher : public TensorizeComparator { if (match) { evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); evaluated_symbols.back()[ffi::GetRef(operand_a)] = - MakeConstScalar(rhs_ptr->b.dtype(), 0); + MakeConstScalar(rhs_ptr->b.ty(), 0); return true; } } @@ -173,7 +173,7 @@ class ForMatcher : public TensorizeComparator { if (match) { evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); evaluated_symbols.back()[ffi::GetRef(operand_b)] = - MakeConstScalar(rhs_ptr->a.dtype(), 0); + MakeConstScalar(rhs_ptr->a.ty(), 0); return true; } } @@ -622,7 +622,7 @@ std::pair> SplitFunctions( } } arg_partition->push_back(arg_partition1); - new_params1.push_back(Var("output", DataType::Handle())); + new_params1.push_back(Var("output", PrimType::Handle())); ffi::Map new_buffer_map1; for (const auto& kv : func->buffer_map) { if (partitioner.input1.count(kv.second)) { @@ -635,7 +635,7 @@ std::pair> SplitFunctions( // Step 4. Craft the second function. ffi::Array new_params2; std::vector arg_partition2; - new_params2.push_back(Var("input", DataType::Handle())); + new_params2.push_back(Var("input", PrimType::Handle())); for (int i = 0; i < static_cast(func->params.size()); i++) { Var param = func->params[i]; if (partitioner.input2.count(func->buffer_map[param])) { @@ -752,7 +752,7 @@ class SplitMutator : public ExprMutator { TVM_FFI_ICHECK(lib_func->IsInstance()); builder_->UpdateFunction(gv, lib_func); tirx::Buffer intermediate_buffer = func1->buffer_map.at(func1->params.back()); - DataType dtype = intermediate_buffer->dtype; + PrimType dtype = intermediate_buffer->dtype; Call call1(call_dps_packed_, {lib_func, Tuple(args1)}, call->attrs, {TensorType(ShapeExpr(intermediate_buffer->shape), dtype)}); Var call_var1 = builder_->Emit(call1); diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc index 0560582fac59..e09e377e8a70 100644 --- a/src/relax/transform/split_layout_rewrite_preproc.cc +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -65,11 +65,11 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { ffi::Map buffer_map; for (const auto& info : rewrite_infos_) { - params.push_back(Var(info.pre_rewrite_buffer->name, DataType::Handle())); + params.push_back(Var(info.pre_rewrite_buffer->name, PrimType::Handle())); buffer_map.Set(params.back(), info.pre_rewrite_buffer); } for (const auto& info : rewrite_infos_) { - params.push_back(Var(info.post_rewrite_buffer->name, DataType::Handle())); + params.push_back(Var(info.post_rewrite_buffer->name, PrimType::Handle())); buffer_map.Set(params.back(), info.post_rewrite_buffer); } diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 651b70961090..2a04461555d0 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -106,7 +106,7 @@ class StorageTokenNode : public ffi::Object { /*! \brief Number of bytes that this token requires. */ PrimExpr bytes; /*! \brief The dtype of this token. */ - DataType dtype; + DLDataType dtype; /*! \brief The memory scope of the token. */ std::string storage_scope; /*! \brief The VDevice information. */ @@ -135,10 +135,13 @@ class StorageTokenNode : public ffi::Object { */ class StorageToken : public ffi::ObjectRef { public: - explicit StorageToken(ffi::Array shape, DataType dtype, std::string storage_scope, + explicit StorageToken(ffi::Array shape, DLDataType dtype, std::string storage_scope, ffi::Optional vdevice = std::nullopt) { // Compute the tensor size from the shape. - int64_t const_coeff = dtype.bytes() * dtype.lanes(); + PrimType dtype_ty(dtype); + TVM_FFI_ICHECK(!dtype_ty.IsScalableVector()) + << "Cannot statically plan storage size for scalable vector dtype " << dtype_ty; + int64_t const_coeff = static_cast(dtype_ty.StorageBytes()); PrimExpr size = IntImm::Int64(1); bool size_computed = false; @@ -303,13 +306,16 @@ class TokenAllocatorMixed { } private: - /*! \brief The hash class to enable std::pair as map key class. */ - struct PairHash { - template - std::size_t operator()(const std::pair& p) const { - auto h1 = std::hash{}(p.first); - auto h2 = std::hash{}(p.second); - return h1 ^ h2; + using PoolKey = std::pair; + + /*! \brief The hash class to enable storage scope and raw dtype as map key class. */ + struct PoolKeyHash { + std::size_t operator()(const PoolKey& p) const { + std::size_t h = std::hash{}(p.first); + h ^= static_cast(p.second.code) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= static_cast(p.second.bits) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= static_cast(p.second.lanes) + 0x9e3779b9 + (h << 6) + (h >> 2); + return h; } }; @@ -318,9 +324,7 @@ class TokenAllocatorMixed { /*! \brief A constant scale representing the token search range. */ const int match_range_{16}; /*! \brief The pool of available storage tokens for each storage scope and dtype. */ - std::unordered_map, std::multimap, - PairHash> - available_pool_; + std::unordered_map, PoolKeyHash> available_pool_; /*! \brief All the storage tokens that have been allocated with actual storage. */ std::vector full_pool_; }; @@ -636,7 +640,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { const auto* shape = ty->shape.as(); TVM_FFI_ICHECK_NOTNULL(shape); TVM_FFI_ICHECK(!ty->IsUnknownDtype()); - TVM_FFI_ICHECK(ty->dtype == call->args[1].as_or_throw()->value); + TVM_FFI_ICHECK(ty->dtype->dtype == call->args[1].as_or_throw()->value); TVM_FFI_ICHECK(!token_map_.count(call)); // Use the upper bounds of TIR vars as their values. The upper bound shape can still be dynamic @@ -653,7 +657,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { } ffi::Optional vdevice = GetGlobalVDevice(ctx_mod_, vdevice_index); - StorageToken token(upper_bounded_shape, ty->dtype, storage_scope->value, vdevice); + StorageToken token(upper_bounded_shape, ty->dtype->dtype, storage_scope->value, vdevice); Tokens tokens(token); SetTokens(call, tokens); @@ -938,7 +942,7 @@ class StorageAllocationRewriter : public ExprMutator { if (it_token == token2storage_var_.end()) { ShapeExpr size({token->bytes}); PrimValue virtual_device_index = runtime_device_index; - DataType dtype = token->dtype; + DLDataType dtype = token->dtype; Call alloc_storage(mem_alloc_storage, {std::move(size), virtual_device_index, StringImm(token->storage_scope), DataTypeImm(dtype)}, @@ -951,7 +955,7 @@ class StorageAllocationRewriter : public ExprMutator { // And always create a `memory.alloc_tensor` for the old `builtin.alloc_tensor`. PrimValue offset = PrimValue::Int64(0); - DataType dtype = ty->dtype; + DLDataType dtype = ty->dtype->dtype; return Call(mem_alloc_tensor, {storage_var, offset, ty->shape.value(), DataTypeImm(dtype), call->args[2]}, Attrs()); @@ -970,22 +974,26 @@ class StorageAllocationRewriter : public ExprMutator { GetUpperBoundShape(shape->values, ana_.get(), dom_map_); if (!IsStaticShape(shape->values)) { TVM_FFI_ICHECK(!ty->IsUnknownDtype()); - TVM_FFI_ICHECK_EQ(ty->dtype, call->args[1].as_or_throw()->value); + TVM_FFI_ICHECK_EQ(ty->dtype->dtype, call->args[1].as_or_throw()->value); PrimExpr bytes = upper_bounded_shape[0]; for (int i = 1; i < static_cast(upper_bounded_shape.size()); ++i) { bytes *= upper_bounded_shape[i]; } - bytes *= ty->dtype.bytes() * ty->dtype.lanes(); + DLDataType dtype = ty->dtype->dtype; + PrimType dtype_ty(dtype); + TVM_FFI_ICHECK(!dtype_ty.IsScalableVector()) + << "Cannot statically plan storage size for scalable vector dtype " << dtype_ty; + bytes *= IntImm::Int64(static_cast(dtype_ty.StorageBytes())); Call alloc_storage(mem_alloc_storage, {/*size=*/ShapeExpr({bytes}), /*virtual_device_index=*/call->args[2].as_or_throw(), /*storage_scope=*/call->args[3].as_or_throw(), // - /*dtype=*/DataTypeImm(ty->dtype)}); + /*dtype=*/DataTypeImm(dtype)}); Var storage = builder_->Emit(alloc_storage, "storage"); return Call(mem_alloc_tensor, {storage, // /*offset=*/PrimValue::Int64(0), /*shape=*/ffi::GetRef(shape), // - /*dtype=*/DataTypeImm(ty->dtype), + /*dtype=*/DataTypeImm(dtype), /*vdevice_index=*/call->args[2]}); } } @@ -1040,7 +1048,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.transform.StaticPlanBlockMemory", StaticPlanBlockMemory); } -PrimExpr GetTextureMemorySizeFromVDevice(ffi::Array pshape, DataType dtype, +PrimExpr GetTextureMemorySizeFromVDevice(ffi::Array pshape, DLDataType dtype, VDevice vdevice) { int image_row_align = static_cast( vdevice->target->GetAttr("image_base_address_alignment").value_or(64)); @@ -1056,7 +1064,9 @@ PrimExpr GetTextureMemorySizeFromVDevice(ffi::Array pshape, DataType d }; auto shape = Shape{pshape}; - size_t size = runtime::GetTextureMemorySize(shape, dtype.bytes() * 8, dtype.lanes(), + int lanes = static_cast(dtype.lanes); + TVM_FFI_ICHECK_GE(lanes, 0) << "Can't fetch the bytes of a scalable vector at a compile time."; + size_t size = runtime::GetTextureMemorySize(shape, dtype.bits, lanes, vdevice->memory_scope, image_row_align); return IntImm::Int64(size); } diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc index ddd23ce2ea7b..45d2af9b8579 100644 --- a/src/relax/transform/to_mixed_precision.cc +++ b/src/relax/transform/to_mixed_precision.cc @@ -116,9 +116,9 @@ int GetMixedPrecisionInfo(const CallNode* call_node) { */ class DTypeDecisionCollector : public ExprVisitor { public: - explicit DTypeDecisionCollector(DataType output_dtype) : output_dtype_(output_dtype) {} + explicit DTypeDecisionCollector(DLDataType output_dtype) : output_dtype_(output_dtype) {} - static VarDTypeMap Collect(Function func, DataType output_dtype) { + static VarDTypeMap Collect(Function func, DLDataType output_dtype) { DTypeDecisionCollector collector(output_dtype); collector.VisitExpr(func); return std::move(collector.only_fp16_map_); @@ -165,7 +165,7 @@ class DTypeDecisionCollector : public ExprVisitor { } // merge the message for all vars in the expr list - void RequireArgsToType(ffi::Array args, DataType to) { + void RequireArgsToType(ffi::Array args, DLDataType to) { std::vector arg_arr; std::vector to_arr; for (const Expr& arg : args) { @@ -262,16 +262,16 @@ class DTypeDecisionCollector : public ExprVisitor { } } - DataType unknown_ = DataType(DataType::TypeCode::kFloat, 0, 1); - DataType fp16_ = DataType(DataType::TypeCode::kFloat, 16, 1); - DataType fp32_ = DataType(DataType::TypeCode::kFloat, 32, 1); - DataType output_dtype_; + DLDataType unknown_ = DLDataType{kDLFloat, 0, 1}; + DLDataType fp16_ = DLDataType{kDLFloat, 16, 1}; + DLDataType fp32_ = DLDataType{kDLFloat, 32, 1}; + DLDataType output_dtype_; VarDTypeMap only_fp16_map_; }; class ToMixedPrecisionRewriter : public ExprMutator { public: - explicit ToMixedPrecisionRewriter(const VarDTypeMap* only_fp16_map, DataType output_dtype, + explicit ToMixedPrecisionRewriter(const VarDTypeMap* only_fp16_map, DLDataType output_dtype, const std::unordered_set& fp16_input_names) : only_fp16_map_(only_fp16_map), output_dtype_(output_dtype), @@ -290,7 +290,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { if (tensor_ty->vdevice.defined()) { vdev = tensor_ty->vdevice.value(); } - TensorType fp16_ty(tensor_ty->shape.value(), DataType::Float(16), vdev, tensor_ty->span); + TensorType fp16_ty(tensor_ty->shape.value(), PrimType::Float(16), vdev, tensor_ty->span); Var fp16_var(var->vid, fp16_ty, var->span); var_remap_[var->vid] = fp16_var; return fp16_var; @@ -315,13 +315,14 @@ class ToMixedPrecisionRewriter : public ExprMutator { if (NTypeEqual()(to[0], NTypeFrom(expr))) return expr; // We only rewrite the expr if the dtype is fp16 or fp32, dtypes such as int32, float64 is not // supported to be rewritten - if (tensor->dtype != fp16_ && tensor->dtype != fp32_) return expr; - return astype(expr, DataType(ffi::StringToDLDataType(to[0].LeafValue()))); + DLDataType tensor_dtype = tensor->dtype->dtype; + if (tensor_dtype != fp16_ && tensor_dtype != fp32_) return expr; + return astype(expr, ffi::StringToDLDataType(to[0].LeafValue())); }; return TransformTupleLeaf(expr, std::array({to}), fvisitleaf); } - ffi::Array RewriteArgs(const ffi::Array& args, DataType to) { + ffi::Array RewriteArgs(const ffi::Array& args, DLDataType to) { ffi::Array new_args; for (const Expr& arg : args) { if (IsNestedTensor(arg)) { @@ -346,7 +347,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { bool AllFP16Castable(const ffi::Array& args) { auto is_fp16 = [](Type ty) { if (auto tensor_ty = ty.as(); - tensor_ty && tensor_ty->dtype == DataType::Float(16)) { + tensor_ty && tensor_ty->dtype == PrimType::Float(16)) { return true; } return false; @@ -359,7 +360,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { return false; } - if (data.DataType() == DataType::Float(16)) { + if (data.DataType() == DLDataType{kDLFloat, 16, 1}) { return true; } @@ -372,17 +373,17 @@ class ToMixedPrecisionRewriter : public ExprMutator { std::vector bytes(size_1d * elem_bytes); data.CopyToBytes(bytes.data(), bytes.size()); - if (data.DataType() == DataType::Float(32)) { + if (data.DataType() == DLDataType{kDLFloat, 32, 1}) { return CheckInFP16Range(bytes, size_1d); - } else if (data.DataType() == DataType::Float(64)) { + } else if (data.DataType() == DLDataType{kDLFloat, 64, 1}) { return CheckInFP16Range(bytes, size_1d); - } else if (data.DataType() == DataType::Int(8)) { + } else if (data.DataType() == DLDataType{kDLInt, 8, 1}) { return CheckInFP16Range(bytes, size_1d); - } else if (data.DataType() == DataType::Int(16)) { + } else if (data.DataType() == DLDataType{kDLInt, 16, 1}) { return CheckInFP16Range(bytes, size_1d); - } else if (data.DataType() == DataType::Int(32)) { + } else if (data.DataType() == DLDataType{kDLInt, 32, 1}) { return CheckInFP16Range(bytes, size_1d); - } else if (data.DataType() == DataType::Int(64)) { + } else if (data.DataType() == DLDataType{kDLInt, 64, 1}) { return CheckInFP16Range(bytes, size_1d); } return false; @@ -476,7 +477,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { new_call.CopyOnWrite()->args = RemapArgs(new_call->args); // Then we rewrite the args according to the policy - std::optional opt_new_dtype = std::nullopt; + std::optional opt_new_dtype = std::nullopt; if (policy == kAlways) { opt_new_dtype = fp16_; @@ -589,16 +590,16 @@ class ToMixedPrecisionRewriter : public ExprMutator { const VarDTypeMap* only_fp16_map_; - DataType fp16_ = DataType(DataType::TypeCode::kFloat, 16, 1); - DataType fp32_ = DataType(DataType::TypeCode::kFloat, 32, 1); - DataType output_dtype_; + DLDataType fp16_ = DLDataType{kDLFloat, 16, 1}; + DLDataType fp32_ = DLDataType{kDLFloat, 32, 1}; + DLDataType output_dtype_; ffi::Array params_; std::unordered_set fp16_input_names_; const Op& wrap_param_op = Op::Get("relax.wrap_param"); }; -Expr ToMixedPrecision(const Function& f, const DataType& out_dtype, +Expr ToMixedPrecision(const Function& f, DLDataType out_dtype, ffi::Optional> fp16_input_names) { VarDTypeMap only_fp16_map = DTypeDecisionCollector::Collect(f, out_dtype); std::unordered_set fp16_input_names_set; @@ -611,7 +612,7 @@ Expr ToMixedPrecision(const Function& f, const DataType& out_dtype, namespace transform { -Pass ToMixedPrecision(const DataType& out_dtype, +Pass ToMixedPrecision(DLDataType out_dtype, ffi::Optional> fp16_input_names) { auto pass_func = [=](Function f, IRModule m, PassContext pc) { return ToMixedPrecision(f, out_dtype, fp16_input_names).as_or_throw(); diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 275c7ca94f8d..d4607459c74f 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -319,39 +319,39 @@ class FunctionCopier : public SymbolicVarRenewMutator { * \return A Constant. */ template -inline Constant MakeConstantScalar(T value, DataType dtype) { +inline Constant MakeConstantScalar(T value, DLDataType dtype) { runtime::Tensor arr = runtime::Tensor::Empty({}, dtype, {kDLCPU, 0}); - if (dtype == DataType::Float(32)) { + if (dtype == DLDataType{kDLFloat, 32, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Float(64)) { + } else if (dtype == DLDataType{kDLFloat, 64, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Int(32)) { + } else if (dtype == DLDataType{kDLInt, 32, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Int(64)) { + } else if (dtype == DLDataType{kDLInt, 64, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Bool()) { + } else if (dtype == DLDataType{kDLBool, 8, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::UInt(8)) { + } else if (dtype == DLDataType{kDLUInt, 8, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::UInt(16)) { + } else if (dtype == DLDataType{kDLUInt, 16, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::UInt(32)) { + } else if (dtype == DLDataType{kDLUInt, 32, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::UInt(64)) { + } else if (dtype == DLDataType{kDLUInt, 64, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Int(8)) { + } else if (dtype == DLDataType{kDLInt, 8, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Int(16)) { + } else if (dtype == DLDataType{kDLInt, 16, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Int(32)) { + } else if (dtype == DLDataType{kDLInt, 32, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Int(64)) { + } else if (dtype == DLDataType{kDLInt, 64, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Float(16)) { + } else if (dtype == DLDataType{kDLFloat, 16, 1}) { // convert to float16 storage is uint16_t *static_cast(arr->data) = __truncXfYf2__(static_cast(value)); - } else if (dtype == DataType::BFloat(16)) { + } else if (dtype == DLDataType{kDLBfloat, 16, 1}) { // convert to bfloat16 storage is uint16_t *static_cast(arr->data) = __truncXfYf2__(static_cast(value)); diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 370947e4b01f..2f5cc6d9dea8 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -179,11 +179,11 @@ tvm::ffi::Map InferSymbolicVarMap( } bool IsBoolType(const Type& ty, bool permit_unknown_rank, bool permit_unknown_dtype) { - DataType dtype; + DLDataType dtype; int ndim; if (const auto* tensor = ty.as()) { - dtype = tensor->dtype; + dtype = tensor->dtype->dtype; ndim = tensor->ndim; } else if (const auto* prim = ty.as()) { dtype = prim->dtype; @@ -192,7 +192,9 @@ bool IsBoolType(const Type& ty, bool permit_unknown_rank, bool permit_unknown_dt return false; } - bool correct_dtype = dtype.is_bool() || (permit_unknown_dtype && dtype.is_void()); + // Bool-type matching preserves the old element-code-only behavior; rank is checked separately. + bool correct_dtype = dtype.code == DLDataTypeCode::kDLBool || + (permit_unknown_dtype && dtype == DLDataType{kDLOpaqueHandle, 0, 0}); bool correct_rank = ndim == 0 || (permit_unknown_rank && ndim == -1); return correct_dtype && correct_rank; } diff --git a/src/runtime/extra/contrib/cblas/cblas.cc b/src/runtime/extra/contrib/cblas/cblas.cc index d71eaeb17672..a19ccc99bb3f 100644 --- a/src/runtime/extra/contrib/cblas/cblas.cc +++ b/src/runtime/extra/contrib/cblas/cblas.cc @@ -21,10 +21,10 @@ * \file Use external cblas library call. */ #include +#include #include #include #include -#include extern "C" { #include @@ -35,7 +35,6 @@ extern "C" { namespace tvm { namespace contrib { -using namespace runtime; inline CBLAS_TRANSPOSE CBLASBooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; } @@ -128,38 +127,39 @@ struct CblasDgemmBatchIterativeOp { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_packed( - "tvm.contrib.cblas.matmul", - [](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + .def_packed("tvm.contrib.cblas.matmul", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + TVM_FFI_ICHECK((A->dtype == DLDataType{kDLFloat, 32, 1} || + A->dtype == DLDataType{kDLFloat, 64, 1})); - if (TypeMatch(A->dtype, kDLFloat, 32)) - CallGemm(args, ret, CblasSgemmOp()); - else - CallGemm(args, ret, CblasDgemmOp()); - }) - .def_packed( - "tvm.contrib.cblas.batch_matmul", - [](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallBatchGemm(args, ret, CblasSgemmBatchOp()); - } else { - CallBatchGemm(args, ret, CblasDgemmBatchOp()); - } - }) - .def_packed( - "tvm.contrib.cblas.batch_matmul_iterative", [](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp()); - } else { - CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp()); - } - }); + if (A->dtype == DLDataType{kDLFloat, 32, 1}) + CallGemm(args, ret, CblasSgemmOp()); + else + CallGemm(args, ret, CblasDgemmOp()); + }) + .def_packed("tvm.contrib.cblas.batch_matmul", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + TVM_FFI_ICHECK((A->dtype == DLDataType{kDLFloat, 32, 1} || + A->dtype == DLDataType{kDLFloat, 64, 1})); + if (A->dtype == DLDataType{kDLFloat, 32, 1}) { + CallBatchGemm(args, ret, CblasSgemmBatchOp()); + } else { + CallBatchGemm(args, ret, CblasDgemmBatchOp()); + } + }) + .def_packed("tvm.contrib.cblas.batch_matmul_iterative", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + TVM_FFI_ICHECK((A->dtype == DLDataType{kDLFloat, 32, 1} || + A->dtype == DLDataType{kDLFloat, 64, 1})); + if (A->dtype == DLDataType{kDLFloat, 32, 1}) { + CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp()); + } else { + CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp()); + } + }); } } // namespace contrib } // namespace tvm diff --git a/src/runtime/extra/contrib/cblas/dnnl_blas.cc b/src/runtime/extra/contrib/cblas/dnnl_blas.cc index 08d72e57b7ad..c0828c12e8b6 100644 --- a/src/runtime/extra/contrib/cblas/dnnl_blas.cc +++ b/src/runtime/extra/contrib/cblas/dnnl_blas.cc @@ -21,10 +21,10 @@ * \file Use external cblas library call. */ #include +#include #include #include #include -#include extern "C" { #include @@ -35,7 +35,6 @@ extern "C" { namespace tvm { namespace contrib { -using namespace runtime; inline char DNNLBooleanToTransposeChar(bool trans) { return trans ? 'T' : 'N'; } struct DNNLSgemmOp { @@ -52,7 +51,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm.contrib.dnnl.matmul", [](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); - TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); + TVM_FFI_ICHECK((A->dtype == DLDataType{kDLFloat, 32, 1})); CallGemm(args, ret, DNNLSgemmOp()); }); } diff --git a/src/runtime/extra/contrib/cblas/gemm_common.h b/src/runtime/extra/contrib/cblas/gemm_common.h index 52f306e86238..65b13aa4c728 100644 --- a/src/runtime/extra/contrib/cblas/gemm_common.h +++ b/src/runtime/extra/contrib/cblas/gemm_common.h @@ -26,8 +26,8 @@ #define TVM_RUNTIME_CONTRIB_CBLAS_GEMM_COMMON_H_ #include +#include #include -#include #include #include @@ -37,7 +37,6 @@ namespace contrib { using ffi::Any; using ffi::PackedArgs; -using runtime::TypeMatch; inline int ColumnStride(const DLTensor* tensor) { // If the tensor itself is transposed then it will have strides @@ -96,8 +95,8 @@ inline void CallGemm(ffi::PackedArgs args, ffi::Any* ret, TGemmOp op) { transa = IsInPlaceTransposed(A) ? !transa : transa; transb = IsInPlaceTransposed(B) ? !transb : transb; - TVM_FFI_ICHECK(TypeMatch(B->dtype, kDLFloat, bit_depth)); - TVM_FFI_ICHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); + TVM_FFI_ICHECK((B->dtype == DLDataType{kDLFloat, static_cast(bit_depth), 1})); + TVM_FFI_ICHECK((C->dtype == DLDataType{kDLFloat, static_cast(bit_depth), 1})); double alpha = args.size() > 5 ? args[5].cast() : 1.0; double beta = args.size() > 6 ? args[6].cast() : 0.0; op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), ColumnCount(A, transa), @@ -143,9 +142,9 @@ inline void CallU8S8S32Gemm(ffi::PackedArgs args, ffi::Any* ret, TGemmOp op) { transa = IsInPlaceTransposed(A) ? !transa : transa; transb = IsInPlaceTransposed(B) ? !transb : transb; - TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLUInt, 8)); - TVM_FFI_ICHECK(TypeMatch(B->dtype, kDLInt, 8)); - TVM_FFI_ICHECK(TypeMatch(C->dtype, kDLInt, 32)); + TVM_FFI_ICHECK((A->dtype == DLDataType{kDLUInt, 8, 1})); + TVM_FFI_ICHECK((B->dtype == DLDataType{kDLInt, 8, 1})); + TVM_FFI_ICHECK((C->dtype == DLDataType{kDLInt, 32, 1})); double alpha = args.size() > 5 ? args[5].cast() : 1.0; double beta = args.size() > 6 ? args[6].cast() : 0.0; op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), ColumnCount(A, transa), @@ -207,8 +206,8 @@ inline void CallBatchGemm(ffi::PackedArgs args, ffi::Any* ret, TBatchGemmOp op) transa = IsInPlaceTransposed3D(A) ? !transa : transa; transb = IsInPlaceTransposed3D(B) ? !transb : transb; - TVM_FFI_ICHECK(TypeMatch(B->dtype, kDLFloat, bit_depth)); - TVM_FFI_ICHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); + TVM_FFI_ICHECK((B->dtype == DLDataType{kDLFloat, static_cast(bit_depth), 1})); + TVM_FFI_ICHECK((C->dtype == DLDataType{kDLFloat, static_cast(bit_depth), 1})); double alpha = args.size() > 5 ? args[5].cast() : 1.0; double beta = args.size() > 6 ? args[6].cast() : 0.0; diff --git a/src/runtime/extra/contrib/cblas/mkl.cc b/src/runtime/extra/contrib/cblas/mkl.cc index 20f0c539076b..366ada41d2f1 100644 --- a/src/runtime/extra/contrib/cblas/mkl.cc +++ b/src/runtime/extra/contrib/cblas/mkl.cc @@ -21,10 +21,10 @@ * \file Use external mkl library call. */ #include +#include #include #include #include -#include extern "C" { #include @@ -35,7 +35,6 @@ extern "C" { namespace tvm { namespace contrib { -using namespace runtime; inline CBLAS_TRANSPOSE MKLBooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; } @@ -160,9 +159,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm.contrib.mkl.matmul", [](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); - TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + TVM_FFI_ICHECK( + (A->dtype == DLDataType{kDLFloat, 32, 1} || A->dtype == DLDataType{kDLFloat, 64, 1})); - if (TypeMatch(A->dtype, kDLFloat, 32)) + if (A->dtype == DLDataType{kDLFloat, 32, 1}) CallGemm(args, ret, MKLSgemmOp()); else CallGemm(args, ret, MKLDgemmOp()); @@ -178,33 +178,34 @@ TVM_FFI_STATIC_INIT_BLOCK() { auto A = args[0].cast(); auto B = args[1].cast(); auto C = args[2].cast(); - TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLUInt, 8) && - TypeMatch(B->dtype, kDLInt, 8) && - TypeMatch(C->dtype, kDLInt, 32)); + TVM_FFI_ICHECK((A->dtype == DLDataType{kDLUInt, 8, 1} && + B->dtype == DLDataType{kDLInt, 8, 1} && + C->dtype == DLDataType{kDLInt, 32, 1})); CallU8S8S32Gemm(args, ret, MKLGemmU8S8S32Op()); }) - .def_packed( - "tvm.contrib.mkl.batch_matmul", - [](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallBatchGemm(args, ret, MKLSgemmBatchOp()); - } else { - CallBatchGemm(args, ret, MKLDgemmBatchOp()); - } - }) - .def_packed( - "tvm.contrib.mkl.batch_matmul_iterative", [](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallBatchGemm(args, ret, MKLSgemmBatchIterativeOp()); - } else { - CallBatchGemm(args, ret, MKLDgemmBatchIterativeOp()); - } - }); + .def_packed("tvm.contrib.mkl.batch_matmul", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + TVM_FFI_ICHECK((A->dtype == DLDataType{kDLFloat, 32, 1} || + A->dtype == DLDataType{kDLFloat, 64, 1})); + if (A->dtype == DLDataType{kDLFloat, 32, 1}) { + CallBatchGemm(args, ret, MKLSgemmBatchOp()); + } else { + CallBatchGemm(args, ret, MKLDgemmBatchOp()); + } + }) + .def_packed("tvm.contrib.mkl.batch_matmul_iterative", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + TVM_FFI_ICHECK((A->dtype == DLDataType{kDLFloat, 32, 1} || + A->dtype == DLDataType{kDLFloat, 64, 1})); + if (A->dtype == DLDataType{kDLFloat, 32, 1}) { + CallBatchGemm(args, ret, MKLSgemmBatchIterativeOp()); + } else { + CallBatchGemm(args, ret, MKLDgemmBatchIterativeOp()); + } + }); } } // namespace contrib } // namespace tvm diff --git a/src/runtime/extra/contrib/coreml/coreml_runtime.mm b/src/runtime/extra/contrib/coreml/coreml_runtime.mm index a72948b250a7..d9823407fb0a 100644 --- a/src/runtime/extra/contrib/coreml/coreml_runtime.mm +++ b/src/runtime/extra/contrib/coreml/coreml_runtime.mm @@ -44,15 +44,15 @@ [shape addObject:[NSNumber numberWithInteger:data_in->shape[i]]]; } - DataType dtype(data_in->dtype); + DLDataType dtype = data_in->dtype; MLMultiArrayDataType dataType; - if (dtype == DataType::Float(64)) { + if (dtype == DLDataType{kDLFloat, 64, 1}) { dataType = MLMultiArrayDataTypeDouble; size *= sizeof(double); - } else if (dtype == DataType::Float(32)) { + } else if (dtype == DLDataType{kDLFloat, 32, 1}) { dataType = MLMultiArrayDataTypeFloat32; size *= sizeof(float); - } else if (dtype == DataType::Int(32)) { + } else if (dtype == DLDataType{kDLInt, 32, 1}) { dataType = MLMultiArrayDataTypeInt32; size *= sizeof(int); } else { @@ -87,15 +87,15 @@ shape.push_back(n); } - DataType dtype; + DLDataType dtype = DLDataType{kDLOpaqueHandle, 0, 0}; if (data_desc.dataType == MLMultiArrayDataTypeDouble) { - dtype = DataType::Float(64); + dtype = DLDataType{kDLFloat, 64, 1}; size *= sizeof(double); } else if (data_desc.dataType == MLMultiArrayDataTypeFloat32) { - dtype = DataType::Float(32); + dtype = DLDataType{kDLFloat, 32, 1}; size *= sizeof(float); } else if (data_desc.dataType == MLMultiArrayDataTypeInt32) { - dtype = DataType::Int(32); + dtype = DLDataType{kDLInt, 32, 1}; size *= sizeof(int); } else { LOG(FATAL) << "unexpected data type " << data_desc.dataType; diff --git a/src/runtime/extra/contrib/cublas/cublas.cc b/src/runtime/extra/contrib/cublas/cublas.cc index 4ef1b702c16c..461bbee1f86c 100644 --- a/src/runtime/extra/contrib/cublas/cublas.cc +++ b/src/runtime/extra/contrib/cublas/cublas.cc @@ -21,11 +21,11 @@ * \file Use external cblas library call. */ #include +#include #include #include #include #include -#include #include "../../../../../3rdparty/compiler-rt/builtin_fp16.h" #include "../cblas/gemm_common.h" @@ -34,7 +34,6 @@ namespace tvm { namespace contrib { -using namespace runtime; inline cublasOperation_t CUBLASBooleanToTranspose(bool item) { return item ? CUBLAS_OP_T : CUBLAS_OP_N; } @@ -125,11 +124,11 @@ struct CublasDgemmBatchOp { // Check cublas supported mix-precision computation type and return computeType bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_support = true) { - if (int_support && TypeMatch(out_dtype, kDLInt, 32)) { - return TypeMatch(in_dtype, kDLInt, 8); - } else if (TypeMatch(out_dtype, kDLFloat, 32)) { - return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16) || - TypeMatch(in_dtype, kDLBfloat, 16); + if (int_support && out_dtype == DLDataType{kDLInt, 32, 1}) { + return in_dtype == DLDataType{kDLInt, 8, 1}; + } else if (out_dtype == DLDataType{kDLFloat, 32, 1}) { + return in_dtype == DLDataType{kDLInt, 8, 1} || in_dtype == DLDataType{kDLFloat, 16, 1} || + in_dtype == DLDataType{kDLBfloat, 16, 1}; } else { return false; } @@ -145,7 +144,7 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, const DLTensor* C, bool transa, bool transb, void* workspace_ptr, size_t workspace_size, cublasLtEpilogue_t epilogue, std::optional dq_scale) { - TVM_FFI_ICHECK(TypeEqual(A->dtype, B->dtype)); + TVM_FFI_ICHECK(A->dtype == B->dtype); // Reversed strides indicates an in-place transpose operation. transa = IsInPlaceTransposed(A) ? !transa : transa; transb = IsInPlaceTransposed(B) ? !transb : transb; @@ -164,26 +163,26 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, void* alpha = &alpha_value; void* beta = &zero_fp32; - if (TypeMatch(A->dtype, kDLFloat, 16)) { + if (A->dtype == DLDataType{kDLFloat, 16, 1}) { ab_type = CUDA_R_16F; - } else if (TypeMatch(A->dtype, kDLBfloat, 16)) { + } else if (A->dtype == DLDataType{kDLBfloat, 16, 1}) { ab_type = CUDA_R_16BF; - } else if (TypeMatch(A->dtype, kDLInt, 8)) { + } else if (A->dtype == DLDataType{kDLInt, 8, 1}) { ab_type = CUDA_R_8I; - } else if (TypeMatch(A->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)) { + } else if (A->dtype == DLDataType{kDLFloat8_e4m3fn, 8, 1}) { #if CUDART_VERSION >= 11080 - TVM_FFI_ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)); + TVM_FFI_ICHECK((B->dtype == DLDataType{kDLFloat8_e4m3fn, 8, 1})); ab_type = CUDA_R_8F_E4M3; #else TVM_FFI_THROW(InternalError) << "Float8 (E4M3) is only supported in CUDA 11.8 and above."; #endif } - if (TypeMatch(C->dtype, kDLFloat, 16)) { + if (C->dtype == DLDataType{kDLFloat, 16, 1}) { c_type = CUDA_R_16F; - } else if (TypeMatch(C->dtype, kDLBfloat, 16)) { + } else if (C->dtype == DLDataType{kDLBfloat, 16, 1}) { c_type = CUDA_R_16BF; - } else if (TypeMatch(C->dtype, kDLInt, 32)) { + } else if (C->dtype == DLDataType{kDLInt, 32, 1}) { c_type = CUDA_R_32I; compute_type = CUBLAS_COMPUTE_32I; scale_type = CUDA_R_32I; @@ -346,9 +345,9 @@ inline void CallLtIgemm(ffi::PackedArgs args, ffi::Any* ret, cublasLtHandle_t hd TVM_FFI_ICHECK_EQ(ElementStride(B), 1); TVM_FFI_ICHECK_EQ(ElementStride(C), 1); - TVM_FFI_ICHECK(TypeEqual(A->dtype, B->dtype)); - TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLInt, 8)); - TVM_FFI_ICHECK(TypeMatch(C->dtype, kDLInt, 32)); + TVM_FFI_ICHECK(A->dtype == B->dtype); + TVM_FFI_ICHECK((A->dtype == DLDataType{kDLInt, 8, 1})); + TVM_FFI_ICHECK((C->dtype == DLDataType{kDLInt, 32, 1})); TVM_FFI_ICHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type"; int32_t alpha = args.size() > 5 ? args[5].cast() : 1; @@ -405,7 +404,7 @@ inline void CallGemmEx(ffi::PackedArgs args, ffi::Any* ret, cublasHandle_t hdl) TVM_FFI_ICHECK_EQ(ElementStride(B), 1); TVM_FFI_ICHECK_EQ(ElementStride(C), 1); - TVM_FFI_ICHECK(TypeEqual(A->dtype, B->dtype)); + TVM_FFI_ICHECK(A->dtype == B->dtype); // C can never be transposed. TVM_FFI_ICHECK(!IsInPlaceTransposed(C)); @@ -415,9 +414,9 @@ inline void CallGemmEx(ffi::PackedArgs args, ffi::Any* ret, cublasHandle_t hdl) transb = IsInPlaceTransposed(B) ? !transb : transb; TVM_FFI_ICHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type"; - TVM_FFI_ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0) + TVM_FFI_ICHECK((!(A->dtype == DLDataType{kDLInt, 8, 1}) || ColumnStride(A) % 4 == 0)) << "leading dimension must divide 4 for int8 gemm"; - TVM_FFI_ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0) + TVM_FFI_ICHECK((!(B->dtype == DLDataType{kDLInt, 8, 1}) || ColumnStride(B) % 4 == 0)) << "leading dimension must divide 4 for int8 gemm"; double alpha = args.size() > 5 ? args[5].cast() : 1.0; double beta = args.size() > 6 ? args[6].cast() : 0.0; @@ -464,7 +463,7 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, cublasHandle_t TVM_FFI_ICHECK_EQ(ElementStride3D(B), 1); TVM_FFI_ICHECK_EQ(ElementStride3D(C), 1); - TVM_FFI_ICHECK(TypeEqual(A->dtype, B->dtype)); + TVM_FFI_ICHECK(A->dtype == B->dtype); // C can never be transposed. TVM_FFI_ICHECK(!IsInPlaceTransposed3D(C)); @@ -474,9 +473,9 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, cublasHandle_t transb = IsInPlaceTransposed3D(B) ? !transb : transb; TVM_FFI_ICHECK(CheckMixPrecisionType(A->dtype, C->dtype, true)) << "Unsupported data type"; - TVM_FFI_ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride3D(A) % 4 == 0) + TVM_FFI_ICHECK((!(A->dtype == DLDataType{kDLInt, 8, 1}) || ColumnStride3D(A) % 4 == 0)) << "leading dimension must divide 4 for int8 gemm"; - TVM_FFI_ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride3D(B) % 4 == 0) + TVM_FFI_ICHECK((!(B->dtype == DLDataType{kDLInt, 8, 1}) || ColumnStride3D(B) % 4 == 0)) << "leading dimension must divide 4 for int8 gemm"; double alpha = args.size() > 5 ? args[5].cast() : 1.0; double beta = args.size() > 6 ? args[6].cast() : 0.0; @@ -538,13 +537,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { CUBLASTryEnableTensorCore(entry_ptr->handle); - if (TypeEqual(A->dtype, C->dtype)) { - TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || - TypeMatch(A->dtype, kDLFloat, 64)); + if (A->dtype == C->dtype) { + TVM_FFI_ICHECK((A->dtype == DLDataType{kDLFloat, 16, 1} || + A->dtype == DLDataType{kDLFloat, 32, 1} || + A->dtype == DLDataType{kDLFloat, 64, 1})); - if (TypeMatch(A->dtype, kDLFloat, 16)) + if (A->dtype == DLDataType{kDLFloat, 16, 1}) CallGemm(args, ret, CublasHgemmOp(entry_ptr->handle)); - else if (TypeMatch(A->dtype, kDLFloat, 32)) + else if (A->dtype == DLDataType{kDLFloat, 32, 1}) CallGemm(args, ret, CublasSgemmOp(entry_ptr->handle)); else CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle)); @@ -565,7 +565,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { CUBLASTryEnableTensorCore(entry_ptr->handle); - TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n"; + TVM_FFI_ICHECK((A->dtype == DLDataType{kDLInt, 8, 1})) << "Expects dtype to be int8\n"; cublasLtHandle_t ltHandle; CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); cudaStream_t stream = @@ -586,13 +586,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(A->device); CUBLASTryEnableTensorCore(entry_ptr->handle); - if (TypeEqual(A->dtype, C->dtype)) { - TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || - TypeMatch(A->dtype, kDLFloat, 64)); + if (A->dtype == C->dtype) { + TVM_FFI_ICHECK((A->dtype == DLDataType{kDLFloat, 16, 1} || + A->dtype == DLDataType{kDLFloat, 32, 1} || + A->dtype == DLDataType{kDLFloat, 64, 1})); - if (TypeMatch(A->dtype, kDLFloat, 16)) + if (A->dtype == DLDataType{kDLFloat, 16, 1}) CallBatchGemm(args, ret, CublasHgemmBatchOp(entry_ptr->handle)); - else if (TypeMatch(A->dtype, kDLFloat, 32)) + else if (A->dtype == DLDataType{kDLFloat, 32, 1}) CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle)); else CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle)); diff --git a/src/runtime/extra/contrib/cudnn/conv_backward.cc b/src/runtime/extra/contrib/cudnn/conv_backward.cc index df3d7c8e6ff7..47b8ab50cdbf 100644 --- a/src/runtime/extra/contrib/cudnn/conv_backward.cc +++ b/src/runtime/extra/contrib/cudnn/conv_backward.cc @@ -21,9 +21,9 @@ * \file cuDNN kernel calls for backward algorithms. */ #include +#include #include #include -#include #include #include @@ -32,8 +32,6 @@ namespace tvm { namespace contrib { -using namespace runtime; - void ConvolutionBackwardData(int mode, int format, int algo, int dims, int groups, const int pad[], const int stride[], const int dilation[], DLTensor* dy, DLTensor* w, DLTensor* dx, const std::string& conv_dtype) { diff --git a/src/runtime/extra/contrib/cudnn/conv_forward.cc b/src/runtime/extra/contrib/cudnn/conv_forward.cc index 3a573297f29e..aba57b7a9de7 100644 --- a/src/runtime/extra/contrib/cudnn/conv_forward.cc +++ b/src/runtime/extra/contrib/cudnn/conv_forward.cc @@ -21,9 +21,9 @@ * \file cuDNN kernel calls for the forward algorithm. */ #include +#include #include #include -#include #include #include @@ -32,8 +32,6 @@ namespace tvm { namespace contrib { -using namespace runtime; - void ConvolutionForward(int mode, int format, int algo, int dims, int groups, const int pad[], const int stride[], const int dilation[], const DLTensor* x, const DLTensor* w, const DLTensor* y, const std::string& conv_dtype) { diff --git a/src/runtime/extra/contrib/cudnn/cudnn_utils.cc b/src/runtime/extra/contrib/cudnn/cudnn_utils.cc index 5c34d4a2b0a6..3edb20dbacbc 100644 --- a/src/runtime/extra/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/extra/contrib/cudnn/cudnn_utils.cc @@ -23,10 +23,10 @@ #include "cudnn_utils.h" +#include #include #include #include -#include #include #include diff --git a/src/runtime/extra/contrib/cudnn/softmax.cc b/src/runtime/extra/contrib/cudnn/softmax.cc index fde7d5e4e182..50b4f69f7383 100644 --- a/src/runtime/extra/contrib/cudnn/softmax.cc +++ b/src/runtime/extra/contrib/cudnn/softmax.cc @@ -31,8 +31,6 @@ namespace tvm { namespace contrib { -using namespace runtime; - void softmax_impl(cudnnSoftmaxAlgorithm_t alg, ffi::PackedArgs args, ffi::Any* ret) { auto x = args[0].cast(); auto y = args[1].cast(); diff --git a/src/runtime/extra/contrib/cutlass/fp16_group_gemm.cuh b/src/runtime/extra/contrib/cutlass/fp16_group_gemm.cuh index 35c4a5767236..85653222169b 100644 --- a/src/runtime/extra/contrib/cutlass/fp16_group_gemm.cuh +++ b/src/runtime/extra/contrib/cutlass/fp16_group_gemm.cuh @@ -49,17 +49,17 @@ void tvm_cutlass_group_gemm_impl(Tensor x, Tensor weight, Tensor indptr, Tensor float alpha = 1.0f; float beta = 0.0f; - if (DataType(x->dtype) == DataType::Float(16)) { - TVM_FFI_ICHECK(DataType(weight->dtype) == DataType::Float(16)); - TVM_FFI_ICHECK(DataType(out->dtype) == DataType::Float(16)); + if (x->dtype == DLDataType{kDLFloat, 16, 1}) { + TVM_FFI_ICHECK((weight->dtype == DLDataType{kDLFloat, 16, 1})); + TVM_FFI_ICHECK((out->dtype == DLDataType{kDLFloat, 16, 1})); using Dtype = cutlass::half_t; CutlassGroupGemm::run( static_cast(x->data), static_cast(weight->data), static_cast(indptr->data), static_cast(workspace->data), workspace->shape[0], n, k, num_groups, alpha, beta, static_cast(out->data), stream); - } else if (DataType(x->dtype) == DataType::BFloat(16)) { - TVM_FFI_ICHECK(DataType(weight->dtype) == DataType::BFloat(16)); - TVM_FFI_ICHECK(DataType(out->dtype) == DataType::BFloat(16)); + } else if (x->dtype == DLDataType{kDLBfloat, 16, 1}) { + TVM_FFI_ICHECK((weight->dtype == DLDataType{kDLBfloat, 16, 1})); + TVM_FFI_ICHECK((out->dtype == DLDataType{kDLBfloat, 16, 1})); using Dtype = cutlass::bfloat16_t; CutlassGroupGemm::run( static_cast(x->data), static_cast(weight->data), diff --git a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh index db88ec0faaed..1af60af4da3a 100644 --- a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh +++ b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh @@ -66,14 +66,15 @@ void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(Tensor a, Tensor b, Tensor scale TVM_FFI_ICHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[0]); TVM_FFI_ICHECK_EQ(scales_b->shape[1] * block_size_1, k); - using tvm::runtime::DataType; - TVM_FFI_ICHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); - TVM_FFI_ICHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); - TVM_FFI_ICHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); - TVM_FFI_ICHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); - TVM_FFI_ICHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); + TVM_FFI_ICHECK_EQ(a->dtype, DLDataType{kDLFloat8_e4m3fn, 8, 1}); + TVM_FFI_ICHECK_EQ(b->dtype, DLDataType{kDLFloat8_e4m3fn, 8, 1}); + TVM_FFI_ICHECK_EQ(scales_a->dtype, DLDataType{kDLFloat, 32, 1}); + TVM_FFI_ICHECK_EQ(scales_b->dtype, DLDataType{kDLFloat, 32, 1}); + TVM_FFI_ICHECK_EQ(workspace->dtype, DLDataType{kDLUInt, 8, 1}); + int64_t workspace_nbytes = + workspace->shape[0] * ((workspace->dtype.bits * workspace->dtype.lanes + 7) / 8); - if (DataType(out->dtype) == DataType::Float(16)) { + if (out->dtype == DLDataType{kDLFloat, 16, 1}) { CutlassFP8GroupwiseGemm::run(static_cast(a->data), @@ -81,10 +82,9 @@ void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(Tensor a, Tensor b, Tensor scale static_cast(scales_a->data), static_cast(scales_b->data), static_cast(out->data), - static_cast(workspace->data), - workspace->shape[0] * DataType(workspace->dtype).bytes(), m, + static_cast(workspace->data), workspace_nbytes, m, n, k, 1, stream); - } else if (DataType(out->dtype) == DataType::BFloat(16)) { + } else if (out->dtype == DLDataType{kDLBfloat, 16, 1}) { CutlassFP8GroupwiseGemm::run(static_cast(a->data), @@ -92,11 +92,10 @@ void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(Tensor a, Tensor b, Tensor scale static_cast(scales_a->data), static_cast(scales_b->data), static_cast(out->data), - static_cast(workspace->data), - workspace->shape[0] * DataType(workspace->dtype).bytes(), m, + static_cast(workspace->data), workspace_nbytes, m, n, k, 1, stream); } else { - LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype); + LOG(FATAL) << "Unsupported output dtype: " << out->dtype; } } @@ -131,14 +130,15 @@ void tvm_cutlass_fp8_groupwise_scaled_bmm_impl(Tensor a, Tensor b, Tensor scales TVM_FFI_ICHECK_EQ(scales_b->shape[1] * block_size_0, n); TVM_FFI_ICHECK_EQ(scales_b->shape[2] * block_size_1, k); - using tvm::runtime::DataType; - TVM_FFI_ICHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); - TVM_FFI_ICHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); - TVM_FFI_ICHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); - TVM_FFI_ICHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); - TVM_FFI_ICHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); + TVM_FFI_ICHECK_EQ(a->dtype, DLDataType{kDLFloat8_e4m3fn, 8, 1}); + TVM_FFI_ICHECK_EQ(b->dtype, DLDataType{kDLFloat8_e4m3fn, 8, 1}); + TVM_FFI_ICHECK_EQ(scales_a->dtype, DLDataType{kDLFloat, 32, 1}); + TVM_FFI_ICHECK_EQ(scales_b->dtype, DLDataType{kDLFloat, 32, 1}); + TVM_FFI_ICHECK_EQ(workspace->dtype, DLDataType{kDLUInt, 8, 1}); + int64_t workspace_nbytes = + workspace->shape[0] * ((workspace->dtype.bits * workspace->dtype.lanes + 7) / 8); - if (DataType(out->dtype) == DataType::Float(16)) { + if (out->dtype == DLDataType{kDLFloat, 16, 1}) { CutlassFP8GroupwiseGemm::run(static_cast(a->data), @@ -146,10 +146,9 @@ void tvm_cutlass_fp8_groupwise_scaled_bmm_impl(Tensor a, Tensor b, Tensor scales static_cast(scales_a->data), static_cast(scales_b->data), static_cast(out->data), - static_cast(workspace->data), - workspace->shape[0] * DataType(workspace->dtype).bytes(), m, + static_cast(workspace->data), workspace_nbytes, m, n, k, batch_size, stream); - } else if (DataType(out->dtype) == DataType::BFloat(16)) { + } else if (out->dtype == DLDataType{kDLBfloat, 16, 1}) { CutlassFP8GroupwiseGemm::run(static_cast(a->data), @@ -157,11 +156,10 @@ void tvm_cutlass_fp8_groupwise_scaled_bmm_impl(Tensor a, Tensor b, Tensor scales static_cast(scales_a->data), static_cast(scales_b->data), static_cast(out->data), - static_cast(workspace->data), - workspace->shape[0] * DataType(workspace->dtype).bytes(), m, + static_cast(workspace->data), workspace_nbytes, m, n, k, batch_size, stream); } else { - LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype); + LOG(FATAL) << "Unsupported output dtype: " << out->dtype; } } diff --git a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu index ea70eee38650..6bd9f45ab25e 100644 --- a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu +++ b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu @@ -57,15 +57,14 @@ void tvm_fp8_groupwise_scaled_group_gemm_sm100(Tensor a, Tensor b, Tensor scales TVM_FFI_ICHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[1]); TVM_FFI_ICHECK_EQ((k + block_size_1 - 1) / block_size_1, scales_b->shape[2]); - using tvm::runtime::DataType; - TVM_FFI_ICHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); - TVM_FFI_ICHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); - TVM_FFI_ICHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); - TVM_FFI_ICHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); - TVM_FFI_ICHECK_EQ(DataType(indptr->dtype), DataType::Int(64)); - TVM_FFI_ICHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); + TVM_FFI_ICHECK_EQ(a->dtype, DLDataType{kDLFloat8_e4m3fn, 8, 1}); + TVM_FFI_ICHECK_EQ(b->dtype, DLDataType{kDLFloat8_e4m3fn, 8, 1}); + TVM_FFI_ICHECK_EQ(scales_a->dtype, DLDataType{kDLFloat, 32, 1}); + TVM_FFI_ICHECK_EQ(scales_b->dtype, DLDataType{kDLFloat, 32, 1}); + TVM_FFI_ICHECK_EQ(indptr->dtype, DLDataType{kDLInt, 64, 1}); + TVM_FFI_ICHECK_EQ(workspace->dtype, DLDataType{kDLUInt, 8, 1}); - if (DataType(out->dtype) == DataType::Float(16)) { + if (out->dtype == DLDataType{kDLFloat, 16, 1}) { using Dtype = cutlass::half_t; cutlass_fp8_groupwise_scaled_group_gemm_sm100( @@ -73,7 +72,7 @@ void tvm_fp8_groupwise_scaled_group_gemm_sm100(Tensor a, Tensor b, Tensor scales static_cast(scales_a->data), static_cast(scales_b->data), static_cast(indptr->data), static_cast(workspace->data), workspace->shape[0], n, k, num_groups, static_cast(out->data), stream); - } else if (DataType(out->dtype) == DataType::BFloat(16)) { + } else if (out->dtype == DLDataType{kDLBfloat, 16, 1}) { using Dtype = cutlass::bfloat16_t; cutlass_fp8_groupwise_scaled_group_gemm_sm100( diff --git a/src/runtime/extra/contrib/dnnl/dnnl_utils.cc b/src/runtime/extra/contrib/dnnl/dnnl_utils.cc index 23992209f2ad..e41d378b3d30 100644 --- a/src/runtime/extra/contrib/dnnl/dnnl_utils.cc +++ b/src/runtime/extra/contrib/dnnl/dnnl_utils.cc @@ -32,21 +32,21 @@ namespace contrib { dnnl::memory::data_type dtype_dl2dnnl(DLDataType dltype) { using dt = dnnl::memory::data_type; dt dnnl_type = dt::undef; - if (dltype.code == DataType::TypeCode::kFloat) { + if (dltype.code == DLDataTypeCode::kDLFloat) { if (dltype.bits == 16) { dnnl_type = dt::f16; } else if (dltype.bits == 32) { dnnl_type = dt::f32; } - } else if (dltype.code == DataType::TypeCode::kBFloat && dltype.bits == 16) { + } else if (dltype.code == DLDataTypeCode::kDLBfloat && dltype.bits == 16) { dnnl_type = dt::bf16; - } else if (dltype.code == DataType::TypeCode::kInt) { + } else if (dltype.code == DLDataTypeCode::kDLInt) { if (dltype.bits == 8) { dnnl_type = dt::s8; } else if (dltype.bits == 32) { dnnl_type = dt::s32; } - } else if (dltype.code == DataType::TypeCode::kUInt && dltype.bits == 8) { + } else if (dltype.code == DLDataTypeCode::kDLUInt && dltype.bits == 8) { dnnl_type = dt::u8; } if (dnnl_type == dt::undef) { diff --git a/src/runtime/extra/contrib/dnnl/dnnl_utils.h b/src/runtime/extra/contrib/dnnl/dnnl_utils.h index a598b6704450..6f36ed4d8fbe 100644 --- a/src/runtime/extra/contrib/dnnl/dnnl_utils.h +++ b/src/runtime/extra/contrib/dnnl/dnnl_utils.h @@ -34,7 +34,7 @@ // -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command #include -#include "tvm/runtime/data_type.h" +#include "tvm/ffi/dtype.h" namespace tvm { namespace runtime { diff --git a/src/runtime/extra/contrib/hipblas/hipblas.cc b/src/runtime/extra/contrib/hipblas/hipblas.cc index 5276b4f7956d..18e136b0fdec 100644 --- a/src/runtime/extra/contrib/hipblas/hipblas.cc +++ b/src/runtime/extra/contrib/hipblas/hipblas.cc @@ -21,10 +21,10 @@ * \file Use external hipblas library call. */ #include +#include #include #include #include -#include #include "../../../../../3rdparty/compiler-rt/builtin_fp16.h" #include "../cblas/gemm_common.h" @@ -33,7 +33,6 @@ namespace tvm { namespace contrib { -using namespace runtime; inline hipblasOperation_t HIPBLASBooleanToTranspose(bool item) { return item ? HIPBLAS_OP_T : HIPBLAS_OP_N; } @@ -117,10 +116,10 @@ struct HipblasDgemmBatchOp { // Check supported mix-precision computation type and return computeType bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_support = true) { - if (int_support && TypeMatch(out_dtype, kDLInt, 32)) { - return TypeMatch(in_dtype, kDLInt, 8); - } else if (TypeMatch(out_dtype, kDLFloat, 32)) { - return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16); + if (int_support && out_dtype == DLDataType{kDLInt, 32, 1}) { + return in_dtype == DLDataType{kDLInt, 8, 1}; + } else if (out_dtype == DLDataType{kDLFloat, 32, 1}) { + return in_dtype == DLDataType{kDLInt, 8, 1} || in_dtype == DLDataType{kDLFloat, 16, 1}; } else { return false; } @@ -131,7 +130,7 @@ void CallHipblasLt(hipblasLtHandle_t hdl, hipStream_t stream, const DLTensor* B, const DLTensor* bias, const DLTensor* C, bool transa, bool transb, void* workspace_ptr, size_t workspace_size, hipblasLtEpilogue_t epilogue) { - TVM_FFI_ICHECK(TypeEqual(A->dtype, B->dtype)); + TVM_FFI_ICHECK(A->dtype == B->dtype); // Reversed strides indicates an in-place transpose operation. transa = IsInPlaceTransposed(A) ? !transa : transa; transb = IsInPlaceTransposed(B) ? !transb : transb; @@ -147,15 +146,15 @@ void CallHipblasLt(hipblasLtHandle_t hdl, hipStream_t stream, void* alpha = &one_fp32; void* beta = &zero_fp32; - if (TypeMatch(A->dtype, kDLFloat, 16)) { + if (A->dtype == DLDataType{kDLFloat, 16, 1}) { ab_type = HIP_R_16F; - } else if (TypeMatch(A->dtype, kDLInt, 8)) { + } else if (A->dtype == DLDataType{kDLInt, 8, 1}) { ab_type = HIP_R_8I; } - if (TypeMatch(C->dtype, kDLFloat, 16)) { + if (C->dtype == DLDataType{kDLFloat, 16, 1}) { c_type = HIP_R_16F; - } else if (TypeMatch(C->dtype, kDLInt, 32)) { + } else if (C->dtype == DLDataType{kDLInt, 32, 1}) { c_type = HIP_R_32I; compute_type = HIPBLAS_COMPUTE_32I; scale_type = HIP_R_32I; @@ -288,7 +287,7 @@ inline void CallGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t hdl) TVM_FFI_ICHECK_EQ(ElementStride(B), 1); TVM_FFI_ICHECK_EQ(ElementStride(C), 1); - TVM_FFI_ICHECK(TypeEqual(A->dtype, B->dtype)); + TVM_FFI_ICHECK(A->dtype == B->dtype); // C can never be transposed. TVM_FFI_ICHECK(!IsInPlaceTransposed(C)); @@ -298,9 +297,9 @@ inline void CallGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t hdl) transb = IsInPlaceTransposed(B) ? !transb : transb; TVM_FFI_ICHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type"; - TVM_FFI_ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0) + TVM_FFI_ICHECK((!(A->dtype == DLDataType{kDLInt, 8, 1}) || ColumnStride(A) % 4 == 0)) << "leading dimension must divide 4 for int8 gemm"; - TVM_FFI_ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0) + TVM_FFI_ICHECK((!(B->dtype == DLDataType{kDLInt, 8, 1}) || ColumnStride(B) % 4 == 0)) << "leading dimension must divide 4 for int8 gemm"; double alpha = args.size() > 5 ? args[5].cast() : 1.0; double beta = args.size() > 6 ? args[6].cast() : 0.0; @@ -347,7 +346,7 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t TVM_FFI_ICHECK_EQ(ElementStride3D(B), 1); TVM_FFI_ICHECK_EQ(ElementStride3D(C), 1); - TVM_FFI_ICHECK(TypeEqual(A->dtype, B->dtype)); + TVM_FFI_ICHECK(A->dtype == B->dtype); // C can never be transposed. TVM_FFI_ICHECK(!IsInPlaceTransposed3D(C)); @@ -357,9 +356,9 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t transb = IsInPlaceTransposed3D(B) ? !transb : transb; TVM_FFI_ICHECK(CheckMixPrecisionType(A->dtype, C->dtype, true)) << "Unsupported data type"; - TVM_FFI_ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride3D(A) % 4 == 0) + TVM_FFI_ICHECK((!(A->dtype == DLDataType{kDLInt, 8, 1}) || ColumnStride3D(A) % 4 == 0)) << "leading dimension must divide 4 for int8 gemm"; - TVM_FFI_ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride3D(B) % 4 == 0) + TVM_FFI_ICHECK((!(B->dtype == DLDataType{kDLInt, 8, 1}) || ColumnStride3D(B) % 4 == 0)) << "leading dimension must divide 4 for int8 gemm"; double alpha = args.size() > 5 ? args[5].cast() : 1.0; double beta = args.size() > 6 ? args[6].cast() : 0.0; @@ -419,14 +418,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { HipBlasThreadEntry* entry_ptr = HipBlasThreadEntry::ThreadLocal(A->device); - if (TypeEqual(A->dtype, C->dtype)) { - TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || - TypeMatch(A->dtype, kDLFloat, 32) || - TypeMatch(A->dtype, kDLFloat, 64)); + if (A->dtype == C->dtype) { + TVM_FFI_ICHECK((A->dtype == DLDataType{kDLFloat, 16, 1} || + A->dtype == DLDataType{kDLFloat, 32, 1} || + A->dtype == DLDataType{kDLFloat, 64, 1})); - if (TypeMatch(A->dtype, kDLFloat, 16)) { + if (A->dtype == DLDataType{kDLFloat, 16, 1}) { CallGemm(args, ret, HipblasHgemmOp(entry_ptr->handle)); - } else if (TypeMatch(A->dtype, kDLFloat, 32)) { + } else if (A->dtype == DLDataType{kDLFloat, 32, 1}) { CallGemm(args, ret, HipblasSgemmOp(entry_ptr->handle)); } else { CallGemm(args, ret, HipblasDgemmOp(entry_ptr->handle)); @@ -441,13 +440,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { HipBlasThreadEntry* entry_ptr = HipBlasThreadEntry::ThreadLocal(A->device); - if (TypeEqual(A->dtype, C->dtype)) { - TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || - TypeMatch(A->dtype, kDLFloat, 64)); + if (A->dtype == C->dtype) { + TVM_FFI_ICHECK((A->dtype == DLDataType{kDLFloat, 16, 1} || + A->dtype == DLDataType{kDLFloat, 32, 1} || + A->dtype == DLDataType{kDLFloat, 64, 1})); - if (TypeMatch(A->dtype, kDLFloat, 16)) { + if (A->dtype == DLDataType{kDLFloat, 16, 1}) { CallBatchGemm(args, ret, HipblasHgemmBatchOp(entry_ptr->handle)); - } else if (TypeMatch(A->dtype, kDLFloat, 32)) { + } else if (A->dtype == DLDataType{kDLFloat, 32, 1}) { CallBatchGemm(args, ret, HipblasSgemmBatchOp(entry_ptr->handle)); } else { CallBatchGemm(args, ret, HipblasDgemmBatchOp(entry_ptr->handle)); diff --git a/src/runtime/extra/contrib/json/json_node.h b/src/runtime/extra/contrib/json/json_node.h index c165f6b05cf3..40c96d826914 100644 --- a/src/runtime/extra/contrib/json/json_node.h +++ b/src/runtime/extra/contrib/json/json_node.h @@ -29,9 +29,9 @@ #include #include #include +#include #include #include -#include #include #include diff --git a/src/runtime/extra/contrib/nvshmem/memory_allocator.cc b/src/runtime/extra/contrib/nvshmem/memory_allocator.cc index cb6e3520c8c1..1483563b6200 100644 --- a/src/runtime/extra/contrib/nvshmem/memory_allocator.cc +++ b/src/runtime/extra/contrib/nvshmem/memory_allocator.cc @@ -57,7 +57,7 @@ class NVSHMEMAllocator final : public PooledAllocator { return allocator; } - Tensor Empty(ffi::Shape shape, DataType dtype, Device device) { + Tensor Empty(ffi::Shape shape, DLDataType dtype, Device device) { class NVSHMEMAlloc { public: explicit NVSHMEMAlloc(Buffer buffer) : buffer_(buffer) {} @@ -87,7 +87,7 @@ class NVSHMEMAllocator final : public PooledAllocator { void DeviceFreeDataSpace(Device dev, void* ptr) final { nvshmem_free(ptr); } }; -Tensor NVSHMEMEmpty(ffi::Shape shape, DataType dtype, ffi::Optional device) { +Tensor NVSHMEMEmpty(ffi::Shape shape, DLDataType dtype, ffi::Optional device) { return NVSHMEMAllocator::Global()->Empty(shape, dtype, UseDefaultDeviceIfNone(device)); } diff --git a/src/runtime/extra/contrib/random/random.cc b/src/runtime/extra/contrib/random/random.cc index a3d0cd8b85a8..0a96185933e3 100644 --- a/src/runtime/extra/contrib/random/random.cc +++ b/src/runtime/extra/contrib/random/random.cc @@ -21,10 +21,10 @@ * \file External random functions for tensor. */ #include +#include #include #include #include -#include #include #include @@ -69,8 +69,6 @@ namespace tvm { namespace contrib { -using namespace runtime; - struct RandomThreadLocalEntry { RandomEngine random_engine; static RandomThreadLocalEntry* ThreadLocal(); diff --git a/src/runtime/extra/contrib/sort/sort.cc b/src/runtime/extra/contrib/sort/sort.cc index 51a94111b6e6..6e3a99f93522 100644 --- a/src/runtime/extra/contrib/sort/sort.cc +++ b/src/runtime/extra/contrib/sort/sort.cc @@ -23,10 +23,10 @@ #include #include +#include #include #include #include -#include #include #include @@ -36,8 +36,6 @@ namespace tvm { namespace contrib { -using namespace runtime; - template bool CompareAscend(const std::pair& lhs, const std::pair& rhs) { if constexpr (stable_comparison) { diff --git a/src/runtime/extra/contrib/vllm/cache_alloc.cc b/src/runtime/extra/contrib/vllm/cache_alloc.cc index 266138406cb9..42601d7a5e69 100644 --- a/src/runtime/extra/contrib/vllm/cache_alloc.cc +++ b/src/runtime/extra/contrib/vllm/cache_alloc.cc @@ -39,9 +39,9 @@ ffi::Array AllocateKVCache(int head_size, int num_layers, int num_heads, for (int i = 0; i < num_layers; ++i) { Tensor key_blocks = Tensor::Empty({num_blocks, num_heads, head_size / vec_size, block_size, vec_size}, - runtime::DataType::Float(16), dev); + DLDataType{kDLFloat, 16, 1}, dev); Tensor value_blocks = Tensor::Empty({num_blocks, num_heads, head_size, block_size}, - runtime::DataType::Float(16), dev); + DLDataType{kDLFloat, 16, 1}, dev); cache.push_back(key_blocks); cache.push_back(value_blocks); } diff --git a/src/runtime/extra/contrib/vllm/cache_kernels.cu b/src/runtime/extra/contrib/vllm/cache_kernels.cu index 5af93a1fd904..6a09497a8d12 100644 --- a/src/runtime/extra/contrib/vllm/cache_kernels.cu +++ b/src/runtime/extra/contrib/vllm/cache_kernels.cu @@ -206,16 +206,16 @@ TVM_FFI_STATIC_INIT_BLOCK() { DLDevice dev = key_cache->device; Tensor key_cache_ptrs_gpu = - Tensor::Empty({static_cast(num_layers)}, runtime::DataType::Int(64), dev); + Tensor::Empty({static_cast(num_layers)}, DLDataType{kDLInt, 64, 1}, dev); Tensor value_cache_ptrs_gpu = - Tensor::Empty({static_cast(num_layers)}, runtime::DataType::Int(64), dev); + Tensor::Empty({static_cast(num_layers)}, DLDataType{kDLInt, 64, 1}, dev); key_cache_ptrs_gpu.CopyFromBytes(key_cache_ptrs.data(), sizeof(int64_t) * key_cache_ptrs.size()); value_cache_ptrs_gpu.CopyFromBytes(value_cache_ptrs.data(), sizeof(int64_t) * value_cache_ptrs.size()); Tensor block_mapping_gpu = - Tensor::Empty(block_mapping.Shape(), runtime::DataType::Int(64), dev); + Tensor::Empty(block_mapping.Shape(), DLDataType{kDLInt, 64, 1}, dev); block_mapping_gpu.CopyFromBytes(block_mapping->data, sizeof(int64_t) * block_mapping->shape[0]); diff --git a/src/runtime/extra/disco/builtin.cc b/src/runtime/extra/disco/builtin.cc index da9f472b3e76..d9d5fc132768 100644 --- a/src/runtime/extra/disco/builtin.cc +++ b/src/runtime/extra/disco/builtin.cc @@ -71,7 +71,7 @@ ffi::Module LoadVMModule(std::string path, ffi::Optional device) { return mod; } -Tensor DiscoEmptyTensor(ffi::Shape shape, DataType dtype, ffi::Optional device) { +Tensor DiscoEmptyTensor(ffi::Shape shape, DLDataType dtype, ffi::Optional device) { return Tensor::Empty(shape, dtype, UseDefaultDeviceIfNone(device)); } @@ -131,7 +131,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef() .def("runtime.disco.load_vm_module", LoadVMModule) .def("runtime.disco.empty", - [](ffi::Shape shape, DataType dtype, ffi::Optional device, bool worker0_only, + [](ffi::Shape shape, DLDataType dtype, ffi::Optional device, bool worker0_only, bool in_group) -> ffi::Optional { int worker_id = WorkerId(); int group_size = diff --git a/src/runtime/extra/disco/cuda_ipc/cuda_ipc_memory.cc b/src/runtime/extra/disco/cuda_ipc/cuda_ipc_memory.cc index 426557b7b7ad..a8a8030f0169 100644 --- a/src/runtime/extra/disco/cuda_ipc/cuda_ipc_memory.cc +++ b/src/runtime/extra/disco/cuda_ipc/cuda_ipc_memory.cc @@ -97,10 +97,12 @@ class CUDAIPCMemoryAllocator final : public memory::PooledAllocator { auto [data_ptr, data_comm_ptrs] = AllocIPCMemory(dev, size, alignment, type_hint, /*reset_memory_to_zero=*/false); int barrier_ptr_size = sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE; - auto [barrier_in_ptr, barrier_in_comm_ptrs] = AllocIPCMemory( - dev, barrier_ptr_size, alignment, DataType::UInt(32), /*reset_memory_to_zero=*/true); - auto [barrier_out_ptr, barrier_out_comm_ptrs] = AllocIPCMemory( - dev, barrier_ptr_size, alignment, DataType::UInt(32), /*reset_memory_to_zero=*/true); + auto [barrier_in_ptr, barrier_in_comm_ptrs] = + AllocIPCMemory(dev, barrier_ptr_size, alignment, DLDataType{kDLUInt, 32, 1}, + /*reset_memory_to_zero=*/true); + auto [barrier_out_ptr, barrier_out_comm_ptrs] = + AllocIPCMemory(dev, barrier_ptr_size, alignment, DLDataType{kDLUInt, 32, 1}, + /*reset_memory_to_zero=*/true); // Create the CUDAIPCMemory object. ffi::ObjectPtr ipc_memory = ffi::make_object(); diff --git a/src/runtime/extra/disco/cuda_ipc/custom_allreduce.cc b/src/runtime/extra/disco/cuda_ipc/custom_allreduce.cc index ffe00d5feef9..3eaca5ba98d4 100644 --- a/src/runtime/extra/disco/cuda_ipc/custom_allreduce.cc +++ b/src/runtime/extra/disco/cuda_ipc/custom_allreduce.cc @@ -81,7 +81,7 @@ void CustomAllReduce(DLTensor* send, int strategy, DLTensor* recv) { // Dispatch to nccl AllReduce if the customized all-reduce cannot apply. deviceStream_t stream = ctx->GetDefaultStream(); NCCL_CALL(ncclAllReduce(send->data, recv->data, num_elements, - /*datatype=*/nccl::AsNCCLDataType(DataType(send->dtype)), + /*datatype=*/nccl::AsNCCLDataType(send->dtype), /*op=*/ncclSum, ctx->global_comm, stream)); return; } diff --git a/src/runtime/extra/disco/loader.cc b/src/runtime/extra/disco/loader.cc index 86caac6573ed..f714112aecf3 100644 --- a/src/runtime/extra/disco/loader.cc +++ b/src/runtime/extra/disco/loader.cc @@ -17,10 +17,10 @@ * under the License. */ #include +#include #include #include #include -#include #include #include @@ -45,7 +45,7 @@ using ParamRecord = TensorCacheMetadata::FileRecord::ParamRecord; struct ShardInfo { struct TensorInfo { ffi::Shape shape; - DataType dtype; + DLDataType dtype; }; struct ShardFunc { std::string name; @@ -67,8 +67,7 @@ ShardInfo::TensorInfo LoadTensorInfoFromJSON(const json::Array& json_tensor_info shape.push_back(shape_json[i].cast()); } std::string dtype = json_tensor_info[1].cast(); - return ShardInfo::TensorInfo{ffi::Shape(std::move(shape)), - DataType(ffi::StringToDLDataType(dtype))}; + return ShardInfo::TensorInfo{ffi::Shape(std::move(shape)), ffi::StringToDLDataType(dtype)}; } ShardInfo::ShardFunc LoadShardFuncFromJSON(const json::Array& json_shard_func) { @@ -301,7 +300,7 @@ Tensor ShardLoaderObj::Load(int weight_index) const { bool needs_sharding = !param_info.shard_info.funcs.empty(); if (needs_sharding) { ffi::Shape shape = param_info.shard_info.funcs.back().output_info.shape; - DataType dtype = param_info.shard_info.funcs.back().output_info.dtype; + DLDataType dtype = param_info.shard_info.funcs.back().output_info.dtype; TVM_FFI_CHECK(shape.size() >= 1 && shape[0] == num_shards, ValueError) << "The first dimension of the " << "output shape must be equal to the " diff --git a/src/runtime/extra/disco/nccl/nccl.cc b/src/runtime/extra/disco/nccl/nccl.cc index 887f440b1b4f..cd00a1ac3d6b 100644 --- a/src/runtime/extra/disco/nccl/nccl.cc +++ b/src/runtime/extra/disco/nccl/nccl.cc @@ -122,8 +122,8 @@ void AllReduce(Tensor send, ReduceKind reduce_kind, bool in_group, Tensor recv) ffi::Shape shape = send.Shape(); int64_t numel = shape->Product(); deviceStream_t stream = ctx->GetDefaultStream(); - DataType dtype = DataType(send->dtype); - if (dtype == DataType::Float8E4M3FN() || dtype == DataType::Float8E5M2()) { + DLDataType dtype = send->dtype; + if (dtype == DLDataType{kDLFloat8_e4m3fn, 8, 1} || dtype == DLDataType{kDLFloat8_e5m2, 8, 1}) { TVM_FFI_THROW(InternalError) << "Float8 data type cannot be allreduced, as nccl does not support this data type."; } @@ -139,7 +139,7 @@ void AllGather(Tensor send, bool in_group, Tensor recv) { int64_t numel = shape->Product(); deviceStream_t stream = ctx->GetDefaultStream(); NCCL_CALL(ncclAllGather(send->data, recv->data, numel, - /*datatype=*/AsNCCLDataType(DataType(send->dtype)), + /*datatype=*/AsNCCLDataType(send->dtype), in_group ? ctx->group_comm : ctx->global_comm, stream)); } @@ -162,7 +162,7 @@ void BroadcastFromWorker0(ffi::Optional send, bool in_group, Tensor recv deviceStream_t stream = ctx->GetDefaultStream(); NCCL_CALL(ncclBroadcast(send_data, recv->data, numel, - /*datatype=*/AsNCCLDataType(DataType(recv->dtype)), + /*datatype=*/AsNCCLDataType(recv->dtype), /*root=*/0, in_group ? ctx->group_comm : ctx->global_comm, stream)); } @@ -185,9 +185,9 @@ void ScatterFromWorker0(ffi::Optional send, bool in_group, Tensor recv) "of elements in the buffer to be " "divisible by the number of workers, but got numel = " << numel << " and " << num_receiver << " workers."; - DataType dtype(buffer->dtype); + DLDataType dtype = buffer->dtype; int64_t numel_per_shard = numel / num_receiver; - int64_t bytes_per_shard = numel_per_shard * dtype.bytes(); + int64_t bytes_per_shard = numel_per_shard * ((dtype.bits * dtype.lanes + 7) / 8); TVM_FFI_CHECK_EQ(numel_per_shard, recv.Shape().Product(), ValueError) << "The number of elements in buffer `recv` must be the same as each shard " "of " @@ -209,7 +209,7 @@ void ScatterFromWorker0(ffi::Optional send, bool in_group, Tensor recv) NCCL_CALL(ncclGroupStart()); } int64_t numel = recv.Shape().Product(); - DataType dtype(recv->dtype); + DLDataType dtype = recv->dtype; NCCL_CALL(ncclRecv(recv->data, numel, AsNCCLDataType(dtype), 0, in_group ? ctx->group_comm : ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); @@ -234,9 +234,9 @@ void GatherToWorker0(Tensor send, bool in_group, ffi::Optional recv) { "of elements in the buffer to be " "divisible by the number of workers, but got numel = " << numel << " and " << num_receiver << " workers."; - DataType dtype(buffer->dtype); + DLDataType dtype = buffer->dtype; int64_t numel_per_shard = numel / num_receiver; - int64_t bytes_per_shard = numel_per_shard * dtype.bytes(); + int64_t bytes_per_shard = numel_per_shard * ((dtype.bits * dtype.lanes + 7) / 8); TVM_FFI_CHECK_EQ(numel_per_shard, send.Shape().Product(), ValueError) << "The number of elements in buffer `send` must be the same as each shard " "of " @@ -258,7 +258,7 @@ void GatherToWorker0(Tensor send, bool in_group, ffi::Optional recv) { NCCL_CALL(ncclGroupStart()); } int64_t numel = send.Shape().Product(); - DataType dtype(send->dtype); + DLDataType dtype = send->dtype; NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0, in_group ? ctx->group_comm : ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); diff --git a/src/runtime/extra/disco/nccl/nccl_context.h b/src/runtime/extra/disco/nccl/nccl_context.h index 7a99be0897c0..d529ab441d11 100644 --- a/src/runtime/extra/disco/nccl/nccl_context.h +++ b/src/runtime/extra/disco/nccl/nccl_context.h @@ -86,39 +86,39 @@ inline void StreamDestroy(deviceStream_t stream) { ROCM_CALL(hipStreamDestroy(st #endif -/*! \brief Convert DataType to ncclDataType. */ -inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) { - if (dtype == DataType::Int(8)) { +/*! \brief Convert DLPack dtype to ncclDataType. */ +inline ncclDataType_t AsNCCLDataType(DLDataType dtype) { + if (dtype == DLDataType{kDLInt, 8, 1}) { return ncclInt8; } - if (dtype == DataType::UInt(8) || dtype == DataType::Float8E4M3FN() || - dtype == DataType::Float8E5M2()) { + if (dtype == DLDataType{kDLUInt, 8, 1} || dtype == DLDataType{kDLFloat8_e4m3fn, 8, 1} || + dtype == DLDataType{kDLFloat8_e5m2, 8, 1}) { // For float8 data type, pretend to be uint8 in nccl. // And will throw error when allreduce, as it makes no sense in this case. return ncclUint8; } - if (dtype == DataType::Int(32)) { + if (dtype == DLDataType{kDLInt, 32, 1}) { return ncclInt32; } - if (dtype == DataType::UInt(32)) { + if (dtype == DLDataType{kDLUInt, 32, 1}) { return ncclUint32; } - if (dtype == DataType::Int(64)) { + if (dtype == DLDataType{kDLInt, 64, 1}) { return ncclInt64; } - if (dtype == DataType::UInt(64)) { + if (dtype == DLDataType{kDLUInt, 64, 1}) { return ncclUint64; } - if (dtype == DataType::Float(16)) { + if (dtype == DLDataType{kDLFloat, 16, 1}) { return ncclFloat16; } - if (dtype == DataType::Float(32)) { + if (dtype == DLDataType{kDLFloat, 32, 1}) { return ncclFloat32; } - if (dtype == DataType::Float(64)) { + if (dtype == DLDataType{kDLFloat, 64, 1}) { return ncclFloat64; } - if (dtype == DataType::BFloat(16)) { + if (dtype == DLDataType{kDLBfloat, 16, 1}) { return ncclBfloat16; } TVM_FFI_THROW(ValueError) << "Unsupported data type " << dtype; diff --git a/src/runtime/tensor.cc b/src/runtime/tensor.cc index 887d576537f2..ed12d0b4885a 100644 --- a/src/runtime/tensor.cc +++ b/src/runtime/tensor.cc @@ -33,7 +33,7 @@ #include "../support/base64.h" #include "../support/bytes_io.h" -#include "tvm/runtime/data_type.h" +#include "tvm/ffi/dtype.h" namespace tvm { namespace runtime { @@ -52,11 +52,11 @@ inline void VerifyDataType(DLDataType dtype) { return; else if (dtype.bits == 4 && dtype.code == kDLInt) return; - else if (dtype.bits == 6 && dtype.code == DataType::kFloat6_e2m3fn) + else if (dtype.bits == 6 && dtype.code == kDLFloat6_e2m3fn) return; - else if (dtype.bits == 6 && dtype.code == DataType::kFloat6_e3m2fn) + else if (dtype.bits == 6 && dtype.code == kDLFloat6_e3m2fn) return; - else if (dtype.bits == 4 && dtype.code == DataType::kFloat4_e2m1fn) + else if (dtype.bits == 4 && dtype.code == kDLFloat4_e2m1fn) return; else TVM_FFI_ICHECK_EQ(dtype.bits % 8, 0); diff --git a/src/runtime/vm/attn_backend.h b/src/runtime/vm/attn_backend.h index 067fa8d10dc1..6aececc755ea 100644 --- a/src/runtime/vm/attn_backend.h +++ b/src/runtime/vm/attn_backend.h @@ -321,7 +321,7 @@ class PagedDecodeFunc : public AttnBackendFunc { Tensor page_locked_int_workspace_buffer, HostMemoryVector* page_indptr, int64_t batch_size, int64_t page_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, int64_t v_head_dim, - RoPEMode rope_mode, DataType q_dtype, DataType kv_dtype, + RoPEMode rope_mode, DLDataType q_dtype, DLDataType kv_dtype, TVMStreamHandle copy_stream) { // Do nothing. Subclasses can override to customize behavior. } @@ -377,7 +377,7 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc { Tensor page_locked_int_workspace_buffer, HostMemoryVector* page_indptr, int64_t batch_size, int64_t page_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, int64_t v_head_dim, - RoPEMode rope_mode, DataType q_dtype, DataType kv_dtype, + RoPEMode rope_mode, DLDataType q_dtype, DLDataType kv_dtype, TVMStreamHandle copy_stream) final { // Todo(tvm-team): enable cuda graph ffi::Shape plan_info_vec = diff --git a/src/runtime/vm/attn_utils.h b/src/runtime/vm/attn_utils.h index 7a2c93414c0f..4f9cd648e9d7 100644 --- a/src/runtime/vm/attn_utils.h +++ b/src/runtime/vm/attn_utils.h @@ -359,7 +359,7 @@ class HostMemoryVector { explicit HostMemoryVector(int64_t reserved_size, DLDataType dtype, Device device) : reserved_size_(reserved_size) { - TVM_FFI_ICHECK(DataType(dtype) == DataType::Int(32)); + TVM_FFI_ICHECK((dtype == DLDataType{kDLInt, 32, 1})); data_ = Tensor::Empty({reserved_size}, dtype, device); } @@ -368,7 +368,7 @@ class HostMemoryVector { if (current_size_ == reserved_size_) { reserved_size_ *= 2; Tensor new_data = Tensor::Empty({reserved_size_}, data_->dtype, data_->device); - std::memcpy(new_data->data, data_->data, current_size_ * DataType(data_->dtype).bytes()); + std::memcpy(new_data->data, data_->data, current_size_ * (((data_->dtype).bits + 7) / 8)); data_ = new_data; } static_cast(data_->data)[current_size_++] = value; @@ -382,7 +382,7 @@ class HostMemoryVector { reserved_size_ *= 2; } Tensor new_data = Tensor::Empty({reserved_size_}, data_->dtype, data_->device); - std::memcpy(new_data->data, data_->data, current_size_ * DataType(data_->dtype).bytes()); + std::memcpy(new_data->data, data_->data, current_size_ * (((data_->dtype).bits + 7) / 8)); data_ = new_data; } std::memcpy(static_cast(data_->data) + current_size_, values.data(), @@ -466,7 +466,7 @@ class PagedKVCacheAuxDataManager { device_(device), preferred_host_device_(preferred_host_device), copy_stream_(copy_stream) { - TVM_FFI_ICHECK(DataType(dtype_aux) == DataType::Int(32)); + TVM_FFI_ICHECK((dtype_aux == DLDataType{kDLInt, 32, 1})); } virtual ~PagedKVCacheAuxDataManager() = default; diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc index 8fc18c5c0722..30fbf77b9c7f 100644 --- a/src/runtime/vm/builtin.cc +++ b/src/runtime/vm/builtin.cc @@ -22,11 +22,11 @@ #include #include #include +#include #include #include #include #include -#include #include #include #include @@ -243,14 +243,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { void CheckTensorInfo(ffi::PackedArgs args, ffi::Any* rv) { ffi::AnyView arg = args[0]; int ndim = args[1].cast(); - DataType dtype; + DLDataType dtype; ffi::Optional err_ctx; if (args.size() == 3) { - dtype = DataType::Void(); + dtype = DLDataType{kDLOpaqueHandle, 0, 0}; err_ctx = args[2].cast>(); } else { - dtype = args[2].cast(); + dtype = args[2].cast(); err_ctx = args[3].cast>(); } @@ -264,10 +264,10 @@ void CheckTensorInfo(ffi::PackedArgs args, ffi::Any* rv) { << err_ctx.value_or("") << " expect Tensor with ndim " << ndim << " but get " << ptr->ndim; } - if (dtype != DataType::Void()) { - TVM_FFI_CHECK(DataType(ptr->dtype) == dtype, ValueError) + if (dtype != DLDataType{kDLOpaqueHandle, 0, 0}) { + TVM_FFI_CHECK(ptr->dtype == dtype, ValueError) << err_ctx.value_or("") << " expect Tensor with dtype " << dtype << " but get " - << DataType(ptr->dtype); + << ptr->dtype; } } @@ -301,23 +301,24 @@ TVM_FFI_STATIC_INIT_BLOCK() { /*! * \brief Builtin function to check if arg is PrimValue(dtype) * \param arg The input argument. - * \param dtype Expected dtype of the PrimValue. Can be DataType::Void() for unknown dtype. + * \param dtype Expected dtype of the PrimValue. Can be DLDataType{kDLOpaqueHandle, 0, 0} for + * unknown dtype. * \param err_ctx Additional context if error occurs. */ -void CheckPrimValueInfo(ffi::AnyView arg, DataType dtype, ffi::Optional err_ctx) { +void CheckPrimValueInfo(ffi::AnyView arg, DLDataType dtype, ffi::Optional err_ctx) { if (auto opt_obj = arg.as()) { TVM_FFI_THROW(TypeError) << err_ctx.value_or("") << ", expected dtype " << dtype << ", but received ObjectRef of type " << opt_obj.value()->GetTypeKey(); - } else if (dtype.is_bool()) { + } else if (((dtype).code == kDLBool)) { arg.cast(); - } else if (dtype.is_int()) { + } else if (((dtype).code == kDLInt)) { arg.cast(); - } else if (dtype.is_uint()) { + } else if (((dtype).code == kDLUInt)) { arg.cast(); - } else if (dtype.is_float()) { + } else if (((dtype).code == kDLFloat)) { arg.cast(); - } else if (dtype.is_handle()) { + } else if (dtype.code == kDLOpaqueHandle && !(dtype.bits == 0 && dtype.lanes == 0)) { arg.cast(); } else { TVM_FFI_THROW(TypeError) << err_ctx.value_or("") << ", unsupported dtype " << dtype; @@ -398,7 +399,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { Storage sobj = args[0].cast(); int64_t offset = args[1].cast(); ffi::Shape shape = args[2].cast(); - DataType dtype = args[3].cast(); + DLDataType dtype = args[3].cast(); if (args.size() == 5) { ffi::String scope = args[4].cast(); *rv = sobj->AllocTensorScoped(offset, shape, dtype, scope); diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 33ff1503f823..9e3a5f932309 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -101,8 +101,7 @@ std::string VMExecutable::Stats() const { oss << opt_int.value(); oss << ", "; } else if (auto opt_dtype = it.as()) { - DataType dtype(opt_dtype.value()); - oss << dtype; + oss << opt_dtype.value(); oss << ", "; } else { TVM_FFI_THROW(InternalError) << "Unsupported constant pool type " << it.GetTypeKey(); diff --git a/src/runtime/vm/lm_support.cc b/src/runtime/vm/lm_support.cc index 51b271441a27..2516e0d8a1af 100644 --- a/src/runtime/vm/lm_support.cc +++ b/src/runtime/vm/lm_support.cc @@ -362,7 +362,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { // NOTE this is a built-in highly related to LM so we put it here. int SampleTopPFromLogits(Tensor logits, double temperature, double top_p, double uniform_sample) { TVM_FFI_ICHECK(logits.IsContiguous()); - TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32)); + TVM_FFI_ICHECK((logits.DataType() == DLDataType{kDLFloat, 32, 1})); if (logits->device.device_type != kDLCPU) { logits = logits.CopyTo(DLDevice{kDLCPU, 0}); @@ -428,7 +428,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { int SampleTopPFromProb(Tensor prob, double top_p, double uniform_sample) { TVM_FFI_ICHECK(prob.IsContiguous()); - TVM_FFI_ICHECK(prob.DataType() == DataType::Float(32)); + TVM_FFI_ICHECK((prob.DataType() == DLDataType{kDLFloat, 32, 1})); if (prob->device.device_type != kDLCPU) { prob = prob.CopyTo(DLDevice{kDLCPU, 0}); @@ -543,7 +543,8 @@ Tensor MultinomialFromUniform(Tensor prob, Tensor uniform_sample) { int64_t vocab_size = prob->shape[prob->ndim - 1]; const float* pprob = static_cast(prob->data); const float* psample = static_cast(uniform_sample->data); - Tensor new_array = Tensor::Empty({batch_size, 1}, DataType::Int(64), uniform_sample->device); + Tensor new_array = + Tensor::Empty({batch_size, 1}, DLDataType{kDLInt, 64, 1}, uniform_sample->device); int64_t* parray = static_cast(new_array->data); for (int64_t i = 0; i < batch_size; ++i) { float cum_sum_prob = 0.0f; @@ -569,8 +570,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { void ApplyRepetitionPenalty(Tensor logits, Tensor token_ids, double penalty) { TVM_FFI_ICHECK(logits.IsContiguous()); TVM_FFI_ICHECK(token_ids.IsContiguous()); - TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - TVM_FFI_ICHECK(token_ids.DataType() == DataType::Int(32)) << "token ids must be int32!"; + TVM_FFI_ICHECK((logits.DataType() == DLDataType{kDLFloat, 32, 1})) + << "Logits data type is not float32!"; + TVM_FFI_ICHECK((token_ids.DataType() == DLDataType{kDLInt, 32, 1})) << "token ids must be int32!"; TVM_FFI_ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!"; TVM_FFI_ICHECK(token_ids->device.device_type == kDLCPU) << "token_ids device must be CPU!"; float* logits_raw_data = static_cast(logits->data); @@ -606,9 +608,11 @@ void ApplyPresenceAndFrequencyPenalty(Tensor logits, Tensor token_ids, Tensor to TVM_FFI_ICHECK(logits.IsContiguous()); TVM_FFI_ICHECK(token_ids.IsContiguous()); TVM_FFI_ICHECK(token_freqs.IsContiguous()); - TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - TVM_FFI_ICHECK(token_ids.DataType() == DataType::Int(32)) << "token ids must be int32!"; - TVM_FFI_ICHECK(token_freqs.DataType() == DataType::Int(32)) << "token freqs must be int32!"; + TVM_FFI_ICHECK((logits.DataType() == DLDataType{kDLFloat, 32, 1})) + << "Logits data type is not float32!"; + TVM_FFI_ICHECK((token_ids.DataType() == DLDataType{kDLInt, 32, 1})) << "token ids must be int32!"; + TVM_FFI_ICHECK((token_freqs.DataType() == DLDataType{kDLInt, 32, 1})) + << "token freqs must be int32!"; TVM_FFI_ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!"; TVM_FFI_ICHECK(token_ids->device.device_type == kDLCPU) << "token_ids device must be CPU!"; TVM_FFI_ICHECK(token_freqs->device.device_type == kDLCPU) << "token_ids device must be CPU!"; @@ -633,7 +637,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { // This is an inplace operation. void ApplySoftmaxWithTemperature(Tensor logits, double temperature) { TVM_FFI_ICHECK(logits.IsContiguous()); - TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; + TVM_FFI_ICHECK((logits.DataType() == DLDataType{kDLFloat, 32, 1})) + << "Logits data type is not float32!"; TVM_FFI_ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!"; int vocab_size = logits->shape[logits->ndim - 1]; float* logits_raw_data = static_cast(logits->data); diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index e5c4576e01c1..cd7920d6eef0 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -116,9 +116,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { const ffi::Optional rope_ext_factors_; /*! \brief The KV cache dtype. */ - const DataType kv_dtype_; + const DLDataType kv_dtype_; /*! \brief We fix int32 to be the index dtype of auxiliary data. */ - const DLDataType dtype_aux_ = DLDataType(DataType::Int(32, 1)); + const DLDataType dtype_aux_ = DLDataType{kDLInt, 32, 1}; /********************* Page Structures *********************/ @@ -326,7 +326,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { rotary_scale_(rotary_scale), rotary_theta_(rotary_theta), rope_ext_factors_(std::move(rope_ext_factors)), - kv_dtype_(DataType(dtype)), + kv_dtype_(dtype), reserved_num_seqs_(reserved_num_seqs), f_transpose_append_mha_(std::move(f_transpose_append_mha)), f_transpose_append_mla_(std::move(f_transpose_append_mla)), @@ -372,7 +372,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { pages_.push_back(nvshmem_pages_.CreateView( {num_total_pages_, 2, num_kv_heads_, page_size_, qk_head_dim_}, nvshmem_pages_->dtype, i * num_total_pages_ * 2 * num_kv_heads_ * page_size_ * qk_head_dim_ * - nvshmem_pages_.DataType().bytes())); + (nvshmem_pages_.DataType().bits + 7) / 8)); } const auto f_transfer_kv_ptr = tvm::ffi::Function::GetGlobal("nvshmem.KVTransfer"); @@ -450,9 +450,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { if (NeedKernelBeginForward()) { temp_int_attn_workspace_.push_back( - Tensor::Empty({kIntAttnWorkspaceByte}, DataType::UInt(8), device)); + Tensor::Empty({kIntAttnWorkspaceByte}, DLDataType{kDLUInt, 8, 1}, device)); temp_int_pinned_attn_workspace_.push_back(Tensor::Empty( - {kIntAttnWorkspaceByte}, DataType::UInt(8), GetPreferredHostDevice(device))); + {kIntAttnWorkspaceByte}, DLDataType{kDLUInt, 8, 1}, GetPreferredHostDevice(device))); } qo_indptr_on_depths_view_.push_back(Tensor()); page_indptr_on_depths_view_.push_back(Tensor()); @@ -470,11 +470,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Additional workspace for the "prefill with ragged kv" kernel. if (NeedKernelBeginForward()) { temp_int_attn_workspace_.push_back( - Tensor::Empty({kIntAttnWorkspaceByte}, DataType::UInt(8), device)); + Tensor::Empty({kIntAttnWorkspaceByte}, DLDataType{kDLUInt, 8, 1}, device)); temp_int_pinned_attn_workspace_.push_back(Tensor::Empty( - {kIntAttnWorkspaceByte}, DataType::UInt(8), GetPreferredHostDevice(device))); + {kIntAttnWorkspaceByte}, DLDataType{kDLUInt, 8, 1}, GetPreferredHostDevice(device))); temp_float_attn_workspace_ = - Tensor::Empty({kFloatAttnWorkspaceByte}, DataType::UInt(8), device); + Tensor::Empty({kFloatAttnWorkspaceByte}, DLDataType{kDLUInt, 8, 1}, device); } if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMHA) != attn_kinds_.end()) { @@ -488,9 +488,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { temp_attn_output_device_ = Tensor::Empty({prefill_chunk_size_, num_qo_heads, v_head_dim}, dtype, device); temp_attn_lse_device_ = - Tensor::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device); + Tensor::Empty({prefill_chunk_size_, num_qo_heads}, DLDataType{kDLFloat, 32, 1}, device); merged_attn_lse_device_ = - Tensor::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device); + Tensor::Empty({prefill_chunk_size_, num_qo_heads}, DLDataType{kDLFloat, 32, 1}, device); for (int64_t page_id = num_total_pages - 1; page_id >= 0; --page_id) { free_page_ids_.push_back(page_id); } diff --git a/src/runtime/vm/rnn_state.cc b/src/runtime/vm/rnn_state.cc index 9926b3d235e8..a38acf6e1cdf 100644 --- a/src/runtime/vm/rnn_state.cc +++ b/src/runtime/vm/rnn_state.cc @@ -83,7 +83,7 @@ class RNNStateImpObj : public RNNStateObj { const ffi::Array init_layer_value_; /*! \brief We fix int32 to be the index dtype of auxiliary data. */ - const DLDataType dtype_aux_ = DLDataType(DataType::Int(32, 1)); + const DLDataType dtype_aux_ = DLDataType{kDLInt, 32, 1}; /******************* Storage Structures *******************/ diff --git a/src/runtime/vm/tensor_cache_support.cc b/src/runtime/vm/tensor_cache_support.cc index ee77c5ddd8f0..62fd1a34c62f 100644 --- a/src/runtime/vm/tensor_cache_support.cc +++ b/src/runtime/vm/tensor_cache_support.cc @@ -64,7 +64,7 @@ TensorCacheMetadata::FileRecord::ParamRecord JSONAsParamRecord(const json::Objec TensorCacheMetadata::FileRecord::ParamRecord result; std::string dtype = json["dtype"].cast(); result.name = json["name"].cast(); - result.dtype = DataType(ffi::StringToDLDataType(dtype)); + result.dtype = ffi::StringToDLDataType(dtype); result.format = json["format"].cast(); result.nbytes = json["nbytes"].cast(); result.byte_offset = json["byteOffset"].cast(); @@ -154,7 +154,7 @@ void CopyTensorFromBytes(Tensor param, const void* data, size_t nbytes, Tensor TensorCacheMetadata::FileRecord::ParamRecord::Load( Device device, const std::string* raw_data, ffi::Optional* staging_buffer) const { Tensor arr = Tensor::Empty(shape, dtype, device); - if (dtype == DataType::Float(32) && format == "f32-to-bf16") { + if (dtype == DLDataType{kDLFloat, 32, 1} && format == "f32-to-bf16") { // decode bf16 to f32 std::vector buffer(nbytes / 2); std::vector decoded(nbytes / 2); diff --git a/src/s_tir/analysis/calculate_allocated_memory.cc b/src/s_tir/analysis/calculate_allocated_memory.cc index 51330a63e88b..41df4ee4bb8a 100644 --- a/src/s_tir/analysis/calculate_allocated_memory.cc +++ b/src/s_tir/analysis/calculate_allocated_memory.cc @@ -76,7 +76,7 @@ class AllocBufferCalculator : public StmtExprVisitor { break; } } - size *= op->buffer->dtype.bytes() * op->buffer->dtype.lanes(); + size *= ((op->buffer->dtype.bits() + 7) / 8) * op->buffer->dtype.lanes(); _current_size[storage_scope] += size; _max_size[storage_scope] = std::max(_current_size[storage_scope], _max_size[storage_scope]); StmtExprVisitor::VisitStmt_(op); diff --git a/src/s_tir/analysis/estimate_flops.cc b/src/s_tir/analysis/estimate_flops.cc index d77e715db1b6..bcde2d4b70bd 100644 --- a/src/s_tir/analysis/estimate_flops.cc +++ b/src/s_tir/analysis/estimate_flops.cc @@ -26,15 +26,13 @@ namespace tvm { namespace s_tir { using namespace tvm::tirx; -int32_t DataType2Int(const tvm::DataType& dtype) { +int32_t DataType2Int(DLDataType dtype) { static_assert(sizeof(DLDataType) == sizeof(int32_t), "Incorrect size of DLDataType"); union { DLDataType src; int32_t dst; } converter; - converter.src.code = dtype.code(); - converter.src.bits = dtype.bits(); - converter.src.lanes = dtype.lanes(); + converter.src = dtype; return converter.dst; } @@ -57,7 +55,7 @@ ffi::String Int2DataTypeStr(int32_t dtype) { struct TResult { TResult() = default; - void Add(const tvm::DataType& dtype) { data_[DataType2Int(dtype)] += 1; } + void Add(DLDataType dtype) { data_[DataType2Int(dtype)] += 1; } TResult operator+=(const TResult& rhs) { for (const auto& kv : rhs.data_) { @@ -98,7 +96,7 @@ class FlopEstimator : private ExprFunctor, TResult VisitExpr_(const Node* op) final { \ TResult result = VisitExpr(op->a); \ result += VisitExpr(op->b); \ - result.Add(op->dtype); \ + result.Add(op->ty()->dtype); \ return result; \ } TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(AddNode); diff --git a/src/s_tir/analysis/sblock_access_region_detector.cc b/src/s_tir/analysis/sblock_access_region_detector.cc index 18eef8e2fe01..9fa0a7b0b325 100644 --- a/src/s_tir/analysis/sblock_access_region_detector.cc +++ b/src/s_tir/analysis/sblock_access_region_detector.cc @@ -348,7 +348,7 @@ ffi::Array BlockReadWriteDetector::CollectRegions( const tvm::arith::IntSet& range = regions[i][j]; if (range.CanProveSinglePoint(ana_)) { PrimExpr min = range.min(); - region.push_back(Range::FromMinExtent(min, MakeConst(min.dtype(), 1))); + region.push_back(Range::FromMinExtent(min, MakeConst(min.ty(), 1))); } else { region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j]))); } diff --git a/src/s_tir/analysis/verify_gpu_code.cc b/src/s_tir/analysis/verify_gpu_code.cc index bd7b7c92ba7c..8155fd791e4b 100644 --- a/src/s_tir/analysis/verify_gpu_code.cc +++ b/src/s_tir/analysis/verify_gpu_code.cc @@ -76,19 +76,19 @@ class GPUCodeVerifier : public StmtExprVisitor { break; } } + PrimType dtype_ty = op->buffer->dtype; + TVM_FFI_ICHECK(!dtype_ty.IsScalableVector()) + << "Cannot verify GPU memory usage for scalable vector dtype " << dtype_ty; if (storage_scope.rank == runtime::StorageRank::kLocal) { - local_memory_per_block_ += - static_cast(const_size) * op->buffer->dtype.bytes() * op->buffer->dtype.lanes(); + local_memory_per_block_ += static_cast(const_size) * ElementBytes(dtype_ty); } else if (storage_scope.rank == runtime::StorageRank::kShared) { - shared_memory_per_block_ += - static_cast(const_size) * op->buffer->dtype.bytes() * op->buffer->dtype.lanes(); + shared_memory_per_block_ += static_cast(const_size) * ElementBytes(dtype_ty); } - if (op->buffer->dtype.is_vector()) { - if (static_cast(op->buffer->dtype.lanes() * op->buffer->dtype.bytes()) > - max_vector_bytes_) { + if (dtype_ty.IsFixedLengthVector()) { + if (ElementBytes(dtype_ty) > max_vector_bytes_) { std::stringstream s; - s << "Number of lanes (" << op->buffer->dtype.lanes() << ") times number of bytes (" - << op->buffer->dtype.bytes() << ") for dtype " << op->buffer->dtype + s << "Number of lanes (" << dtype_ty.lanes() << ") times number of bytes (" + << ((dtype_ty.bits() + 7) / 8) << ") for dtype " << dtype_ty << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")"; errors_.push_back(s.str()); } @@ -202,11 +202,12 @@ class GPUCodeVerifier : public StmtExprVisitor { void CheckBufferIndicesVectorizable(const ffi::Array indices) { for (const auto index : indices) { if (const auto* ramp = index.as()) { - if (!is_one(ramp->stride) && - static_cast(ramp->dtype.lanes() * ramp->dtype.bytes()) > max_vector_bytes_) { + PrimType ramp_ty = ramp->ty(); + if (!is_one(ramp->stride) && ramp_ty.IsFixedLengthVector() && + ElementBytes(ramp_ty) > max_vector_bytes_) { std::stringstream s; - s << "Number of lanes (" << ramp->dtype.lanes() << ") times number of bytes (" - << ramp->dtype.bytes() << ") for dtype " << ramp->dtype + s << "Number of lanes (" << ramp_ty.lanes() << ") times number of bytes (" + << ((ramp_ty.bits() + 7) / 8) << ") for dtype " << ramp_ty << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")"; errors_.push_back(s.str()); } @@ -215,11 +216,12 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitExpr_(const CastNode* op) { - if (op->dtype.is_vector()) { - if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { + PrimType op_ty = op->ty(); + if (op_ty.IsFixedLengthVector()) { + if (ElementBytes(op_ty) > max_vector_bytes_) { std::stringstream s; - s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" - << op->dtype.bytes() << ") for dtype " << op->dtype + s << "Number of lanes (" << op_ty.lanes() << ") times number of bytes (" + << ((op_ty.bits() + 7) / 8) << ") for dtype " << op_ty << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")"; errors_.push_back(s.str()); } @@ -228,11 +230,12 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode* op) { - if (op->dtype.is_vector()) { - if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { + PrimType op_ty = op->ty(); + if (op_ty.IsFixedLengthVector()) { + if (ElementBytes(op_ty) > max_vector_bytes_) { std::stringstream s; - s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" - << op->dtype.bytes() << ") for dtype " << op->dtype + s << "Number of lanes (" << op_ty.lanes() << ") times number of bytes (" + << ((op_ty.bits() + 7) / 8) << ") for dtype " << op_ty << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")"; errors_.push_back(s.str()); } @@ -242,12 +245,12 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitStmt_(const BufferStoreNode* op) { - if (op->value->dtype.is_vector()) { - if (static_cast(op->value->dtype.lanes() * op->value->dtype.bytes()) > - max_vector_bytes_) { + PrimType value_ty = op->value.ty(); + if (value_ty.IsFixedLengthVector()) { + if (ElementBytes(value_ty) > max_vector_bytes_) { std::stringstream s; - s << "Number of lanes (" << op->value->dtype.lanes() << ") times number of bytes (" - << op->value->dtype.bytes() << ") for dtype " << op->value->dtype + s << "Number of lanes (" << value_ty.lanes() << ") times number of bytes (" + << ((value_ty.bits() + 7) / 8) << ") for dtype " << value_ty << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")"; errors_.push_back(s.str()); } @@ -277,6 +280,8 @@ class GPUCodeVerifier : public StmtExprVisitor { std::vector errors_; + static size_t ElementBytes(const PrimType& ty) { return ty.StorageBytes(); } + void Reset_() { local_memory_per_block_ = 0; shared_memory_per_block_ = 0; diff --git a/src/s_tir/backend/adreno/inject_texture_alloc.cc b/src/s_tir/backend/adreno/inject_texture_alloc.cc index e4e7c322ef55..5b6aeda19362 100644 --- a/src/s_tir/backend/adreno/inject_texture_alloc.cc +++ b/src/s_tir/backend/adreno/inject_texture_alloc.cc @@ -79,11 +79,11 @@ class TextureAllocInjector : public arith::IRMutatorWithAnalyzer { ffi::Array args; args.push_back(StringImm(storage_scope)); args.push_back(IntImm::Int64(3)); - args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), + args.push_back(Call(PrimType::Handle(), builtin::tvm_stack_make_shape(), {texture.width, texture.height, texture.depth})); args.push_back(IntImm::Int64(channel_size)); stmt = Bind(op->buffer->data, - Call(op->buffer->data.dtype(), builtin::nd_mem_alloc_with_scope(), args)); + Call(op->buffer->data.ty(), builtin::nd_mem_alloc_with_scope(), args)); } return stmt; } diff --git a/src/s_tir/backend/adreno/texture_flatten.cc b/src/s_tir/backend/adreno/texture_flatten.cc index 0dd939ad817a..d4297e42e4d2 100644 --- a/src/s_tir/backend/adreno/texture_flatten.cc +++ b/src/s_tir/backend/adreno/texture_flatten.cc @@ -100,7 +100,7 @@ class TextureFlattener : public TextureLoweringBase { if (IsTextureStorage(storage_scope)) { ffi::Array args = GetTextureAccessArgs(op, op->buffer); args.push_back(op->value); - stmt = Evaluate(Call(args[0]->dtype, builtin::texture2d_store(), args)); + stmt = Evaluate(Call(args[0].ty(), builtin::texture2d_store(), args)); } return stmt; @@ -147,7 +147,7 @@ class TextureFlattener : public TextureLoweringBase { PrimExpr col_offset = SimplifyOffset(col_dims, col_indices); PrimExpr depth_offset = SimplifyOffset(depth_dims, depth_indices); PrimExpr channel_size = IntImm( - DataType::Int(32, 1), *tirx::as_const_int(buffer->shape.back()) * buffer->dtype.bits()); + PrimType::Int(32, 1), *tirx::as_const_int(buffer->shape.back()) * buffer->dtype.bits()); args.push_back(row_offset); args.push_back(col_offset); args.push_back(depth_offset); diff --git a/src/s_tir/data_layout.cc b/src/s_tir/data_layout.cc index 787386c8ccb9..6fa2db0206e4 100644 --- a/src/s_tir/data_layout.cc +++ b/src/s_tir/data_layout.cc @@ -22,10 +22,10 @@ * \brief Data SLayout expression. */ #include +#include #include #include #include -#include #include #include #include @@ -113,8 +113,9 @@ SLayout::SLayout(const ffi::Array& axes) { data_ = std::move(node); } -SLayout::SLayout(const std::string& name, DataType dtype) { // NOLINT(*) - TVM_FFI_CHECK(dtype.is_int(), TypeError) << "The input dtype should be integer type"; +SLayout::SLayout(const std::string& name, PrimType index_ty) { // NOLINT(*) + TVM_FFI_CHECK(index_ty.code() == DLDataTypeCode::kDLInt, TypeError) + << "The input dtype should be integer type"; if (name == "__undef__") return; auto node = ffi::make_object(); @@ -131,8 +132,8 @@ SLayout::SLayout(const std::string& name, DataType dtype) { // NOLINT(*) if (c >= 'A' && c <= 'Z') { TVM_FFI_ICHECK_EQ(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor << " before dimension " << c; - IterVar axis(Range(IntImm(dtype, 0), Var(std::string(1, c), dtype)), - Var(std::string(1, c), dtype), tirx::kDataPar); + IterVar axis(Range(IntImm(index_ty, 0), Var(std::string(1, c), index_ty)), + Var(std::string(1, c), index_ty), tirx::kDataPar); if (!in_packing) { node->axes.push_back(axis); } else { @@ -143,7 +144,7 @@ SLayout::SLayout(const std::string& name, DataType dtype) { // NOLINT(*) << factor << " for dimension " << c; std::stringstream name; name << factor << c; - IterVar axis(Range(IntImm(dtype, 0), IntImm(dtype, factor)), Var(name.str(), dtype), + IterVar axis(Range(IntImm(index_ty, 0), IntImm(index_ty, factor)), Var(name.str(), index_ty), tirx::kDataPar); if (!in_packing) { node->axes.push_back(axis); @@ -174,8 +175,8 @@ SLayout::SLayout(const std::string& name, DataType dtype) { // NOLINT(*) extent = extent * factor->value; } std::string grouped_name = ss.str(); - IterVar grouped_axis(Range(IntImm(dtype, 0), IntImm(dtype, extent)), Var(grouped_name, dtype), - tirx::kDataPar); + IterVar grouped_axis(Range(IntImm(index_ty, 0), IntImm(index_ty, extent)), + Var(grouped_name, index_ty), tirx::kDataPar); node->axes.push_back(grouped_axis); in_packing = false; @@ -231,21 +232,21 @@ ffi::Array SLayout::UnpackIterVar(IterVar packed_iter) { int64_t factor = 0, final_factor = 1; std::string name(packed_iter->var->name_hint.c_str()); - DataType dtype = packed_iter->var.dtype(); + PrimType index_ty = packed_iter->var.ty(); for (auto ch : name) { if (ch >= '0' && ch <= '9') { factor = factor * 10 + (ch - '0'); } else if (ch >= 'a' && ch <= 'z') { TVM_FFI_ICHECK(factor != 0) << "Invalid Factor Size"; - result.push_back(IterVar(Range(IntImm(dtype, 0), IntImm(dtype, factor)), - Var(std::string(1, ch), dtype), tirx::kDataPar)); + result.push_back(IterVar(Range(IntImm(index_ty, 0), IntImm(index_ty, factor)), + Var(std::string(1, ch), index_ty), tirx::kDataPar)); final_factor *= factor; factor = 0; } else if (ch >= 'A' && ch <= 'Z') { TVM_FFI_ICHECK(factor == 0) << "Can't have non-zero factors for primal axis"; - result.push_back(IterVar(Range(IntImm(dtype, 0), Var(std::string(1, ch), dtype)), - Var(std::string(1, ch), dtype), tirx::kDataPar)); + result.push_back(IterVar(Range(IntImm(index_ty, 0), Var(std::string(1, ch), index_ty)), + Var(std::string(1, ch), index_ty), tirx::kDataPar)); } } @@ -256,7 +257,7 @@ IterVar SLayout::PackIterVar(ffi::Array iter_vars) { std::stringstream name; size_t extent = 1; - DataType dtype = iter_vars[0]->dom->extent.as().value()->dtype; + PrimType index_ty = iter_vars[0]->dom->extent.as().value().ty(); for (auto itvar : iter_vars) { TVM_FFI_ICHECK(itvar->dom->extent.as()) << "Packed Axis can contain only Subordinate Axes"; @@ -264,7 +265,7 @@ IterVar SLayout::PackIterVar(ffi::Array iter_vars) { extent = extent * itvar->dom->extent.as().value()->value; } - return IterVar(Range(IntImm(dtype, 0), IntImm(dtype, extent)), Var(name.str(), dtype), + return IterVar(Range(IntImm(index_ty, 0), IntImm(index_ty, extent)), Var(name.str(), index_ty), tirx::kDataPar); } @@ -357,7 +358,8 @@ inline bool GetStoreRule(ffi::Array* index_rule, ffi::Array* if (axis == sub_axis) { const auto* sub_extent = inter_unpacked_axes[l]->dom->extent.as(); TVM_FFI_ICHECK(sub_extent) << "Expected Integer Extents for Offset Calculation"; - factor_ij = factor_ij * IntImm(sub_extent->dtype, sub_extent->value); + factor_ij = + factor_ij * IntImm(ffi::GetRef(sub_extent).ty(), sub_extent->value); } } } @@ -498,11 +500,11 @@ inline ffi::Array TransformShape(const ffi::Array& src_shape << ", get " << orig_shape; } } - bind_map[orig_axis->var.get()] = IntImm(orig_axis->var->dtype, 0); + bind_map[orig_axis->var.get()] = IntImm(orig_axis->var.ty(), 0); } else { - bind_map[orig_axis->var.get()] = orig_axis->var->dtype == orig_shape->dtype + bind_map[orig_axis->var.get()] = orig_axis->var.ty()->dtype == orig_shape.ty()->dtype ? orig_shape - : cast(orig_axis->var->dtype, orig_shape); + : cast(orig_axis->var.ty(), orig_shape); } } // infer the target shape, @@ -583,7 +585,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("s_tir.SLayout", [](std::string name, DataType dtype) { return SLayout(name, dtype); }) + .def("s_tir.SLayout", [](std::string name, PrimType dtype) { return SLayout(name, dtype); }) .def("s_tir.SLayoutIndexOf", [](SLayout layout, std::string axis) -> int { return layout.IndexOf(axis); }) .def("s_tir.SLayoutFactorOf", diff --git a/src/s_tir/meta_schedule/arg_info.cc b/src/s_tir/meta_schedule/arg_info.cc index dc452b370037..73fa41773883 100644 --- a/src/s_tir/meta_schedule/arg_info.cc +++ b/src/s_tir/meta_schedule/arg_info.cc @@ -98,7 +98,7 @@ ffi::Array ArgInfo::FromPrimFunc(const tirx::PrimFunc& func) { for (const tirx::Var& arg : func->params) { if (ffi::Optional _buffer = func->buffer_map.Get(arg)) { tirx::Buffer buffer = _buffer.value(); - result.push_back(TensorInfo(/*dtype=*/buffer->dtype, + result.push_back(TensorInfo(/*dtype=*/buffer->dtype->dtype, /*shape=*/AsVector(buffer->shape))); } else { TVM_FFI_THROW(ValueError) << "Unsupported argument type: " << arg; @@ -117,7 +117,7 @@ ffi::Array ArgInfo::FromEntryFunc(const IRModule& mod, bool remove_prep /******** TensorInfo ********/ -TensorInfo::TensorInfo(runtime::DataType dtype, ffi::Shape shape) { +TensorInfo::TensorInfo(DLDataType dtype, ffi::Shape shape) { ffi::ObjectPtr n = ffi::make_object(); n->dtype = dtype; n->shape = shape; @@ -150,7 +150,7 @@ TensorInfo TensorInfo::FromJSON(const ffi::ObjectRef& json_obj) { } std::vector s; std::transform(shape.begin(), shape.end(), std::back_inserter(s), [](int64_t i) { return i; }); - return TensorInfo(DataType(dtype), ffi::Shape(s.begin(), s.end())); + return TensorInfo(dtype, ffi::Shape(s.begin(), s.end())); } /******** Repr ********/ @@ -182,10 +182,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("s_tir.meta_schedule.ArgInfoFromPrimFunc", ArgInfo::FromPrimFunc) .def("s_tir.meta_schedule.ArgInfoFromEntryFunc", ArgInfo::FromEntryFunc) .def("s_tir.meta_schedule.ArgInfoFromJSON", ArgInfo::FromJSON) - .def("s_tir.meta_schedule.TensorInfo", - [](runtime::DataType dtype, ffi::Shape shape) -> TensorInfo { - return TensorInfo(dtype, shape); - }); + .def("s_tir.meta_schedule.TensorInfo", [](DLDataType dtype, ffi::Shape shape) -> TensorInfo { + return TensorInfo(dtype, shape); + }); } } // namespace meta_schedule diff --git a/src/s_tir/meta_schedule/database/database_utils.cc b/src/s_tir/meta_schedule/database/database_utils.cc index ea1473ae6500..826c38c8d1b0 100644 --- a/src/s_tir/meta_schedule/database/database_utils.cc +++ b/src/s_tir/meta_schedule/database/database_utils.cc @@ -32,7 +32,9 @@ void JSONDumps(Any json_obj, std::ostringstream& os) { os << "null"; } else if (auto opt_int_imm = json_obj.try_cast()) { IntImm int_imm = *std::move(opt_int_imm); - if (int_imm->dtype == DataType::Bool()) { + PrimType int_ty = int_imm.ty(); + if (int_ty.MatchesElementType(DLDataTypeCode::kDLBool, 8) && !int_ty.IsScalableVector() && + !int_ty.IsFixedLengthVector()) { if (int_imm->value) { os << "true"; } else { @@ -154,7 +156,6 @@ class JSONTokenizer { bool NextFalse() { return NextLiteral("false", 5); } bool NextNumber(Token* token) { - using runtime::DataType; bool is_float = false; const char* st = cur_; for (; cur_ != end_; ++cur_) { diff --git a/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc b/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc index f0e3aa897cdd..2f87217db065 100644 --- a/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc @@ -273,12 +273,12 @@ Pass SimplifyForFeatureExtraction() { HasBufferLoad(node->condition)) { return ffi::GetRef