2424 require_tile_type , require_constant_bool , require_constant_pointer_info ,
2525 require_scalar_pointer_type ,
2626)
27- from cuda .lang ._ir .type import (
28- LocalArrayContextManagerTy , ContextManagerState , TensorMapTy ,
29- dtype_to_tensor_map_type , ArrayValue , PointerInfoTy
30- )
27+ from cuda .tile ._ir .type import TensorLikeTy
3128from cuda .tile ._ir .ops import (
32- binary_arithmetic ,
29+ binary_arithmetic_tensorlike ,
30+ binary_arithmetic_tensorlike_raw ,
3331 loosely_typed_const ,
3432 tile_impl_registry ,
3533 bind_method ,
3634 build_tuple ,
3735 strictly_typed_const ,
3836 astype ,
39- raw_binary_arithmetic ,
4037 Return ,
4138 return_ ,
4239 Assign ,
7673from .. import _stub as stub
7774
7875from .type import (
76+ LocalArrayContextManagerTy , ContextManagerState , TensorMapTy ,
77+ dtype_to_tensor_map_type , ArrayValue , PointerInfoTy ,
7978 MemorySpace ,
8079 Type ,
8180 make_vector_ty ,
@@ -285,8 +284,8 @@ def _array_linear_offset(array: Var, indices: tuple[Var, ...]) -> Var:
285284 for index , stride in zip (indices , array_val .strides , strict = True ):
286285 index = astype (index , datatype .uint64 )
287286 stride = astype (stride , datatype .uint64 )
288- scaled = raw_binary_arithmetic ("mul" , index , stride )
289- offset = raw_binary_arithmetic ("add" , offset , scaled )
287+ scaled = binary_arithmetic_tensorlike_raw ("mul" , index , stride )
288+ offset = binary_arithmetic_tensorlike_raw ("add" , offset , scaled )
290289 return offset
291290
292291
@@ -538,7 +537,7 @@ def _is_pointer_type(ty):
538537 return isinstance (ty , TileTy ) and is_pointer_dtype (ty .dtype )
539538
540539
541- @impl (operator .add , overload = (TileTy , TileTy ))
540+ @impl (operator .add , overload = (TensorLikeTy , TensorLikeTy ))
542541async def add_impl (x : Var , y : Var ) -> Var :
543542 xty , yty = x .get_type (), y .get_type ()
544543 if _is_pointer_type (yty ):
@@ -554,7 +553,7 @@ async def add_impl(x: Var, y: Var) -> Var:
554553 return await call_function (operator .add , x , y )
555554
556555
557- @impl (operator .sub , overload = (TileTy , TileTy ))
556+ @impl (operator .sub , overload = (TensorLikeTy , TensorLikeTy ))
558557async def sub_impl (x : Var , y : Var ) -> Var :
559558 xty , yty = x .get_type (), y .get_type ()
560559 if _is_pointer_type (xty ):
@@ -563,7 +562,7 @@ async def sub_impl(x: Var, y: Var) -> Var:
563562 raise TileTypeError (f"Expected integer pointer offset, got { offset_dtype } " )
564563 y = astype (y , datatype .int64 )
565564 c0 = loosely_typed_const (0 )
566- offset = binary_arithmetic ('sub' , c0 , y )
565+ offset = binary_arithmetic_tensorlike ('sub' , c0 , y )
567566 return _pointer_with_offset (x , offset )
568567 if _is_pointer_type (yty ):
569568 raise TileTypeError ('It is invalid to subtract a pointer from an integer' )
@@ -868,7 +867,7 @@ def shared_array_impl(shape: Var, dtype: Var, dynamic: Var, alignment: Var) -> O
868867
869868 if size is None or total_size is None :
870869 total_size = None
871- total_size_var = raw_binary_arithmetic ("mul" , total_size_var , size_var )
870+ total_size_var = binary_arithmetic_tensorlike_raw ("mul" , total_size_var , size_var )
872871 else :
873872 total_size *= size
874873 total_size_var = strictly_typed_const (total_size , size_ty )
0 commit comments