Skip to content

Commit c168893

Browse files
committed
Don't ignore decorators of user-defined functions
Currently, ast2hir unwraps any decorators before getting the source of the function. This is incorrect. For example: def decorate(func): @functools.wraps(func) def wrapper(x): return func(x + 3) return wrapper @decorate def decorated_helper(x): return x * 10 Calling decorated_helper(5) should return (5 + 3) * 10 == 80, but currently it returns 5 * 10 == 50 because the wrapper is discarded. This patch fixes this by only unwrapping the @ct.function decorator, and respecting all other wrappers. Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 14e9353 commit c168893

6 files changed

Lines changed: 71 additions & 6 deletions

File tree

changelog.d/fix-decorator.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- Fixed a bug where wrappers installed on a user-defined function
2+
(e.g., by a decorator using `functools.wraps()`) were previously ignored.

src/cuda/tile/_execution.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def decorator(func):
4949
def wrapped(*args, **kwargs):
5050
return DispatchMode.get_current().call_tile_function_from_host(
5151
wrapped, args, kwargs)
52+
wrapped._cutile_function_wrapper = True
5253
return wrapped
5354

5455
if func is None:
@@ -188,3 +189,7 @@ def is_stub(func) -> bool:
188189
func = getattr(func, "__wrapped__", None)
189190
if func is None:
190191
return False
192+
193+
194+
def is_function_wrapper(func) -> bool:
195+
return getattr(func, "_cutile_function_wrapper", False)

src/cuda/tile/_ir/op_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def decorate(func):
171171
if len(fixed_args) > 0:
172172
func = functools.partial(orig_func, *fixed_args)
173173

174-
func_sig = get_signature(func)
174+
func_sig = inspect.signature(func)
175175
_verify_params_match(stub_sig, func_sig)
176176
is_coroutine = inspect.iscoroutinefunction(func)
177177
if is_coroutine:

src/cuda/tile/_ir/typing_support.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from .type import Type, DTypeConstructor, DTypeSpec, NONE, StringTy, \
1818
ELLIPSIS, SLICE, ModuleTy, FunctionTy, EnumTy, TypeTy, LooselyTypedScalar
19-
19+
from .._execution import is_function_wrapper
2020

2121
# Store mapping from 3rd party dtype objects
2222
# e.g. np.float32 -> float32, torch.bfloat16 -> bfloat16
@@ -131,7 +131,13 @@ def get_signature(f) -> inspect.Signature:
131131
elif is_dtype_constructor(f):
132132
# Data type constructors
133133
f = lambda x=0, /: None # noqa: E731
134-
return inspect.signature(f)
134+
135+
if isinstance(f, type):
136+
return inspect.signature(f)
137+
138+
while is_function_wrapper(f):
139+
f = f.__wrapped__
140+
return inspect.signature(f, follow_wrapped=False)
135141

136142

137143
def is_supported_builtin_func(x: Any) -> bool:

src/cuda/tile/_passes/ast2hir.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from cuda.tile import _datatype as datatype
1515
from cuda.tile._exception import TileSyntaxError, Loc, FunctionDesc
16+
from cuda.tile._execution import is_function_wrapper
1617
from cuda.tile._ir.hir import make_value, ResolvedName, UNKNOWN_NAME
1718
from cuda.tile._ir import hir, hir_stubs
1819
from cuda.tile._ir.type import ClosureDefaultPlaceholder, FormattedPiece, StringFormat
@@ -22,10 +23,16 @@
2223

2324
@lru_cache
2425
def get_function_hir(pyfunc: Callable, entry_point: bool) -> hir.Function:
25-
# Get the original function from the decorated function if it exists.
26-
pyfunc = getattr(pyfunc, "__wrapped__", pyfunc)
26+
# Unwrap the @function decorator
27+
while is_function_wrapper(pyfunc):
28+
pyfunc = pyfunc.__wrapped__
29+
30+
# Use findsource() instead of getsourcelines() because the latter unwraps any decorators,
31+
# which we don't want.
32+
file_source_lines, first_line = inspect.findsource(pyfunc)
33+
source_lines = inspect.getblock(file_source_lines[first_line:])
34+
first_line += 1
2735

28-
source_lines, first_line = inspect.getsourcelines(pyfunc)
2936
# The source code of our function could be inside a class, an if-else block etc.
3037
# This means it can have extra indentation on the left. If we try to give it
3138
# to ast.parse() as is, we will get a parse error. The common workaround

test/test_helper_function.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# SPDX-License-Identifier: Apache-2.0
4+
import functools
45
import inspect
56

67
import pytest
@@ -364,3 +365,47 @@ def kernel(x): # Line +8
364365
)
365366
with pytest.raises(TileTypeError, match=msg_regex):
366367
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x,))
368+
369+
370+
def decorate(func):
371+
@functools.wraps(func)
372+
def wrapper(x):
373+
return func(x + 3)
374+
return wrapper
375+
376+
377+
@decorate
378+
def decorated_helper(x):
379+
return x * 10
380+
381+
382+
def test_decorated_helper_function():
383+
@ct.kernel
384+
def kernel(y):
385+
t = decorated_helper(5)
386+
ct.scatter(y, (), t)
387+
y = torch.zeros((), dtype=torch.int32, device="cuda")
388+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (y,))
389+
assert y.item() == 80
390+
391+
392+
def forward(func):
393+
@functools.wraps(func)
394+
def wrapper(*args, **kwargs):
395+
return func(*args, **kwargs)
396+
return wrapper
397+
398+
399+
@forward
400+
def forward_helper(x):
401+
return x * 10
402+
403+
404+
def test_decorated_helper_function_forward():
405+
@ct.kernel
406+
def kernel(y):
407+
t = forward_helper(5)
408+
ct.scatter(y, (), t)
409+
y = torch.zeros((), dtype=torch.int32, device="cuda")
410+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (y,))
411+
assert y.item() == 50

0 commit comments

Comments
 (0)