Skip to content

Commit 47547c6

Browse files
committed
fix
1 parent 5d3b9f7 commit 47547c6

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

cuda_core/examples/mdspan_verify_args.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,29 @@
1818
#
1919
# ################################################################################
2020

21+
import os, sys
2122
import cupy as cp
2223
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch
2324

25+
# prepare include
26+
cuda_path = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME"))
27+
if cuda_path is None:
28+
print("this demo requires a valid CUDA_PATH environment variable set", file=sys.stderr)
29+
sys.exit(0)
30+
cuda_include = os.path.join(cuda_path, "include")
31+
assert os.path.isdir(cuda_include)
32+
include_path = [cuda_include]
33+
cccl_include = os.path.join(cuda_include, "cccl")
34+
if os.path.isdir(cccl_include):
35+
include_path.insert(0, cccl_include)
36+
2437
# ################################################################################
2538
# C++ Kernel Code for Verifying mdspan Arguments
2639
# ################################################################################
2740

2841
# Verification kernels that print mdspan properties using printf
2942
code_verify = """
3043
#include <cuda/std/mdspan>
31-
#include <cstdio>
3244
3345
// Kernel to verify layout_right (C-order) mdspan arguments
3446
template<typename T>
@@ -236,6 +248,7 @@ def verify_layout_right():
236248
program_options = ProgramOptions(
237249
std="c++17",
238250
arch=f"sm_{dev.arch}",
251+
include_path=include_path,
239252
)
240253
prog = Program(code_verify, code_type="c++", options=program_options)
241254

@@ -295,6 +308,7 @@ def verify_layout_left():
295308
program_options = ProgramOptions(
296309
std="c++17",
297310
arch=f"sm_{dev.arch}",
311+
include_path=include_path,
298312
)
299313
prog = Program(code_verify, code_type="c++", options=program_options)
300314

@@ -354,6 +368,7 @@ def verify_layout_stride():
354368
program_options = ProgramOptions(
355369
std="c++17",
356370
arch=f"sm_{dev.arch}",
371+
include_path=include_path,
357372
)
358373
prog = Program(code_verify, code_type="c++", options=program_options)
359374

0 commit comments

Comments
 (0)