Skip to content

Commit 5aced39

Browse files
committed
Support Enum Type construction and comparision
Signed-off-by: Qiqi Xiao <qiqix@nvidia.com>
1 parent c168893 commit 5aced39

6 files changed

Lines changed: 255 additions & 6 deletions

File tree

changelog.d/enum.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Added comparison and constructor support for Python ``enum.Enum`` inside kernels. Enum members can now be compared with ``==`` / ``!=`` and constructed from a constant value (e.g. ``Color(0)``).

experimental/cuda-lang/test/examples/test_wmma.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def kernel(a, b, c, d):
338338
if tid >= 32:
339339
return
340340

341-
if static_eval(config.acc_layout == wmma.Layout.COL_MAJOR):
341+
if config.acc_layout == wmma.Layout.COL_MAJOR:
342342
c_ptr = c.get_element_pointer((tile_col, tile_row))
343343
acc_ldm = global_m
344344
else:
@@ -353,14 +353,14 @@ def kernel(a, b, c, d):
353353
)
354354

355355
for kk in range(0, global_k, wmma_k):
356-
if static_eval(config.a_layout == wmma.Layout.COL_MAJOR):
356+
if config.a_layout == wmma.Layout.COL_MAJOR:
357357
a_ptr = a.get_element_pointer((kk, tile_row))
358358
a_ldm = global_m
359359
else:
360360
a_ptr = a.get_element_pointer((tile_row, kk))
361361
a_ldm = global_k
362362

363-
if static_eval(config.b_layout == wmma.Layout.ROW_MAJOR):
363+
if config.b_layout == wmma.Layout.ROW_MAJOR:
364364
b_ptr = b.get_element_pointer((kk, tile_col))
365365
b_ldm = global_n
366366
else:
@@ -374,7 +374,7 @@ def kernel(a, b, c, d):
374374
config.satf,
375375
)
376376

377-
if static_eval(config.acc_layout == wmma.Layout.COL_MAJOR):
377+
if config.acc_layout == wmma.Layout.COL_MAJOR:
378378
d_ptr = d.get_element_pointer((tile_col, tile_row))
379379
else:
380380
d_ptr = d.get_element_pointer((tile_row, tile_col))

src/cuda/tile/_ir/core_ops.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
DataclassInfo, DataclassTy, DataclassValue, BoundMethodValue, BoundMethodTy, InvalidType, \
2828
ContextManagerTy, ContextManagerLifecycle, LiveCapturedScope, ClosureTy, ClosureValue, \
2929
RangeIterType, RangeValue, TypeTy, ModuleTy, NONE, SliceType, StringTy, FormattedStringTy, \
30-
StringFormat, FormattedStringValue, FormattedPiece, DictTy, DictValue
30+
StringFormat, FormattedStringValue, FormattedPiece, DictTy, DictValue, EnumTy
3131
from cuda.tile._ir.typing_support import type_of_constant_python_value, \
3232
loose_type_of_constant_python_value, get_dataclass_info, as_third_party_dtype_spec
3333
from cuda.tile._ir2bytecode import BytecodeContext
@@ -648,6 +648,16 @@ async def getattr_dataclass_impl(object: Var, name: Var):
648648
else:
649649
return sym2var(cls_attr, constant_only=True)
650650

651+
652+
@impl(getattr, overload=(EnumTy, "name"))
653+
def getattr_enum_name_impl(object: Var, name: Var):
654+
return sym2var(object.get_constant().name)
655+
656+
657+
@impl(getattr, overload=(EnumTy, "value"))
658+
def getattr_enum_value_impl(object: Var, name: Var):
659+
return sym2var(object.get_constant().value)
660+
651661
# ===========================================================================================
652662

653663

@@ -707,6 +717,12 @@ def comparison_string_impl(fn: str, x: Var, y: Var):
707717
return binop_propagate_constant(fn, x.get_type().value, y.get_type().value, None)
708718

709719

