Skip to content

Commit adb892d

Browse files
committed
Enhance tuple unpacking
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 181a487 commit adb892d

File tree

6 files changed

+73
-7
lines changed

6 files changed

+73
-7
lines changed

changelog.d/tuple-unpacking.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Added a missing check for unpacking a tuple with too many values. For example, ``a, b = 1, 2, 3``
5+
now raises an error, instead of silently discarding the extra value.
6+
- Added support for unpacking nested tuples (e.g, ``a, (b, c) = t``) , as well as using square
7+
brackets for unpacking (e.g., ``[a, b] = 1, 2``).
8+
- Fixed the missing column indicator in error messages when the underlined text is only one
9+
character wide.

src/cuda/tile/_exception.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _format_location_frame(loc: Loc) -> str:
9999

100100
visual_col = _wcwidth(line_bytes[:loc.col].decode())
101101
if end_col == loc.col + 1:
102-
end_visual_col = visual_col
102+
end_visual_col = visual_col + 1
103103
cols_str = f"col {visual_col + 1}"
104104
else:
105105
end_visual_col = _wcwidth(line_bytes[:end_col].decode())

src/cuda/tile/_ir/hir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def __call__(self, x: Operand) -> str:
221221
def if_else(cond, then_block, else_block, /): ...
222222
def loop(body, iterable, /): ... # infinite if `iterable` is None
223223
def build_tuple(*items): ... # Makes a tuple (i.e. returns `items`)
224+
def unpack(iterable, expected_len, /): ...
224225
def identity(x): ... # Identity function (i.e. returns `x`)
225226
def store_var(name, value, /): ... # Store into a named variable
226227
def load_var(name, /): ... # Load from a named variable

src/cuda/tile/_ir/ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,22 @@ def build_tuple(items: tuple[Var, ...]) -> Var:
13411341
return res
13421342

13431343

1344+
@impl(hir.unpack)
1345+
def unpack_impl(iterable: Var, expected_len: Var) -> Var:
1346+
ty = iterable.get_type()
1347+
# Don't use the require_tuple_type() helper because we'd like to customize the error message
1348+
if not isinstance(ty, TupleTy):
1349+
raise TileTypeError("Expected a tuple", iterable.loc)
1350+
expected_len = require_constant_int(expected_len)
1351+
if len(ty.value_types) != expected_len:
1352+
few_many = "few" if len(ty.value_types) < expected_len else "many"
1353+
raise TileValueError(f"Too {few_many} values to unpack"
1354+
f" (expected {expected_len}, got {len(ty.value_types)})")
1355+
# Return the input tuple. If we add support for additional iterables,
1356+
# the idea is to cast them to a tuple here.
1357+
return iterable
1358+
1359+
13441360
class Unary(TypedOperation):
13451361
def __init__(self, fn: str, operand: Var,
13461362
rounding_mode: Optional[RoundingMode], flush_to_zero: bool,

src/cuda/tile/_passes/ast2hir.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -409,13 +409,11 @@ def _do_assign(value: hir.Operand, target, ctx: _Context):
409409
with ctx.change_loc(target):
410410
if isinstance(target, ast.Name):
411411
ctx.store(target.id, value)
412-
elif isinstance(target, ast.Tuple):
412+
elif isinstance(target, ast.Tuple | ast.List):
413+
value = ctx.call(hir.unpack, (value, len(target.elts)))
413414
for i, el in enumerate(target.elts):
414-
with ctx.change_loc(el):
415-
if not isinstance(el, ast.Name):
416-
raise ctx.unsupported_syntax()
417-
item_var = ctx.call(operator.getitem, (value, i), )
418-
ctx.store(el.id, item_var)
415+
item_var = ctx.call(operator.getitem, (value, i), )
416+
_do_assign(item_var, el, ctx)
419417
else:
420418
raise ctx.unsupported_syntax()
421419

test/test_unpacking.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import pytest
5+
import torch.cuda
6+
7+
import cuda.tile as ct
8+
from cuda.tile import TileValueError
9+
10+
11+
def test_too_few_values_to_unpack():
12+
@ct.kernel
13+
def kernel():
14+
t = 1, 2
15+
a, b, c = t
16+
with pytest.raises(TileValueError, match="Too few values to unpack"):
17+
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
18+
19+
20+
def test_too_many_values_to_unpack():
21+
@ct.kernel
22+
def kernel():
23+
t = 1, 2, 3
24+
a, b = t
25+
with pytest.raises(TileValueError, match="Too many values to unpack"):
26+
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
27+
28+
29+
def test_unpack_nested_tuple():
30+
@ct.kernel
31+
def kernel(x):
32+
t = (1, (2, 3, 4)), 5
33+
[(a, [b, c, d]), e] = t
34+
ct.scatter(x, 0, a)
35+
ct.scatter(x, 1, b)
36+
ct.scatter(x, 2, c)
37+
ct.scatter(x, 3, d)
38+
ct.scatter(x, 4, e)
39+
40+
x = torch.zeros((5,), dtype=torch.int32, device="cuda")
41+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x,))
42+
assert x.tolist() == [1, 2, 3, 4, 5]

0 commit comments

Comments
 (0)