Skip to content

Commit f3768bf

Browse files
committed
Support is not
Signed-off-by: Qiqi Xiao <qiqix@nvidia.com>
1 parent 71d35f4 commit f3768bf

5 files changed

Lines changed: 52 additions & 22 deletions

File tree

src/cuda/tile/_ast2ir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def _binop_expr(binop: ast.BinOp, block: ir.Block, ctx: _Context) -> ir.Var:
186186

187187
_cmp_map = {
188188
ast.Eq: operator.eq, ast.NotEq: operator.ne, ast.Lt: operator.lt, ast.LtE: operator.le,
189-
ast.Gt: operator.gt, ast.GtE: operator.ge, ast.Is: operator.is_
189+
ast.Gt: operator.gt, ast.GtE: operator.ge, ast.Is: operator.is_, ast.IsNot: operator.is_not,
190190
}
191191

192192

src/cuda/tile/_ir/ops.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -726,16 +726,22 @@ def comparison(fn: str, x: Var, y: Var) -> Var:
726726
return raw_comparison(fn, x, y)
727727

728728

729+
def _is_none_compare(x: Var, y: Var, *, negate: bool, op_name: str) -> Var:
730+
x_is_none = x.get_type() is NONE
731+
y_is_none = y.get_type() is NONE
732+
if not (x_is_none or y_is_none):
733+
raise TileTypeError(f"Operator '{op_name}' expects one of the operands to be None")
734+
return loosely_typed_const((x_is_none == y_is_none) ^ negate)
735+
736+
729737
@impl(operator.is_)
730738
def operator_is_impl(x: Var, y: Var):
731-
x_ty = x.get_type()
732-
y_ty = y.get_type()
733-
if x_ty is NONE:
734-
return loosely_typed_const(y_ty is NONE)
735-
elif y_ty is NONE:
736-
return loosely_typed_const(x_ty is NONE)
737-
else:
738-
raise TileTypeError("Operator 'is' expects one of the operands to be None")
739+
return _is_none_compare(x, y, negate=False, op_name="is")
740+
741+
742+
@impl(operator.is_not)
743+
def operator_is_not_impl(x: Var, y: Var):
744+
return _is_none_compare(x, y, negate=True, op_name="is not")
739745

740746

741747
@impl(operator.eq, fixed_args=["eq"])

src/cuda/tile/_ir/typing_support.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def wrapped(handler: TypeHandler):
103103
operator.gt: lambda x, y, /: None,
104104
operator.ge: lambda x, y, /: None,
105105
operator.is_: lambda x, y, /: None,
106+
operator.is_not: lambda x, y, /: None,
106107
operator.invert: lambda x, /: None,
107108
operator.not_: lambda x, /: None,
108109
operator.pos: lambda x, /: None,

test/test_binary_elementwise.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,34 @@ def test_array_compare(shape, tile, dtype, op_symbol, op_func, tmp_path):
278278
assert_equal(z, ref)
279279

280280

281+
def make_is_operator_kernel(cmp):
282+
@ct.kernel
283+
def is_operator(x):
284+
bid = ct.bid(0)
285+
a = 1 if cmp is None else -1
286+
ct.store(x, index=(bid,), tile=a)
287+
return is_operator
288+
289+
290+
def make_is_not_operator_kernel(cmp):
291+
@ct.kernel
292+
def is_not_operator(x):
293+
bid = ct.bid(0)
294+
a = -1 if cmp is not None else 1
295+
ct.store(x, index=(bid,), tile=a)
296+
return is_not_operator
297+
298+
299+
@pytest.mark.parametrize("make_kernel", [make_is_operator_kernel, make_is_not_operator_kernel])
300+
@pytest.mark.parametrize("cmp", [None, 1])
301+
def test_is_or_not_operator(make_kernel, cmp):
302+
x = torch.zeros((1,), dtype=torch.int32, device='cuda')
303+
kernel = make_kernel(cmp)
304+
ct.launch(torch.cuda.current_stream(), (1, 1, 1), kernel, (x, ))
305+
ref = 1 if cmp is None else -1
306+
assert_equal(x, torch.tensor([ref], dtype=torch.int32, device='cuda'))
307+
308+
281309
@pytest.mark.parametrize("max_func", ["max", "ct.maximum"])
282310
@pytest.mark.parametrize("dtype", int_dtypes + float_dtypes, ids=dtype_id)
283311
def test_array_max(shape, tile, dtype, tmp_path, max_func):

test/test_constfold.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,24 +95,19 @@ def kernel():
9595
compile_tile(kernel, (), CompilerOptions())
9696

9797

98-
def test_is_op_on_constant():
99-
100-
def kernel():
101-
None is None
102-
None is ct.bid(0)
103-
ct.bid(0) is None
104-
105-
compile_tile(kernel, (), CompilerOptions())
106-
107-
108-
def test_is_op_on_none_constant():
98+
@pytest.mark.parametrize("negate", [False, True])
99+
def test_is_or_not_op_on_none_constant(negate):
109100

110101
def kernel():
111102
tx = ct.full((1,), 0, ct.float32)
112103
ty = ct.full((1,), 0, ct.float32)
113-
tx is ty
104+
if negate:
105+
tx is not ty
106+
else:
107+
tx is ty
114108

115-
msg = re.escape("Operator 'is' expects one of the operands to be None")
109+
op_name = 'is not' if negate else 'is'
110+
msg = re.escape(f"Operator '{op_name}' expects one of the operands to be None")
116111
with pytest.raises(TileTypeError, match=msg):
117112
compile_tile(kernel, (), CompilerOptions())
118113

0 commit comments

Comments
 (0)