44import builtins
55import dataclasses
66import enum
7+ import functools
78import math
89import operator
910from contextlib import contextmanager
1920from cuda .tile import _datatype as datatype
2021from cuda .tile import RoundingMode , MemoryOrder , MemoryScope
2122from cuda .tile ._mutex import tile_mutex
22- from cuda .tile ._exception import TileTypeError , TileSyntaxError , TileError , \
23+ from cuda .tile ._exception import TileInternalError , TileTypeError , TileSyntaxError , TileError , \
2324 TileStaticAssertionError , TileStaticEvalError , TileValueError , TileUnsupportedFeatureError
2425from cuda .tile ._ir .ir import (
2526 Operation , Var , Loc , Block , add_operation , Builder , enter_nested_block , nested_block ,
3435from . import hir , hir_stubs
3536from .hir import ResolvedName
3637from .op_impl import (
37- ImplRegistry , require_constant_int , require_constant_int_tuple ,
38+ ImplRegistry , is_0d_tile , require_constant_int , require_constant_int_tuple ,
39+ require_optional_dtype_spec ,
3840 require_signed_integer_0d_tile_type ,
3941 require_tile_type , normalize_axis , require_dtype_spec ,
4042 require_constant_bool , require_optional_constant_enum ,
@@ -3283,14 +3285,14 @@ def num_tiles_impl(array: Var, axis: Var, shape: Var, order: Var) -> Var:
32833285 return space_shape [axis ]
32843286
32853287
3286- def full_const (shape : Sequence [int ], fill_value : int | float , dtype : DType ) -> Var :
3288+ def _const (shape : Sequence [int ], value : int | float | bool | tuple , dtype : DType ) -> Var :
32873289 res_ty = make_tile_ty (dtype , shape )
3288- return strictly_typed_const (fill_value , res_ty )
3290+ return strictly_typed_const (value , res_ty )
32893291
32903292
32913293def full (shape : Sequence [int ], fill_value : Var , dtype : DType ) -> Var :
32923294 if fill_value .is_constant ():
3293- return full_const (shape , fill_value .get_constant (), dtype )
3295+ return _const (shape , fill_value .get_constant (), dtype )
32943296 fill_value = astype (fill_value , dtype )
32953297 return broadcast_to (fill_value , shape )
32963298
@@ -3307,14 +3309,93 @@ def full_impl(shape: Var, fill_value: Var, dtype: Var) -> Var:
33073309def ones_impl (shape : Var , dtype : Var ) -> Var :
33083310 shape = require_constant_shape (shape , allow_single_int = True )
33093311 dtype = require_dtype_spec (dtype )
3310- return full_const (shape , 1 , dtype )
3312+ return _const (shape , 1 , dtype )
33113313
33123314
33133315@impl (ct .zeros )
33143316def zeros_impl (shape : Var , dtype : Var ) -> Var :
33153317 shape = require_constant_shape (shape , allow_single_int = True )
33163318 dtype = require_dtype_spec (dtype )
3317- return full_const (shape , 0 , dtype )
3319+ return _const (shape , 0 , dtype )
3320+
3321+
3322+ def _path_str (path : tuple [int , ...]) -> str :
3323+ return "value" + "" .join (f"[{ i } ]" for i in path )
3324+
3325+
3326+ def _tuple_shape (ty : Type , path : tuple [int , ...]) -> tuple [int , ...]:
3327+ path_str = _path_str (path )
3328+ if not isinstance (ty , TupleTy ):
3329+ if not isinstance (ty , TileTy ):
3330+ raise TileTypeError (
3331+ f"Expected scalar elements at { path_str } ; "
3332+ f"got element of type { ty } " )
3333+
3334+ if ty .ndim != 0 :
3335+ raise TileTypeError (
3336+ f"Expected scalar elements at { path_str } ; "
3337+ f"got a tile of shape { ty .shape } " )
3338+
3339+ assert is_0d_tile (ty )
3340+ return ()
3341+
3342+ n = len (ty )
3343+ if not _is_power_of_2 (n ):
3344+ raise TileTypeError (f"Tuple length { n } at { path_str } is not a power of 2" )
3345+
3346+ inner_shapes = {_tuple_shape (t , path + (i ,)) for i , t in enumerate (ty .value_types )}
3347+ if len (inner_shapes ) != 1 :
3348+ raise TileTypeError (f"Tuple has non-uniform inner shapes at { path_str } " )
3349+
3350+ return (n ,) + inner_shapes .pop ()
3351+
3352+
3353+ def _flatten_tuple (value : Var ) -> tuple [Var , ...]:
3354+ value_ty = value .get_type ()
3355+ if not isinstance (value_ty , TupleTy ):
3356+ return (value ,)
3357+ return sum ((_flatten_tuple (i ) for i in value .get_aggregate ().items ), start = ())
3358+
3359+
3360+ def _cat_tuple (tiles : tuple [Var , ...]) -> Var :
3361+ if len (tiles ) == 0 :
3362+ raise TileInternalError ("Expected non-empty tile tuple" )
3363+
3364+ if len (tiles ) == 1 :
3365+ require_0d_tile_type (tiles [0 ])
3366+ return reshape (tiles [0 ], (1 ,))
3367+
3368+ assert len (tiles ) % 2 == 0
3369+ mid = len (tiles ) // 2
3370+ left = _cat_tuple (tiles [:mid ])
3371+ right = _cat_tuple (tiles [mid :])
3372+ return cat ((left , right ), axis = 0 )
3373+
3374+
3375+ @impl (ct .astile )
3376+ def astile_impl (value : Var , dtype : Var ) -> Var :
3377+ dtype : Optional [DType ] = require_optional_dtype_spec (dtype )
3378+ value_ty = value .get_type ()
3379+ if is_0d_tile (value_ty ):
3380+ return value if dtype is None else astype (value , dtype )
3381+
3382+ if not isinstance (value_ty , TupleTy ):
3383+ raise TileTypeError (
3384+ f"Expected a scalar or (possibly nested) tuple of scalars; "
3385+ f"got value of type { value_ty } " )
3386+
3387+ shape = _tuple_shape (value_ty , path = ())
3388+ tiles = _flatten_tuple (value )
3389+ dtype = (functools .reduce (promote_dtypes , (require_0d_tile_type (t ).dtype for t in tiles ))
3390+ if dtype is None
3391+ else dtype )
3392+
3393+ if value .is_constant ():
3394+ return _const (shape , value .get_constant (), dtype )
3395+
3396+ tiles = tuple (astype (t , dtype ) for t in tiles )
3397+ flat = _cat_tuple (tiles )
3398+ return reshape (flat , shape )
33183399
33193400
33203401_TileShape = Tuple [int , ...]
@@ -4116,29 +4197,24 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
41164197 return bc .encode_CatOp (ctx .builder , return_type_id , x_value , y_value , self .axis )
41174198
41184199
4119- def cat (tiles : Var , axis : int ) -> Var :
4120- tuple_ty = require_tuple_type (tiles )
4121- items = tiles .get_aggregate ().items
4122- if len (items ) == 1 :
4123- return items [0 ]
4124-
4125- if len (tuple_ty ) == 0 :
4200+ def cat (tiles : tuple [Var , ...], axis : int ) -> Var :
4201+ if len (tiles ) == 0 :
41264202 raise TileTypeError ("cat() received an empty tuple" )
4127- elif len (items ) == 1 :
4128- return items [0 ]
4129- elif len (tuple_ty ) > 2 :
4130- raise TileTypeError (f"cat() supports at most 2 tiles, got { len (tuple_ty )} " )
4203+ if len (tiles ) == 1 :
4204+ return tiles [0 ]
4205+ if len (tiles ) > 2 :
4206+ raise TileTypeError (f"cat() supports at most 2 tiles, got { len (tiles )} " )
41314207
4132- x_tile , y_tile = items
4208+ x_tile , y_tile = tiles
41334209
4134- if not isinstance (first_tile := tuple_ty . value_types [0 ], TileTy ):
4135- raise TileTypeError (f"Expected tuple of Tile, got a { first_tile } " )
4210+ if not isinstance (first_tile_ty := tiles [0 ]. get_type () , TileTy ):
4211+ raise TileTypeError (f"Expected tuple of Tile, got a { first_tile_ty } " )
41364212
4137- dtype = first_tile .dtype
4138- rank = first_tile .ndim
4139- shape_value = list (first_tile .shape )
4213+ dtype = first_tile_ty .dtype
4214+ rank = first_tile_ty .ndim
4215+ shape_value = list (first_tile_ty .shape )
41404216 axis = normalize_axis (axis , rank )
4141- for tile_ty in tuple_ty . value_types [1 :]:
4217+ for tile_ty in ( t . get_type () for t in tiles [1 :]) :
41424218 if not isinstance (tile_ty , TileTy ):
41434219 raise TileTypeError (f"Expected tuple of Tile, got a { tile_ty } " )
41444220 if tile_ty .ndim != rank :
@@ -4167,8 +4243,9 @@ def _is_power_of_2(x: int):
41674243
41684244@impl (ct .cat )
41694245def cat_impl (tiles : Var , axis : Var ) -> Var :
4246+ require_tuple_type (tiles )
41704247 const_axis = require_constant_int (axis )
4171- return cat (tiles , const_axis )
4248+ return cat (tiles . get_aggregate (). items , const_axis )
41724249
41734250
41744251# Does not support broadcasting or type promotion
0 commit comments