Skip to content

Commit 6653a56

Browse files
committed
Fix helper function with while loop and return
Signed-off-by: Jay Gu <jagu@nvidia.com>
1 parent 94f7824 commit 6653a56

3 files changed

Lines changed: 20 additions & 1 deletion

File tree

changelog.d/while-loop-return.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Fix "potentially undefined variable `$retval`" error when a helper function
5+
returns after a ``while`` loop that contains no early return.

src/cuda/tile/_passes/ast2hir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -929,8 +929,8 @@ def _ast2hir(func_def: ast.FunctionDef | ast.Lambda, ctx: _Context) -> hir.Block
929929
elif isinstance(func_def, ast.FunctionDef):
930930
# To enable early returns in a helper function, wrap the body in a loop.
931931
# Thus, we can use "break" to implement the return statement.
932-
ctx.store("$returning", False)
933932
with ctx.new_block() as body_block:
933+
ctx.store("$returning", False)
934934
_stmt_list(func_def.body, ctx)
935935
if body_block.jump is None:
936936
ctx.store("$retval", None)

test/test_helper_function.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,20 @@ def test_early_return_inside_for_loop():
249249
ct.launch(torch.cuda.current_stream(), (1,), kernel, (n, out))
250250

251251

252+
def return_after_while_loop(n):
253+
while n > 0:
254+
n = n - 1
255+
return n
256+
257+
258+
def test_return_after_while_loop():
259+
n = torch.tensor([3], dtype=torch.int32, device="cuda")
260+
out = torch.zeros_like(n)
261+
kernel = early_return_inside_loop(return_after_while_loop)
262+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (n, out))
263+
assert out.cpu().item() == 0
264+
265+
252266
def loops(n):
253267
a = 0
254268
for i in range(n):

0 commit comments

Comments
 (0)