44# If we have subcategories of examples in the future, this file can be split along those lines
55
66import glob
7+ import importlib .metadata
78import os
89import platform
910import re
@@ -57,23 +58,15 @@ def has_cuda_path() -> bool:
5758 return os .environ .get ("CUDA_PATH" , os .environ .get ("CUDA_HOME" )) is not None
5859
5960
60- PACKAGE_REQUIREMENTS = {
61- "cuda_graphs.py" : ["cupy" ],
62- "jit_lto_fractal.py" : ["cupy" ],
63- "memory_ops.py" : ["cupy" ],
64- "pytorch_example.py" : ["torch" ],
65- "saxpy.py" : ["cupy" ],
66- "simple_multi_gpu_example.py" : ["cupy" ],
67- "strided_memory_view_cpu.py" : ["cffi" ],
68- "strided_memory_view_gpu.py" : ["cupy" ],
69- "tma_tensor_map.py" : ["cupy" ],
70- "vector_add.py" : ["cupy" ],
71- }
61+ # Specific system requirements for each of the examples.
7262
7363
7464SYSTEM_REQUIREMENTS = {
7565 "gl_interop_plasma.py" : has_display ,
76- "pytorch_example.py" : is_x86_64 , # PyTorch only provides CUDA support for x86_64
66+ "pytorch_example.py" : lambda : (
67+ has_compute_capability_9_or_higher () and is_x86_64 ()
68+ ), # PyTorch only provides CUDA support for x86_64
69+ "saxpy.py" : has_compute_capability_9_or_higher ,
7770 "simple_multi_gpu_example.py" : has_multiple_devices ,
7871 "strided_memory_view_cpu.py" : is_not_windows ,
7972 "thread_block_cluster.py" : lambda : has_compute_capability_9_or_higher () and has_cuda_path (),
@@ -85,17 +78,44 @@ def has_cuda_path() -> bool:
8578sample_files = [os .path .basename (x ) for x in glob .glob (samples_path + "**/*.py" , recursive = True )]
8679
8780
81+ def has_package_requirements_or_skip (example ):
82+ with open (example , encoding = "utf-8" ) as f :
83+ content = f .read ()
84+
85+ # The canonical regex as defined in PEP 723
86+ pep723 = re .search (r"(?m)^# /// (?P<type>[a-zA-Z0-9-]+)$\s(?P<content>(^#(| .*)$\s)+)^# ///$" , content )
87+ if not pep723 :
88+ return
89+
90+ metadata = {}
91+ for line in pep723 .group ("content" ).splitlines ():
92+ line = line .lstrip ("# " ).rstrip ()
93+ if not line :
94+ continue
95+ key , value = line .split ("=" , 1 )
96+ key = key .strip ()
97+ value = value .strip ()
98+ metadata [key ] = value
99+
100+ if "dependencies" in metadata :
101+ dependencies = eval (metadata ["dependencies" ]) # noqa: S307
102+ for dependency in dependencies :
103+ name = re .match ("[a-zA-Z0-9_-]+" , dependency )
104+ try :
105+ importlib .metadata .distribution (name .string )
106+ except importlib .metadata .PackageNotFoundError :
107+ pytest .skip (f"Skipping { example } due to missing package requirement: { name } " )
108+
109+
88110@pytest .mark .parametrize ("example" , sample_files )
89111def test_example (example ):
90- package_requirements = PACKAGE_REQUIREMENTS .get (example , [])
91- for package in package_requirements :
92- pytest .importorskip (package , reason = f"Skipping { example } due to missing package requirement: { package } " )
112+ example_path = os .path .join (samples_path , example )
113+ has_package_requirements_or_skip (example_path )
93114
94115 system_requirement = SYSTEM_REQUIREMENTS .get (example , lambda : True )
95116 if not system_requirement ():
96117 pytest .skip (f"Skipping { example } due to unmet system requirement" )
97118
98- example_path = os .path .join (samples_path , example )
99119 process = subprocess .run ([sys .executable , example_path ], capture_output = True ) # noqa: S603
100120 if process .returncode != 0 :
101121 if process .stdout :
@@ -106,21 +126,19 @@ def test_example(example):
106126
107127
108128@pytest .mark .parametrize ("example" , sample_files )
109- @pytest .mark .skipif (not uv_installed (), reason = "uv is required to test PEP 723 metadata installation " )
129+ @pytest .mark .skipif (not uv_installed (), reason = "uv is required to test PEP 723 metadata" )
110130def test_example_pep723 (example ):
131+ example_path = os .path .join (samples_path , example )
132+
111133 system_requirement = SYSTEM_REQUIREMENTS .get (example , lambda : True )
112134 if not system_requirement ():
113135 pytest .skip (f"Skipping { example } due to unmet system requirement" )
114136
115- example_path = os .path .join (samples_path , example )
116-
117137 # Have uv use the same version of Python that is running the test suite,
118138 # not because they have to match but to give Python version coverage in CI.
119139 version_info = sys .version_info
120140 py_version = f"{ version_info .major } .{ version_info .minor } "
121141
122- print ("Parent process environment:" , os .environ )
123-
124142 process = subprocess .run (["uv" , "run" , "--python" , py_version , example_path ], capture_output = True ) # noqa: S603, S607
125143 if process .returncode != 0 :
126144 # This example requires a development version of cuda_core, so requirements can't be met.
0 commit comments