Skip to content

Commit 7bf1722

Browse files
committed
Use overload-based registration for all binary operators
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 232650b commit 7bf1722

5 files changed

Lines changed: 157 additions & 57 deletions

File tree

experimental/cuda-lang/src/cuda/lang/_ir/ops.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
binary_arithmetic,
3333
loosely_typed_const,
3434
tile_impl_registry,
35-
add_impl as tile_add_impl,
3635
bind_method,
3736
build_tuple,
3837
strictly_typed_const,
@@ -541,8 +540,8 @@ def _is_pointer_type(ty):
541540
return isinstance(ty, TileTy) and is_pointer_dtype(ty.dtype)
542541

543542

544-
@impl(operator.add)
545-
def add_impl(x: Var, y: Var) -> Var:
543+
@impl(operator.add, overload=(TileTy, TileTy))
544+
async def add_impl(x: Var, y: Var) -> Var:
546545
xty, yty = x.get_type(), y.get_type()
547546
if _is_pointer_type(yty):
548547
xty, yty = yty, xty
@@ -551,11 +550,14 @@ def add_impl(x: Var, y: Var) -> Var:
551550
if not datatype.is_integral(offset_dtype):
552551
raise TileTypeError(f"Expected integer pointer offset, got {offset_dtype}")
553552
return _pointer_with_offset(x, y)
554-
return tile_add_impl(x, y)
553+
# HACK HACK HACK
554+
with tile_impl_registry.as_current():
555+
from cuda.tile._passes.hir2ir import call_function
556+
return await call_function(operator.add, x, y)
555557

556558

557-
@impl(operator.sub)
558-
def sub_impl(x: Var, y: Var) -> Var:
559+
@impl(operator.sub, overload=(TileTy, TileTy))
560+
async def sub_impl(x: Var, y: Var) -> Var:
559561
xty, yty = x.get_type(), y.get_type()
560562
if _is_pointer_type(xty):
561563
offset_dtype = require_scalar_tile_type(y).dtype
@@ -567,7 +569,10 @@ def sub_impl(x: Var, y: Var) -> Var:
567569
return _pointer_with_offset(x, offset)
568570
if _is_pointer_type(yty):
569571
raise TileTypeError('It is invalid to subtract a pointer from an integer')
570-
return binary_arithmetic('sub', x, y)
572+
# HACK HACK HACK
573+
with tile_impl_registry.as_current():
574+
from cuda.tile._passes.hir2ir import call_function
575+
return await call_function(operator.sub, x, y)
571576

572577

573578
@impl(stub.address_space_cast)

src/cuda/tile/_ir/op_impl.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def clone(self) -> "ImplRegistry":
8181
ret._overloaded_implementations[stub] = dict(overloads)
8282
return ret
8383

84-
def overload_dispatcher(self, stub):
84+
def overload_dispatcher(self, stub, *, fixed_args: Sequence[Any] = ()):
8585
"""
8686
Decorates a function to attach an overloaded implementation dispatcher to a stub.
8787
@@ -92,6 +92,10 @@ def overload_dispatcher(self, stub):
9292
"""
9393

9494
def decorate(key_func):
95+
orig_func = key_func
96+
if len(fixed_args) > 0:
97+
key_func = functools.partial(orig_func, *fixed_args)
98+
9599
@functools.wraps(key_func)
96100
async def implementation(*args):
97101
generator = key_func(*args)
@@ -117,7 +121,8 @@ async def implementation(*args):
117121

118122
assert stub not in self._overloaded_implementations
119123
self._overloaded_implementations[stub] = dict()
120-
return self.impl(stub)(implementation)
124+
self.impl(stub)(implementation)
125+
return orig_func
121126

122127
return decorate
123128

@@ -126,12 +131,11 @@ def _find_overload(self, stub: Callable, overload: tuple[Any, ...]) -> Callable
126131
best_matches = []
127132
best_priority = -1
128133

129-
for parameters, (priority, impl) in candidates.items():
134+
for priority, predicates, impl in candidates.values():
130135
if priority < best_priority:
131136
continue
132137

133-
if not all(p == WILDCARD or p == arg
134-
for p, arg in zip(parameters, overload, strict=True)):
138+
if not all(p(arg) for p, arg in zip(predicates, overload, strict=True)):
135139
continue
136140

137141
if priority > best_priority:
@@ -200,17 +204,27 @@ def wrapper(*args, **kwargs):
200204
if len(overload) == 0:
201205
self.op_implementations[stub] = wrapper
202206
else:
207+
predicates = tuple(_predicate_from_overload_pattern(p) for p in overload)
203208
self._overloaded_implementations[stub][overload] = \
204-
(sum(p != WILDCARD for p in overload), wrapper)
209+
(sum(p != WILDCARD for p in overload), predicates, wrapper)
205210

