@@ -3,14 +3,13 @@ use std::ffi::CString;
33use bitflags:: Flags ;
44use llvm:: Linkage :: * ;
55use rustc_abi:: Align ;
6- use rustc_codegen_ssa:: MemFlags ;
7- use rustc_codegen_ssa:: common:: TypeKind ;
86use rustc_codegen_ssa:: mir:: operand:: { OperandRef , OperandValue } ;
97use rustc_codegen_ssa:: traits:: { BaseTypeCodegenMethods , BuilderMethods } ;
108use rustc_middle:: bug;
119use rustc_middle:: ty:: offload_meta:: { MappingFlags , OffloadMetadata , OffloadSize } ;
1210
1311use crate :: builder:: Builder ;
12+ use crate :: builder:: gpu_helper:: * ;
1413use crate :: common:: CodegenCx ;
1514use crate :: llvm:: AttributePlace :: Function ;
1615use crate :: llvm:: { self , Linkage , Type , Value } ;
@@ -534,36 +533,6 @@ fn declare_offload_fn<'ll>(
534533 )
535534}
536535
537- pub ( crate ) fn scalar_width < ' ll > ( cx : & ' ll SimpleCx < ' _ > , ty : & ' ll Type ) -> u64 {
538- match cx. type_kind ( ty) {
539- TypeKind :: Half
540- | TypeKind :: Float
541- | TypeKind :: Double
542- | TypeKind :: X86_FP80
543- | TypeKind :: FP128
544- | TypeKind :: PPC_FP128 => cx. float_width ( ty) as u64 ,
545- TypeKind :: Integer => cx. int_width ( ty) ,
546- other => bug ! ( "scalar_width was called on a non scalar type {other:?}" ) ,
547- }
548- }
549-
550- fn get_runtime_size < ' ll , ' tcx > (
551- builder : & mut Builder < ' _ , ' ll , ' tcx > ,
552- args : & [ & ' ll Value ] ,
553- index : usize ,
554- meta : & OffloadMetadata ,
555- ) -> & ' ll Value {
556- match meta. payload_size {
557- OffloadSize :: Slice { element_size } => {
558- let length_idx = index + 1 ;
559- let length = args[ length_idx] ;
560- let length_i64 = builder. intcast ( length, builder. cx . type_i64 ( ) , false ) ;
561- builder. mul ( length_i64, builder. cx . get_const_i64 ( element_size) )
562- }
563- _ => bug ! ( "unexpected offload size {:?}" , meta. payload_size) ,
564- }
565- }
566-
567536// For each kernel *call*, we now use some of our previous declared globals to move data to and from
568537// the gpu. For now, we only handle the data transfer part of it.
569538// If two consecutive kernels use the same memory, we still move it to the host and back to the gpu.
@@ -613,136 +582,21 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
613582 let end_mapper_decl = offload_globals. end_mapper ;
614583 let fn_ty = offload_globals. mapper_fn_ty ;
615584
585+ let ( ty, ty2, a1, a2, a4) =
586+ preper_datatransfers ( builder, args, types, offload_sizes, metadata, has_dynamic) ;
616587 let num_args = types. len ( ) as u64 ;
617- let bb = builder . llbb ( ) ;
588+ assert_eq ! ( num_args as usize , args . len ( ) ) ;
618589
619- // Step 0)
590+ let bb = builder . llbb ( ) ;
620591 unsafe {
621592 llvm:: LLVMRustPositionBuilderPastAllocas ( & builder. llbuilder , builder. llfn ( ) ) ;
622593 }
623-
624- let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
625- // Baseptr are just the input pointer to the kernel, stored in a local alloca
626- let a1 = builder. direct_alloca ( ty, Align :: EIGHT , ".offload_baseptrs" ) ;
627- // Ptrs are the result of a gep into the baseptr, at least for our trivial types.
628- let a2 = builder. direct_alloca ( ty, Align :: EIGHT , ".offload_ptrs" ) ;
629- // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
630- let ty2 = cx. type_array ( cx. type_i64 ( ) , num_args) ;
631-
632- let a4 = if has_dynamic {
633- let alloc = builder. direct_alloca ( ty2, Align :: EIGHT , ".offload_sizes" ) ;
634-
635- builder. memcpy (
636- alloc,
637- Align :: EIGHT ,
638- offload_sizes,
639- Align :: EIGHT ,
640- cx. get_const_i64 ( 8 * args. len ( ) as u64 ) ,
641- MemFlags :: empty ( ) ,
642- None ,
643- ) ;
644-
645- alloc
646- } else {
647- offload_sizes
648- } ;
649-
650594 //%kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
651595 let a5 = builder. direct_alloca ( tgt_kernel_decl, Align :: EIGHT , "kernel_args" ) ;
652-
653- // Step 1)
654596 unsafe {
655597 llvm:: LLVMPositionBuilderAtEnd ( & builder. llbuilder , bb) ;
656598 }
657599
658- // Now we allocate once per function param, a copy to be passed to one of our maps.
659- let mut vals = vec ! [ ] ;
660- let mut geps = vec ! [ ] ;
661- let i32_0 = cx. get_const_i32 ( 0 ) ;
662- for & v in args {
663- let ty = cx. val_ty ( v) ;
664- let ty_kind = cx. type_kind ( ty) ;
665- let ( base_val, gep_base) = match ty_kind {
666- TypeKind :: Pointer => ( v, v) ,
667- TypeKind :: Half | TypeKind :: Float | TypeKind :: Double | TypeKind :: Integer => {
668- // FIXME(Sa4dUs): check for `f128` support, latest NVIDIA cards support it
669- let num_bits = scalar_width ( cx, ty) ;
670-
671- let bb = builder. llbb ( ) ;
672- unsafe {
673- llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , builder. llfn ( ) ) ;
674- }
675- let addr = builder. direct_alloca ( cx. type_i64 ( ) , Align :: EIGHT , "addr" ) ;
676- unsafe {
677- llvm:: LLVMPositionBuilderAtEnd ( builder. llbuilder , bb) ;
678- }
679-
680- let cast = builder. bitcast ( v, cx. type_ix ( num_bits) ) ;
681- let value = builder. zext ( cast, cx. type_i64 ( ) ) ;
682- builder. store ( value, addr, Align :: EIGHT ) ;
683- ( value, addr)
684- }
685- other => bug ! ( "offload does not support {other:?}" ) ,
686- } ;
687-
688- let gep = builder. inbounds_gep ( cx. type_f32 ( ) , gep_base, & [ i32_0] ) ;
689-
690- vals. push ( base_val) ;
691- geps. push ( gep) ;
692- }
693-
694- for i in 0 ..num_args {
695- let idx = cx. get_const_i32 ( i) ;
696- let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, idx] ) ;
697- builder. store ( vals[ i as usize ] , gep1, Align :: EIGHT ) ;
698- let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, idx] ) ;
699- builder. store ( geps[ i as usize ] , gep2, Align :: EIGHT ) ;
700-
701- if !matches ! ( metadata[ i as usize ] . payload_size, OffloadSize :: Static ( _) ) {
702- let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, idx] ) ;
703- let size_val = get_runtime_size ( builder, args, i as usize , & metadata[ i as usize ] ) ;
704- builder. store ( size_val, gep3, Align :: EIGHT ) ;
705- }
706- }
707-
708- // For now we have a very simplistic indexing scheme into our
709- // offload_{baseptrs,ptrs,sizes}. We will probably improve this along with our gpu frontend pr.
710- fn get_geps < ' ll , ' tcx > (
711- builder : & mut Builder < ' _ , ' ll , ' tcx > ,
712- ty : & ' ll Type ,
713- ty2 : & ' ll Type ,
714- a1 : & ' ll Value ,
715- a2 : & ' ll Value ,
716- a4 : & ' ll Value ,
717- is_dynamic : bool ,
718- ) -> [ & ' ll Value ; 3 ] {
719- let cx = builder. cx ;
720- let i32_0 = cx. get_const_i32 ( 0 ) ;
721-
722- let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
723- let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
724- let gep3 = if is_dynamic { builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) } else { a4 } ;
725- [ gep1, gep2, gep3]
726- }
727-
728- fn generate_mapper_call < ' ll , ' tcx > (
729- builder : & mut Builder < ' _ , ' ll , ' tcx > ,
730- geps : [ & ' ll Value ; 3 ] ,
731- o_type : & ' ll Value ,
732- fn_to_call : & ' ll Value ,
733- fn_ty : & ' ll Type ,
734- num_args : u64 ,
735- s_ident_t : & ' ll Value ,
736- ) {
737- let cx = builder. cx ;
738- let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
739- let i64_max = cx. get_const_i64 ( u64:: MAX ) ;
740- let num_args = cx. get_const_i32 ( num_args) ;
741- let args =
742- vec ! [ s_ident_t, i64_max, num_args, geps[ 0 ] , geps[ 1 ] , geps[ 2 ] , o_type, nullptr, nullptr] ;
743- builder. call ( fn_ty, None , None , fn_to_call, & args, None , None ) ;
744- }
745-
746600 // Step 2)
747601 let s_ident_t = offload_globals. ident_t_global ;
748602 let geps = get_geps ( builder, ty, ty2, a1, a2, a4, has_dynamic) ;
@@ -767,6 +621,7 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
767621
768622 // Step 3)
769623 // Here we fill the KernelArgsTy, see the documentation above
624+ let i32_0 = cx. get_const_i32 ( 0 ) ;
770625 for ( i, value) in values. iter ( ) . enumerate ( ) {
771626 let ptr = builder. inbounds_gep ( tgt_kernel_decl, a5, & [ i32_0, cx. get_const_i32 ( i as u64 ) ] ) ;
772627 let name = std:: ffi:: CString :: new ( value. 1 ) . unwrap ( ) ;
0 commit comments