@@ -132,7 +132,7 @@ def __rmod__(self, other: PrimExpr) -> PrimExpr:
132132 return _ffi_api ._OpFloorMod (other , self , None ) # type: ignore
133133
134134 def __neg__ (self ) -> PrimExpr :
135- neg_one = const (- 1 , self .dtype ) # type: ignore
135+ neg_one = const (- 1 , self .expr_ty (). dtype )
136136 return self .__mul__ (neg_one )
137137
138138 def __lshift__ (self , other : PrimExpr ) -> PrimExpr :
@@ -215,7 +215,7 @@ def equal(self, other: PrimExpr, span: Span | None = None) -> bool:
215215 """
216216 return _ffi_api ._OpEQ (self , other , span ) # type: ignore
217217
218- def astype (self , dtype : str , span : Span | None = None ) -> PrimExpr :
218+ def astype (self , dtype : str | ir . PrimType , span : Span | None = None ) -> PrimExpr :
219219 """Cast the expression to other type.
220220
221221 Parameters
@@ -477,12 +477,10 @@ def __init__(
477477 raise TypeError ("dom need to be Range" )
478478
479479 name = var if var is not None else "iter"
480- dtype = "int32" if dom is None else dom .extent .dtype
480+ dtype = "int32" if dom is None else dom .extent .ty
481481 var = Var (name , dtype = dtype , span = span ) if not isinstance (var , Var ) else var
482482 if dom is not None :
483- assert var .dtype == dom .extent .dtype , (
484- "IterVar's Var dtype must match its domain's extent's dtype"
485- )
483+ assert var .ty == dom .extent .ty , "IterVar's Var type must match its domain's extent type"
486484 self .__init_handle_by_constructor__ (
487485 _ffi_api .IterVar ,
488486 dom ,
@@ -618,7 +616,9 @@ class FloatImm(ConstExpr):
618616
619617 value : float
620618
621- def __init__ (self , dtype : str , value : float , span : Span | None = None ) -> None :
619+ def __init__ (self , dtype : str | ir .PrimType , value : float , span : Span | None = None ) -> None :
620+ if isinstance (dtype , ir .PrimType ):
621+ dtype = dtype .dtype
622622 self .__init_handle_by_constructor__ (
623623 tvm .ir ._ffi_api .FloatImm ,
624624 dtype ,
@@ -648,7 +648,9 @@ class IntImm(ConstExpr):
648648
649649 value : int
650650
651- def __init__ (self , dtype : str , value : int , span : Span | None = None ) -> None :
651+ def __init__ (self , dtype : str | ir .PrimType , value : int , span : Span | None = None ) -> None :
652+ if isinstance (dtype , ir .PrimType ):
653+ dtype = dtype .dtype
652654 self .__init_handle_by_constructor__ (
653655 tvm .ir ._ffi_api .IntImm ,
654656 dtype ,
@@ -725,7 +727,9 @@ class Cast(PrimExprWithOp):
725727
726728 value : PrimExpr
727729
728- def __init__ (self , dtype , value , span : Span | None = None ) -> None :
730+ def __init__ (self , dtype : str | ir .PrimType , value , span : Span | None = None ) -> None :
731+ if isinstance (dtype , ir .PrimType ):
732+ dtype = dtype .dtype
729733 self .__init_handle_by_constructor__ (_ffi_api .Cast , dtype , value , span ) # type: ignore
730734
731735
@@ -1336,7 +1340,7 @@ class Call(PrimExprWithOp):
13361340
13371341 def __init__ (
13381342 self ,
1339- dtype : str ,
1343+ dtype : str | ir . PrimType ,
13401344 op : Op | str ,
13411345 args : list [PrimExpr ],
13421346 attrs : ir .Attrs | dict | None = None ,
0 commit comments