Skip to content

Commit ebaa570

Browse files
committed
Support compare operations on tuples
Signed-off-by: Qiqi Xiao <qiqix@nvidia.com>
1 parent f31c762 commit ebaa570

2 files changed

Lines changed: 168 additions & 0 deletions

File tree

src/cuda/tile/_ir/ops.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,42 @@ def operator_is_not_impl(x: Var, y: Var):
797797
return _is_none_compare(x, y, negate=True, op_name="is not")
798798

799799

800+
def _tuple_comparison(fn: str, x: Var, y: Var) -> Var:
801+
if fn not in ("eq", "ne"):
802+
raise TileTypeError(f"Operator '{fn}' is not supported for tuples")
803+
804+
x_ty = require_tuple_type(x)
805+
y_ty = require_tuple_type(y)
806+
807+
if x.is_constant() and y.is_constant():
808+
res = x.get_constant() == y.get_constant()
809+
return loosely_typed_const(res if fn == "eq" else not res)
810+
811+
if len(x_ty) != len(y_ty):
812+
return loosely_typed_const(fn == "ne")
813+
814+
x_items = x.get_aggregate().items
815+
y_items = y.get_aggregate().items
816+
817+
for item in (*x_items, *y_items):
818+
item_ty = item.get_type()
819+
if isinstance(item_ty, TileTy) and item_ty.ndim > 0:
820+
raise TileTypeError("Tuple comparison is not supported for N-D tile elements")
821+
if not isinstance(item_ty, (TileTy, TupleTy, LooselyTypedScalar, DTypeSpec, StringTy)):
822+
raise TileTypeError(
823+
f"Tuple comparison is not supported for elements of type {item_ty}"
824+
)
825+
826+
elem_cmps = [comparison_operator_impl("eq", xi, yi) for xi, yi in zip(x_items, y_items)]
827+
result = functools.reduce(lambda a, b: binary_bitwise("and_", a, b), elem_cmps,
828+
loosely_typed_const(True))
829+
830+
if fn == "ne":
831+
result = logical_not_impl(result)
832+
833+
return result
834+
835+
800836
@impl(operator.eq, fixed_args=["eq"])
801837
@impl(operator.ne, fixed_args=["ne"])
802838
@impl(operator.lt, fixed_args=["lt"])
@@ -812,6 +848,8 @@ def comparison_operator_impl(fn: str, x: Var, y: Var) -> Var:
812848
return _binop_propagate_constant(fn, x_ty.dtype, y_ty.dtype, None)
813849
case StringTy(), StringTy():
814850
return _binop_propagate_constant(fn, x_ty.value, y_ty.value, None)
851+
case TupleTy(), TupleTy():
852+
return _tuple_comparison(fn, x, y)
815853
case _, _:
816854
return comparison(fn, x, y)
817855

test/test_tuple.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,133 @@ def kernel():
135135

136136
with pytest.raises(TileTypeError, match=re.escape("Expected a tuple after *")):
137137
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
138+
139+
140+
def test_tuple_compare_empty_eq():
141+
@ct.kernel
142+
def kernel(x):
143+
if () == ():
144+
ct.scatter(x, (), 1)
145+
else:
146+
ct.scatter(x, (), 0)
147+
148+
x = torch.zeros((), dtype=torch.int32, device="cuda")
149+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x,))
150+
assert x.item() == 1
151+
152+
153+
def test_tuple_compare_constants_eq():
154+
@ct.kernel
155+
def kernel(x):
156+
if (1, 2, 3) == (1, 2, 3):
157+
ct.scatter(x, (), 1)
158+
else:
159+
ct.scatter(x, (), 0)
160+
161+
x = torch.zeros((), dtype=torch.int32, device="cuda")
162+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x,))
163+
assert x.item() == 1
164+
165+
166+
def test_tuple_compare_constants_ne():
167+
@ct.kernel
168+
def kernel(x):
169+
if (1, 2) != (1, 3):
170+
ct.scatter(x, (), 1)
171+
else:
172+
ct.scatter(x, (), 0)
173+
174+
x = torch.zeros((), dtype=torch.int32, device="cuda")
175+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x,))
176+
assert x.item() == 1
177+
178+
179+
def test_tuple_compare_different_lengths():
180+
@ct.kernel
181+
def kernel(x):
182+
a = ct.bid(0)
183+
if (a, 1) != (a, 1, 2):
184+
ct.scatter(x, (), 1)
185+
else:
186+
ct.scatter(x, (), 0)
187+
188+
x = torch.zeros((), dtype=torch.int32, device="cuda")
189+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x,))
190+
assert x.item() == 1
191+
192+
193+
def test_tuple_compare_0d_tiles_eq():
194+
@ct.kernel
195+
def kernel(x):
196+
a = ct.bid(0)
197+
b = ct.bid(1)
198+
if (a, b) == (0, 0):
199+
ct.scatter(x, (a, b), 1)
200+
else:
201+
ct.scatter(x, (a, b), -1)
202+
203+
x = torch.zeros((2, 2), dtype=torch.int32, device="cuda")
204+
ct.launch(torch.cuda.current_stream(), (2, 2), kernel, (x,))
205+
assert x.tolist() == [[1, -1], [-1, -1]]
206+
207+
208+
def test_tuple_compare_nd_tile_error():
209+
@ct.kernel
210+
def kernel():
211+
t = ct.ones((4,), dtype=ct.int32)
212+
if (t,) == (t,):
213+
pass
214+
215+
with pytest.raises(TileTypeError, match="not supported for N-D tile elements"):
216+
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
217+
218+
219+
def test_tuple_compare_unsupported_op():
220+
@ct.kernel
221+
def kernel():
222+
if (1, 2) < (3, 4):
223+
pass
224+
225+
with pytest.raises(TileTypeError, match="not supported for tuples"):
226+
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
227+
228+
229+
def test_tuple_compare_nested():
230+
@ct.kernel
231+
def kernel(x):
232+
a = ct.bid(0)
233+
if ((a, 1), 2) == ((0, 1), 2):
234+
ct.scatter(x, (a, ), 1)
235+
else:
236+
ct.scatter(x, (a, ), -1)
237+
238+
x = torch.zeros((2, ), dtype=torch.int32, device="cuda")
239+
ct.launch(torch.cuda.current_stream(), (2, ), kernel, (x,))
240+
assert x.tolist() == [1, -1]
241+
242+
243+
def test_tuple_compare_array_element_error():
244+
@ct.kernel
245+
def kernel(x, y):
246+
if (x,) == (y,):
247+
pass
248+
249+
with pytest.raises(TileTypeError, match="not supported for elements of type"):
250+
ct.launch(torch.cuda.current_stream(), (1,), kernel,
251+
(torch.zeros(4, dtype=torch.int32, device="cuda"),
252+
torch.zeros(4, dtype=torch.int32, device="cuda")))
253+
254+
255+
def test_tuple_compare_constant_args():
256+
@ct.kernel
257+
def kernel(x, M: ct.Constant[int], N: ct.Constant[int]):
258+
if (M, N) == (4, 8):
259+
ct.scatter(x, (), 1)
260+
else:
261+
ct.scatter(x, (), -1)
262+
263+
x = torch.zeros((), dtype=torch.int32, device="cuda")
264+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x, 4, 8))
265+
assert x.item() == 1
266+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x, 4, 9))
267+
assert x.item() == -1

0 commit comments

Comments
 (0)