Skip to content

Commit 20c4a00

Browse files
committed
[REFACTOR][IR] Use PrimType for compiler dtypes
Use PrimType as the compiler-facing dtype/type carrier so primitive expression dtype information is unified with Expr.ty instead of flowing through a dedicated dtype path. This keeps compiler IR type information in the type system and leaves room for future expression type annotations. Use raw DLDataType at runtime, ABI, storage-helper, and dtype-valued attr boundaries where a plain DLPack dtype value is the real interface. Keep the PrimType API minimal and hot-path friendly with value equality, matching helpers, documented factories, and cached common constructors. Update TIRX, TE, TOPI, Relax, codegen, Python bindings, and tests to follow the compiler PrimType versus runtime DLDataType boundary.
1 parent 2bdedc9 commit 20c4a00

420 files changed

Lines changed: 6461 additions & 5737 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

include/tvm/ir/base_expr.h

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/ir/base_expr.h
22+
* \brief Base expression and primitive type nodes.
23+
*/
24+
#ifndef TVM_IR_BASE_EXPR_H_
25+
#define TVM_IR_BASE_EXPR_H_
26+
27+
#include <tvm/ffi/cast.h>
28+
#include <tvm/ffi/dtype.h>
29+
#include <tvm/ffi/reflection/registry.h>
30+
#include <tvm/ir/source_map.h>
31+
32+
#include <cstdint>
33+
34+
namespace tvm {
35+
36+
/*!
37+
* \brief Type is the base type of all types.
38+
*
39+
* TVM's type system contains following subclasses:
40+
*
41+
* - PrimType: type of primitive type values used in the low-level IR.
42+
* - FuncType: type of a function.
43+
* - TensorType: type of certain Tensor values in the expression.
44+
*
45+
* There are also advanced types to support generic(polymorphic types).
46+
* \sa Type
47+
*/
48+
class TypeNode : public ffi::Object {
49+
public:
50+
/*!
51+
* \brief Span that points to the original source code.
52+
* Reserved debug information.
53+
*/
54+
mutable Span span;
55+
56+
static void RegisterReflection() {
57+
namespace refl = tvm::ffi::reflection;
58+
// span do not participate in structural equal and hash.
59+
refl::ObjectDef<TypeNode>().def_ro("span", &TypeNode::span, refl::DefaultValue(Span()),
60+
refl::AttachFieldFlag::SEqHashIgnore());
61+
}
62+
63+
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
64+
65+
static constexpr const uint32_t _type_child_slots = 14;
66+
TVM_FFI_DECLARE_OBJECT_INFO("ir.Type", TypeNode, ffi::Object);
67+
};
68+
69+
/*!
70+
* \brief Managed reference to TypeNode.
71+
* \sa TypeNode
72+
*/
73+
class Type : public ffi::ObjectRef {
74+
public:
75+
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Type, ffi::ObjectRef, TypeNode);
76+
};
77+
78+
/*!
79+
* \brief Primitive data types used in the low-level IR.
80+
*
81+
* PrimType represents POD-values and handles that are
82+
* not automatically managed by the runtime.
83+
*
84+
* \sa PrimType
85+
*/
86+
class PrimTypeNode final : public TypeNode {
87+
public:
88+
/*!
89+
* \brief The raw DLPack dtype represented by this primitive type.
90+
*/
91+
DLDataType dtype;
92+
93+
static void RegisterReflection() {
94+
namespace refl = tvm::ffi::reflection;
95+
refl::ObjectDef<PrimTypeNode>().def_ro("dtype", &PrimTypeNode::dtype);
96+
}
97+
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.PrimType", PrimTypeNode, TypeNode);
98+
};
99+
100+
/*
101+
* \brief Managed reference to PrimTypeNode.
102+
* \sa PrimTypeNode
103+
*/
104+
class PrimType final : public Type {
105+
public:
106+
/*!
107+
* \brief Construct from a raw DLPack dtype.
108+
* \param dtype The corresponding DLPack dtype.
109+
*/
110+
TVM_DLL explicit PrimType(DLDataType dtype);
111+
112+
/*!
113+
* \brief Construct from DLPack dtype fields.
114+
* \param code The DLPack dtype code.
115+
* \param bits The scalar bit width.
116+
* \param lanes The fixed lane count.
117+
*/
118+
TVM_DLL PrimType(DLDataTypeCode code, int bits, int lanes = 1);
119+
120+
/*! \brief Construct a signed integer type with fixed lanes. */
121+
TVM_DLL static PrimType Int(int bits, int lanes = 1);
122+
/*! \brief Construct an unsigned integer type with fixed lanes. */
123+
TVM_DLL static PrimType UInt(int bits, int lanes = 1);
124+
/*! \brief Construct a floating-point type with fixed lanes. */
125+
TVM_DLL static PrimType Float(int bits, int lanes = 1);
126+
/*! \brief Construct a bfloat type with fixed lanes. */
127+
TVM_DLL static PrimType BFloat(int bits, int lanes = 1);
128+
/*! \brief Construct a boolean type with fixed lanes. */
129+
TVM_DLL static PrimType Bool(int lanes = 1);
130+
/*! \brief Construct an opaque handle type. */
131+
TVM_DLL static PrimType Handle(int bits = 64, int lanes = 1);
132+
/*! \brief Construct the void sentinel type, encoded as handle(0, 0). */
133+
TVM_DLL static PrimType Void();
134+
/*!
135+
* \brief Construct a scalable vector type.
136+
* \param code The DLPack dtype code.
137+
* \param bits The scalar bit width.
138+
* \param lanes The positive vscale factor to encode in the DLPack lane field.
139+
*/
140+
TVM_DLL static PrimType ScalableVector(DLDataTypeCode code, int bits, int lanes);
141+
142+
/*! \return The DLPack dtype code. */
143+
TVM_FFI_INLINE DLDataTypeCode code() const {
144+
return static_cast<DLDataTypeCode>(static_cast<int>(get()->dtype.code));
145+
}
146+
147+
/*! \return The scalar bit width. */
148+
TVM_FFI_INLINE int32_t bits() const { return get()->dtype.bits; }
149+
150+
/*!
151+
* \return The fixed lane count.
152+
* \note Throws on scalable vector types, where the encoded lane field stores a vscale factor.
153+
*/
154+
TVM_FFI_INLINE int32_t lanes() const {
155+
int16_t encoded_lanes = static_cast<int16_t>(get()->dtype.lanes);
156+
if (TVM_FFI_PREDICT_FALSE(encoded_lanes < 0)) {
157+
TVM_FFI_THROW(InternalError)
158+
<< "Can't fetch the lanes of a scalable vector at a compile time.";
159+
}
160+
return encoded_lanes;
161+
}
162+
163+
/*!
164+
* \brief Check the scalar element code and bit width.
165+
* \note Lane count and scalable-vector encoding are intentionally ignored.
166+
*/
167+
TVM_FFI_INLINE bool MatchesElementType(DLDataTypeCode code, int bits) const {
168+
DLDataType dtype = get()->dtype;
169+
return dtype.code == static_cast<uint8_t>(code) && dtype.bits == bits;
170+
}
171+
172+
/*!
173+
* \brief Check whether the dtype code matches any of the provided DLPack codes.
174+
* \note Bit width and lanes are intentionally ignored.
175+
*/
176+
template <typename... Codes>
177+
TVM_FFI_INLINE bool MatchesCode(Codes... codes) const {
178+
uint8_t dtype_code = get()->dtype.code;
179+
return ((dtype_code == static_cast<uint8_t>(codes)) || ...);
180+
}
181+
182+
/*! \brief Whether this type is a scalar, excluding fixed and scalable vectors. */
183+
TVM_FFI_INLINE bool IsScalar() const {
184+
int16_t encoded_lanes = static_cast<int16_t>(get()->dtype.lanes);
185+
return encoded_lanes == 1;
186+
}
187+
188+
/*! \brief Whether this type is the void sentinel `handle(0, 0)`. */
189+
TVM_FFI_INLINE bool IsVoid() const {
190+
DLDataType dtype = get()->dtype;
191+
return dtype.code == static_cast<uint8_t>(DLDataTypeCode::kDLOpaqueHandle) && dtype.bits == 0 &&
192+
static_cast<int16_t>(dtype.lanes) == 0;
193+
}
194+
195+
/*! \brief Whether this type is an opaque handle, excluding the void sentinel. */
196+
TVM_FFI_INLINE bool IsHandle() const {
197+
return this->code() == DLDataTypeCode::kDLOpaqueHandle && !this->IsVoid();
198+
}
199+
200+
/*! \brief Whether this type is a scalable vector. */
201+
TVM_FFI_INLINE bool IsScalableVector() const {
202+
return static_cast<int16_t>(get()->dtype.lanes) < -1;
203+
}
204+
205+
/*! \brief Whether this type is a fixed-length vector. */
206+
TVM_FFI_INLINE bool IsFixedLengthVector() const {
207+
return static_cast<int16_t>(get()->dtype.lanes) > 1;
208+
}
209+
210+
/*! \brief Return the same type with a different dtype code, preserving bits and lanes. */
211+
TVM_FFI_INLINE PrimType WithCode(DLDataTypeCode code) const {
212+
DLDataType dtype = get()->dtype;
213+
int16_t encoded_lanes = static_cast<int16_t>(dtype.lanes);
214+
if (encoded_lanes < -1) {
215+
return ScalableVector(code, dtype.bits, -encoded_lanes);
216+
}
217+
return PrimType(code, dtype.bits, encoded_lanes);
218+
}
219+
220+
/*! \brief Return the same type with a different scalar bit width, preserving code and lanes. */
221+
TVM_FFI_INLINE PrimType WithBits(int bits) const {
222+
DLDataType dtype = get()->dtype;
223+
int16_t encoded_lanes = static_cast<int16_t>(dtype.lanes);
224+
if (encoded_lanes < -1) {
225+
return ScalableVector(this->code(), bits, -encoded_lanes);
226+
}
227+
return PrimType(this->code(), bits, encoded_lanes);
228+
}
229+
230+
/*! \brief Return the same scalar element type with a fixed lane count. */
231+
TVM_FFI_INLINE PrimType WithLanes(int lanes) const {
232+
return PrimType(this->code(), this->bits(), lanes);
233+
}
234+
235+
/*! \return The vscale factor encoded in a scalable vector type. */
236+
TVM_FFI_INLINE int32_t VScaleFactor() const {
237+
int16_t encoded_lanes = static_cast<int16_t>(get()->dtype.lanes);
238+
if (encoded_lanes >= -1) {
239+
TVM_FFI_THROW(InternalError) << "A fixed length vector doesn't have a vscale factor.";
240+
}
241+
return -encoded_lanes;
242+
}
243+
244+
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PrimType, Type, PrimTypeNode);
245+
};
246+
247+
inline bool operator==(const PrimType& lhs, const PrimType& rhs) {
248+
return lhs->dtype == rhs->dtype;
249+
}
250+
251+
inline bool operator!=(const PrimType& lhs, const PrimType& rhs) { return !(lhs == rhs); }
252+
253+
/*!
254+
* \brief Base type of all the expressions.
255+
* \sa Expr
256+
*/
257+
class BaseExprNode : public ffi::Object {
258+
public:
259+
/*!
260+
* \brief Span that points to the original source code.
261+
* Reserved debug information.
262+
*/
263+
mutable Span span;
264+
265+
/*!
266+
* \brief The deduced or annotated type of the expression.
267+
*
268+
* This field is intentionally nullable because type information may
269+
* be populated by later analysis passes instead of expression
270+
* constructors.
271+
*/
272+
mutable Type ty;
273+
274+
static void RegisterReflection() {
275+
namespace refl = tvm::ffi::reflection;
276+
// span and ty do not participate in structural equal and hash.
277+
refl::ObjectDef<BaseExprNode>()
278+
.def_ro("span", &BaseExprNode::span, refl::DefaultValue(Span()),
279+
refl::AttachFieldFlag::SEqHashIgnore())
280+
.def_ro("ty", &BaseExprNode::ty, refl::DefaultValue(Type()),
281+
refl::AttachFieldFlag::SEqHashIgnore());
282+
}
283+
284+
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
285+
286+
static constexpr const uint32_t _type_child_slots = 64;
287+
TVM_FFI_DECLARE_OBJECT_INFO("ir.BaseExpr", BaseExprNode, ffi::Object);
288+
};
289+
290+
/*!
291+
* \brief Managed reference to BaseExprNode.
292+
* \sa BaseExprNode
293+
*/
294+
class BaseExpr : public ffi::ObjectRef {
295+
public:
296+
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BaseExpr, ffi::ObjectRef, BaseExprNode);
297+
};
298+
299+
namespace ffi {
300+
template <>
301+
inline constexpr bool use_default_type_traits_v<PrimType> = false;
302+
303+
template <>
304+
struct TypeTraits<PrimType> : public ObjectRefWithFallbackTraitsBase<PrimType, DLDataType> {
305+
TVM_FFI_INLINE static PrimType ConvertFallbackValue(DLDataType dtype) { return PrimType(dtype); }
306+
};
307+
} // namespace ffi
308+
309+
} // namespace tvm
310+
311+
#endif // TVM_IR_BASE_EXPR_H_

0 commit comments

Comments
 (0)