Skip to content

Commit d37b24d

Browse files
committed
Add support for **kwargs
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 64ad5db commit d37b24d

9 files changed

Lines changed: 236 additions & 18 deletions

File tree

changelog.d/dict.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- Added limited support for dictionaries, variadic keyword parameters in user-defined functions
2+
(e.g, `def foo(**kwargs)`), and dictionary unpacking (e.g., `foo(x, **y)`).

src/cuda/tile/_ir/core_ops.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import functools
66
import operator
77
from dataclasses import dataclass
8-
from types import MethodType, FunctionType, BuiltinFunctionType
8+
from types import MethodType, FunctionType, BuiltinFunctionType, MappingProxyType
99
from typing import Any, Optional
1010

1111
from typing_extensions import override
@@ -27,7 +27,7 @@
2727
DataclassInfo, DataclassTy, DataclassValue, BoundMethodValue, BoundMethodTy, InvalidType, \
2828
ContextManagerTy, ContextManagerLifecycle, LiveCapturedScope, ClosureTy, ClosureValue, \
2929
RangeIterType, RangeValue, TypeTy, ModuleTy, NONE, SliceType, StringTy, FormattedStringTy, \
30-
StringFormat, FormattedStringValue, FormattedPiece
30+
StringFormat, FormattedStringValue, FormattedPiece, DictTy, DictValue
3131
from cuda.tile._ir.typing_support import type_of_constant_python_value, \
3232
loose_type_of_constant_python_value, get_dataclass_info, as_third_party_dtype_spec
3333
from cuda.tile._ir2bytecode import BytecodeContext
@@ -65,6 +65,7 @@ def core_impl_registry() -> ImplRegistry:
6565
@overload_dispatcher(operator.lshift, fixed_args=["<<"])
6666
@overload_dispatcher(operator.rshift, fixed_args=[">>"])
6767
@overload_dispatcher(operator.matmul, fixed_args=["@"])
68+
@overload_dispatcher(hir_stubs.is_contained_in, fixed_args=["'in'"])
6869
@overload_dispatcher(min, fixed_args=["min"])
6970
@overload_dispatcher(max, fixed_args=["max"])
7071
def binop_overload_dispatcher(name: str, x: Var, y: Var):
@@ -76,6 +77,13 @@ def binop_overload_dispatcher(name: str, x: Var, y: Var):
7677
raise TileTypeError(f"Unsupported operand types for {name}: {x_ty} and {y_ty}")
7778

7879

80+
@impl(hir_stubs.is_not_contained_in)
81+
async def is_not_contained_in_impl(x: Var, y: Var):
82+
from .._passes.hir2ir import call_function
83+
contained = await call_function(hir_stubs.is_contained_in, x, y)
84+
return await call_function(operator.not_, contained)
85+
86+
7987
def comparison_operator_impl(registry: ImplRegistry, lhs_ty: type[Type], rhs_ty: type[Type]):
8088
def decorate(func):
8189
for name in ("eq", "ne", "lt", "le", "gt", "ge"):
@@ -346,6 +354,71 @@ def len_tuple_impl(x: Var[TupleTy]) -> Var:
346354
return loosely_typed_const(len(x.get_type()))
347355

348356

