5959)
6060from .scope import Scope , JumpInfo , ControlFlowInfo
6161from .typing_support import (
62- typeof_pyval , loose_type_of_pyval , get_constant_value , get_dataclass_info ,
62+ get_dataclass_info , as_third_party_dtype_spec , type_of_constant_python_value ,
63+ loose_type_of_constant_python_value ,
6364)
6465from .type import (
6566 PartitionViewTy , StridedViewTy , GatherScatterViewTy , TupleTy , TileTy , NoneType ,
7273)
7374from cuda .tile ._datatype import (
7475 DType , is_integral , is_float , is_signed , is_boolean , is_pointer_dtype , PointerInfo ,
75- opaque_pointer_dtype , pointer_dtype ,
76+ opaque_pointer_dtype , pointer_dtype
7677)
7778from cuda .tile ._ir2bytecode import (
7879 BytecodeContext , typeid ,
@@ -651,21 +652,29 @@ def loosely_typed_const(value: Any,
651652 ty : Optional [Type ] = None ,
652653 loose_ty : Optional [Type ] = None ,
653654 name : str | None = None ) -> Var :
655+ builder = Builder .get_current ()
654656 if ty is None :
655- if isinstance (value , tuple ):
656- return build_tuple (tuple (loosely_typed_const (item ) for item in value ))
657- ty = typeof_pyval (value )
658- ret = strictly_typed_const (value , ty , name = name )
657+ ty = type_of_constant_python_value (value , builder .ir_ctx .typing_hooks )
658+ assert not ty .is_aggregate (), "Use sym2var(value, constant_only=True) instead"
659+
660+ # Normalize third party dtype spec objects (e.g. torch.float32 -> ct.float32)
661+ if isinstance (ty , DTypeSpec ):
662+ value = ty .dtype
663+
664+ ret = _strictly_typed_const_inner (builder , value , ty , name = name )
659665 if loose_ty is None :
660- loose_ty = loose_type_of_pyval (value )
666+ loose_ty = loose_type_of_constant_python_value (value , builder . ir_ctx . typing_hooks )
661667 ret .set_loose_type (loose_ty )
662668 return ret
663669
664670
665671def strictly_typed_const (value : Any , ty : Type , name : str | None = None ) -> Var :
666- builder = Builder .get_current ()
667- result = None if name is None else builder .ir_ctx .make_var (name , builder .loc )
672+ return _strictly_typed_const_inner (Builder .get_current (), value , ty , name )
668673
674+
675+ def _strictly_typed_const_inner (builder : Builder ,
676+ value : Any , ty : Type , name : str | None = None ) -> Var :
677+ result = None if name is None else builder .ir_ctx .make_var (name , builder .loc )
669678 ret = builder .add_operation (TypedConst , ty , dict (value = value ), result = result )
670679 if not isinstance (ty , TileTy ) or ty .ndim == 0 :
671680 # We currently don't have a way to represent an N-dimensional tile constant
@@ -1893,7 +1902,7 @@ def getattr_tile_dtype_impl(object: Var, name: Var):
18931902
18941903@impl (getattr , overload = (TileTy , "shape" ))
18951904def getattr_tile_shape_impl (object : Var , name : Var ):
1896- return loosely_typed_const (object .get_type ().shape )
1905+ return sym2var (object .get_type ().shape , constant_only = True )
18971906
18981907
18991908@impl (getattr , overload = (TileTy , "ndim" ))
@@ -1924,12 +1933,12 @@ def getattr_tiled_view_dtype_impl(object: Var, name: Var):
19241933
19251934@impl (getattr , overload = (TiledViewTy , "tile_shape" ))
19261935def getattr_tiled_view_tile_shape_impl (object : Var , name : Var ):
1927- return loosely_typed_const (object .get_type ().tile_shape )
1936+ return sym2var (object .get_type ().tile_shape , constant_only = True )
19281937
19291938
19301939@impl (getattr , overload = (TiledViewTy , "traversal_steps" ))
19311940def getattr_tiled_view_traversal_steps_impl (object : Var , name : Var ):
1932- return loosely_typed_const (object .get_type ().traversal_steps )
1941+ return sym2var (object .get_type ().traversal_steps , constant_only = True )
19331942
19341943
19351944@impl (getattr , overload = (TiledViewTy , "num_tiles" ))
@@ -1981,7 +1990,7 @@ def getattr_module_impl(object: Var, name: Var):
19811990 ty = object .get_type ()
19821991 attr_name = require_constant_str (name )
19831992 try :
1984- return loosely_typed_const (getattr (ty .py_mod , attr_name ))
1993+ return sym2var (getattr (ty .py_mod , attr_name ), constant_only = True )
19851994 except AttributeError :
19861995 raise TileTypeError (f"Module '{ ty .py_mod .__name__ } ' has no attribute '{ attr_name } '" )
19871996
@@ -1991,7 +2000,7 @@ def getattr_type_impl(object: Var, name: Var):
19912000 ty = object .get_type ()
19922001 attr_name = require_constant_str (name )
19932002 try :
1994- return loosely_typed_const (getattr (ty .ty , attr_name ))
2003+ return sym2var (getattr (ty .ty , attr_name ), constant_only = True )
19952004 except AttributeError :
19962005 raise TileTypeError (f"'{ ty .ty .__name__ } ' object has no attribute '{ attr_name } '" )
19972006
@@ -2023,7 +2032,7 @@ async def getattr_dataclass_impl(object: Var, name: Var):
20232032 getter = loosely_typed_const (cls_attr .fget )
20242033 return await call (getter , (object ,), {})
20252034 else :
2026- return loosely_typed_const (cls_attr )
2035+ return sym2var (cls_attr , constant_only = True )
20272036
20282037
20292038# ===========================================================================================
@@ -2058,7 +2067,11 @@ def assign(value: Var, res: Var) -> None:
20582067@impl (hir_stubs .identity )
20592068def identity_impl (x : Var ) -> Var :
20602069 if x .is_constant ():
2061- return loosely_typed_const (x .get_constant (), x .get_type (), x .get_loose_type ())
2070+ ty = x .get_type ()
2071+ if ty .is_aggregate ():
2072+ return make_aggregate (x .get_aggregate (), ty , x .get_loose_type ())
2073+ else :
2074+ return loosely_typed_const (x .get_constant (), x .get_type (), x .get_loose_type ())
20622075 else :
20632076 return x
20642077
@@ -5268,8 +5281,7 @@ def load_var_impl(name):
52685281 return ret
52695282 elif rn .index >= 0 :
52705283 val = scope .func_hir .frozen_global_values [rn .index ]
5271- val = get_constant_value (val )
5272- return loosely_typed_const (val )
5284+ return sym2var (val , constant_only = True )
52735285 else :
52745286 raise TileSyntaxError (f"Undefined variable { name } used" )
52755287
@@ -5450,30 +5462,38 @@ async def static_foreach_impl(body: hir.Block, items: Var):
54505462 await dispatch_hir_block (body )
54515463
54525464
5453- def sym2var (x : Any ) -> Var :
5465+ def sym2var (x : Any , constant_only : bool = False ) -> Var :
54545466 # TODO: verify we don't have a stale closure
54555467
54565468 if isinstance (x , Symbol ):
5469+ if constant_only :
5470+ raise TileTypeError ("Cannot create a constant from a symbolic value" )
54575471 return x ._var
54585472
54595473 if isinstance (x , tuple ):
5460- return build_tuple (tuple (sym2var (item ) for item in x ))
5474+ return build_tuple (tuple (sym2var (item , constant_only = constant_only ) for item in x ))
54615475
54625476 cls = type (x )
54635477 if dataclasses .is_dataclass (cls ):
54645478 info = get_dataclass_info (cls )
5465- field_vars = tuple (sym2var (getattr (x , f .name ))
5479+ field_vars = tuple (sym2var (getattr (x , f .name ), constant_only = constant_only )
54665480 for f in dataclasses .fields (cls ))
54675481 return build_dataclass_instance (field_vars , info )
54685482
54695483 if isinstance (x , MethodType ):
5470- self_var = sym2var (x .__self__ )
5484+ self_var = sym2var (x .__self__ , constant_only = constant_only )
54715485 if not isinstance (x .__func__ , FunctionType | BuiltinFunctionType ):
54725486 raise TileTypeError (f"Object of type { type (x ).__name__ } "
54735487 f" cannot be used as a function for binding a method" )
54745488 return bind_method (self_var , x .__func__ )
54755489
5476- x = get_constant_value (x )
5490+ # Transform a third party typed scalar (e.g., np.int16(5)) into a strictly typed constant
5491+ dtype_spec = as_third_party_dtype_spec (type (x ))
5492+ if dtype_spec is not None :
5493+ pyval = datatype .numeric_dtype_category (dtype_spec .dtype ).pytype (x )
5494+ ty = Builder .get_current ().ir_ctx .typing_hooks .get_tensor_like_type (dtype_spec .dtype , ())
5495+ return strictly_typed_const (pyval , ty )
5496+
54775497 return loosely_typed_const (x )
54785498
54795499
0 commit comments