Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 60 additions & 25 deletions crates/cuda-bindings/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

use std::{env, error::Error, path::Path, path::PathBuf, process::exit};
use std::{env, error::Error, fs, path::Path, path::PathBuf, process::exit};

/// Returns the CUDA toolkit install root: `CUDA_TOOLKIT_PATH` or `CUDA_HOME` if set,
/// otherwise `/usr/local/cuda`. Used for include paths, library search paths,
Expand Down Expand Up @@ -32,52 +32,88 @@ fn run() -> Result<(), Box<dyn Error>> {
println!("cargo::rustc-check-cfg=cfg(cuda_has_cuEventElapsedTime_v2)");

let toolkit = cuda_toolkit_dir();
let cuda_h = Path::new(&toolkit).join("include/cuda.h");
let include_paths = collect_include_paths(&toolkit);
let cuda_h = include_paths
.iter()
.map(|include| include.join("cuda.h"))
.find(|path| path.is_file())
.ok_or_else(|| {
format!(
"cuda-bindings: could not find cuda.h under {}. Set CUDA_TOOLKIT_PATH or CUDA_HOME to a CUDA Toolkit install that contains include/cuda.h or targets/*/include/cuda.h.",
toolkit
)
})?;
println!("cargo:rerun-if-changed={}", cuda_h.display());

match std::fs::read_to_string(&cuda_h) {
Ok(contents) => {
if contents.contains("cuEventElapsedTime_v2") {
println!("cargo:rustc-cfg=cuda_has_cuEventElapsedTime_v2");
}
}
Err(err) => {
println!(
"cargo:warning=cuda-bindings: Could not read cuda.h at {}: {}",
cuda_h.display(),
err
);
}
let contents = std::fs::read_to_string(&cuda_h).map_err(|err| {
format!(
"cuda-bindings: could not read cuda.h at {}: {}",
cuda_h.display(),
err
)
})?;
if contents.contains("cuEventElapsedTime_v2") {
println!("cargo:rustc-cfg=cuda_has_cuEventElapsedTime_v2");
}

for path in collect_lib_paths(&toolkit) {
println!("cargo:rustc-link-search=native={}", path.display());
}
println!("cargo:rustc-link-lib=dylib=cuda");

bindgen::builder()
let mut builder = bindgen::builder()
.header("wrapper.h")
.clang_arg(format!("-I{}/include", toolkit))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
// CUDA 13.2+ adds types to CUlaunchAttributeValue that bindgen/libclang
// cannot translate, collapsing the struct to a 1-byte opaque blob while the
// size assertion still expects the real C size. Making both the struct and its
// inner union opaque produces correctly-sized byte blobs across CUDA versions.
// launch_kernel_ex in cuda-core constructs this struct via raw pointer writes.
.opaque_type("CUlaunchAttribute_st")
.opaque_type("CUlaunchAttributeValue_union")
.opaque_type("CUlaunchAttributeValue_union");
for include_path in include_paths {
builder = builder.clang_arg(format!("-I{}", include_path.display()));
}
builder
.generate()
.expect("Unable to generate CUDA bindings")
.map_err(|err| format!("cuda-bindings: unable to generate CUDA bindings: {err}"))?
.write_to_file(Path::new(&env::var("OUT_DIR")?).join("bindings.rs"))?;

Ok(())
}

/// Candidate CUDA include directories for both standard and redistributable layouts.
fn collect_include_paths(toolkit: &str) -> Vec<PathBuf> {
let base = PathBuf::from(toolkit);
let mut paths = Vec::new();

paths.push(base.join("include"));
for target in collect_target_roots(&base) {
paths.push(target.join("include"));
}

paths
}

/// CUDA target-layout roots such as `targets/x86_64-linux` or `targets/sbsa-linux`.
fn collect_target_roots(base: &Path) -> Vec<PathBuf> {
let targets_dir = base.join("targets");
let mut roots: Vec<_> = fs::read_dir(targets_dir)
.into_iter()
.flatten()
.filter_map(Result::ok)
.map(|entry| entry.path())
.filter(|path| path.join("include/cuda.h").is_file())
.collect();
roots.sort();
roots
}

/// Candidate directories for `rustc-link-search=native` when linking against the driver library.
///
/// Adds `{toolkit}/lib64` and `{toolkit}/lib64/stubs` when `lib64` exists. If
/// `{toolkit}/targets/x86_64-linux/include/cuda.h` exists (redistributable / cross-layout install),
/// also adds `targets/x86_64-linux/lib` and `.../lib/stubs`. Order is preserved; duplicates are not
/// `{toolkit}/targets/*/include/cuda.h` exists (redistributable / cross-layout install),
/// also adds each matching target's `lib` and `lib/stubs`. Order is preserved; duplicates are not
/// filtered.
fn collect_lib_paths(toolkit: &str) -> Vec<PathBuf> {
let base = PathBuf::from(toolkit);
Expand All @@ -89,10 +125,9 @@ fn collect_lib_paths(toolkit: &str) -> Vec<PathBuf> {
paths.push(lib64.join("stubs"));
}

let targets = base.join("targets/x86_64-linux");
if targets.join("include/cuda.h").is_file() {
paths.push(targets.join("lib"));
paths.push(targets.join("lib/stubs"));
for target in collect_target_roots(&base) {
paths.push(target.join("lib"));
paths.push(target.join("lib/stubs"));
}

paths
Expand Down