Skip to content

Commit fbf7a7a

Browse files
committed
cuda-bindings: handle CUDA header discovery errors
Report missing or unreadable CUDA headers directly from the build script and discover CUDA headers and libraries in both the standard toolkit layout and target-specific layouts such as targets/x86_64-linux and targets/sbsa-linux. Signed-off-by: yagna-1 <yagna-1@users.noreply.github.com>
1 parent 76a1d88 commit fbf7a7a

1 file changed

Lines changed: 42 additions & 18 deletions

File tree

crates/cuda-bindings/build.rs

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,47 +32,71 @@ fn run() -> Result<(), Box<dyn Error>> {
3232
println!("cargo::rustc-check-cfg=cfg(cuda_has_cuEventElapsedTime_v2)");
3333

3434
let toolkit = cuda_toolkit_dir();
35-
let cuda_h = Path::new(&toolkit).join("include/cuda.h");
35+
let include_paths = collect_include_paths(&toolkit);
36+
let cuda_h = include_paths
37+
.iter()
38+
.map(|include| include.join("cuda.h"))
39+
.find(|path| path.is_file())
40+
.ok_or_else(|| {
41+
format!(
42+
"cuda-bindings: could not find cuda.h under {}. Set CUDA_TOOLKIT_PATH or CUDA_HOME to a CUDA Toolkit install that contains include/cuda.h or targets/x86_64-linux/include/cuda.h.",
43+
toolkit
44+
)
45+
})?;
3646
println!("cargo:rerun-if-changed={}", cuda_h.display());
3747

38-
match std::fs::read_to_string(&cuda_h) {
39-
Ok(contents) => {
40-
if contents.contains("cuEventElapsedTime_v2") {
41-
println!("cargo:rustc-cfg=cuda_has_cuEventElapsedTime_v2");
42-
}
43-
}
44-
Err(err) => {
45-
println!(
46-
"cargo:warning=cuda-bindings: Could not read cuda.h at {}: {}",
47-
cuda_h.display(),
48-
err
49-
);
50-
}
48+
let contents = std::fs::read_to_string(&cuda_h).map_err(|err| {
49+
format!(
50+
"cuda-bindings: could not read cuda.h at {}: {}",
51+
cuda_h.display(),
52+
err
53+
)
54+
})?;
55+
if contents.contains("cuEventElapsedTime_v2") {
56+
println!("cargo:rustc-cfg=cuda_has_cuEventElapsedTime_v2");
5157
}
5258

5359
for path in collect_lib_paths(&toolkit) {
5460
println!("cargo:rustc-link-search=native={}", path.display());
5561
}
5662
println!("cargo:rustc-link-lib=dylib=cuda");
5763

58-
bindgen::builder()
64+
let mut builder = bindgen::builder()
5965
.header("wrapper.h")
60-
.clang_arg(format!("-I{}/include", toolkit))
6166
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
6267
// CUDA 13.2+ adds types to CUlaunchAttributeValue that bindgen/libclang
6368
// cannot translate, collapsing the struct to a 1-byte opaque blob while the
6469
// size assertion still expects the real C size. Making both the struct and its
6570
// inner union opaque produces correctly-sized byte blobs across CUDA versions.
6671
// launch_kernel_ex in cuda-core constructs this struct via raw pointer writes.
6772
.opaque_type("CUlaunchAttribute_st")
68-
.opaque_type("CUlaunchAttributeValue_union")
73+
.opaque_type("CUlaunchAttributeValue_union");
74+
for include_path in include_paths {
75+
builder = builder.clang_arg(format!("-I{}", include_path.display()));
76+
}
77+
builder
6978
.generate()
70-
.expect("Unable to generate CUDA bindings")
79+
.map_err(|err| format!("cuda-bindings: unable to generate CUDA bindings: {err}"))?
7180
.write_to_file(Path::new(&env::var("OUT_DIR")?).join("bindings.rs"))?;
7281

7382
Ok(())
7483
}
7584

85+
/// Candidate CUDA include directories for both standard and redistributable layouts.
86+
fn collect_include_paths(toolkit: &str) -> Vec<PathBuf> {
87+
let base = PathBuf::from(toolkit);
88+
let mut paths = Vec::new();
89+
90+
paths.push(base.join("include"));
91+
92+
let targets_include = base.join("targets/x86_64-linux/include");
93+
if targets_include.join("cuda.h").is_file() {
94+
paths.push(targets_include);
95+
}
96+
97+
paths
98+
}
99+
76100
/// Candidate directories for `rustc-link-search=native` when linking against the driver library.
77101
///
78102
/// Adds `{toolkit}/lib64` and `{toolkit}/lib64/stubs` when `lib64` exists. If

0 commit comments

Comments
 (0)