Skip to content

Commit d68923f

Browse files
committed
[FIX][IR] Repair PrimType dtype CI fallout
1 parent 4b0c660 commit d68923f

27 files changed

Lines changed: 183 additions & 127 deletions

File tree

include/tvm/ir/base_expr.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <tvm/ffi/reflection/registry.h>
3030
#include <tvm/ir/source_map.h>
3131

32+
#include <cstddef>
3233
#include <cstdint>
3334

3435
namespace tvm {
@@ -207,6 +208,14 @@ class PrimType final : public Type {
207208
return static_cast<int16_t>(get()->dtype.lanes) > 1;
208209
}
209210

211+
/*!
212+
* \brief Return the number of bytes needed to store one value of this type.
213+
*
214+
* This uses the same packed sub-byte dtype sizing rule as runtime tensors.
215+
* Scalable vector types have no compile-time storage size and are rejected.
216+
*/
217+
TVM_DLL size_t StorageBytes() const;
218+
210219
/*! \brief Return the same type with a different dtype code, preserving bits and lanes. */
211220
TVM_FFI_INLINE PrimType WithCode(DLDataTypeCode code) const {
212221
DLDataType dtype = get()->dtype;

include/tvm/topi/detail/broadcast.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@ struct BroadcastHelper {
4343
};
4444

4545
static inline PrimType CommonType(const PrimType& type1, const PrimType& type2) {
46-
TVM_FFI_ICHECK(!type1.IsScalableVector() && !type2.IsScalableVector());
47-
TVM_FFI_ICHECK_EQ(type1.lanes(), 1);
48-
TVM_FFI_ICHECK_EQ(type2.lanes(), 1);
46+
TVM_FFI_ICHECK(type1.IsScalar() && type2.IsScalar());
4947
TVM_FFI_ICHECK(type1.code() == type2.code());
5048
return type1.bits() < type2.bits() ? type1.WithBits(type2.bits()) : type1;
5149
}
@@ -59,7 +57,7 @@ inline BroadcastHelper BroadcastShape(const tvm::ffi::Array<tvm::PrimExpr>& shap
5957
int i;
6058

6159
auto cast_if_needed = [](PrimType to_type, PrimExpr expr) {
62-
return to_type->dtype == expr.ty()->dtype ? expr : cast(to_type, expr);
60+
return to_type == expr.ty() ? expr : cast(to_type, expr);
6361
};
6462

6563
for (i = 1; i <= std::min(s1_size, s2_size); ++i) {

python/tvm/ir/expr.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,16 @@ class PrimExpr(BaseExpr):
4545

4646
@property
4747
def dtype(self):
48-
"""Return the runtime dtype represented by this expression's PrimType."""
49-
return self.ty.dtype
48+
"""Compatibility alias for the runtime dtype of scalar PrimExpr.
49+
50+
New code should inspect ``expr.ty`` directly. For scalar primitive
51+
expressions, use ``expr.ty.dtype``.
52+
"""
53+
if self.ty is None:
54+
return None
55+
if hasattr(self.ty, "dtype"):
56+
return self.ty.dtype
57+
return "handle"
5058

5159

5260
@tvm_ffi.register_object("ir.RelaxExpr")

python/tvm/script/parser/core/evaluator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,11 @@ def _eval_if_exp(self, fields: dict[str, Any]) -> Any:
396396
orelse = self._eval_expr(fields["orelse"])
397397
if isinstance(test, bool):
398398
return body if test else orelse
399-
elif isinstance(test, tvm.tirx.PrimExpr) and test.dtype.type_code == tvm.DataTypeCode.BOOL:
399+
elif (
400+
isinstance(test, tvm.tirx.PrimExpr)
401+
and isinstance(test.ty, tvm.ir.PrimType)
402+
and test.ty.matches_code(tvm.DataTypeCode.BOOL)
403+
):
400404
return tvm.tirx.op.if_then_else(test, body, orelse)
401405
else:
402406
raise TypeError(f"Expected Python bool or TIR bool, but got {type(test)}")

python/tvm/tirx/buffer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def _infer_shape(shape):
352352
shape = args
353353
assert all(
354354
isinstance(arg, int)
355-
or (isinstance(arg, PrimExpr) and arg.dtype in ["int32", "int64"])
355+
or (isinstance(arg, PrimExpr) and arg.ty.dtype in ["int32", "int64"])
356356
for arg in shape
357357
), "shape must be a list of integers or PrimExprs with dtype int32 or int64"
358358
# Safely get optional keyword arguments
@@ -462,7 +462,7 @@ def permute(self, *dims) -> "Buffer":
462462

463463
def __getitem__(self, indices):
464464
from ..arith import Analyzer # pylint: disable=import-outside-toplevel
465-
from .expr import BufferLoad, Ramp, const # pylint: disable=import-outside-toplevel
465+
from .expr import BufferLoad, Ramp # pylint: disable=import-outside-toplevel
466466
from .stmt import BufferRegion # pylint: disable=import-outside-toplevel
467467

468468
if not isinstance(indices, tuple | list):
@@ -483,7 +483,8 @@ def __getitem__(self, indices):
483483
else:
484484
region.append(
485485
Range.from_min_extent(
486-
index, const(1, index.dtype) if isinstance(index, PrimExpr) else 1
486+
index,
487+
tvm.tirx.expr.IntImm(index.ty, 1) if isinstance(index, PrimExpr) else 1,
487488
)
488489
)
489490
if has_implicit_slice:
@@ -499,7 +500,7 @@ def __getitem__(self, indices):
499500
step = 1 if index.step is None else index.step
500501
# We should ensure the dtype of start is the same with that of step.
501502
if isinstance(start, tvm.tirx.expr.PrimExpr) and isinstance(step, int):
502-
step = tvm.tirx.expr.IntImm(start.dtype, step)
503+
step = tvm.tirx.expr.IntImm(start.ty, step)
503504
lanes = analyzer.simplify((stop - start + step - 1) // step)
504505
if lanes == 1:
505506
expr_indices.append(start)
@@ -540,8 +541,8 @@ def decl_buffer(
540541
layout = TileLayout(S[tuple(shape)]) if shape else None
541542

542543
if offset_factor != 0 and elem_offset is None:
543-
shape_dtype = shape[0].dtype if shape and hasattr(shape[0], "dtype") else "int32"
544-
elem_offset = Var(f"{name}_elem_offset", shape_dtype)
544+
shape_ty = shape[0].ty if shape and isinstance(shape[0], PrimExpr) else "int32"
545+
elem_offset = Var(f"{name}_elem_offset", shape_ty)
545546
if data is None:
546547
# Bool is represented as uint1 in the IR, but stored as int8
547548
storage_type = dtype if isinstance(dtype, PrimType) else PrimType(dtype)

python/tvm/tirx/expr.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __rmod__(self, other: PrimExpr) -> PrimExpr:
132132
return _ffi_api._OpFloorMod(other, self, None) # type: ignore
133133

134134
def __neg__(self) -> PrimExpr:
135-
neg_one = const(-1, self.dtype) # type: ignore
135+
neg_one = const(-1, self.expr_ty().dtype)
136136
return self.__mul__(neg_one)
137137

138138
def __lshift__(self, other: PrimExpr) -> PrimExpr:
@@ -215,7 +215,7 @@ def equal(self, other: PrimExpr, span: Span | None = None) -> bool:
215215
"""
216216
return _ffi_api._OpEQ(self, other, span) # type: ignore
217217

218-
def astype(self, dtype: str, span: Span | None = None) -> PrimExpr:
218+
def astype(self, dtype: str | ir.PrimType, span: Span | None = None) -> PrimExpr:
219219
"""Cast the expression to other type.
220220
221221
Parameters
@@ -477,12 +477,10 @@ def __init__(
477477
raise TypeError("dom need to be Range")
478478

479479
name = var if var is not None else "iter"
480-
dtype = "int32" if dom is None else dom.extent.dtype
480+
dtype = "int32" if dom is None else dom.extent.ty
481481
var = Var(name, dtype=dtype, span=span) if not isinstance(var, Var) else var
482482
if dom is not None:
483-
assert var.dtype == dom.extent.dtype, (
484-
"IterVar's Var dtype must match its domain's extent's dtype"
485-
)
483+
assert var.ty == dom.extent.ty, "IterVar's Var type must match its domain's extent type"
486484
self.__init_handle_by_constructor__(
487485
_ffi_api.IterVar,
488486
dom,
@@ -618,7 +616,9 @@ class FloatImm(ConstExpr):
618616

619617
value: float
620618

621-
def __init__(self, dtype: str, value: float, span: Span | None = None) -> None:
619+
def __init__(self, dtype: str | ir.PrimType, value: float, span: Span | None = None) -> None:
620+
if isinstance(dtype, ir.PrimType):
621+
dtype = dtype.dtype
622622
self.__init_handle_by_constructor__(
623623
tvm.ir._ffi_api.FloatImm,
624624
dtype,
@@ -648,7 +648,9 @@ class IntImm(ConstExpr):
648648

649649
value: int
650650

651-
def __init__(self, dtype: str, value: int, span: Span | None = None) -> None:
651+
def __init__(self, dtype: str | ir.PrimType, value: int, span: Span | None = None) -> None:
652+
if isinstance(dtype, ir.PrimType):
653+
dtype = dtype.dtype
652654
self.__init_handle_by_constructor__(
653655
tvm.ir._ffi_api.IntImm,
654656
dtype,
@@ -725,7 +727,9 @@ class Cast(PrimExprWithOp):
725727

726728
value: PrimExpr
727729

728-
def __init__(self, dtype, value, span: Span | None = None) -> None:
730+
def __init__(self, dtype: str | ir.PrimType, value, span: Span | None = None) -> None:
731+
if isinstance(dtype, ir.PrimType):
732+
dtype = dtype.dtype
729733
self.__init_handle_by_constructor__(_ffi_api.Cast, dtype, value, span) # type: ignore
730734

731735

@@ -1336,7 +1340,7 @@ class Call(PrimExprWithOp):
13361340

13371341
def __init__(
13381342
self,
1339-
dtype: str,
1343+
dtype: str | ir.PrimType,
13401344
op: Op | str,
13411345
args: list[PrimExpr],
13421346
attrs: ir.Attrs | dict | None = None,

python/tvm/tirx/layout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,10 @@ def _get_default_strides(data: list[int | PrimExpr], stride: int = 1) -> tuple:
332332
# produce for int64-shaped buffers (otherwise the last stride stays a
333333
# Python ``int`` -> int32 IntImm and breaks structural-equal).
334334
for t in data:
335-
if isinstance(t, PrimExpr) and t.dtype != "int32":
335+
if isinstance(t, PrimExpr) and t.ty.dtype != "int32":
336336
from .expr import IntImm # pylint: disable=import-outside-toplevel
337337

338-
stride = IntImm(t.dtype, stride)
338+
stride = IntImm(t.ty, stride)
339339
break
340340
res = list()
341341
for t in reversed(data):

0 commit comments

Comments
 (0)