|
18 | 18 | # |
19 | 19 | # ################################################################################ |
20 | 20 |
|
| 21 | +import os, sys |
21 | 22 | import cupy as cp |
22 | 23 | from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch |
23 | 24 |
|
| 25 | +# prepare include |
| 26 | +cuda_path = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME")) |
| 27 | +if cuda_path is None: |
| 28 | + print("this demo requires a valid CUDA_PATH environment variable set", file=sys.stderr) |
| 29 | + sys.exit(0) |
| 30 | +cuda_include = os.path.join(cuda_path, "include") |
| 31 | +assert os.path.isdir(cuda_include) |
| 32 | +include_path = [cuda_include] |
| 33 | +cccl_include = os.path.join(cuda_include, "cccl") |
| 34 | +if os.path.isdir(cccl_include): |
| 35 | + include_path.insert(0, cccl_include) |
| 36 | + |
24 | 37 | # ################################################################################ |
25 | 38 | # C++ Kernel Code for Verifying mdspan Arguments |
26 | 39 | # ################################################################################ |
27 | 40 |
|
28 | 41 | # Verification kernels that print mdspan properties using printf |
29 | 42 | code_verify = """ |
30 | 43 | #include <cuda/std/mdspan> |
31 | | -#include <cstdio> |
32 | 44 |
|
33 | 45 | // Kernel to verify layout_right (C-order) mdspan arguments |
34 | 46 | template<typename T> |
@@ -236,6 +248,7 @@ def verify_layout_right(): |
236 | 248 | program_options = ProgramOptions( |
237 | 249 | std="c++17", |
238 | 250 | arch=f"sm_{dev.arch}", |
| 251 | + include_path=include_path, |
239 | 252 | ) |
240 | 253 | prog = Program(code_verify, code_type="c++", options=program_options) |
241 | 254 |
|
@@ -295,6 +308,7 @@ def verify_layout_left(): |
295 | 308 | program_options = ProgramOptions( |
296 | 309 | std="c++17", |
297 | 310 | arch=f"sm_{dev.arch}", |
| 311 | + include_path=include_path, |
298 | 312 | ) |
299 | 313 | prog = Program(code_verify, code_type="c++", options=program_options) |
300 | 314 |
|
@@ -354,6 +368,7 @@ def verify_layout_stride(): |
354 | 368 | program_options = ProgramOptions( |
355 | 369 | std="c++17", |
356 | 370 | arch=f"sm_{dev.arch}", |
| 371 | + include_path=include_path, |
357 | 372 | ) |
358 | 373 | prog = Program(code_verify, code_type="c++", options=program_options) |
359 | 374 |
|
|
0 commit comments