357+
# ===========================================================================================
358+
# Dictionary
359+
# ===========================================================================================
360+
361+
def build_dict(keys: tuple[str, ...], values: tuple[Var, ...]) -> Var:
362+
keys = tuple(keys)
363+
values = tuple(values)
364+
assert len(keys) == len(values)
365+
366+
ty = DictTy(keys, tuple(x.get_type() for x in values))
367+
loose_ty = DictTy(keys, tuple(x.get_loose_type() for x in values))
368+
res = make_aggregate(DictValue(values), ty, loose_ty)
369+
if all(x.is_constant() for x in values):
370+
items = [(k, v.get_constant()) for k, v in zip(keys, values, strict=True)]
371+
res.set_constant(MappingProxyType(dict(items)))
372+
return res
373+
374+
375+
def _find_dict_key_index(key: Var, dict_ty: DictTy) -> int | None:
376+
key_ty = key.get_type()
377+
if not isinstance(key_ty, StringTy):
378+
# Python would happily report that the key is not found when a "wrong" key type is passed,
379+
# but we can add a stronger check here.
380+
raise TileTypeError(f"Dictionary keys must be strings, not {key_ty}")
381+
382+
return dict_ty.keys.index(key_ty.value) if key_ty.value in dict_ty.keys else None
383+
384+
385+
@impl(hir_stubs.is_contained_in, overload=(WILDCARD, DictTy))
386+
async def is_contained_in_dict_impl(x: Var, y: Var[DictTy]):
387+
return loosely_typed_const(_find_dict_key_index(x, y.get_type()) is not None)
388+
389+
390+
@impl(getattr, overload=(DictTy, "get"))
391+
def getattr_dict_method(object: Var, name: Var):
392+
name = require_constant_str(name)
393+
unbound_func = getattr(dict, name)
394+
return bind_method(object, unbound_func)
395+
396+
397+
@impl(operator.getitem, overload=(DictTy, WILDCARD))
398+
def getitem_dict_impl(object: Var[DictTy], key: Var):
399+
idx = _find_dict_key_index(key, object.get_type())
400+
if idx is None:
401+
raise TileTypeError(f"Key '{key.get_type().value}' not found in dictionary")
402+
dict_value = object.get_aggregate()
403+
assert isinstance(dict_value, DictValue)
404+
return dict_value.values[idx]
405+
406+
407+
@impl(dict.get)
408+
def dict_get_impl(self: Var, key: Var, default: Var):
409+
dict_ty = self.get_type()
410+
if not isinstance(dict_ty, DictTy):
411+
raise TileTypeError(f"dict.get() expects a dictionary as `self`, got {dict_ty}")
412+
413+
idx = _find_dict_key_index(key, dict_ty)
414+
if idx is None:
415+
return default
416+
417+
dict_value = self.get_aggregate()
418+
assert isinstance(dict_value, DictValue)
419+
return dict_value.values[idx]
420+
421+
349422
# ===========================================================================================
350423
# Dataclass
351424
# ===========================================================================================

src/cuda/tile/_ir/hir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class Call:
7676
result: Value | None
7777
callee: Operand
7878
args: tuple[Operand | Starred, ...]
79-
kwargs: tuple[tuple[str, Operand], ...]
79+
kwargs: tuple[tuple[str | None, Operand], ...] # None means an **unpacked argument
8080
loc: Loc
8181

8282
def __str__(self):

src/cuda/tile/_ir/hir_stubs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,11 @@ def enter_context(manager, /): ...
6565

6666
@stub
6767
def pop_context(): ...
68+
69+
70+
@stub
71+
def is_contained_in(x, y, /): ... # "return x in y"
72+
73+
74+
@stub
75+
def is_not_contained_in(x, y, /): ... # return "x not in y"

src/cuda/tile/_ir/type.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88
from dataclasses import dataclass
99
from enum import EnumMeta
10-
from types import ModuleType, FunctionType, BuiltinFunctionType, MethodType
10+
from types import ModuleType, FunctionType, BuiltinFunctionType, MethodType, MappingProxyType
1111
from typing import Any, Callable, Optional, Sequence, Tuple, Iterator, Mapping
1212
from functools import reduce
1313
import operator
@@ -269,6 +269,43 @@ def as_tuple(self) -> tuple["Var", ...]:
269269
return self.items
270270

271271

272+
# ============== Dictionary (limited support for **kwargs) ==============
273+
274+
@dataclass(frozen=True)
275+
class DictTy(Type):
276+
keys: tuple[str, ...]
277+
value_types: tuple[Type, ...]
278+
279+
def make_symbol(self, var: "Var"):
280+
dict_val = var.get_aggregate()
281+
assert isinstance(dict_val, DictValue)
282+
items = [(k, var2sym(v)) for k, v in zip(self.keys, dict_val.values, strict=True)]
283+
return MappingProxyType(dict(items))
284+
285+
def is_aggregate(self) -> bool:
286+
return True
287+
288+
def aggregate_item_types(self) -> tuple["Type", ...]:
289+
return self.value_types
290+
291+
def make_aggregate_value(self, values: tuple["Var", ...]) -> "AggregateValue":
292+
assert len(values) == len(self.keys)
293+
return DictValue(values)
294+
295+
def __str__(self):
296+
return (
297+
"dict["
298+
+ ", ".join(f"{k}: {ty}"
299+
for k, ty in zip(self.keys, self.value_types, strict=True))
300+
+ "]"
301+
)
302+
303+
304+
@dataclass
305+
class DictValue(AggregateValue):
306+
values: tuple["Var", ...]
307+
308+
272309
# ============== Dataclass ===============
273310

