Skip to content

Commit fed6e8a

Browse files
committed
Add compiletests
These tests check that code compiles (or doesn't) as expected. The framework also has the capability to check the generated ptx (see the tests in `dis/`. These tests are not run, so they do not validate runtime behavior. This also means they can run on machines without a GPU, like GitHub Actions runners. This is largely modeled / structured after rust-gpu's compiletests in anticipation of a project merge one day.
1 parent aa7e615 commit fed6e8a

30 files changed

+1691
-3
lines changed

.cargo/config.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
[alias]
22
xtask = "run -p xtask --bin xtask --"
3+
compiletest = "run --release -p compiletests --"

.github/workflows/ci_linux.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,18 @@ jobs:
153153
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
154154
run: |
155155
echo "Stubbed out"
156+
compiletest:
157+
name: Compile tests
158+
runs-on: ubuntu-latest
159+
container:
160+
image: "ghcr.io/rust-gpu/rust-cuda-ubuntu24-cuda12:latest"
161+
steps:
162+
- name: Checkout repository
163+
uses: actions/checkout@v4
164+
- name: Run cargo version
165+
run: cargo --version
166+
- name: Rustfmt compiletests
167+
shell: bash
168+
run: shopt -s globstar && rustfmt --check tests/compiletests/ui/**/*.rs
169+
- name: Compiletest
170+
run: cargo run -p compiletests --release --no-default-features -- --target-arch compute_61,compute_70,compute_90

.github/workflows/ci_windows.yml

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,18 @@ jobs:
2626
target: x86_64-pc-windows-msvc
2727
cuda: "12.8.1"
2828
linux-local-args: []
29-
sub-packages: ["nvcc", "nvrtc", "nvrtc_dev", "cuda_profiler_api", "cudart", "cublas", "cublas_dev", "curand", "curand_dev"]
29+
sub-packages:
30+
[
31+
"nvcc",
32+
"nvrtc",
33+
"nvrtc_dev",
34+
"cuda_profiler_api",
35+
"cudart",
36+
"cublas",
37+
"cublas_dev",
38+
"curand",
39+
"curand_dev",
40+
]
3041

3142
steps:
3243
- name: Checkout repository
@@ -41,7 +52,7 @@ jobs:
4152
linux-local-args: ${{ toJson(matrix.linux-local-args) }}
4253
use-local-cache: false
4354
sub-packages: ${{ toJson(matrix.sub-packages) }}
44-
log-file-suffix: '${{matrix.os}}-${{matrix.cuda}}'
55+
log-file-suffix: "${{matrix.os}}-${{matrix.cuda}}"
4556

4657
- name: Verify CUDA installation
4758
run: nvcc --version
@@ -76,3 +87,6 @@ jobs:
7687
env:
7788
RUSTDOCFLAGS: -Dwarnings
7889
run: cargo doc --workspace --all-features --document-private-items --no-deps --exclude "optix*" --exclude "path-tracer" --exclude "denoiser" --exclude "vecadd*" --exclude "gemm*" --exclude "ex*" --exclude "cudnn*" --exclude "cust_raw"
90+
# Disabled due to dll issues, someone with Windows knowledge needed
91+
# - name: Compiletest
92+
# run: cargo run -p compiletests --release --no-default-features -- --target-arch compute_61,compute_70,compute_90

Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ members = [
1616
"examples/cuda/path_tracer/kernels",
1717

1818
"examples/optix/*",
19+
"tests/compiletests",
20+
"tests/compiletests/deps-helper",
1921
]
2022

2123
exclude = [
@@ -24,3 +26,7 @@ exclude = [
2426

2527
[profile.dev.package.rustc_codegen_nvvm]
2628
opt-level = 3
29+
30+
[workspace.dependencies]
31+
cuda_std = { path = "crates/cuda_std" }
32+
cuda_builder = { path = "crates/cuda_builder" }

crates/rustc_codegen_nvvm/src/context.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,10 @@ pub struct CodegenArgs {
557557
pub override_libm: bool,
558558
pub use_constant_memory_space: bool,
559559
pub final_module_path: Option<PathBuf>,
560+
pub disassemble: bool,
561+
pub disassemble_fn: Option<String>,
562+
pub disassemble_entry: Option<String>,
563+
pub disassemble_globals: bool,
560564
}
561565

562566
impl CodegenArgs {
@@ -569,7 +573,13 @@ impl CodegenArgs {
569573
// TODO: replace this with a "proper" arg parser.
570574
let mut cg_args = Self::default();
571575

576+
let mut skip_next = false;
572577
for (idx, arg) in args.iter().enumerate() {
578+
if skip_next {
579+
skip_next = false;
580+
continue;
581+
}
582+
573583
if let Ok(flag) = NvvmOption::from_str(arg) {
574584
cg_args.nvvm_options.push(flag);
575585
} else if arg == "--override-libm" {
@@ -580,6 +590,29 @@ impl CodegenArgs {
580590
cg_args.final_module_path = Some(PathBuf::from(
581591
args.get(idx + 1).expect("No path for --final-module-path"),
582592
));
593+
skip_next = true;
594+
} else if arg == "--disassemble" {
595+
cg_args.disassemble = true;
596+
} else if arg == "--disassemble-globals" {
597+
cg_args.disassemble_globals = true;
598+
} else if arg == "--disassemble-fn" {
599+
cg_args.disassemble_fn = Some(
600+
args.get(idx + 1)
601+
.expect("No function name for --disassemble-fn")
602+
.clone(),
603+
);
604+
skip_next = true;
605+
} else if let Some(func) = arg.strip_prefix("--disassemble-fn=") {
606+
cg_args.disassemble_fn = Some(func.to_string());
607+
} else if arg == "--disassemble-entry" {
608+
cg_args.disassemble_entry = Some(
609+
args.get(idx + 1)
610+
.expect("No entry name for --disassemble-entry")
611+
.clone(),
612+
);
613+
skip_next = true;
614+
} else if let Some(entry) = arg.strip_prefix("--disassemble-entry=") {
615+
cg_args.disassemble_entry = Some(entry.to_string());
583616
}
584617
}
585618

crates/rustc_codegen_nvvm/src/lib.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ mod lto;
5252
mod mono_item;
5353
mod nvvm;
5454
mod override_fns;
55+
mod ptx_filter;
5556
mod target;
5657
mod ty;
5758

@@ -216,14 +217,30 @@ impl CodegenBackend for NvvmCodegenBackend {
216217
let cmdline = sess.opts.cg.target_feature.split(',');
217218
let cfg = sess.target.options.features.split(',');
218219

219-
let target_features: Vec<_> = cfg
220+
let mut target_features: Vec<_> = cfg
220221
.chain(cmdline)
221222
.filter(|l| l.starts_with('+'))
222223
.map(|l| &l[1..])
223224
.filter(|l| !l.is_empty())
224225
.map(rustc_span::Symbol::intern)
225226
.collect();
226227

228+
// Add backend-synthesized features (e.g., hierarchical compute capabilities)
229+
// Parse CodegenArgs to get the architecture from llvm-args
230+
let args = context::CodegenArgs::from_session(sess);
231+
for opt in &args.nvvm_options {
232+
if let ::nvvm::NvvmOption::Arch(arch) = opt {
233+
// Add all features up to and including the current architecture
234+
let backend_features = arch.all_target_features();
235+
target_features.extend(
236+
backend_features
237+
.iter()
238+
.map(|f| rustc_span::Symbol::intern(f)),
239+
);
240+
break;
241+
}
242+
}
243+
227244
// For NVPTX, all target features are stable
228245
let unstable_target_features = target_features.clone();
229246

crates/rustc_codegen_nvvm/src/link.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use tracing::{debug, trace};
3030

3131
use crate::LlvmMod;
3232
use crate::context::CodegenArgs;
33+
use crate::ptx_filter::{PtxFilter, PtxFilterConfig};
3334

3435
pub(crate) struct NvvmMetadataLoader;
3536

@@ -305,6 +306,34 @@ fn codegen_into_ptx_file(
305306
}
306307
};
307308

309+
// If disassembly is requested, print PTX to stderr
310+
if (args.disassemble
311+
|| args.disassemble_globals
312+
|| args.disassemble_fn.is_some()
313+
|| args.disassemble_entry.is_some())
314+
&& let Ok(ptx_str) = std::str::from_utf8(&ptx_bytes)
315+
{
316+
let config = PtxFilterConfig::from_codegen_args(&args);
317+
let filter = PtxFilter::new(config);
318+
let output = filter.filter(ptx_str);
319+
if !output.is_empty() {
320+
// Check if we're in JSON mode by checking the error format
321+
use rustc_session::config::ErrorOutputType;
322+
match sess.opts.error_format {
323+
ErrorOutputType::Json { .. } => {
324+
sess.dcx()
325+
.err("PTX disassembly output in JSON mode is not supported");
326+
}
327+
_ => {
328+
// In normal mode, just print to stderr
329+
// Replace tabs with spaces for cleaner output
330+
let output = output.replace('\t', " ");
331+
eprintln!("{output}");
332+
}
333+
}
334+
}
335+
}
336+
308337
std::fs::write(out_filename, ptx_bytes)
309338
}
310339

0 commit comments

Comments
 (0)