206211
return orig_func
207212

208213
return decorate
209214

210215
def _have_overload_matching_first_param(self, stub: Callable, first_param: Any) -> bool:
211216
candidates = self._overloaded_implementations[stub]
212-
return any(parameters[0] == first_param
213-
for parameters in candidates.keys())
217+
return any(predicates[0](first_param)
218+
for _priority, predicates, _impl in candidates.values())
219+
220+
221+
def _predicate_from_overload_pattern(pattern):
222+
if pattern == WILDCARD:
223+
return lambda _: True
224+
elif isinstance(pattern, type):
225+
return lambda x: issubclass(x, pattern)
226+
else:
227+
return lambda x: pattern == x
214228

215229

216230
class _CurrentRegistry(threading.local):

src/cuda/tile/_ir/ops.py

Lines changed: 70 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,34 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
708708
self.flush_to_zero)
709709

710710

711+
@overload_dispatcher(operator.add, fixed_args=["+"])
712+
@overload_dispatcher(operator.sub, fixed_args=["-"])
713+
@overload_dispatcher(operator.mul, fixed_args=["*"])
714+
@overload_dispatcher(operator.floordiv, fixed_args=["//"])
715+
@overload_dispatcher(operator.truediv, fixed_args=["/"])
716+
@overload_dispatcher(operator.pow, fixed_args=["**"])
717+
@overload_dispatcher(operator.mod, fixed_args=["%"])
718+
@overload_dispatcher(operator.eq, fixed_args=["=="])
719+
@overload_dispatcher(operator.ne, fixed_args=["!="])
720+
@overload_dispatcher(operator.lt, fixed_args=["<"])
721+
@overload_dispatcher(operator.le, fixed_args=["<="])
722+
@overload_dispatcher(operator.gt, fixed_args=[">"])
723+
@overload_dispatcher(operator.ge, fixed_args=[">="])
724+
@overload_dispatcher(operator.and_, fixed_args=["&"])
725+
@overload_dispatcher(operator.or_, fixed_args=["|"])
726+
@overload_dispatcher(operator.xor, fixed_args=["^"])
727+
@overload_dispatcher(operator.lshift, fixed_args=["<<"])
728+
@overload_dispatcher(operator.rshift, fixed_args=[">>"])
729+
@overload_dispatcher(operator.matmul, fixed_args=["@"])
730+
def binop_overload_dispatcher(name: str, x: Var, y: Var):
731+
x_ty = x.get_type()
732+
y_ty = y.get_type()
733+
try:
734+
yield type(x_ty), type(y_ty)
735+
except OverloadNotFoundError:
736+
raise TileTypeError(f"Unsupported operand types for {name}: {x_ty} and {y_ty}")
737+
738+
711739
# Does not do broadcasting or type promotion, hence the name "Raw"
712740
@dataclass(eq=False)
713741
class RawComparisonOperation(Operation, opcode="raw_cmp"):
@@ -756,12 +784,22 @@ def _binop_propagate_constant(fn: str, x: Any, y: Any, type: Optional[Type]) ->
756784
return strictly_typed_const(res, type)
757785

758786

787+
def comparison_operator_impl(lhs_ty: type[Type], rhs_ty: type[Type]):
788+
def decorate(func):
789+
for name in ("eq", "ne", "lt", "le", "gt", "ge"):
790+
impl(getattr(operator, name), fixed_args=[name], overload=(lhs_ty, rhs_ty))(func)
791+
return func
792+
793+
return decorate
794+
795+
759796
@impl(ct.equal, fixed_args=["eq"])
760797
@impl(ct.greater, fixed_args=["gt"])
761798
@impl(ct.not_equal, fixed_args=["ne"])
762799
@impl(ct.greater_equal, fixed_args=["ge"])
763800
@impl(ct.less, fixed_args=["lt"])
764801
@impl(ct.less_equal, fixed_args=["le"])
802+
@comparison_operator_impl(TileTy, TileTy)
765803
def comparison(fn: str, x: Var, y: Var) -> Var:
766804
x_ty = require_tile_maybe_loose_type(x)
767805
y_ty = require_tile_maybe_loose_type(y)
@@ -798,7 +836,8 @@ def operator_is_not_impl(x: Var, y: Var):
798836
return _is_none_compare(x, y, negate=True, op_name="is not")
799837