274311

src/cuda/tile/_ir/typing_support.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def __missing__(self, key: TypeKey) -> TypeHandler:
121121
bool: lambda x=False, /: None,
122122
print: lambda *args, sep=' ', end='\n': None,
123123
dataclasses.replace: dataclasses.replace,
124+
dict.get: dict.get,
124125
}
125126

126127

src/cuda/tile/_passes/ast2hir.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ def _binop_expr(binop: ast.BinOp, ctx: _Context) -> hir.Value:
423423
_cmp_map = {
424424
ast.Eq: operator.eq, ast.NotEq: operator.ne, ast.Lt: operator.lt, ast.LtE: operator.le,
425425
ast.Gt: operator.gt, ast.GtE: operator.ge, ast.Is: operator.is_, ast.IsNot: operator.is_not,
426+
ast.In: hir_stubs.is_contained_in, ast.NotIn: hir_stubs.is_not_contained_in
426427
}
427428

428429

@@ -975,11 +976,6 @@ def _get_all_parameters(func_def: ast.FunctionDef | ast.Lambda, ctx: _Context) -
975976
for a in (func_def.args.vararg, func_def.args.kwarg):
976977
if a is not None:
977978
raise ctx.syntax_error("Variadic kernel parameters are not supported", a)
978-
else:
979-
if func_def.args.kwarg is not None:
980-
raise ctx.syntax_error(
981-
"Variadic keyword parameters in user-defined functions are not supported",
982-
func_def.args.kwarg)
983979

984980
all_args = []
985981
for arg in func_def.args.posonlyargs:
@@ -990,6 +986,8 @@ def _get_all_parameters(func_def: ast.FunctionDef | ast.Lambda, ctx: _Context) -
990986
all_args.append(func_def.args.vararg)
991987
for arg in func_def.args.kwonlyargs:
992988
all_args.append(arg)
989+
if func_def.args.kwarg is not None:
990+
all_args.append(func_def.args.kwarg)
993991
return all_args
994992

995993

src/cuda/tile/_passes/hir2ir.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,20 @@
1111
from .. import TileTypeError
1212
from .._coroutine_util import resume_after, run_coroutine
1313
from .._datatype import PointerInfo
14-
from .._exception import Loc, FunctionDesc, TileSyntaxError, TileInternalError, TileError, \
15-
TileRecursionError
14+
from .._exception import Loc, FunctionDesc, TileInternalError, TileError, TileRecursionError
1615
from .._execution import is_stub
1716
from .._ir import hir, ir
1817
from .._ir.ir import Var, IRContext
1918
from .._ir.op_impl import ImplRegistry
2019
from .._ir.control_flow_ops import end_branch, return_, continue_, break_
2120
from .._ir.core_ops import (
22-
loosely_typed_const, build_dataclass_instance, build_tuple, sym2var, store_var
21+
loosely_typed_const, build_dataclass_instance, build_tuple, sym2var, store_var, build_dict
2322
)
2423
from .._ir.arithmetic_ops import dtype_constructor
2524
from .._ir.scope import Scope, LocalScope, IntMap
2625
from .._ir.type import FunctionTy, BoundMethodTy, DTypeConstructor, ClosureTy, \
2726
ClosureDefaultPlaceholder, StringFormat, TypeTy, TupleTy, BoundMethodValue, TupleValue, \
28-
ClosureValue
27+
ClosureValue, DictTy, DictValue
2928
from .._ir.typing_support import get_signature, is_supported_builtin_func, \
3029
get_dataclass_info
3130

@@ -167,7 +166,20 @@ async def _dispatch_call(hir_call: hir.Call, scope: Scope):
167166
args.extend(tup_value.items)
168167
else:
169168
args.append(_resolve_operand(x, scope))
170-
kwargs = {k: _resolve_operand(v, scope) for k, v in hir_call.kwargs}
169+
kwargs = {}
170+
for k, v in hir_call.kwargs:
171+
resolved_val = _resolve_operand(v, scope)
172+
if k is None:
173+
assert isinstance(resolved_val, Var)
174+
dict_ty = resolved_val.get_type()
175+
if not isinstance(dict_ty, DictTy):
176+
raise TileTypeError(f"Expected a dictionary after **, got {dict_ty}")
177+
dict_value = resolved_val.get_aggregate()
178+
assert isinstance(dict_value, DictValue)
179+
for item_key, item_value in zip(dict_ty.keys, dict_value.values, strict=True):
180+
kwargs[item_key] = item_value
181+
else:
182+
kwargs[k] = resolved_val
171183
retval = await call(callee_var, args, kwargs)
172184
if hir_call.result is not None and retval is not None:
173185
scope.hir2ir_varmap[hir_call.result.id] = retval
@@ -178,10 +190,6 @@ async def _call_user_defined(callee_hir: hir.Function,
178190
builder: ir.Builder,
179191
parent_scopes: tuple[LocalScope, ...] = ()):
180192
_check_recursive_call(builder.loc)
181-
for param_name, param in callee_hir.signature.parameters.items():
182-
if param.kind == inspect.Parameter.VAR_KEYWORD:
183-
raise TileSyntaxError("Variadic keyword parameters in user-defined"
184-
" functions are not supported")
185193

