Skip to content

Commit f03471d

Browse files
committed
Handle metadata automatically
1 parent b79ba43 commit f03471d

6 files changed

Lines changed: 45 additions & 27 deletions

File tree

cuda_core/examples/cuda_graphs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# ################################################################################
1212

1313
# /// script
14-
# dependencies = ["cuda_bindings", "cuda_core", "cupy-cuda13x"]
14+
# dependencies = ["cuda_bindings", "cuda_core", "nvidia-cuda-nvrtc", "cupy-cuda13x"]
1515
# ///
1616

1717
import sys

cuda_core/examples/jit_lto_fractal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# ################################################################################
1414

1515
# /// script
16-
# dependencies = ["cuda_bindings", "cuda_core", "cupy-cuda13x"]
16+
# dependencies = ["cuda_bindings", "cuda_core", "nvidia-cuda-nvrtc", "cupy-cuda13x"]
1717
# ///
1818

1919
import argparse

cuda_core/examples/memory_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# ################################################################################
1212

1313
# /// script
14-
# dependencies = ["cuda_bindings", "cuda_core", "cupy-cuda13x"]
14+
# dependencies = ["cuda_bindings", "cuda_core", "nvidia-cuda-nvrtc", "cupy-cuda13x"]
1515
# ///
1616

1717
import sys

cuda_core/examples/strided_memory_view_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# ################################################################################
1111

1212
# /// script
13-
# dependencies = ["cuda_bindings", "cuda_core", "cupy-cuda13x"]
13+
# dependencies = ["cuda_bindings", "cuda_core", "nvidia-cuda-nvrtc", "cupy-cuda13x"]
1414
# ///
1515

1616
import string

cuda_core/examples/vector_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# ################################################################################
1111

1212
# /// script
13-
# dependencies = ["cuda_bindings", "cuda_core", "cupy-cuda13x"]
13+
# dependencies = ["cuda_bindings", "cuda_core", "nvidia-cuda-nvrtc", "cupy-cuda13x"]
1414
# ///
1515

1616
import cupy as cp

cuda_core/tests/example_tests/test_basic_examples.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# If we have subcategories of examples in the future, this file can be split along those lines
55

66
import glob
7+
import importlib.metadata
78
import os
89
import platform
910
import re
@@ -57,23 +58,15 @@ def has_cuda_path() -> bool:
5758
return os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME")) is not None
5859

5960

60-
PACKAGE_REQUIREMENTS = {
61-
"cuda_graphs.py": ["cupy"],
62-
"jit_lto_fractal.py": ["cupy"],
63-
"memory_ops.py": ["cupy"],
64-
"pytorch_example.py": ["torch"],
65-
"saxpy.py": ["cupy"],
66-
"simple_multi_gpu_example.py": ["cupy"],
67-
"strided_memory_view_cpu.py": ["cffi"],
68-
"strided_memory_view_gpu.py": ["cupy"],
69-
"tma_tensor_map.py": ["cupy"],
70-
"vector_add.py": ["cupy"],
71-
}
61+
# Specific system requirements for each of the examples.
7262

7363

7464
SYSTEM_REQUIREMENTS = {
7565
"gl_interop_plasma.py": has_display,
76-
"pytorch_example.py": is_x86_64, # PyTorch only provides CUDA support for x86_64
66+
"pytorch_example.py": lambda: (
67+
has_compute_capability_9_or_higher() and is_x86_64()
68+
), # PyTorch only provides CUDA support for x86_64
69+
"saxpy.py": has_compute_capability_9_or_higher,
7770
"simple_multi_gpu_example.py": has_multiple_devices,
7871
"strided_memory_view_cpu.py": is_not_windows,
7972
"thread_block_cluster.py": lambda: has_compute_capability_9_or_higher() and has_cuda_path(),
@@ -85,17 +78,44 @@ def has_cuda_path() -> bool:
8578
sample_files = [os.path.basename(x) for x in glob.glob(samples_path + "**/*.py", recursive=True)]
8679

8780

81+
def has_package_requirements_or_skip(example):
82+
with open(example, encoding="utf-8") as f:
83+
content = f.read()
84+
85+
# The canonical regex as defined in PEP 723
86+
pep723 = re.search(r"(?m)^# /// (?P<type>[a-zA-Z0-9-]+)$\s(?P<content>(^#(| .*)$\s)+)^# ///$", content)
87+
if not pep723:
88+
return
89+
90+
metadata = {}
91+
for line in pep723.group("content").splitlines():
92+
line = line.lstrip("# ").rstrip()
93+
if not line:
94+
continue
95+
key, value = line.split("=", 1)
96+
key = key.strip()
97+
value = value.strip()
98+
metadata[key] = value
99+
100+
if "dependencies" in metadata:
101+
dependencies = eval(metadata["dependencies"]) # noqa: S307
102+
for dependency in dependencies:
103+
name = re.match("[a-zA-Z0-9_-]+", dependency)
104+
try:
105+
importlib.metadata.distribution(name.string)
106+
except importlib.metadata.PackageNotFoundError:
107+
pytest.skip(f"Skipping {example} due to missing package requirement: {name}")
108+
109+
88110
@pytest.mark.parametrize("example", sample_files)
89111
def test_example(example):
90-
package_requirements = PACKAGE_REQUIREMENTS.get(example, [])
91-
for package in package_requirements:
92-
pytest.importorskip(package, reason=f"Skipping {example} due to missing package requirement: {package}")
112+
example_path = os.path.join(samples_path, example)
113+
has_package_requirements_or_skip(example_path)
93114

94115
system_requirement = SYSTEM_REQUIREMENTS.get(example, lambda: True)
95116
if not system_requirement():
96117
pytest.skip(f"Skipping {example} due to unmet system requirement")
97118

98-
example_path = os.path.join(samples_path, example)
99119
process = subprocess.run([sys.executable, example_path], capture_output=True) # noqa: S603
100120
if process.returncode != 0:
101121
if process.stdout:
@@ -106,21 +126,19 @@ def test_example(example):
106126

107127

108128
@pytest.mark.parametrize("example", sample_files)
109-
@pytest.mark.skipif(not uv_installed(), reason="uv is required to test PEP 723 metadata installation")
129+
@pytest.mark.skipif(not uv_installed(), reason="uv is required to test PEP 723 metadata")
110130
def test_example_pep723(example):
131+
example_path = os.path.join(samples_path, example)
132+
111133
system_requirement = SYSTEM_REQUIREMENTS.get(example, lambda: True)
112134
if not system_requirement():
113135
pytest.skip(f"Skipping {example} due to unmet system requirement")
114136

115-
example_path = os.path.join(samples_path, example)
116-
117137
# Have uv use the same version of Python that is running the test suite,
118138
# not because they have to match but to give Python version coverage in CI.
119139
version_info = sys.version_info
120140
py_version = f"{version_info.major}.{version_info.minor}"
121141

122-
print("Parent process environment:", os.environ)
123-
124142
process = subprocess.run(["uv", "run", "--python", py_version, example_path], capture_output=True) # noqa: S603, S607
125143
if process.returncode != 0:
126144
# This example requires a development version of cuda_core, so requirements can't be met.

0 commit comments

Comments
 (0)