Skip to content

Commit f2caad6

Browse files
committed
Merge branch 'unroll'
2 parents 5392e43 + 9fa20e5 commit f2caad6

File tree

3 files changed

+28
-3
lines changed

3 files changed

+28
-3
lines changed

crates/cuda_builder/src/lib.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,10 @@ pub struct CudaBuilder {
194194
/// An optional path where to dump LLVM IR of the final output the codegen will feed to libnvvm. Usually
195195
/// used for debugging.
196196
pub final_module_path: Option<PathBuf>,
197+
/// The threshold for LLVM's loop unrolling optimization pass. Higher values allow more
198+
/// aggressive unrolling, which can improve performance but increases code size.
199+
/// When `None`, LLVM uses its default threshold.
200+
pub unroll_threshold: Option<u32>,
197201
}
198202

199203
impl CudaBuilder {
@@ -216,6 +220,7 @@ impl CudaBuilder {
216220
debug: DebugInfo::None,
217221
build_args: vec![],
218222
final_module_path: None,
223+
unroll_threshold: None,
219224
}
220225
}
221226

@@ -351,6 +356,13 @@ impl CudaBuilder {
351356
self
352357
}
353358

359+
/// Sets the threshold for LLVM's loop unrolling optimization pass. Higher values allow more
360+
/// aggressive unrolling, which can improve performance but increases code size.
361+
pub fn unroll_threshold(mut self, threshold: u32) -> Self {
362+
self.unroll_threshold = Some(threshold);
363+
self
364+
}
365+
354366
/// Runs rustc to build the codegen and codegens the gpu crate, returning the path of the final
355367
/// ptx file. If [`ptx_file_copy_path`](Self::ptx_file_copy_path) is set, this returns the copied path.
356368
pub fn build(self) -> Result<PathBuf, CudaBuilderError> {
@@ -748,6 +760,10 @@ fn invoke_rustc(builder: &CudaBuilder) -> Result<PathBuf, CudaBuilderError> {
748760
llvm_args.push(path.to_str().unwrap().to_string());
749761
}
750762

763+
if let Some(threshold) = builder.unroll_threshold {
764+
llvm_args.push(format!("-unroll-threshold={threshold}"));
765+
}
766+
751767
if builder.debug != DebugInfo::None {
752768
let (nvvm_flag, rustc_flag) = builder.debug.into_nvvm_and_rustc_options();
753769
llvm_args.push(nvvm_flag);

crates/rustc_codegen_nvvm/src/context.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,7 @@ pub struct CodegenArgs {
654654
pub use_constant_memory_space: bool,
655655
pub final_module_path: Option<PathBuf>,
656656
pub disassemble: Option<DisassembleMode>,
657+
pub unroll_threshold: Option<u32>,
657658
}
658659

659660
impl CodegenArgs {
@@ -712,6 +713,11 @@ impl CodegenArgs {
712713
skip_next = true;
713714
} else if let Some(entry) = arg.strip_prefix("--disassemble-entry=") {
714715
cg_args.disassemble = Some(DisassembleMode::Entry(entry.to_string()));
716+
} else if let Some(threshold) = arg.strip_prefix("-unroll-threshold=") {
717+
cg_args.unroll_threshold = Some(threshold.parse().unwrap_or_else(|_| {
718+
sess.dcx()
719+
.fatal("-unroll-threshold requires a valid integer value")
720+
}));
715721
} else {
716722
// Do this only after all the other flags above have been tried.
717723
match NvvmOption::from_str(arg) {

crates/rustc_codegen_nvvm/src/init.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,12 @@ unsafe fn configure_llvm(sess: &Session) {
107107
// Use non-zero `import-instr-limit` multiplier for cold callsites.
108108
add("-import-cold-multiplier=0.1", false);
109109

110-
// for arg in sess_args {
111-
// add(&(*arg), true);
112-
// }
110+
// Forward unroll-threshold if specified in llvm_args
111+
for arg in &sess.opts.cg.llvm_args {
112+
if arg.starts_with("-unroll-threshold=") {
113+
add(arg, true);
114+
}
115+
}
113116
}
114117

115118
unsafe { llvm::LLVMInitializePasses() };

0 commit comments

Comments
 (0)