720+
@comparison_operator_impl(_registry, EnumTy, EnumTy)
721+
def comparison_enum_impl(fn: str, x: Var, y: Var):
722+
from cuda.tile._ir.arithmetic_ops import binop_propagate_constant
723+
return binop_propagate_constant(fn, x.get_constant(), y.get_constant(), None)
724+
725+
710726
# ===========================================================================================
711727
# Print
712728
# ===========================================================================================
@@ -787,6 +803,9 @@ def _expand_var(var: Var | str, format_spec: str | None = None,
787803
leaf_vars.append(var)
788804
elif isinstance(ty, DTypeSpec):
789805
format_parts.append(str(ty.dtype))
806+
elif isinstance(ty, EnumTy):
807+
member = var.get_constant()
808+
format_parts.append(f"{ty.enum_ty.__name__}.{member.name}")
790809
else:
791810
raise TileTypeError(f"Can't print value of type {ty}")
792811

src/cuda/tile/_passes/hir2ir.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
import sys
66
from contextlib import contextmanager
77
import dataclasses
8+
9+
from enum import Enum
810
from typing import Sequence, Mapping, Callable
911

1012
from .ast2hir import get_function_hir
1113
from .. import TileTypeError
1214
from .._coroutine_util import resume_after, run_coroutine
1315
from .._datatype import PointerInfo
14-
from .._exception import Loc, FunctionDesc, TileInternalError, TileError, TileRecursionError
16+
from .._exception import Loc, FunctionDesc, TileInternalError, TileError, TileRecursionError, \
17+
TileValueError
1518
from .._execution import is_stub
1619
from .._ir import hir, ir
1720
from .._ir.ir import Var, IRContext
@@ -302,6 +305,17 @@ async def call(callee_var: Var, args, kwargs) -> Var | None:
302305
elif isinstance(callee_ty, TypeTy) and callee_ty.ty is PointerInfo:
303306
arg_list = _bind_args(_POINTER_INFO_SIGNATURE, "PointerInfo", args, kwargs)
304307
return await _call_builtin(PointerInfo, arg_list, builder)
308+
elif isinstance(callee_ty, TypeTy) and issubclass(callee_ty.ty, Enum):
309+
if len(args) != 1 or kwargs:
310+
raise TileTypeError("Enum constructor takes exactly one positional argument")
311+
arg = args[0]
312+
if not arg.is_constant():
313+
raise TileTypeError("Enum constructor argument must be a constant")
314+
val = arg.get_constant()
315+
try:
316+
return loosely_typed_const(callee_ty.ty(val))
317+
except ValueError:
318+
raise TileValueError(f"{val!r} is not a valid {callee_ty.ty.__name__}")
305319
else:
306320
raise TileTypeError(f"Cannot call an object of type {callee_ty}")
307321

test/test_enum.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from enum import Enum, IntEnum
6+
7+
import pytest
8+
import torch
9+
10+
import cuda.tile as ct
11+
from cuda.tile import TileTypeError, TileValueError
12+
13+
14+
class Color(Enum):
15+
RED = 0
16+
GREEN = 1
17+
BLUE = 2
18+
19+
20+
class Status(Enum):
21+
OK = "ok"
22+
ERROR = "error"
23+
24+
25+
class Weight(Enum):
26+
LIGHT = 0.5
27+
HEAVY = 2.0
28+
29+
30+
class Priority(IntEnum):
31+
LOW = 0
32+
MEDIUM = 1
33+
HIGH = 2
34+
35+
36+
def test_comparison_eq():
37+
@ct.kernel
38+
def kernel(out):
39+
x = Color.RED
40+
if x == Color.RED:
41+
ct.scatter(out, (), 1)
42+
else:
43+
ct.scatter(out, (), -1)
44+
45+
out = torch.zeros((), dtype=torch.int32, device="cuda")
46+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (out,))
47+
assert out.item() == 1
48+
49+
50+
def test_comparison_not_equal():
51+
@ct.kernel
52+
def kernel(out):
53+
x = Color.RED
54+
if x != Color.GREEN:
55+
ct.scatter(out, (), 1)
56+
else:
57+
ct.scatter(out, (), -1)
58+
59+
out = torch.zeros((), dtype=torch.int32, device="cuda")
60+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (out,))
61+
assert out.item() == 1
62+
63+
64+
def test_construction_from_known_int():
65+
@ct.kernel
66+
def kernel(out):
67+
i = 0
68+
x = Color(i)
69+
if x == Color.RED:
70+
ct.scatter(out, (), 10)
71+
elif x == Color.GREEN:
72+
ct.scatter(out, (), 20)
73+
else:
74+
ct.scatter(out, (), 30)
75+
76+
out = torch.zeros((), dtype=torch.int32, device="cuda")
77+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (out,))
78+
assert out.item() == 10
79+
80+
81+
def test_construction_from_string_value():
82+
@ct.kernel
83+
def kernel(out):
84+
x = Status("ok")
85+
if x == Status.OK:
86+
ct.scatter(out, (), 1)
87+
else:
88+
ct.scatter(out, (), 0)
89+
90+
out = torch.zeros((), dtype=torch.int32, device="cuda")
91+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (out,))
92+
assert out.item() == 1
93+
94+
95+
def test_construction_from_float_value():
96+
@ct.kernel
97+
def kernel(out):
98+
x = Weight(0.5)
99+
if x == Weight.LIGHT:
100+
ct.scatter(out, (), 1)
101+
else:
102+
ct.scatter(out, (), 0)
103+
104+
out = torch.zeros((), dtype=torch.int32, device="cuda")
105+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (out,))
106+
assert out.item() == 1
107+
108+
109+
def test_intenum_ordering():
110+
@ct.kernel
111+
def kernel(out):
112+
if Priority.LOW < Priority.HIGH:
113+
ct.scatter(out, (), 1)
114+
else:
115+
ct.scatter(out, (), -1)
116+
117+
out = torch.zeros((), dtype=torch.int32, device="cuda")
118+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (out,))
119+
assert out.item() == 1
120+
121+
122+
# ===========================================================================
123+
# Error cases
124+
# ===========================================================================
125+
126+
def test_construction_from_runtime_value_raises():
127+
@ct.kernel
128+
def kernel(x, out):
129+
bid = ct.bid(0)
130+
_ = Color(bid)
131+
ct.scatter(out, (), 0)
132+
133+
x = torch.zeros(1, device="cuda")
134+
out = torch.zeros((), dtype=torch.int32, device="cuda")
135+
with pytest.raises(TileTypeError):
136+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x, out))
137+
138+
139+
@pytest.mark.parametrize("invalid_value", ["foo", 99])
140+
def test_construction_from_invalid_type_or_value_raises(invalid_value):
141+
@ct.kernel
142+
def kernel(out):
143+
_ = Color(invalid_value)
144+
ct.scatter(out, (), 0)
145+
146+
out = torch.zeros((), dtype=torch.int32, device="cuda")
147+
with pytest.raises(TileValueError):
148+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (out,))
149+
150+
151+
@pytest.mark.parametrize("enum_value", [Color.BLUE, Status.ERROR, Weight.HEAVY])
152+
def test_name_attribute(enum_value):
153+
name = enum_value.name
154+
155+
@ct.kernel
156+
def kernel(out):
157+
if enum_value.name == name:
158+
ct.scatter(out, (), 1)
159+
160+
out = torch.zeros((), dtype=torch.int32, device="cuda")
161+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (out,))
162+
assert out.item() == 1
163+
164+
165+
@pytest.mark.parametrize("enum_value", [Color.BLUE, Status.ERROR, Weight.HEAVY])
166+
def test_value_attribute(enum_value):
167+
value = enum_value.value
168+
169+
@ct.kernel
170+
def kernel(out):
171+
if enum_value.value == value:
172+
ct.scatter(out, (), 1)
173+
174+
out = torch.zeros((), dtype=torch.int32, device="cuda")
175+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (out,))
176+
assert out.item() == 1
177+
178+
179+
def test_enum_ordering_raises():
180+
@ct.kernel
181+
def kernel(out):
182+
if Color.RED < Color.GREEN:
183+
ct.scatter(out, (), 1)
184+
else:
185+
ct.scatter(out, (), -1)
186+
187+
out = torch.zeros((), dtype=torch.int32, device="cuda")
188+
with pytest.raises(TileTypeError):
189+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (out,))