800838

801-
def _tuple_comparison(fn: str, x: Var, y: Var) -> Var:
839+
@comparison_operator_impl(TupleTy, TupleTy)
840+
async def comparison_operator_tuple_impl(fn: str, x: Var, y: Var) -> Var:
802841
if fn not in ("eq", "ne"):
803842
raise TileTypeError(f"Operator '{fn}' is not supported for tuples")
804843

@@ -824,7 +863,8 @@ def _tuple_comparison(fn: str, x: Var, y: Var) -> Var:
824863
f"Tuple comparison is not supported for elements of type {item_ty}"
825864
)
826865

827-
elem_cmps = [comparison_operator_impl("eq", xi, yi) for xi, yi in zip(x_items, y_items)]
866+
from cuda.tile._passes.hir2ir import call_function
867+
elem_cmps = [await call_function(operator.eq, xi, yi) for xi, yi in zip(x_items, y_items)]
828868
result = functools.reduce(lambda a, b: binary_bitwise("and_", a, b), elem_cmps,
829869
loosely_typed_const(True))
830870

@@ -834,25 +874,14 @@ def _tuple_comparison(fn: str, x: Var, y: Var) -> Var:
834874
return result
835875

836876

837-
@impl(operator.eq, fixed_args=["eq"])
838-
@impl(operator.ne, fixed_args=["ne"])
839-
@impl(operator.lt, fixed_args=["lt"])
840-
@impl(operator.le, fixed_args=["le"])
841-
@impl(operator.gt, fixed_args=["gt"])
842-
@impl(operator.ge, fixed_args=["ge"])
843-
def comparison_operator_impl(fn: str, x: Var, y: Var) -> Var:
844-
x_ty = x.get_type()
845-
y_ty = y.get_type()
877+
@comparison_operator_impl(DTypeSpec, DTypeSpec)
878+
def comparison_dtype_spec_impl(fn: str, x: Var, y: Var):
879+
return _binop_propagate_constant(fn, x.get_type().dtype, y.get_type().dtype, None)
880+
846881

847-
match x_ty, y_ty:
848-
case DTypeSpec(), DTypeSpec():
849-
return _binop_propagate_constant(fn, x_ty.dtype, y_ty.dtype, None)
850-
case StringTy(), StringTy():
851-
return _binop_propagate_constant(fn, x_ty.value, y_ty.value, None)
852-
case TupleTy(), TupleTy():
853-
return _tuple_comparison(fn, x, y)
854-
case _, _:
855-
return comparison(fn, x, y)
882+
@comparison_operator_impl(StringTy, StringTy)
883+
def comparison_string_impl(fn: str, x: Var, y: Var):
884+
return _binop_propagate_constant(fn, x.get_type().value, y.get_type().value, None)
856885

857886

858887
def _promote_and_broadcast_to(x: Var, ty: TileTy) -> Var:
@@ -888,9 +917,9 @@ def raw_binary_bitwise(fn: str, x: Var, y: Var) -> Var:
888917
@impl(ct.bitwise_and, fixed_args=["and_"])
889918
@impl(ct.bitwise_or, fixed_args=["or_"])
890919
@impl(ct.bitwise_xor, fixed_args=["xor"])
891-
@impl(operator.and_, fixed_args=["and_"])
892-
@impl(operator.or_, fixed_args=["or_"])
893-
@impl(operator.xor, fixed_args=["xor"])
920+
@impl(operator.and_, fixed_args=["and_"], overload=(TileTy, TileTy))
921+
@impl(operator.or_, fixed_args=["or_"], overload=(TileTy, TileTy))
922+
@impl(operator.xor, fixed_args=["xor"], overload=(TileTy, TileTy))
894923
def binary_bitwise(fn: str, x: Var, y: Var) -> Var:
895924
x_ty = require_tile_maybe_loose_type(x)
896925
y_ty = require_tile_maybe_loose_type(y)
@@ -956,8 +985,8 @@ def raw_bitwise_shift(fn: str, x: Var, y: Var) -> Var:
956985

