11use std:: ffi:: CString ;
22
3+ use bitflags:: Flags ;
34use llvm:: Linkage :: * ;
45use rustc_abi:: Align ;
56use rustc_codegen_ssa:: common:: TypeKind ;
67use rustc_codegen_ssa:: mir:: operand:: { OperandRef , OperandValue } ;
78use rustc_codegen_ssa:: traits:: { BaseTypeCodegenMethods , BuilderMethods } ;
89use rustc_middle:: bug;
9- use rustc_middle:: ty:: offload_meta:: OffloadMetadata ;
10+ use rustc_middle:: ty:: offload_meta:: { MappingFlags , OffloadMetadata } ;
1011
1112use crate :: builder:: Builder ;
1213use crate :: common:: CodegenCx ;
@@ -28,10 +29,6 @@ pub(crate) struct OffloadGlobals<'ll> {
2829 pub mapper_fn_ty : & ' ll llvm:: Type ,
2930
3031 pub ident_t_global : & ' ll llvm:: Value ,
31-
32- // FIXME(offload): Drop this, once we fully automated our offload compilation pipeline, since
33- // LLVM will initialize them for us if it sees gpu kernels being registered.
34- pub init_rtls : & ' ll llvm:: Value ,
3532}
3633
3734impl < ' ll > OffloadGlobals < ' ll > {
@@ -42,9 +39,6 @@ impl<'ll> OffloadGlobals<'ll> {
4239 let ( begin_mapper, _, end_mapper, mapper_fn_ty) = gen_tgt_data_mappers ( cx) ;
4340 let ident_t_global = generate_at_one ( cx) ;
4441
45- let init_ty = cx. type_func ( & [ ] , cx. type_void ( ) ) ;
46- let init_rtls = declare_offload_fn ( cx, "__tgt_init_all_rtls" , init_ty) ;
47-
4842 // We want LLVM's openmp-opt pass to pick up and optimize this module, since it covers both
4943 // openmp and offload optimizations.
5044 llvm:: add_module_flag_u32 ( cx. llmod ( ) , llvm:: ModuleFlagMergeBehavior :: Max , "openmp" , 51 ) ;
@@ -58,7 +52,6 @@ impl<'ll> OffloadGlobals<'ll> {
5852 end_mapper,
5953 mapper_fn_ty,
6054 ident_t_global,
61- init_rtls,
6255 }
6356 }
6457}
@@ -91,6 +84,11 @@ pub(crate) fn register_offload<'ll>(cx: &CodegenCx<'ll, '_>) {
9184 let atexit = cx. type_func ( & [ cx. type_ptr ( ) ] , cx. type_i32 ( ) ) ;
9285 let atexit_fn = declare_offload_fn ( cx, "atexit" , atexit) ;
9386
87+ // FIXME(offload): Drop this, once we fully automated our offload compilation pipeline, since
88+ // LLVM will initialize them for us if it sees gpu kernels being registered.
89+ let init_ty = cx. type_func ( & [ ] , cx. type_void ( ) ) ;
90+ let init_rtls = declare_offload_fn ( cx, "__tgt_init_all_rtls" , init_ty) ;
91+
9492 let desc_ty = cx. type_func ( & [ ] , cx. type_void ( ) ) ;
9593 let reg_name = ".omp_offloading.descriptor_reg" ;
9694 let unreg_name = ".omp_offloading.descriptor_unreg" ;
@@ -104,12 +102,14 @@ pub(crate) fn register_offload<'ll>(cx: &CodegenCx<'ll, '_>) {
104102 // define internal void @.omp_offloading.descriptor_reg() section ".text.startup" {
105103 // entry:
106104 // call void @__tgt_register_lib(ptr @.omp_offloading.descriptor)
105+ // call void @__tgt_init_all_rtls()
107106 // %0 = call i32 @atexit(ptr @.omp_offloading.descriptor_unreg)
108107 // ret void
109108 // }
110109 let bb = Builder :: append_block ( cx, desc_reg_fn, "entry" ) ;
111110 let mut a = Builder :: build ( cx, bb) ;
112111 a. call ( reg_lib_decl, None , None , register_lib, & [ omp_descriptor] , None , None ) ;
112+ a. call ( init_ty, None , None , init_rtls, & [ ] , None , None ) ;
113113 a. call ( atexit, None , None , atexit_fn, & [ desc_unreg_fn] , None , None ) ;
114114 a. ret_void ( ) ;
115115
@@ -345,7 +345,9 @@ impl KernelArgsTy {
345345#[ derive( Copy , Clone ) ]
346346pub ( crate ) struct OffloadKernelGlobals < ' ll > {
347347 pub offload_sizes : & ' ll llvm:: Value ,
348- pub memtransfer_types : & ' ll llvm:: Value ,
348+ pub memtransfer_begin : & ' ll llvm:: Value ,
349+ pub memtransfer_kernel : & ' ll llvm:: Value ,
350+ pub memtransfer_end : & ' ll llvm:: Value ,
349351 pub region_id : & ' ll llvm:: Value ,
350352}
351353
@@ -423,18 +425,38 @@ pub(crate) fn gen_define_handling<'ll>(
423425
424426 let offload_entry_ty = offload_globals. offload_entry_ty ;
425427
426- // FIXME(Sa4dUs): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary
427428 let ( sizes, transfer) : ( Vec < _ > , Vec < _ > ) =
428- metadata. iter ( ) . map ( |m| ( m. payload_size , m. mode . bits ( ) | 0x20 ) ) . unzip ( ) ;
429+ metadata. iter ( ) . map ( |m| ( m. payload_size , m. mode ) ) . unzip ( ) ;
430+ // Our begin mapper should only see simplified information about which args have to be
431+ // transferred to the device, the end mapper only about which args should be transferred back.
432+ // Any information beyond that makes it harder for LLVM's opt pass to evaluate whether it can
433+ // safely move (=optimize) the LLVM-IR location of this data transfer. Only the mapping types
434+ // mentioned below are handled, so make sure that we don't generate any other ones.
435+ let handled_mappings = MappingFlags :: TO
436+ | MappingFlags :: FROM
437+ | MappingFlags :: TARGET_PARAM
438+ | MappingFlags :: LITERAL
439+ | MappingFlags :: IMPLICIT ;
440+ for arg in & transfer {
441+ debug_assert ! ( !arg. contains_unknown_bits( ) ) ;
442+ debug_assert ! ( handled_mappings. contains( * arg) ) ;
443+ }
444+
445+ let valid_begin_mappings = MappingFlags :: TO | MappingFlags :: LITERAL | MappingFlags :: IMPLICIT ;
446+ let transfer_to: Vec < u64 > =
447+ transfer. iter ( ) . map ( |m| m. intersection ( valid_begin_mappings) . bits ( ) ) . collect ( ) ;
448+ let transfer_from: Vec < u64 > =
449+ transfer. iter ( ) . map ( |m| m. intersection ( MappingFlags :: FROM ) . bits ( ) ) . collect ( ) ;
450+ // FIXME(offload): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary
451+ let transfer_kernel = vec ! [ MappingFlags :: TARGET_PARAM . bits( ) ; transfer_to. len( ) ] ;
429452
430453 let offload_sizes = add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{symbol}" ) , & sizes) ;
431- // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
432- // or both to and from the gpu (=3). Other values shouldn't affect us for now.
433- // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
434- // will be 2. For now, everything is 3, until we have our frontend set up.
435- // 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later).
436- let memtransfer_types =
437- add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{symbol}" ) , & transfer) ;
454+ let memtransfer_begin =
455+ add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{symbol}.begin" ) , & transfer_to) ;
456+ let memtransfer_kernel =
457+ add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{symbol}.kernel" ) , & transfer_kernel) ;
458+ let memtransfer_end =
459+ add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{symbol}.end" ) , & transfer_from) ;
438460
439461 // Next: For each function, generate these three entries. A weak constant,
440462 // the llvm.rodata entry name, and the llvm_offload_entries value
@@ -469,7 +491,13 @@ pub(crate) fn gen_define_handling<'ll>(
469491
470492 cx. add_compiler_used_global ( offload_entry) ;
471493
472- let result = OffloadKernelGlobals { offload_sizes, memtransfer_types, region_id } ;
494+ let result = OffloadKernelGlobals {
495+ offload_sizes,
496+ memtransfer_begin,
497+ memtransfer_kernel,
498+ memtransfer_end,
499+ region_id,
500+ } ;
473501
474502 // FIXME(Sa4dUs): use this global for constant offload sizes
475503 cx. add_compiler_used_global ( result. offload_sizes ) ;
@@ -535,7 +563,13 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
535563 offload_dims : & OffloadKernelDims < ' ll > ,
536564) {
537565 let cx = builder. cx ;
538- let OffloadKernelGlobals { memtransfer_types, region_id, .. } = offload_data;
566+ let OffloadKernelGlobals {
567+ memtransfer_begin,
568+ memtransfer_kernel,
569+ memtransfer_end,
570+ region_id,
571+ ..
572+ } = offload_data;
539573 let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
540574 offload_dims;
541575
@@ -608,12 +642,6 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
608642 geps. push ( gep) ;
609643 }
610644
611- let init_ty = cx. type_func ( & [ ] , cx. type_void ( ) ) ;
612- let init_rtls_decl = offload_globals. init_rtls ;
613-
614- // call void @__tgt_init_all_rtls()
615- builder. call ( init_ty, None , None , init_rtls_decl, & [ ] , None , None ) ;
616-
617645 for i in 0 ..num_args {
618646 let idx = cx. get_const_i32 ( i) ;
619647 let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, idx] ) ;
@@ -668,14 +696,14 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
668696 generate_mapper_call (
669697 builder,
670698 geps,
671- memtransfer_types ,
699+ memtransfer_begin ,
672700 begin_mapper_decl,
673701 fn_ty,
674702 num_args,
675703 s_ident_t,
676704 ) ;
677705 let values =
678- KernelArgsTy :: new ( & cx, num_args, memtransfer_types , geps, workgroup_dims, thread_dims) ;
706+ KernelArgsTy :: new ( & cx, num_args, memtransfer_kernel , geps, workgroup_dims, thread_dims) ;
679707
680708 // Step 3)
681709 // Here we fill the KernelArgsTy, see the documentation above
@@ -701,7 +729,7 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
701729 generate_mapper_call (
702730 builder,
703731 geps,
704- memtransfer_types ,
732+ memtransfer_end ,
705733 end_mapper_decl,
706734 fn_ty,
707735 num_args,
0 commit comments