@@ -708,6 +708,34 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
708708 self .flush_to_zero )
709709
710710
711+ @overload_dispatcher (operator .add , fixed_args = ["+" ])
712+ @overload_dispatcher (operator .sub , fixed_args = ["-" ])
713+ @overload_dispatcher (operator .mul , fixed_args = ["*" ])
714+ @overload_dispatcher (operator .floordiv , fixed_args = ["//" ])
715+ @overload_dispatcher (operator .truediv , fixed_args = ["/" ])
716+ @overload_dispatcher (operator .pow , fixed_args = ["**" ])
717+ @overload_dispatcher (operator .mod , fixed_args = ["%" ])
718+ @overload_dispatcher (operator .eq , fixed_args = ["==" ])
719+ @overload_dispatcher (operator .ne , fixed_args = ["!=" ])
720+ @overload_dispatcher (operator .lt , fixed_args = ["<" ])
721+ @overload_dispatcher (operator .le , fixed_args = ["<=" ])
722+ @overload_dispatcher (operator .gt , fixed_args = [">" ])
723+ @overload_dispatcher (operator .ge , fixed_args = [">=" ])
724+ @overload_dispatcher (operator .and_ , fixed_args = ["&" ])
725+ @overload_dispatcher (operator .or_ , fixed_args = ["|" ])
726+ @overload_dispatcher (operator .xor , fixed_args = ["^" ])
727+ @overload_dispatcher (operator .lshift , fixed_args = ["<<" ])
728+ @overload_dispatcher (operator .rshift , fixed_args = [">>" ])
729+ @overload_dispatcher (operator .matmul , fixed_args = ["@" ])
730+ def binop_overload_dispatcher (name : str , x : Var , y : Var ):
731+ x_ty = x .get_type ()
732+ y_ty = y .get_type ()
733+ try :
734+ yield type (x_ty ), type (y_ty )
735+ except OverloadNotFoundError :
736+ raise TileTypeError (f"Unsupported operand types for { name } : { x_ty } and { y_ty } " )
737+
738+
711739# Does not do broadcasting or type promotion, hence the name "Raw"
712740@dataclass (eq = False )
713741class RawComparisonOperation (Operation , opcode = "raw_cmp" ):
@@ -756,12 +784,22 @@ def _binop_propagate_constant(fn: str, x: Any, y: Any, type: Optional[Type]) ->
756784 return strictly_typed_const (res , type )
757785
758786
787+ def comparison_operator_impl (lhs_ty : type [Type ], rhs_ty : type [Type ]):
788+ def decorate (func ):
789+ for name in ("eq" , "ne" , "lt" , "le" , "gt" , "ge" ):
790+ impl (getattr (operator , name ), fixed_args = [name ], overload = (lhs_ty , rhs_ty ))(func )
791+ return func
792+
793+ return decorate
794+
795+
759796@impl (ct .equal , fixed_args = ["eq" ])
760797@impl (ct .greater , fixed_args = ["gt" ])
761798@impl (ct .not_equal , fixed_args = ["ne" ])
762799@impl (ct .greater_equal , fixed_args = ["ge" ])
763800@impl (ct .less , fixed_args = ["lt" ])
764801@impl (ct .less_equal , fixed_args = ["le" ])
802+ @comparison_operator_impl (TileTy , TileTy )
765803def comparison (fn : str , x : Var , y : Var ) -> Var :
766804 x_ty = require_tile_maybe_loose_type (x )
767805 y_ty = require_tile_maybe_loose_type (y )
@@ -798,7 +836,8 @@ def operator_is_not_impl(x: Var, y: Var):
798836 return _is_none_compare (x , y , negate = True , op_name = "is not" )
799837
800838
801- def _tuple_comparison (fn : str , x : Var , y : Var ) -> Var :
839+ @comparison_operator_impl (TupleTy , TupleTy )
840+ async def comparison_operator_tuple_impl (fn : str , x : Var , y : Var ) -> Var :
802841 if fn not in ("eq" , "ne" ):
803842 raise TileTypeError (f"Operator '{ fn } ' is not supported for tuples" )
804843
@@ -824,7 +863,8 @@ def _tuple_comparison(fn: str, x: Var, y: Var) -> Var:
824863 f"Tuple comparison is not supported for elements of type { item_ty } "
825864 )
826865
827- elem_cmps = [comparison_operator_impl ("eq" , xi , yi ) for xi , yi in zip (x_items , y_items )]
866+ from cuda .tile ._passes .hir2ir import call_function
867+ elem_cmps = [await call_function (operator .eq , xi , yi ) for xi , yi in zip (x_items , y_items )]
828868 result = functools .reduce (lambda a , b : binary_bitwise ("and_" , a , b ), elem_cmps ,
829869 loosely_typed_const (True ))
830870
@@ -834,25 +874,14 @@ def _tuple_comparison(fn: str, x: Var, y: Var) -> Var:
834874 return result
835875
836876
837- @impl (operator .eq , fixed_args = ["eq" ])
838- @impl (operator .ne , fixed_args = ["ne" ])
839- @impl (operator .lt , fixed_args = ["lt" ])
840- @impl (operator .le , fixed_args = ["le" ])
841- @impl (operator .gt , fixed_args = ["gt" ])
842- @impl (operator .ge , fixed_args = ["ge" ])
843- def comparison_operator_impl (fn : str , x : Var , y : Var ) -> Var :
844- x_ty = x .get_type ()
845- y_ty = y .get_type ()
877+ @comparison_operator_impl (DTypeSpec , DTypeSpec )
878+ def comparison_dtype_spec_impl (fn : str , x : Var , y : Var ):
879+ return _binop_propagate_constant (fn , x .get_type ().dtype , y .get_type ().dtype , None )
880+
846881
847- match x_ty , y_ty :
848- case DTypeSpec (), DTypeSpec ():
849- return _binop_propagate_constant (fn , x_ty .dtype , y_ty .dtype , None )
850- case StringTy (), StringTy ():
851- return _binop_propagate_constant (fn , x_ty .value , y_ty .value , None )
852- case TupleTy (), TupleTy ():
853- return _tuple_comparison (fn , x , y )
854- case _, _:
855- return comparison (fn , x , y )
882+ @comparison_operator_impl (StringTy , StringTy )
883+ def comparison_string_impl (fn : str , x : Var , y : Var ):
884+ return _binop_propagate_constant (fn , x .get_type ().value , y .get_type ().value , None )
856885
857886
858887def _promote_and_broadcast_to (x : Var , ty : TileTy ) -> Var :
@@ -888,9 +917,9 @@ def raw_binary_bitwise(fn: str, x: Var, y: Var) -> Var:
888917@impl (ct .bitwise_and , fixed_args = ["and_" ])
889918@impl (ct .bitwise_or , fixed_args = ["or_" ])
890919@impl (ct .bitwise_xor , fixed_args = ["xor" ])
891- @impl (operator .and_ , fixed_args = ["and_" ])
892- @impl (operator .or_ , fixed_args = ["or_" ])
893- @impl (operator .xor , fixed_args = ["xor" ])
920+ @impl (operator .and_ , fixed_args = ["and_" ], overload = ( TileTy , TileTy ) )
921+ @impl (operator .or_ , fixed_args = ["or_" ], overload = ( TileTy , TileTy ) )
922+ @impl (operator .xor , fixed_args = ["xor" ], overload = ( TileTy , TileTy ) )
894923def binary_bitwise (fn : str , x : Var , y : Var ) -> Var :
895924 x_ty = require_tile_maybe_loose_type (x )
896925 y_ty = require_tile_maybe_loose_type (y )
@@ -956,8 +985,8 @@ def raw_bitwise_shift(fn: str, x: Var, y: Var) -> Var:
956985
957986@impl (ct .bitwise_lshift , fixed_args = ["lshift" ])
958987@impl (ct .bitwise_rshift , fixed_args = ["rshift" ])
959- @impl (operator .lshift , fixed_args = ["lshift" ])
960- @impl (operator .rshift , fixed_args = ["rshift" ])
988+ @impl (operator .lshift , fixed_args = ["lshift" ], overload = ( TileTy , TileTy ) )
989+ @impl (operator .rshift , fixed_args = ["rshift" ], overload = ( TileTy , TileTy ) )
961990def bitwise_shift (fn : str , x : Var , y : Var ) -> Var :
962991 x_ty = require_tile_maybe_loose_type (x )
963992 y_ty = require_tile_maybe_loose_type (y )
@@ -1114,23 +1143,26 @@ def binary_arithmetic(fn: str, x: Var, y: Var, rounding_mode: Optional[RoundingM
11141143@impl (ct .floordiv , fixed_args = ["floordiv" ])
11151144@impl (ct .cdiv , fixed_args = ["cdiv" ])
11161145@impl (ct .pow , fixed_args = ["pow" ])
1117- @impl (operator .sub , fixed_args = ["sub" ])
1118- @impl (operator .mul , fixed_args = ["mul" ])
1119- @impl (operator .floordiv , fixed_args = ["floordiv" ])
1120- @impl (operator .truediv , fixed_args = ["truediv" ])
1121- @impl (operator .pow , fixed_args = ["pow" ])
1146+ @impl (operator .sub , fixed_args = ["sub" ], overload = ( TileTy , TileTy ) )
1147+ @impl (operator .mul , fixed_args = ["mul" ], overload = ( TileTy , TileTy ) )
1148+ @impl (operator .floordiv , fixed_args = ["floordiv" ], overload = ( TileTy , TileTy ) )
1149+ @impl (operator .truediv , fixed_args = ["truediv" ], overload = ( TileTy , TileTy ) )
1150+ @impl (operator .pow , fixed_args = ["pow" ], overload = ( TileTy , TileTy ) )
11221151@impl (min , fixed_args = ["min" ])
11231152@impl (max , fixed_args = ["max" ])
11241153def binary_arithmetic_impl (fn : str , x : Var , y : Var ) -> Var :
11251154 return binary_arithmetic (fn , x , y )
11261155
11271156
1128- @impl (operator .add )
1129- def add_impl (x : Var , y : Var ) -> Var :
1130- if isinstance (x .get_type (), TupleTy ) and isinstance (y .get_type (), TupleTy ):
1131- x_items = x .get_aggregate ().items
1132- y_items = y .get_aggregate ().items
1133- return build_tuple (x_items + y_items )
1157+ @impl (operator .add , overload = (TupleTy , TupleTy ))
1158+ def add_tuple_impl (x : Var , y : Var ):
1159+ x_items = x .get_aggregate ().items
1160+ y_items = y .get_aggregate ().items
1161+ return build_tuple (x_items + y_items )
1162+
1163+
1164+ @impl (operator .add , overload = (TileTy , TileTy ))
1165+ def add_tile_impl (x : Var , y : Var ) -> Var :
11341166 return binary_arithmetic ("add" , x , y )
11351167
11361168
@@ -1157,7 +1189,7 @@ def binary_arithmetic_impl_with_rd_and_ftz(fn: str, x: Var, y: Var,
11571189 return binary_arithmetic (fn , x , y , rounding_mode , flush_to_zero )
11581190
11591191
1160- @impl (operator .mod )
1192+ @impl (operator .mod , overload = ( TileTy , TileTy ) )
11611193@impl (ct .mod )
11621194def mod (x : Var , y : Var ) -> Var :
11631195 x_ty = require_tile_maybe_loose_type (x )
@@ -3577,7 +3609,7 @@ def mma_impl(x: Var, y: Var, acc: Var, use_fast_acc: Var) -> Var:
35773609
35783610
35793611@impl (ct .matmul )
3580- @impl (operator .matmul )
3612+ @impl (operator .matmul , overload = ( TileTy , TileTy ) )
35813613def matmul_impl (x : Var , y : Var ) -> Var :
35823614 x_tile_type = require_tile_type (x )
35833615 y_tile_type = require_tile_type (y )
0 commit comments