Skip to content

Commit 2f4fd1d

Browse files
brandonrosclaude
andcommitted
fix(cuda_std): pack warp shuffle/match return into i64 for LLVM 19 verifier
LLVM 19's verifier rejects the `align N` return-attribute that rustc's C ABI lowering attaches to calls returning small aggregates like { i32, i8 } (align is only valid on pointer returns). Three intrinsic wrappers in libintrinsics.ll triggered this: - __nvvm_warp_shuffle - __nvvm_warp_match_all_32 - __nvvm_warp_match_all_64 Switch their return type from { i32, i8 } to a packed i64 (low 32 bits = value, bit 32 = predicate). Primitive integer return ⇒ no struct ABI ⇒ no spurious return-attribute. Uses only LLVM 1.0-era IR primitives (zext/shl/or), so it's safe under both LLVM 7 (CUDA 12.x libnvvm) and LLVM 19 (CUDA 13.x libnvvm). Removes the now-redundant WarpShuffleResult struct. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent ab52ac7 commit 2f4fd1d

2 files changed

Lines changed: 53 additions & 44 deletions

File tree

crates/cuda_std/src/warp.rs

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -313,20 +313,20 @@ unsafe fn match_any_64(mask: u32, value: u64) -> u32 {
313313
#[inline(always)]
314314
unsafe fn match_all_32(mask: u32, value: u32) -> (u32, bool) {
315315
unsafe extern "C" {
316-
#[allow(improper_ctypes)]
317-
fn __nvvm_warp_match_all_32(mask: u32, value: u32) -> (u32, bool);
316+
// see libintrinsics.ll — packs (value, predicate) into i64
317+
fn __nvvm_warp_match_all_32(mask: u32, value: u32) -> u64;
318318
}
319-
unsafe { __nvvm_warp_match_all_32(mask, value) }
319+
unpack_warp_result(unsafe { __nvvm_warp_match_all_32(mask, value) })
320320
}
321321

322322
#[gpu_only]
323323
#[inline(always)]
324324
unsafe fn match_all_64(mask: u32, value: u64) -> (u32, bool) {
325325
unsafe extern "C" {
326-
#[allow(improper_ctypes)]
327-
fn __nvvm_warp_match_all_64(mask: u32, value: u64) -> (u32, bool);
326+
// see libintrinsics.ll — packs (value, predicate) into i64
327+
fn __nvvm_warp_match_all_64(mask: u32, value: u64) -> u64;
328328
}
329-
unsafe { __nvvm_warp_match_all_64(mask, value) }
329+
unpack_warp_result(unsafe { __nvvm_warp_match_all_64(mask, value) })
330330
}
331331

332332
/// Synchronizes a subset of threads in a warp then performs a reduce-and-broadcast
@@ -741,14 +741,16 @@ pub enum WarpShuffleMode {
741741
Xor = 3,
742742
}
743743

744-
// C-compatible struct to match LLVM IR's {i32, i8} return type
745-
// This fixes an ABI mismatch where Rust would represent (u32, bool) as [2 x i32]
746-
// but the LLVM intrinsic returns {i32, i8} (a struct, not an array)
747-
#[doc(hidden)]
748-
#[repr(C)]
749-
pub struct WarpShuffleResult {
750-
value: u32,
751-
predicate: u8,
744+
// The libintrinsics.ll wrappers pack their (value, predicate) result into a
745+
// single i64: low 32 bits = value, bit 32 = predicate. Returning a primitive
746+
// integer avoids the small-aggregate ABI path where rustc attaches `align N`
747+
// to the call's return value — an attribute LLVM 19's verifier rejects on
748+
// non-pointer returns.
749+
// Unused on host targets — every caller is `#[gpu_only]`.
750+
#[allow(dead_code)]
751+
#[inline(always)]
752+
fn unpack_warp_result(packed: u64) -> (u32, bool) {
753+
(packed as u32, (packed >> 32) & 1 != 0)
752754
}
753755

754756
#[gpu_only]
@@ -761,8 +763,7 @@ unsafe fn warp_shuffle_32(
761763
) -> (u32, bool) {
762764
unsafe extern "C" {
763765
// see libintrinsics.ll
764-
// Returns {i32, i8} in LLVM IR, which maps to our WarpShuffleResult struct
765-
fn __nvvm_warp_shuffle(mask: u32, mode: u32, a: u32, b: u32, c: u32) -> WarpShuffleResult;
766+
fn __nvvm_warp_shuffle(mask: u32, mode: u32, a: u32, b: u32, c: u32) -> u64;
766767
}
767768

768769
assert!(
@@ -776,7 +777,7 @@ unsafe fn warp_shuffle_32(
776777
c |= (32 - width) << 8;
777778

778779
let result = unsafe { __nvvm_warp_shuffle(mask, mode as u32, value, b, c) };
779-
(result.value, result.predicate != 0)
780+
unpack_warp_result(result)
780781
}
781782

782783
unsafe fn warp_shuffle_128(

crates/rustc_codegen_nvvm/libintrinsics.ll

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -152,44 +152,52 @@ start:
152152
}
153153
declare {i16, i1} @llvm.umul.with.overflow.i16(i16, i16) #0
154154

155-
; Required because we need to explicitly generate { i32, i1 } for the following intrinsics
156-
; except rustc will not generate them (it will make { i32, i8 }) which libnvvm rejects.
157-
158-
define { i32, i8 } @__nvvm_warp_shuffle(i32, i32, i32, i32, i32) #1 {
155+
; NVVM intrinsics return { i32, i1 }, but rustc lowering of (u32, bool) — or any
156+
; small two-field aggregate — produces { i32, i8 }, which libnvvm rejects. We
157+
; used to bridge by re-packing into { i32, i8 } here, but that aggregate return
158+
; causes rustc's call-site ABI to attach `align N` to the return value, which
159+
; LLVM 19's verifier rejects (align is only valid on pointer returns). So we
160+
; pack into a plain i64 instead: low 32 bits = value, bit 32 = predicate.
161+
; Primitive integer return ⇒ no struct ABI ⇒ no spurious return-attribute.
162+
163+
define i64 @__nvvm_warp_shuffle(i32, i32, i32, i32, i32) #1 {
159164
start:
160-
%5 = call { i32, i1 } @llvm.nvvm.shfl.sync.i32(i32 %0, i32 %1, i32 %2, i32 %3, i32 %4)
161-
%6 = extractvalue { i32, i1 } %5, 1
162-
%7 = zext i1 %6 to i8
163-
%8 = extractvalue { i32, i1 } %5, 0
164-
%9 = insertvalue { i32, i8 } undef, i32 %8, 0
165-
%10 = insertvalue { i32, i8 } %9, i8 %7, 1
166-
ret { i32, i8 } %10
165+
%r = call { i32, i1 } @llvm.nvvm.shfl.sync.i32(i32 %0, i32 %1, i32 %2, i32 %3, i32 %4)
166+
%val = extractvalue { i32, i1 } %r, 0
167+
%pred = extractvalue { i32, i1 } %r, 1
168+
%val64 = zext i32 %val to i64
169+
%pred64 = zext i1 %pred to i64
170+
%pred_hi = shl i64 %pred64, 32
171+
%packed = or i64 %val64, %pred_hi
172+
ret i64 %packed
167173
}
168174

169175
declare { i32, i1 } @llvm.nvvm.shfl.sync.i32(i32, i32, i32, i32, i32) #1
170176

171-
define { i32, i8 } @__nvvm_warp_match_all_32(i32, i32) {
177+
define i64 @__nvvm_warp_match_all_32(i32, i32) {
172178
start:
173-
%2 = call { i32, i1 } @llvm.nvvm.match.all.sync.i32(i32 %0, i32 %1)
174-
%3 = extractvalue { i32, i1 } %2, 1
175-
%4 = zext i1 %3 to i8
176-
%5 = extractvalue { i32, i1 } %2, 0
177-
%6 = insertvalue { i32, i8 } undef, i32 %5, 0
178-
%7 = insertvalue { i32, i8 } %6, i8 %4, 1
179-
ret { i32, i8 } %7
179+
%r = call { i32, i1 } @llvm.nvvm.match.all.sync.i32(i32 %0, i32 %1)
180+
%val = extractvalue { i32, i1 } %r, 0
181+
%pred = extractvalue { i32, i1 } %r, 1
182+
%val64 = zext i32 %val to i64
183+
%pred64 = zext i1 %pred to i64
184+
%pred_hi = shl i64 %pred64, 32
185+
%packed = or i64 %val64, %pred_hi
186+
ret i64 %packed
180187
}
181188

182189
declare { i32, i1 } @llvm.nvvm.match.all.sync.i32(i32, i32) #1
183190

184-
define { i32, i8 } @__nvvm_warp_match_all_64(i32, i64) {
191+
define i64 @__nvvm_warp_match_all_64(i32, i64) {
185192
start:
186-
%2 = call { i32, i1 } @llvm.nvvm.match.all.sync.i64(i32 %0, i64 %1)
187-
%3 = extractvalue { i32, i1 } %2, 1
188-
%4 = zext i1 %3 to i8
189-
%5 = extractvalue { i32, i1 } %2, 0
190-
%6 = insertvalue { i32, i8 } undef, i32 %5, 0
191-
%7 = insertvalue { i32, i8 } %6, i8 %4, 1
192-
ret { i32, i8 } %7
193+
%r = call { i32, i1 } @llvm.nvvm.match.all.sync.i64(i32 %0, i64 %1)
194+
%val = extractvalue { i32, i1 } %r, 0
195+
%pred = extractvalue { i32, i1 } %r, 1
196+
%val64 = zext i32 %val to i64
197+
%pred64 = zext i1 %pred to i64
198+
%pred_hi = shl i64 %pred64, 32
199+
%packed = or i64 %val64, %pred_hi
200+
ret i64 %packed
193201
}
194202

195203
declare { i32, i1 } @llvm.nvvm.match.all.sync.i64(i32, i64) #1

0 commit comments

Comments
 (0)