186194
# Activate a fresh Scope. Each inlining gets its own concretized
187195
# FunctionDesc so that DI never merges two specializations whose generated
@@ -197,6 +205,8 @@ async def _call_user_defined(callee_hir: hir.Function,
197205
if isinstance(arg, tuple):
198206
# Handle the *vararg parameter
199207
arg = build_tuple(arg)
208+
elif isinstance(arg, dict):
209+
arg = build_dict(tuple(arg.keys()), tuple(arg.values()))
200210
store_var(local_idx, arg, param_loc)
201211

202212
# Dispatch the function body. Use resume_after() to break the call stack
@@ -411,6 +421,8 @@ def _bind_args(sig: inspect.Signature, func_name: str, args, kwargs,
411421
ret.append(bound_args.arguments[name])
412422
elif param.kind == param.VAR_POSITIONAL:
413423
ret.append(())
424+
elif param.kind == param.VAR_KEYWORD:
425+
ret.append({})
414426
else:
415427
assert param.default is not param.empty
416428
if isinstance(param.default, ClosureDefaultPlaceholder):

test/test_dict.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import pytest
5+
import torch.cuda
6+
7+
import cuda.tile as ct
8+
from cuda.tile import TileTypeError
9+
10+
11+
def test_variadic_kwargs_in_helper_function():
12+
def helper(**kwargs):
13+
ct.static_assert(kwargs == {"foo": 123, "bar": 456})
14+
return 789
15+
16+
@ct.kernel
17+
def kernel():
18+
res = helper(foo=123, bar=456)
19+
ct.static_assert(res == 789)
20+
21+
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
22+
23+
24+
def test_forward_variadic_kwargs():
25+
def leaf(x, foo, bar):
26+
return x * 100 + foo * 10 + bar
27+
28+
def forward(f, **kwargs):
29+
return f(3, **kwargs)
30+
31+
@ct.kernel
32+
def kernel():
33+
res = forward(leaf, foo=4, bar=5)
34+
ct.static_assert(res == 345)
35+
36+
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
37+
38+
39+
def test_dict_access():
40+
def helper(**kwargs):
41+
foo1 = kwargs["foo"]
42+
ct.static_assert(foo1 == 123)
43+
44+
foo2 = kwargs.get("foo")
45+
ct.static_assert(foo2 == 123)
46+
47+
bar1 = kwargs["bar"]
48+
ct.static_assert(bar1 == 456)
49+
50+
bar2 = kwargs.get("bar")
51+
ct.static_assert(bar2 == 456)
52+
53+
qux1 = kwargs.get("qux")
54+
ct.static_assert(qux1 is None)
55+
56+
res1 = "foo" in kwargs
57+
ct.static_assert(res1)
58+
59+
res2 = "foo" not in kwargs
60+
ct.static_assert(not res2)
61+
62+
res3 = "qux" in kwargs
63+
ct.static_assert(not res3)
64+
65+
res4 = "qux" not in kwargs
66+
ct.static_assert(res4)
67+
68+
return 789
69+
70+
@ct.kernel
71+
def kernel():
72+
res = helper(foo=123, bar=456)
73+
ct.static_assert(res == 789)
74+
75+
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
76+
77+
78+
def test_dict_getitem_miss():
79+
def helper(**kwargs):
80+
return kwargs["qux"]
81+
82+
@ct.kernel
83+
def kernel():
84+
helper(foo=123, bar=456)
85+
86+
with pytest.raises(TileTypeError, match="Key 'qux' not found in dictionary"):
87+
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())

0 commit comments

Comments
 (0)