test/test_print.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import ctypes
77
import sys
88
import traceback
9+
from enum import Enum
910
import torch
1011
import numpy as np
1112
import pytest
@@ -239,6 +240,19 @@ def kernel_print_dtype(x, TILE: ct.Constant[int]):
239240
print(f"dtype = {tx.dtype}")
240241

241242

243+
class _PrintColor(Enum):
244+
RED = 0
245+
GREEN = 1
246+
247+
248+
@ct.kernel(opt_level=_OPT_LEVEL)
249+
def kernel_print_enum(x, TILE: ct.Constant[int]):
250+
ct.print(_PrintColor.GREEN)
251+
print(_PrintColor.GREEN)
252+
ct.print(f"color = {_PrintColor.RED}")
253+
print(f"color = {_PrintColor.RED}")
254+
255+
242256
_KERNELS_MAP_ = {
243257
"kernel_printf": kernel_printf,
244258
"kernel_print": kernel_print,
@@ -257,6 +271,7 @@ def kernel_print_dtype(x, TILE: ct.Constant[int]):
257271
"kernel_print_tuple_w_tile": kernel_print_tuple_w_tile,
258272
"kernel_print_tuple_tile_shape": kernel_print_tuple_tile_shape,
259273
"kernel_print_dtype": kernel_print_dtype,
274+
"kernel_print_enum": kernel_print_enum,
260275
}
261276

262277

@@ -471,6 +486,14 @@ def test_ct_print_dtype():
471486
assert actual_outs[3] == "dtype = float32" # print(f"dtype = {tx.dtype}")
472487

473488

489+
def test_ct_print_enum():
490+
actual_outs = _run_kernel(kernel_print_enum, (8,), "float32", 8)
491+
assert actual_outs[0] == "_PrintColor.GREEN" # ct.print(_PrintColor.GREEN)
492+
assert actual_outs[1] == "_PrintColor.GREEN" # print(_PrintColor.GREEN)
493+
assert actual_outs[2] == "color = _PrintColor.RED" # ct.print(f"color = {_PrintColor.RED}")
494+
assert actual_outs[3] == "color = _PrintColor.RED" # print(f"color = {_PrintColor.RED}")
495+
496+
474497
if __name__ == "__main__":
475498
if len(sys.argv) > 1 and sys.argv[1] == "start_kernel_runner":
476499
_kernel_runner_main()

0 commit comments

Comments
 (0)