Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/pre-commit/spelling_allowlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ QuTiP
Quake
Quantinuum
RDMA
REPL
RHEL
RPC
RSA
Expand Down
9 changes: 7 additions & 2 deletions python/cudaq/kernel/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import textwrap
from typing import Optional, Type

from .utils import get_function_source_or_raise


class FunctionDefVisitor(ast.NodeVisitor):
"""
Expand Down Expand Up @@ -108,7 +110,8 @@ def _getChildFuncNames(func_obj: object,
if name is None:
name = func_obj.__name__

tree = ast.parse(textwrap.dedent(inspect.getsource(func_obj)))
src, _ = get_function_source_or_raise(func_obj)
tree = ast.parse(src)
vis = FindDepFuncsVisitor()
visit_set.add(name)
vis.visit(tree)
Expand Down Expand Up @@ -141,7 +144,9 @@ def fetch(func_obj: object):
else:
this_func_obj = FetchDepFuncsSourceCode._getFuncObj(
funcName, callingFrame)
src = textwrap.dedent(inspect.getsource(this_func_obj))
if this_func_obj is None:
continue
src, _ = get_function_source_or_raise(this_func_obj)

code += src + '\n'

Expand Down
12 changes: 3 additions & 9 deletions python/cudaq/kernel/kernel_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from .analysis import FunctionDefVisitor
from .kernel_signature import CapturedLinkedKernel, CapturedVariable, KernelSignature
from .ast_bridge import compile_to_mlir
from .utils import (emitFatalError, emitErrorIfInvalidPauli, get_module_name,
from .utils import (emitFatalError, emitErrorIfInvalidPauli,
get_function_source_or_raise, get_module_name,
globalRegisteredTypes, mlirTypeFromPyType, mlirTypeToPyType,
nvqppPrefix, getMLIRContext, recover_func_op,
recover_value_of)
Expand Down Expand Up @@ -736,14 +737,7 @@ def isa_kernel_decorator(object):
def _get_source(function):
if function is None:
return None, None
# Get the function source location
location = (inspect.getfile(function), inspect.getsourcelines(function)[1])
# Get the function source
src = inspect.getsource(function)
# Strip off the extra tabs
leadingSpaces = len(src) - len(src.lstrip())
src = '\n'.join([line[leadingSpaces:] for line in src.split('\n')])
return src, location
return get_function_source_or_raise(function)


def _recover_defining_frame():
Expand Down
58 changes: 58 additions & 0 deletions python/cudaq/kernel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,64 @@ def emitWarning(msg):
Color.END + '\n\nOffending code:\n' + offendingSrc[0])


def _format_missing_source_error(function, filename):
"""
Build a user-facing diagnostic explaining why source for `function` could
not be retrieved. Distinguishes between three buckets:
- Interactive interpreter-defined (`<stdin>` or `<python-input-...>`).
- Other synthetic filenames (code compiled with a non-file name).
- Real paths that failed to read (missing file, frozen module,
compiled extension).
"""
qualname = getattr(function, '__qualname__',
getattr(function, '__name__', '<unknown>'))
if filename is None:
return (f"@cudaq.kernel could not determine a source location for "
f"function `{qualname}`. `@cudaq.kernel` requires source that "
f"Python's `inspect` module can recover. Move the kernel into "
f"a `.py` module.")
is_repl = filename == '<stdin>' or filename.startswith('<python-input')
is_synthetic = filename.startswith('<') and filename.endswith('>')
if is_repl:
return (f"@cudaq.kernel could not retrieve source for function "
f"`{qualname}` because it is defined in the Python REPL, "
f"which does not preserve source code that `inspect` can "
f"recover. To use `@cudaq.kernel`, either run from a "
f"Jupyter/IPython session (which preserves source via "
f"`linecache`) or move the kernel into a `.py` module.")
if is_synthetic:
return (f"@cudaq.kernel could not retrieve source for function "
f"`{qualname}`: it is defined in a non-file context "
f"(`{filename}`). `@cudaq.kernel` requires source that "
f"`inspect` can recover. Move the kernel into a `.py` "
f"module.")
return (f"@cudaq.kernel could not read source for function "
f"`{qualname}` at `{filename}` (the file may be missing, "
f"frozen, or a compiled extension).")


def get_function_source_or_raise(function):
"""
Return `(dedented_source, (filename, first_lineno))` for `function`.
Wraps `inspect.getfile`, `inspect.getsourcelines`, and
`inspect.getsource`. If any fail (most commonly because `function` was
defined in the interactive Python interpreter), raise `RuntimeError`
with a diagnostic
tailored to the failure mode, chained from the underlying exception.
"""
filename = None
try:
filename = inspect.getfile(function)
first_line = inspect.getsourcelines(function)[1]
src = inspect.getsource(function)
except OSError as e:
raise RuntimeError(_format_missing_source_error(function,
filename)) from e
leadingSpaces = len(src) - len(src.lstrip())
src = '\n'.join([line[leadingSpaces:] for line in src.split('\n')])
return src, (filename, first_line)


def mlirTryCreateStructType(mlirEleTypes, name=None, context=None):
"""
Creates either a `quake.StruqType` or a `cc.StructType` used to represent
Expand Down
141 changes: 141 additions & 0 deletions python/tests/kernel/test_kernel_repl_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# ============================================================================ #
# Copyright (c) 2022 - 2026 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #
"""
Regression tests for GitHub issue #2593: decorating a function with
`@cudaq.kernel` in the standard Python REPL previously raised an opaque
`OSError: could not get source code`. The fix replaces that with a
`RuntimeError` whose message explains the cause and suggests workarounds.
"""

