[REFACTOR][IR] Use PrimType for compiler dtypes#19875
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the codebase to transition from using DataType to PrimType (backed by DLDataType), unifying the type system across IR dialects and improving type safety. The reviewer provided valuable feedback on leveraging the new PrimType APIs (such as bits(), lanes(), and IsScalableVector()) to avoid manual DLDataType unpacking and arithmetic, particularly in tensor allocation, memory view handling, static memory planning, and GPU code verification. Additionally, a bug was identified in src/relax/transform/utils.h where kDLBool was incorrectly checked against a bit width of 1 instead of the standard 8 bits.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
20c4a00 to
3e75c55
Compare
Use PrimType as the compiler-facing dtype/type carrier so primitive expression dtype information is unified with Expr.ty instead of flowing through a dedicated dtype path. This keeps compiler IR type information in the type system and leaves room for future expression type annotations. Use raw DLDataType at runtime, ABI, storage-helper, and dtype-valued attr boundaries where a plain DLPack dtype value is the real interface. Keep the PrimType API minimal and hot-path friendly with value equality, matching helpers, documented factories, and cached common constructors. Update TIRX, TE, TOPI, Relax, codegen, Python bindings, and tests to follow the compiler PrimType versus runtime DLDataType boundary.
3e75c55 to
74fba5e
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request refactors the TVM codebase by replacing the usage of tvm::DataType with tvm::PrimType and DLDataType across C++ and Python files to better align with DLPack. The code review feedback is highly consistent, recommending the use of the more idiomatic MatchesCode helper on PrimType instead of directly comparing raw data type codes (such as kDLFloat, kDLInt, and kDLBool). Additionally, it is suggested to simplify redundant checks by passing multiple type codes to a single variadic MatchesCode call.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
|
/gemini Please run another review pass on latest head b8c7505. Since the last pass, I applied the Gemini MatchesCode cleanups, fixed the Python PrimType-to-runtime-dtype boundary in Relax gradient constants and TE NumPy test uses, reran full LLVM build plus focused Relax/TE/TIRX validation, and resolved the previous review threads that are addressed. |
There was a problem hiding this comment.
Code Review
This pull request refactors TVM's type system by replacing the runtime-specific DataType with a unified compile-time PrimType (backed by DLDataType) across the compiler, runtime, and Python bindings. This extensive refactoring updates buffer declarations, expression nodes, and codegen backends to use PrimType or DLDataType directly. Feedback on the changes highlights a potential division-by-zero bug in Hexagon's GetVectorBytes for sub-byte types, a potential AttributeError in the Python PrimExpr.dtype property when self.ty is null or a pointer, and several opportunities to simplify type checks using the newly introduced IsScalar() helper and PrimType equality operators.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| TVM_FFI_INLINE int GetVectorBytes(const PrimType& dtype) { | ||
| TVM_FFI_ICHECK(dtype.IsFixedLengthVector() || dtype.IsScalar()); | ||
| return dtype.bits() * dtype.lanes() / 8; | ||
| } |
There was a problem hiding this comment.
In GetVectorBytes, using dtype.bits() * dtype.lanes() / 8 can return 0 for sub-byte types (such as 4-bit integers or booleans). This will lead to a division-by-zero error in VectorLookupLoad where native_vector_bytes / GetVectorBytes(buffer_type) is calculated. Using (dtype.bits() * dtype.lanes() + 7) / 8 safely computes the byte size and prevents this potential compiler crash.
TVM_FFI_INLINE int GetVectorBytes(const PrimType& dtype) {
TVM_FFI_ICHECK(dtype.IsFixedLengthVector() || dtype.IsScalar());
return (dtype.bits() * dtype.lanes() + 7) / 8;
}| @property | ||
| def dtype(self): | ||
| """Return the runtime dtype represented by this expression's PrimType.""" | ||
| return self.ty.dtype |
There was a problem hiding this comment.
The dtype property on PrimExpr is implemented as return self.ty.dtype. However, self.ty can be None or a PointerType (which does not have a dtype attribute). This will raise an AttributeError when accessing .dtype on handle variables or un-typed expressions. Adding safety checks to handle None and PointerType (returning "handle" for pointers) ensures backward compatibility and robustness.
| @property | |
| def dtype(self): | |
| """Return the runtime dtype represented by this expression's PrimType.""" | |
| return self.ty.dtype | |
| @property | |
| def dtype(self): | |
| """Return the runtime dtype represented by this expression's PrimType.""" | |
| if self.ty is None: | |
| return None | |
| if hasattr(self.ty, "dtype"): | |
| return self.ty.dtype | |
| return "handle" |
| TVM_FFI_ICHECK(!type1.IsScalableVector() && !type2.IsScalableVector()); | ||
| TVM_FFI_ICHECK_EQ(type1.lanes(), 1); | ||
| TVM_FFI_ICHECK_EQ(type2.lanes(), 1); |
There was a problem hiding this comment.
The checks to ensure type1 and type2 are scalar can be simplified and made more robust by using the newly introduced IsScalar() helper method on PrimType.
| TVM_FFI_ICHECK(!type1.IsScalableVector() && !type2.IsScalableVector()); | |
| TVM_FFI_ICHECK_EQ(type1.lanes(), 1); | |
| TVM_FFI_ICHECK_EQ(type2.lanes(), 1); | |
| TVM_FFI_ICHECK(type1.IsScalar() && type2.IsScalar()); |
| auto cast_if_needed = [](PrimType to_type, PrimExpr expr) { | ||
| return to_type->dtype == expr.ty()->dtype ? expr : cast(to_type, expr); | ||
| }; |
There was a problem hiding this comment.
Since PrimType has operator== defined, the check to_type->dtype == expr.ty()->dtype can be simplified to to_type == expr.ty().
| auto cast_if_needed = [](PrimType to_type, PrimExpr expr) { | |
| return to_type->dtype == expr.ty()->dtype ? expr : cast(to_type, expr); | |
| }; | |
| auto cast_if_needed = [](PrimType to_type, PrimExpr expr) { | |
| return to_type == expr.ty() ? expr : cast(to_type, expr); | |
| }; |
| TVM_FFI_DCHECK(expr->BaseExprNode::ty.defined()); | ||
| const auto* prim_ty = expr->BaseExprNode::ty.as<PrimTypeNode>(); | ||
| TVM_FFI_DCHECK(prim_ty != nullptr); | ||
| return IsIndexType(prim_ty->dtype); |
There was a problem hiding this comment.
The manual checks on expr->BaseExprNode::ty are redundant because PrimExprNode::ty() already implements these checks and returns PrimType. We can simplify this to a single call to expr->ty().
| TVM_FFI_DCHECK(expr->BaseExprNode::ty.defined()); | |
| const auto* prim_ty = expr->BaseExprNode::ty.as<PrimTypeNode>(); | |
| TVM_FFI_DCHECK(prim_ty != nullptr); | |
| return IsIndexType(prim_ty->dtype); | |
| return IsIndexType(expr->ty()->dtype); |
|
|
||
| bool IsValidScalar(const PrimExpr& expr) const { | ||
| return expr.defined() && expr.dtype().is_scalar(); | ||
| if (!expr.defined()) return false; |
|
/gemini Please run another review pass on latest head 4b0c660. Since the last pass, I fixed CI compile failures in CBLAS/CUBLAS/HIPBLAS/MKL/DNNL/CUDNN/random and CUDA runtime paths after DataType helper removal: contrib backends now use direct raw DLDataType equality instead of TypeMatch/TypeEqual/runtime namespace helpers, and CUDA tensor-map runtime code uses raw DLDataType fields. Local validation: changed-file pre-commit passed, LLVM-enabled ninja build passed, and syntax-only checks passed for the reported disabled-backend sources with local CUDA/CUDNN/CBLAS/DNNL headers. |
There was a problem hiding this comment.
Code Review
This pull request performs a major refactoring of TVM's type system by removing the tvm::runtime::DataType wrapper class and replacing its usage with tvm::PrimType and raw DLDataType across the entire codebase, including TIR, Relax, TOPI, and various hardware backends (CUDA, Vulkan, Metal, WebGPU, Hexagon, and Trainium). PrimExpr and PrimExprNode now expose their types via PrimType (ty()) instead of DataType (dtype()), and corresponding Python bindings have been updated to reflect these changes. Since no review comments were provided, there is no specific feedback to address, and the changes appear to successfully unify and simplify the type representation.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
This PR removes the old runtime
DataTypewrapper as the compiler-facing dtype abstraction and routes dtype usage by boundary.Rationale:
PrimTypeis the compiler IR type for primitive expressions, so using it as the dtype/type carrier unifies dtype information withExpr.ty.DLDataType, where a plain DLPack dtype value is the real boundary object.Migration guide:
PrimTypewhen code reasons about compiler expression types, tensor element compiler types, or constructs aPrimExpr/compiler type.expr.ty(),ExprOp.expr_ty(), or TE tensor elementdtypewhere possible instead of rebuilding a type from dtype text.DLDataTypefor runtime constants, ABI paths, dtype-valued attrs, and storage/runtime helper logic.PrimTypeequality,MatchesCode(...),MatchesElementType(...), andWithCode(...)over local wrappers or string dtype checks.GetDataTypeandoutput_dtypewhere they are API terminology, but align their value type with the compiler/runtime boundary.Validation:
git diff --checkpassedPrimType::IsPredicate,.IsBool(,.IsInt(,.IsUInt(, removedtype.cchelper names, orruntime/data_type.hreferences in checked pathsninja -C build -j$(nproc)completed and linkedlib/libtvm_compiler.soThe branch is intentionally stacked on #19874 (
2bdedc93aa) so CI has the needed base fix.