957986
@impl(ct.bitwise_lshift, fixed_args=["lshift"])
958987
@impl(ct.bitwise_rshift, fixed_args=["rshift"])
959-
@impl(operator.lshift, fixed_args=["lshift"])
960-
@impl(operator.rshift, fixed_args=["rshift"])
988+
@impl(operator.lshift, fixed_args=["lshift"], overload=(TileTy, TileTy))
989+
@impl(operator.rshift, fixed_args=["rshift"], overload=(TileTy, TileTy))
961990
def bitwise_shift(fn: str, x: Var, y: Var) -> Var:
962991
x_ty = require_tile_maybe_loose_type(x)
963992
y_ty = require_tile_maybe_loose_type(y)
@@ -1114,23 +1143,26 @@ def binary_arithmetic(fn: str, x: Var, y: Var, rounding_mode: Optional[RoundingM
11141143
@impl(ct.floordiv, fixed_args=["floordiv"])
11151144
@impl(ct.cdiv, fixed_args=["cdiv"])
11161145
@impl(ct.pow, fixed_args=["pow"])
1117-
@impl(operator.sub, fixed_args=["sub"])
1118-
@impl(operator.mul, fixed_args=["mul"])
1119-
@impl(operator.floordiv, fixed_args=["floordiv"])
1120-
@impl(operator.truediv, fixed_args=["truediv"])
1121-
@impl(operator.pow, fixed_args=["pow"])
1146+
@impl(operator.sub, fixed_args=["sub"], overload=(TileTy, TileTy))
1147+
@impl(operator.mul, fixed_args=["mul"], overload=(TileTy, TileTy))
1148+
@impl(operator.floordiv, fixed_args=["floordiv"], overload=(TileTy, TileTy))
1149+
@impl(operator.truediv, fixed_args=["truediv"], overload=(TileTy, TileTy))
1150+
@impl(operator.pow, fixed_args=["pow"], overload=(TileTy, TileTy))
11221151
@impl(min, fixed_args=["min"])
11231152
@impl(max, fixed_args=["max"])
11241153
def binary_arithmetic_impl(fn: str, x: Var, y: Var) -> Var:
11251154
return binary_arithmetic(fn, x, y)
11261155

11271156

1128-
@impl(operator.add)
1129-
def add_impl(x: Var, y: Var) -> Var:
1130-
if isinstance(x.get_type(), TupleTy) and isinstance(y.get_type(), TupleTy):
1131-
x_items = x.get_aggregate().items
1132-
y_items = y.get_aggregate().items
1133-
return build_tuple(x_items + y_items)
1157+
@impl(operator.add, overload=(TupleTy, TupleTy))
1158+
def add_tuple_impl(x: Var, y: Var):
1159+
x_items = x.get_aggregate().items
1160+
y_items = y.get_aggregate().items
1161+
return build_tuple(x_items + y_items)
1162+
1163+
1164+
@impl(operator.add, overload=(TileTy, TileTy))
1165+
def add_tile_impl(x: Var, y: Var) -> Var:
11341166
return binary_arithmetic("add", x, y)
11351167

11361168

@@ -1157,7 +1189,7 @@ def binary_arithmetic_impl_with_rd_and_ftz(fn: str, x: Var, y: Var,
11571189
return binary_arithmetic(fn, x, y, rounding_mode, flush_to_zero)
11581190

11591191

1160-
@impl(operator.mod)
1192+
@impl(operator.mod, overload=(TileTy, TileTy))
11611193
@impl(ct.mod)
11621194
def mod(x: Var, y: Var) -> Var:
11631195
x_ty = require_tile_maybe_loose_type(x)
@@ -3577,7 +3609,7 @@ def mma_impl(x: Var, y: Var, acc: Var, use_fast_acc: Var) -> Var:
35773609

35783610

35793611
@impl(ct.matmul)
3580-
@impl(operator.matmul)
3612+
@impl(operator.matmul, overload=(TileTy, TileTy))
35813613
def matmul_impl(x: Var, y: Var) -> Var:
35823614
x_tile_type = require_tile_type(x)
35833615
y_tile_type = require_tile_type(y)

src/cuda/tile/_passes/hir2ir.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
import sys
66
from contextlib import contextmanager
77
import dataclasses
8-
from types import BuiltinFunctionType, FunctionType
9-
from typing import Sequence, Mapping
8+
from typing import Sequence, Mapping, Callable
109

1110
from .ast2hir import get_function_hir
1211
from .. import TileTypeError
@@ -205,7 +204,11 @@ async def _call_user_defined(callee_hir: hir.Function,
205204
return ret
206205

207206

208-
async def _call_function(callee: FunctionType | BuiltinFunctionType,
207+
async def call_function(callee: Callable, *args: Var, **kwargs: Var):
208+
return await _call_function(callee, args, kwargs, ir.Builder.get_current())
209+
210+
211+
async def _call_function(callee: Callable,
209212
args: Sequence[Var],
210213
kwargs: Mapping[str, Var],
211214
builder: ir.Builder):

0 commit comments

Comments
 (0)