import linecache

import pytest

import cudaq
from cudaq.kernel.analysis import FetchDepFuncsSourceCode
from cudaq.kernel.utils import get_function_source_or_raise


@pytest.fixture(autouse=True)
def _clear_registries():
yield
cudaq.__clearKernelRegistries()


def _make_synthetic_function(src, name, filename):
"""
Compile `src` with `filename` as its source-of-record (mimicking what
CPython does when it executes code typed at the REPL). Returns the
named callable from the resulting namespace.

The filename must not already be cached in `linecache` — otherwise
`inspect.getsource` could succeed unexpectedly and produce a false
negative for the tests below.
"""
assert filename not in linecache.cache, (
f"linecache already has an entry for {filename!r}; pick a unique name")
code = compile(src, filename, 'exec')
ns = {}
exec(code, ns)
return ns[name]


def test_repl_decoration_raises_clear_error():
"""
Direct reproduction of issue #2593: a function with `<stdin>` as its
source filename cannot be compiled, but the error must name the
function and point at Jupyter/file workarounds instead of surfacing a
raw `OSError`.
"""
fn = _make_synthetic_function(
"def my_repl_kernel(n: int):\n pass\n",
name='my_repl_kernel',
filename='<stdin>',
)

with pytest.raises(RuntimeError) as excinfo:
cudaq.kernel(fn)

msg = str(excinfo.value)
assert 'my_repl_kernel' in msg
assert 'REPL' in msg
assert 'Jupyter' in msg
# Original OSError preserved for debugging.
assert isinstance(excinfo.value.__cause__, OSError)


def test_synthetic_filename_raises_non_repl_message():
"""
A function whose source filename is synthetic but not the REPL
sentinel (e.g., `<generated>`) produces the non-file-context message,
not the REPL-specific one.
"""
fn = _make_synthetic_function(
"def generated_kernel(n: int):\n pass\n",
name='generated_kernel',
filename='<generated-test-src>',
)

with pytest.raises(RuntimeError) as excinfo:
get_function_source_or_raise(fn)

msg = str(excinfo.value)
assert 'generated_kernel' in msg
assert '<generated-test-src>' in msg
# Must not misidentify this as a REPL case.
assert 'REPL' not in msg


def test_dep_fetch_raises_clear_error_for_repl_helper():
"""
When a kernel calls a helper defined in the REPL, the dependency
fetcher in `analysis.py` must surface the same clear diagnostic,
naming the offending helper rather than blowing up with `OSError`.
"""
repl_helper = _make_synthetic_function(
"def repl_helper(x: int) -> int:\n return x + 1\n",
name='repl_helper',
filename='<python-input-1>',
)

def parent_kernel(x: int) -> int:
return repl_helper(x)

# Inject the helper into the calling frame's locals so
# FetchDepFuncsSourceCode can resolve it by name, then trigger the
# dep fetch. The failure happens inside analysis.py, not the decorator.
with pytest.raises(RuntimeError) as excinfo:
FetchDepFuncsSourceCode.fetch(parent_kernel)

msg = str(excinfo.value)
assert 'repl_helper' in msg
assert 'REPL' in msg
assert isinstance(excinfo.value.__cause__, OSError)


def test_normal_function_still_compiles():
"""
Regression guard: ensure the error-path wrapping did not break the
ordinary success path. A kernel defined in this test file (which
`inspect.getsource` can read) must compile without raising.
"""

@cudaq.kernel
def bell_pair():
q = cudaq.qvector(2)
h(q[0])
x.ctrl(q[0], q[1])

result = cudaq.sample(bell_pair, shots_count=100)
# The test passes if decoration and sampling succeed; specific counts
# are irrelevant here.
assert result is not None


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading