Skip to content

Commit ca2eed6

Browse files
committed
Fix non-release kernel builds via CudaBuilder
First, we weren't handling all the types. After fixing that, it exposed a `libnvvm` crash. Also saw a type issue in one of the warp APIs used by vecadd so fixed that. Fixes #320
1 parent 6ffb344 commit ca2eed6

4 files changed

Lines changed: 21 additions & 2 deletions

File tree

crates/cuda_builder/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,12 @@ fn invoke_rustc(builder: &CudaBuilder) -> Result<PathBuf, CudaBuilderError> {
715715
rustflags.push(format!("--emit={string}"));
716716
}
717717

718+
if builder.debug == DebugInfo::None {
719+
// Default dev builds: strip debuginfo to avoid libnvvm crashes with unoptimized IR.
720+
// TODO: drop this once newer libnvvm toolchains are stable with debuginfo in opt=0 builds.
721+
rustflags.push("-Cdebuginfo=0".into());
722+
}
723+
718724
let mut llvm_args = vec![NvvmOption::Arch(builder.arch).to_string()];
719725

720726
if !builder.nvvm_opts {

crates/cuda_std/src/warp.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,9 @@ unsafe fn match_any_32(mask: u32, value: u32) -> u32 {
296296
unsafe fn match_any_64(mask: u32, value: u64) -> u32 {
297297
extern "C" {
298298
#[link_name = "llvm.nvvm.match.any.sync.i64"]
299-
fn __nvvm_warp_match_any_64(mask: u32, value: u64) -> u32;
299+
fn __nvvm_warp_match_any_64(mask: u32, value: u64) -> u64;
300300
}
301-
__nvvm_warp_match_any_64(mask, value)
301+
__nvvm_warp_match_any_64(mask, value) as u32
302302
}
303303

304304
#[gpu_only]

crates/rustc_codegen_nvvm/src/override_fns.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::context::CodegenCx;
88
use crate::llvm;
99
use rustc_codegen_ssa::mono_item::MonoItemExt;
1010
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
11+
use rustc_hir::def::DefKind;
1112
use rustc_hir::def_id::LOCAL_CRATE;
1213
use rustc_middle::mir::mono::{Linkage, MonoItem, MonoItemData, Visibility};
1314
use rustc_middle::ty::layout::FnAbiOf;
@@ -43,6 +44,12 @@ fn should_override<'tcx>(func: Instance<'tcx>, cx: &CodegenCx<'_, 'tcx>) -> bool
4344
return false;
4445
}
4546

47+
// Only try to override top-level/assoc functions; closures/anon fns cause ICE via item_name.
48+
match cx.tcx.def_kind(func.def_id()) {
49+
DefKind::Fn | DefKind::AssocFn => {}
50+
_ => return false,
51+
}
52+
4653
let sym = cx.tcx.item_name(func.def_id());
4754
let name = sym.as_str();
4855

crates/rustc_codegen_nvvm/src/ty.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,16 @@ impl<'ll, 'tcx> BaseTypeCodegenMethods for CodegenCx<'ll, 'tcx> {
228228

229229
fn float_width(&self, ty: &'ll Type) -> usize {
230230
match self.type_kind(ty) {
231+
TypeKind::Half => 16,
231232
TypeKind::Float => 32,
232233
TypeKind::Double => 64,
233234
TypeKind::X86_FP80 => 80,
234235
TypeKind::FP128 | TypeKind::PPC_FP128 => 128,
236+
TypeKind::BFloat => 16,
237+
TypeKind::Vector | TypeKind::ScalableVector => {
238+
// Recurse on element type for vector floats
239+
self.float_width(self.element_type(ty))
240+
}
235241
_ => bug!("llvm_float_width called on a non-float type"),
236242
}
237243
}

0 commit comments

Comments
 (0)