Skip to content

Commit 64f1100

Browse files
committed
Clean up constant creation logic
- Allow only non-aggregate types in loosely_typed_constant(). - Aggregate types must go through sym2var(), which now gets an optional `constant_only` flag. This makes sure we also support dataclasses and other aggregate types. - This enables frozen dataclasses as global constants. - Get rid of get_constant_value(), unify with typeof_pyval(). Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent c0e176f commit 64f1100

8 files changed

Lines changed: 107 additions & 62 deletions

File tree

changelog.d/global-dataclass.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Allow frozen dataclass instances to be used as globals in device code.

src/cuda/tile/_ir/ops.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@
5959
)
6060
from .scope import Scope, JumpInfo, ControlFlowInfo
6161
from .typing_support import (
62-
typeof_pyval, loose_type_of_pyval, get_constant_value, get_dataclass_info,
62+
get_dataclass_info, as_third_party_dtype_spec, type_of_constant_python_value,
63+
loose_type_of_constant_python_value,
6364
)
6465
from .type import (
6566
PartitionViewTy, StridedViewTy, GatherScatterViewTy, TupleTy, TileTy, NoneType,
@@ -72,7 +73,7 @@
7273
)
7374
from cuda.tile._datatype import (
7475
DType, is_integral, is_float, is_signed, is_boolean, is_pointer_dtype, PointerInfo,
75-
opaque_pointer_dtype, pointer_dtype,
76+
opaque_pointer_dtype, pointer_dtype
7677
)
7778
from cuda.tile._ir2bytecode import (
7879
BytecodeContext, typeid,
@@ -651,21 +652,29 @@ def loosely_typed_const(value: Any,
651652
ty: Optional[Type] = None,
652653
loose_ty: Optional[Type] = None,
653654
name: str | None = None) -> Var:
655+
builder = Builder.get_current()
654656
if ty is None:
655-
if isinstance(value, tuple):
656-
return build_tuple(tuple(loosely_typed_const(item) for item in value))
657-
ty = typeof_pyval(value)
658-
ret = strictly_typed_const(value, ty, name=name)
657+
ty = type_of_constant_python_value(value, builder.ir_ctx.typing_hooks)
658+
assert not ty.is_aggregate(), "Use sym2var(value, constant_only=True) instead"
659+
660+
# Normalize third party dtype spec objects (e.g. torch.float32 -> ct.float32)
661+
if isinstance(ty, DTypeSpec):
662+
value = ty.dtype
663+
664+
ret = _strictly_typed_const_inner(builder, value, ty, name=name)
659665
if loose_ty is None:
660-
loose_ty = loose_type_of_pyval(value)
666+
loose_ty = loose_type_of_constant_python_value(value, builder.ir_ctx.typing_hooks)
661667
ret.set_loose_type(loose_ty)
662668
return ret
663669

664670

665671
def strictly_typed_const(value: Any, ty: Type, name: str | None = None) -> Var:
666-
builder = Builder.get_current()
667-
result = None if name is None else builder.ir_ctx.make_var(name, builder.loc)
672+
return _strictly_typed_const_inner(Builder.get_current(), value, ty, name)
668673

674+
675+
def _strictly_typed_const_inner(builder: Builder,
676+
value: Any, ty: Type, name: str | None = None) -> Var:
677+
result = None if name is None else builder.ir_ctx.make_var(name, builder.loc)
669678
ret = builder.add_operation(TypedConst, ty, dict(value=value), result=result)
670679
if not isinstance(ty, TileTy) or ty.ndim == 0:
671680
# We currently don't have a way to represent an N-dimensional tile constant
@@ -1893,7 +1902,7 @@ def getattr_tile_dtype_impl(object: Var, name: Var):
18931902

18941903
@impl(getattr, overload=(TileTy, "shape"))
18951904
def getattr_tile_shape_impl(object: Var, name: Var):
1896-
return loosely_typed_const(object.get_type().shape)
1905+
return sym2var(object.get_type().shape, constant_only=True)
18971906

18981907

18991908
@impl(getattr, overload=(TileTy, "ndim"))
@@ -1924,12 +1933,12 @@ def getattr_tiled_view_dtype_impl(object: Var, name: Var):
19241933

19251934
@impl(getattr, overload=(TiledViewTy, "tile_shape"))
19261935
def getattr_tiled_view_tile_shape_impl(object: Var, name: Var):
1927-
return loosely_typed_const(object.get_type().tile_shape)
1936+
return sym2var(object.get_type().tile_shape, constant_only=True)
19281937

19291938

19301939
@impl(getattr, overload=(TiledViewTy, "traversal_steps"))
19311940
def getattr_tiled_view_traversal_steps_impl(object: Var, name: Var):
1932-
return loosely_typed_const(object.get_type().traversal_steps)
1941+
return sym2var(object.get_type().traversal_steps, constant_only=True)
19331942

19341943

19351944
@impl(getattr, overload=(TiledViewTy, "num_tiles"))
@@ -1981,7 +1990,7 @@ def getattr_module_impl(object: Var, name: Var):
19811990
ty = object.get_type()
19821991
attr_name = require_constant_str(name)
19831992
try:
1984-
return loosely_typed_const(getattr(ty.py_mod, attr_name))
1993+
return sym2var(getattr(ty.py_mod, attr_name), constant_only=True)
19851994
except AttributeError:
19861995
raise TileTypeError(f"Module '{ty.py_mod.__name__}' has no attribute '{attr_name}'")
19871996

@@ -1991,7 +2000,7 @@ def getattr_type_impl(object: Var, name: Var):
19912000
ty = object.get_type()
19922001
attr_name = require_constant_str(name)
19932002
try:
1994-
return loosely_typed_const(getattr(ty.ty, attr_name))
2003+
return sym2var(getattr(ty.ty, attr_name), constant_only=True)
19952004
except AttributeError:
19962005
raise TileTypeError(f"'{ty.ty.__name__}' object has no attribute '{attr_name}'")
19972006

@@ -2023,7 +2032,7 @@ async def getattr_dataclass_impl(object: Var, name: Var):
20232032
getter = loosely_typed_const(cls_attr.fget)
20242033
return await call(getter, (object,), {})
20252034
else:
2026-
return loosely_typed_const(cls_attr)
2035+
return sym2var(cls_attr, constant_only=True)
20272036

20282037

20292038
# ===========================================================================================
@@ -2058,7 +2067,11 @@ def assign(value: Var, res: Var) -> None:
20582067
@impl(hir_stubs.identity)
20592068
def identity_impl(x: Var) -> Var:
20602069
if x.is_constant():
2061-
return loosely_typed_const(x.get_constant(), x.get_type(), x.get_loose_type())
2070+
ty = x.get_type()
2071+
if ty.is_aggregate():
2072+
return make_aggregate(x.get_aggregate(), ty, x.get_loose_type())
2073+
else:
2074+
return loosely_typed_const(x.get_constant(), x.get_type(), x.get_loose_type())
20622075
else:
20632076
return x
20642077

@@ -5268,8 +5281,7 @@ def load_var_impl(name):
52685281
return ret
52695282
elif rn.index >= 0:
52705283
val = scope.func_hir.frozen_global_values[rn.index]
5271-
val = get_constant_value(val)
5272-
return loosely_typed_const(val)
5284+
return sym2var(val, constant_only=True)
52735285
else:
52745286
raise TileSyntaxError(f"Undefined variable {name} used")
52755287

@@ -5450,30 +5462,38 @@ async def static_foreach_impl(body: hir.Block, items: Var):
54505462
await dispatch_hir_block(body)
54515463

54525464

5453-
def sym2var(x: Any) -> Var:
5465+
def sym2var(x: Any, constant_only: bool = False) -> Var:
54545466
# TODO: verify we don't have a stale closure
54555467

54565468
if isinstance(x, Symbol):
5469+
if constant_only:
5470+
raise TileTypeError("Cannot create a constant from a symbolic value")
54575471
return x._var
54585472

54595473
if isinstance(x, tuple):
5460-
return build_tuple(tuple(sym2var(item) for item in x))
5474+
return build_tuple(tuple(sym2var(item, constant_only=constant_only) for item in x))
54615475

54625476
cls = type(x)
54635477
if dataclasses.is_dataclass(cls):
54645478
info = get_dataclass_info(cls)
5465-
field_vars = tuple(sym2var(getattr(x, f.name))
5479+
field_vars = tuple(sym2var(getattr(x, f.name), constant_only=constant_only)
54665480
for f in dataclasses.fields(cls))
54675481
return build_dataclass_instance(field_vars, info)
54685482

54695483
if isinstance(x, MethodType):
5470-
self_var = sym2var(x.__self__)
5484+
self_var = sym2var(x.__self__, constant_only=constant_only)
54715485
if not isinstance(x.__func__, FunctionType | BuiltinFunctionType):
54725486
raise TileTypeError(f"Object of type {type(x).__name__}"
54735487
f" cannot be used as a function for binding a method")
54745488
return bind_method(self_var, x.__func__)
54755489

5476-
x = get_constant_value(x)
5490+
# Transform a third party typed scalar (e.g., np.int16(5)) into a strictly typed constant
5491+
dtype_spec = as_third_party_dtype_spec(type(x))
5492+
if dtype_spec is not None:
5493+
pyval = datatype.numeric_dtype_category(dtype_spec.dtype).pytype(x)
5494+
ty = Builder.get_current().ir_ctx.typing_hooks.get_tensor_like_type(dtype_spec.dtype, ())
5495+
return strictly_typed_const(pyval, ty)
5496+
54775497
return loosely_typed_const(x)
54785498

54795499

src/cuda/tile/_ir/typing_support.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44
import inspect
55
import operator
66
import dataclasses
7-
from enum import Enum, IntEnum
7+
from enum import Enum
88
from functools import lru_cache
99
from types import ModuleType, FunctionType
1010
from typing import Any, Callable, Mapping, Union
1111

1212
from cuda.tile import _datatype as datatype, DType
1313
from cuda.tile._exception import TileTypeError, TileValueError
14+
from .ir import TypingHooks
1415
from .type import DataclassInfo, PointerInfoTy
1516

16-
from .type import Type, TupleTy, DTypeConstructor, DTypeSpec, NONE, StringTy, \
17+
from .type import Type, DTypeConstructor, DTypeSpec, NONE, StringTy, \
1718
ELLIPSIS, SLICE, ModuleTy, FunctionTy, EnumTy, TypeTy, LooselyTypedScalar, \
1819
TileTy
1920

@@ -52,6 +53,10 @@ def is_dtype(x: Any):
5253
return isinstance(x, DType) or _safe_get(_dtype_registry, x) is not None
5354

5455

56+
def as_third_party_dtype_spec(x: Any) -> DTypeSpec | None:
57+
return _safe_get(_dtype_registry, x)
58+
59+
5560
def _is_dtype_allowed_as_constructor(dtype: DType) -> bool:
5661
# Only allow byte aligned numeric dtypes as constructors
5762
return datatype.is_numeric(dtype) and (dtype.bitwidth % 8 == 0)
@@ -158,19 +163,15 @@ def dtype_of_constant_scalar(val: bool | int | float) -> DType:
158163
raise TypeError(f'Python value {val} of type {type(val)} is not supported.')
159164

160165

161-
def typeof_pyval(val) -> Type:
166+
def type_of_constant_python_value(val, typing_hooks: TypingHooks) -> Type:
162167
if val is None:
163168
return NONE
164-
if (t := _safe_get(_dtype_registry, type(val))):
165-
return TileTy(t.dtype)
166169
if isinstance(val, bool | int | float):
167-
return TileTy(dtype_of_constant_scalar(val))
170+
return typing_hooks.get_tensor_like_type(dtype_of_constant_scalar(val), ())
168171
if isinstance(val, Enum):
169172
return EnumTy(type(val))
170173
if isinstance(val, str):
171174
return StringTy(val)
172-
if isinstance(val, tuple):
173-
return TupleTy(tuple(typeof_pyval(v) for v in val))
174175
if val is Ellipsis:
175176
return ELLIPSIS
176177
if isinstance(val, slice):
@@ -186,40 +187,23 @@ def typeof_pyval(val) -> Type:
186187
return DTypeConstructor(val)
187188
else:
188189
return DTypeSpec(val)
189-
if (t := _safe_get(_dtype_registry, val)) is not None:
190+
if (t := as_third_party_dtype_spec(val)) is not None:
190191
return t
191192
if isinstance(val, datatype.PointerInfo):
192193
return PointerInfoTy(val)
193-
194194
if isinstance(val, type):
195195
return TypeTy(val)
196196

197-
# TODO: should we add dlpack?
198-
raise TypeError(f'Python value {val} of type {type(val)} is not supported.')
197+
ty = type(val)
198+
prefix = "" if ty.__module__ == "builtins" else f"{ty.__module__}."
199+
raise TileTypeError(f"Cannot create constant from value of type {prefix}{ty.__qualname__}.")
199200

200201

201-
def loose_type_of_pyval(value: Any) -> Type:
202+
def loose_type_of_constant_python_value(value: Any, typing_hooks: TypingHooks) -> Type:
202203
if isinstance(value, bool | int | float):
203204
return LooselyTypedScalar(value)
204-
elif isinstance(value, tuple):
205-
return TupleTy(tuple(loose_type_of_pyval(x) for x in value))
206205
else:
207-
return typeof_pyval(value)
208-
209-
210-
_SUPPORTED_CONST_TYPES = (int, float, bool, str, ModuleType, FunctionType, type, Enum, IntEnum)
211-
212-
213-
def get_constant_value(val: Any) -> Any:
214-
if val is None or isinstance(val, _SUPPORTED_CONST_TYPES) or is_supported_builtin_func(val):
215-
return val
216-
if is_dtype(val):
217-
return to_dtype(val)
218-
if isinstance(val, tuple):
219-
return tuple(get_constant_value(x) for x in val)
220-
typ = type(val)
221-
prefix = "" if typ.__module__ == "builtins" else f"{typ.__module__}."
222-
raise TileTypeError(f"Cannot create constant from value of type {prefix}{typ.__qualname__}.")
206+
return type_of_constant_python_value(value, typing_hooks)
223207

224208

225209
@lru_cache

src/cuda/tile/_passes/hir2ir.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .._ir.ir import Var, IRContext
1919
from .._ir.op_impl import ImplRegistry
2020
from .._ir.ops import loosely_typed_const, end_branch, return_, continue_, \
21-
break_, store_var, build_dataclass_instance, build_tuple, dtype_constructor
21+
break_, store_var, build_dataclass_instance, build_tuple, dtype_constructor, sym2var
2222
from .._ir.scope import Scope, LocalScope, IntMap
2323
from .._ir.type import FunctionTy, BoundMethodTy, DTypeConstructor, ClosureTy, \
2424
ClosureDefaultPlaceholder, StringFormat, TypeTy, TupleTy, BoundMethodValue, TupleValue, \
@@ -390,7 +390,7 @@ def _resolve_operand(x: hir.Operand, scope: Scope) \
390390
elif isinstance(x, hir.Block | hir.Function | hir.StaticEvalExpression | StringFormat):
391391
return x
392392
else:
393-
return loosely_typed_const(x)
393+
return sym2var(x, constant_only=True)
394394

395395

396396
def _bind_args(sig: inspect.Signature, func_name: str, args, kwargs,
@@ -411,6 +411,6 @@ def _bind_args(sig: inspect.Signature, func_name: str, args, kwargs,
411411
assert closure_defaults is not None
412412
default = closure_defaults[param.default.default_value_index]
413413
else:
414-
default = loosely_typed_const(param.default)
414+
default = sym2var(param.default, constant_only=True)
415415
ret.append(default)
416416
return ret

test/test_dataclass.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,18 @@ def kern(x):
4949
assert x.tolist() == [2, 7, 30, 40, 5]
5050

5151

52+
def test_dataclass_global_capture():
53+
fb = FooBar(2, 7)
54+
55+
@ct.kernel
56+
def kern(x):
57+
ct.scatter(x, (), fb.foo)
58+
59+
x = torch.zeros((), dtype=torch.int32, device="cuda")
60+
ct.launch(torch.cuda.current_stream(), (1,), kern, (x,))
61+
assert x.item() == 2
62+
63+
5264
def test_dataclass_with_field_named_self():
5365
@dataclass(frozen=True)
5466
class Selfish:

test/test_ir_types.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44
import copy
5+
import functools
56
import pickle
67

78
import pytest
@@ -23,7 +24,7 @@
2324
IntegerInfo, opaque_pointer_dtype, pointer_dtype, PointerInfo,
2425
)
2526
from cuda.tile._ir.ops_utils import promote_dtypes, check_implicit_cast
26-
from cuda.tile._ir.typing_support import to_dtype, typeof_pyval
27+
from cuda.tile._ir.typing_support import to_dtype
2728
import torch
2829
import numpy as np
2930

@@ -338,12 +339,12 @@ def test_torch_dtype_support():
338339
assert to_dtype(torch.float8_e8m0fnu) == float8_e8m0fnu
339340

340341

341-
def test_typeof_pyval():
342-
tp = typeof_pyval
342+
def test_type_of_constant_python_value():
343+
from cuda.tile._ir.typing_support import type_of_constant_python_value
344+
from cuda.tile._compile import _TileTypingHooks
345+
tp = functools.partial(type_of_constant_python_value, typing_hooks=_TileTypingHooks())
343346
assert tp(1) == TileTy(int32)
344347
assert tp(1.) == TileTy(float32)
345-
assert tp(np.int16(1)) == TileTy(int16)
346-
assert tp(np.float64(1.0)) == TileTy(float64)
347348
assert tp(True) == TileTy(bool_)
348349
assert tp(None) == NONE
349350

test/test_tuple.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,16 @@ def kernel(x, M: ct.Constant[int], N: ct.Constant[int]):
265265
assert x.item() == 1
266266
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x, 4, 9))
267267
assert x.item() == -1
268+
269+
270+
def test_tuple_global_capture():
271+
tup = (100, (101, 102), (103, (104, 105)))
272+
273+
@ct.kernel
274+
def kernel(x):
275+
ct.scatter(x, 0, tup[0])
276+
ct.scatter(x, 1, tup[2][1][1])
277+
278+
x = torch.zeros(2, dtype=torch.int32, device="cuda")
279+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x,))
280+
assert x.tolist() == [100, 105]

test/test_typeinfer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,3 +470,17 @@ def kernel():
470470
with pytest.raises(TileTypeError,
471471
match="Tiles are immutable: item assignment is not supported"):
472472
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
473+
474+
475+
def test_typeof_numpy_scalar():
476+
import numpy as np
477+
val = np.int16(4)
478+
479+
@ct.kernel
480+
def kernel():
481+
val_dtype = val.dtype
482+
ct.static_assert(val_dtype == ct.int16)
483+
x = val
484+
ct.static_assert(x == 4)
485+
486+
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())

0 commit comments

Comments
 (0)