Skip to content

Commit e1854cc

Browse files
committed
Don't rely on co_varnames to get the list of locals
With Python 3.12, a list comprehension like [i*i for i in range(3)] will add `i` to co_varnames, even though `i` is still not a local variable. Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 86af9ba commit e1854cc

3 files changed

Lines changed: 21 additions & 4 deletions

File tree

src/cuda/tile/_passes/ast2hir.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,9 @@ def _call_static_eval(expr: ast.expr, kind: hir.StaticEvalKind, ctx: _Context) -
297297
inner_lambda = outer_lambda(*tuple(None for _ in local_names))
298298

299299
# Make sure the function doesn't store any locals, e.g. using the walrus operator
300-
if len(inner_lambda.__code__.co_varnames) > 0:
301-
name = inner_lambda.__code__.co_varnames[0]
300+
stored_locals = ast_get_all_local_names(inner_lambda_ast).local_names
301+
if len(stored_locals) > 0:
302+
name = min(stored_locals)
302303
raise TileSyntaxError(f"static_eval() expression attempted"
303304
f" to modify a local variable '{name}'")
304305

test/test_ast_util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,14 @@ async def nested_async_def(): pass
5353
print(a, b, c, sneaky_1, sneaky_2, sneaky_3, sneaky_4, local_i)
5454

5555

56-
5756
def test_get_all_local_names():
5857
parsed_ast = ast.parse(inspect.getsource(func))
5958
func_ast = parsed_ast.body[0]
6059
local_names, global_names, nonlocal_names = ast_get_all_local_names(func_ast)
61-
expected = sorted(func.__code__.co_varnames + func.__code__.co_cellvars)
60+
expected = ['a', 'b', 'c', 'cd', 'e', 'getcwd', 'local_i', 'match_alias', 'nested_async_def',
61+
'nested_func', 'nonexistent', 'p1', 'p2', 'p3',
62+
'pat1', 'pat2', 'pat3', 'pat4', 'pat5', 'pat6',
63+
'sneaky_1', 'sneaky_2', 'sneaky_3', 'sneaky_4', 'walrus']
6264
assert sorted(local_names) == expected
6365
assert global_names == {"g"}
6466
assert nonlocal_names == set()

test/test_static_eval.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,20 @@ def kernel(y):
2525
assert y.tolist() == [6, 6]
2626

2727

28+
def test_list_comprehension():
29+
@ct.kernel
30+
def kernel(y):
31+
tup = (1, 2, 3)
32+
s1 = ct.static_eval(sum([i*i for i in tup]))
33+
s2 = cuda.tile.static_eval(sum([i*i for i in tup]))
34+
ct.scatter(y, 0, s1)
35+
ct.scatter(y, 1, s2)
36+
37+
y = torch.zeros((2,), dtype=torch.int32, device="cuda")
38+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (y,))
39+
assert y.tolist() == [1*1 + 2*2 + 3*3, 1*1 + 2*2 + 3*3]
40+
41+
2842
def test_mixed_tuple():
2943
@ct.kernel
3044
def kernel(y):

0 commit comments

Comments
 (0)