@@ -3,8 +3,8 @@ use std::ffi::c_uint;
33use std:: { assert_matches, iter, ptr} ;
44
55use rustc_abi:: {
6- Align , BackendRepr , Float , HasDataLayout , Integer , NumScalableVectors , Primitive , Size ,
7- WrappingRange ,
6+ AddressSpace , Align , BackendRepr , Float , HasDataLayout , Integer , NumScalableVectors , Primitive ,
7+ Size , WrappingRange ,
88} ;
99use rustc_codegen_ssa:: base:: { compare_simd_types, wants_msvc_seh, wants_wasm_eh} ;
1010use rustc_codegen_ssa:: common:: { IntPredicate , TypeKind } ;
@@ -178,6 +178,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
178178 span : Span ,
179179 ) -> Result < ( ) , ty:: Instance < ' tcx > > {
180180 let tcx = self . tcx ;
181+ let llvm_version = crate :: llvm_util:: get_version ( ) ;
181182
182183 let name = tcx. item_name ( instance. def_id ( ) ) ;
183184 let fn_args = instance. args ;
@@ -194,7 +195,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
194195 | sym:: maximum_number_nsz_f64
195196 | sym:: maximum_number_nsz_f128
196197 // Need at least LLVM 22 for `min/maximumnum` to not crash LLVM.
197- if crate :: llvm_util :: get_version ( ) >= ( 22 , 0 , 0 ) =>
198+ if llvm_version >= ( 22 , 0 , 0 ) =>
198199 {
199200 let intrinsic_name = if name. as_str ( ) . starts_with ( "min" ) {
200201 "llvm.minimumnum"
@@ -420,7 +421,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
420421 }
421422
422423 // FIXME move into the branch below when LLVM 22 is the lowest version we support.
423- sym:: carryless_mul if crate :: llvm_util :: get_version ( ) >= ( 22 , 0 , 0 ) => {
424+ sym:: carryless_mul if llvm_version >= ( 22 , 0 , 0 ) => {
424425 let ty = args[ 0 ] . layout . ty ;
425426 if !ty. is_integral ( ) {
426427 tcx. dcx ( ) . emit_err ( InvalidMonomorphization :: BasicIntegerType {
@@ -620,6 +621,46 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
620621 return Ok ( ( ) ) ;
621622 }
622623
624+ sym:: gpu_launch_sized_workgroup_mem => {
625+ // Generate an anonymous global per call, with these properties:
626+ // 1. The global is in the address space for workgroup memory
627+ // 2. It is an `external` global
628+ // 3. It is correctly aligned for the pointee `T`
629+ // All instances of extern addrspace(gpu_workgroup) globals are merged in the LLVM backend.
630+ // The name is irrelevant.
631+ // See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared
632+ let name = if llvm_version < ( 23 , 0 , 0 ) && tcx. sess . target . arch == Arch :: Nvptx64 {
633+ // The auto-assigned name for extern shared globals in the nvptx backend does
634+ // not compile in ptxas. Workaround this issue by assigning a name.
635+ // Fixed in LLVM 23.
636+ "gpu_launch_sized_workgroup_mem"
637+ } else {
638+ ""
639+ } ;
640+ let global = self . declare_global_in_addrspace (
641+ name,
642+ self . type_array ( self . type_i8 ( ) , 0 ) ,
643+ AddressSpace :: GPU_WORKGROUP ,
644+ ) ;
645+ let ty:: RawPtr ( inner_ty, _) = result. layout . ty . kind ( ) else { unreachable ! ( ) } ;
646+ // The alignment of the global is used to specify the *minimum* alignment that
647+ // must be obeyed by the GPU runtime.
648+ // When multiple of these global variables are used by a kernel, the maximum alignment is taken.
649+ // See https://github.com/llvm/llvm-project/blob/a271d07488a85ce677674bbe8101b10efff58c95/llvm/lib/Target/AMDGPU/AMDGPULowerModuleLDSPass.cpp#L821
650+ let alignment = self . align_of ( * inner_ty) . bytes ( ) as u32 ;
651+ unsafe {
652+ // FIXME Workaround the above issue by taking maximum alignment if the global existed
653+ if tcx. sess . target . arch == Arch :: Nvptx64 {
654+ if alignment > llvm:: LLVMGetAlignment ( global) {
655+ llvm:: LLVMSetAlignment ( global, alignment) ;
656+ }
657+ } else {
658+ llvm:: LLVMSetAlignment ( global, alignment) ;
659+ }
660+ }
661+ self . cx ( ) . const_pointercast ( global, self . type_ptr ( ) )
662+ }
663+
623664 sym:: amdgpu_dispatch_ptr => {
624665 let val = self . call_intrinsic ( "llvm.amdgcn.dispatch.ptr" , & [ ] , & [ ] ) ;
625666 // Relying on `LLVMBuildPointerCast` to produce an addrspacecast
0 commit comments