Skip to content

Commit 691b838

Browse files
Rollup merge of rust-lang#151640 - ZuseZ4:cleanup-datatransfer, r=nnethercote
Cleanup offload datatransfer There are 3 steps to run code on a GPU: Copy data from the host to the device, launch the kernel, and move it back. At the moment, we have a single variable describing the memory handling to do in each step, but that makes it hard for LLVM's opt pass to understand what's going on. We therefore split it into three variables, each only including the bits relevant for the corresponding stage. cc @jdoerfert @kevinsala r? compiler
2 parents 1ef0803 + 6de0591 commit 691b838

3 files changed

Lines changed: 92 additions & 53 deletions

File tree

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
use std::ffi::CString;
22

3+
use bitflags::Flags;
34
use llvm::Linkage::*;
45
use rustc_abi::Align;
56
use rustc_codegen_ssa::common::TypeKind;
67
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
78
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
89
use rustc_middle::bug;
9-
use rustc_middle::ty::offload_meta::OffloadMetadata;
10+
use rustc_middle::ty::offload_meta::{MappingFlags, OffloadMetadata};
1011

1112
use crate::builder::Builder;
1213
use crate::common::CodegenCx;
@@ -28,10 +29,6 @@ pub(crate) struct OffloadGlobals<'ll> {
2829
pub mapper_fn_ty: &'ll llvm::Type,
2930

3031
pub ident_t_global: &'ll llvm::Value,
31-
32-
// FIXME(offload): Drop this, once we fully automated our offload compilation pipeline, since
33-
// LLVM will initialize them for us if it sees gpu kernels being registered.
34-
pub init_rtls: &'ll llvm::Value,
3532
}
3633

3734
impl<'ll> OffloadGlobals<'ll> {
@@ -42,9 +39,6 @@ impl<'ll> OffloadGlobals<'ll> {
4239
let (begin_mapper, _, end_mapper, mapper_fn_ty) = gen_tgt_data_mappers(cx);
4340
let ident_t_global = generate_at_one(cx);
4441

45-
let init_ty = cx.type_func(&[], cx.type_void());
46-
let init_rtls = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
47-
4842
// We want LLVM's openmp-opt pass to pick up and optimize this module, since it covers both
4943
// openmp and offload optimizations.
5044
llvm::add_module_flag_u32(cx.llmod(), llvm::ModuleFlagMergeBehavior::Max, "openmp", 51);
@@ -58,7 +52,6 @@ impl<'ll> OffloadGlobals<'ll> {
5852
end_mapper,
5953
mapper_fn_ty,
6054
ident_t_global,
61-
init_rtls,
6255
}
6356
}
6457
}
@@ -91,6 +84,11 @@ pub(crate) fn register_offload<'ll>(cx: &CodegenCx<'ll, '_>) {
9184
let atexit = cx.type_func(&[cx.type_ptr()], cx.type_i32());
9285
let atexit_fn = declare_offload_fn(cx, "atexit", atexit);
9386

87+
// FIXME(offload): Drop this, once we fully automated our offload compilation pipeline, since
88+
// LLVM will initialize them for us if it sees gpu kernels being registered.
89+
let init_ty = cx.type_func(&[], cx.type_void());
90+
let init_rtls = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
91+
9492
let desc_ty = cx.type_func(&[], cx.type_void());
9593
let reg_name = ".omp_offloading.descriptor_reg";
9694
let unreg_name = ".omp_offloading.descriptor_unreg";
@@ -104,12 +102,14 @@ pub(crate) fn register_offload<'ll>(cx: &CodegenCx<'ll, '_>) {
104102
// define internal void @.omp_offloading.descriptor_reg() section ".text.startup" {
105103
// entry:
106104
// call void @__tgt_register_lib(ptr @.omp_offloading.descriptor)
105+
// call void @__tgt_init_all_rtls()
107106
// %0 = call i32 @atexit(ptr @.omp_offloading.descriptor_unreg)
108107
// ret void
109108
// }
110109
let bb = Builder::append_block(cx, desc_reg_fn, "entry");
111110
let mut a = Builder::build(cx, bb);
112111
a.call(reg_lib_decl, None, None, register_lib, &[omp_descriptor], None, None);
112+
a.call(init_ty, None, None, init_rtls, &[], None, None);
113113
a.call(atexit, None, None, atexit_fn, &[desc_unreg_fn], None, None);
114114
a.ret_void();
115115

@@ -345,7 +345,9 @@ impl KernelArgsTy {
345345
#[derive(Copy, Clone)]
346346
pub(crate) struct OffloadKernelGlobals<'ll> {
347347
pub offload_sizes: &'ll llvm::Value,
348-
pub memtransfer_types: &'ll llvm::Value,
348+
pub memtransfer_begin: &'ll llvm::Value,
349+
pub memtransfer_kernel: &'ll llvm::Value,
350+
pub memtransfer_end: &'ll llvm::Value,
349351
pub region_id: &'ll llvm::Value,
350352
}
351353

@@ -423,18 +425,38 @@ pub(crate) fn gen_define_handling<'ll>(
423425

424426
let offload_entry_ty = offload_globals.offload_entry_ty;
425427

426-
// FIXME(Sa4dUs): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary
427428
let (sizes, transfer): (Vec<_>, Vec<_>) =
428-
metadata.iter().map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip();
429+
metadata.iter().map(|m| (m.payload_size, m.mode)).unzip();
430+
// Our begin mapper should only see simplified information about which args have to be
431+
// transferred to the device, the end mapper only about which args should be transferred back.
432+
// Any information beyond that makes it harder for LLVM's opt pass to evaluate whether it can
433+
// safely move (=optimize) the LLVM-IR location of this data transfer. Only the mapping types
434+
// mentioned below are handled, so make sure that we don't generate any other ones.
435+
let handled_mappings = MappingFlags::TO
436+
| MappingFlags::FROM
437+
| MappingFlags::TARGET_PARAM
438+
| MappingFlags::LITERAL
439+
| MappingFlags::IMPLICIT;
440+
for arg in &transfer {
441+
debug_assert!(!arg.contains_unknown_bits());
442+
debug_assert!(handled_mappings.contains(*arg));
443+
}
444+
445+
let valid_begin_mappings = MappingFlags::TO | MappingFlags::LITERAL | MappingFlags::IMPLICIT;
446+
let transfer_to: Vec<u64> =
447+
transfer.iter().map(|m| m.intersection(valid_begin_mappings).bits()).collect();
448+
let transfer_from: Vec<u64> =
449+
transfer.iter().map(|m| m.intersection(MappingFlags::FROM).bits()).collect();
450+
// FIXME(offload): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary
451+
let transfer_kernel = vec![MappingFlags::TARGET_PARAM.bits(); transfer_to.len()];
429452

430453
let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &sizes);
431-
// Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
432-
// or both to and from the gpu (=3). Other values shouldn't affect us for now.
433-
// A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
434-
// will be 2. For now, everything is 3, until we have our frontend set up.
435-
// 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later).
436-
let memtransfer_types =
437-
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &transfer);
454+
let memtransfer_begin =
455+
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}.begin"), &transfer_to);
456+
let memtransfer_kernel =
457+
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}.kernel"), &transfer_kernel);
458+
let memtransfer_end =
459+
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}.end"), &transfer_from);
438460

439461
// Next: For each function, generate these three entries. A weak constant,
440462
// the llvm.rodata entry name, and the llvm_offload_entries value
@@ -469,7 +491,13 @@ pub(crate) fn gen_define_handling<'ll>(
469491

470492
cx.add_compiler_used_global(offload_entry);
471493

472-
let result = OffloadKernelGlobals { offload_sizes, memtransfer_types, region_id };
494+
let result = OffloadKernelGlobals {
495+
offload_sizes,
496+
memtransfer_begin,
497+
memtransfer_kernel,
498+
memtransfer_end,
499+
region_id,
500+
};
473501

474502
// FIXME(Sa4dUs): use this global for constant offload sizes
475503
cx.add_compiler_used_global(result.offload_sizes);
@@ -535,7 +563,13 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
535563
offload_dims: &OffloadKernelDims<'ll>,
536564
) {
537565
let cx = builder.cx;
538-
let OffloadKernelGlobals { memtransfer_types, region_id, .. } = offload_data;
566+
let OffloadKernelGlobals {
567+
memtransfer_begin,
568+
memtransfer_kernel,
569+
memtransfer_end,
570+
region_id,
571+
..
572+
} = offload_data;
539573
let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
540574
offload_dims;
541575

@@ -608,12 +642,6 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
608642
geps.push(gep);
609643
}
610644

611-
let init_ty = cx.type_func(&[], cx.type_void());
612-
let init_rtls_decl = offload_globals.init_rtls;
613-
614-
// call void @__tgt_init_all_rtls()
615-
builder.call(init_ty, None, None, init_rtls_decl, &[], None, None);
616-
617645
for i in 0..num_args {
618646
let idx = cx.get_const_i32(i);
619647
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
@@ -668,14 +696,14 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
668696
generate_mapper_call(
669697
builder,
670698
geps,
671-
memtransfer_types,
699+
memtransfer_begin,
672700
begin_mapper_decl,
673701
fn_ty,
674702
num_args,
675703
s_ident_t,
676704
);
677705
let values =
678-
KernelArgsTy::new(&cx, num_args, memtransfer_types, geps, workgroup_dims, thread_dims);
706+
KernelArgsTy::new(&cx, num_args, memtransfer_kernel, geps, workgroup_dims, thread_dims);
679707

680708
// Step 3)
681709
// Here we fill the KernelArgsTy, see the documentation above
@@ -701,7 +729,7 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
701729
generate_mapper_call(
702730
builder,
703731
geps,
704-
memtransfer_types,
732+
memtransfer_end,
705733
end_mapper_decl,
706734
fn_ty,
707735
num_args,

tests/codegen-llvm/gpu_offload/control_flow.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
// CHECK: br label %bb3
2020
// CHECK-NOT define
2121
// CHECK: bb3
22-
// CHECK: call void @__tgt_target_data_begin_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.foo, ptr null, ptr null)
22+
// CHECK: call void @__tgt_target_data_begin_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.foo.begin, ptr null, ptr null)
2323
// CHECK: %10 = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 256, i32 32, ptr nonnull @.foo.region_id, ptr nonnull %kernel_args)
24-
// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.foo, ptr null, ptr null)
24+
// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.foo.end, ptr null, ptr null)
2525
#[unsafe(no_mangle)]
2626
unsafe fn main() {
2727
let A = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0];

tests/codegen-llvm/gpu_offload/gpu_host.rs

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,20 @@
1414
#[unsafe(no_mangle)]
1515
fn main() {
1616
let mut x = [3.0; 256];
17-
kernel_1(&mut x);
17+
let y = [1.0; 256];
18+
kernel_1(&mut x, &y);
1819
core::hint::black_box(&x);
20+
core::hint::black_box(&y);
1921
}
2022

21-
pub fn kernel_1(x: &mut [f32; 256]) {
22-
core::intrinsics::offload(kernel_1, [256, 1, 1], [32, 1, 1], (x,))
23+
pub fn kernel_1(x: &mut [f32; 256], y: &[f32; 256]) {
24+
core::intrinsics::offload(_kernel_1, [256, 1, 1], [32, 1, 1], (x, y))
2325
}
2426

25-
#[unsafe(no_mangle)]
2627
#[inline(never)]
27-
pub fn _kernel_1(x: &mut [f32; 256]) {
28+
pub fn _kernel_1(x: &mut [f32; 256], y: &[f32; 256]) {
2829
for i in 0..256 {
29-
x[i] = 21.0;
30+
x[i] = 21.0 + y[i];
3031
}
3132
}
3233

@@ -39,8 +40,10 @@ pub fn _kernel_1(x: &mut [f32; 256]) {
3940

4041
// CHECK-DAG: @.omp_offloading.descriptor = internal constant { i32, ptr, ptr, ptr } zeroinitializer
4142
// CHECK-DAG: @llvm.global_ctors = appending constant [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 101, ptr @.omp_offloading.descriptor_reg, ptr null }]
42-
// CHECK-DAG: @.offload_sizes.[[K:[^ ]*kernel_1]] = private unnamed_addr constant [1 x i64] [i64 1024]
43-
// CHECK-DAG: @.offload_maptypes.[[K]] = private unnamed_addr constant [1 x i64] [i64 35]
43+
// CHECK-DAG: @.offload_sizes.[[K:[^ ]*kernel_1]] = private unnamed_addr constant [2 x i64] [i64 1024, i64 1024]
44+
// CHECK-DAG: @.offload_maptypes.[[K]].begin = private unnamed_addr constant [2 x i64] [i64 1, i64 1]
45+
// CHECK-DAG: @.offload_maptypes.[[K]].kernel = private unnamed_addr constant [2 x i64] [i64 32, i64 32]
46+
// CHECK-DAG: @.offload_maptypes.[[K]].end = private unnamed_addr constant [2 x i64] [i64 2, i64 0]
4447
// CHECK-DAG: @.[[K]].region_id = internal constant i8 0
4548
// CHECK-DAG: @.offloading.entry_name.[[K]] = internal unnamed_addr constant [{{[0-9]+}} x i8] c"[[K]]{{\\00}}", section ".llvm.rodata.offloading", align 1
4649
// CHECK-DAG: @.offloading.entry.[[K]] = internal constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.[[K]].region_id, ptr @.offloading.entry_name.[[K]], i64 0, i64 0, ptr null }, section "llvm_offload_entries", align 8
@@ -49,28 +52,35 @@ pub fn _kernel_1(x: &mut [f32; 256]) {
4952

5053
// CHECK-LABEL: define{{( dso_local)?}} void @main()
5154
// CHECK-NEXT: start:
52-
// CHECK-NEXT: %0 = alloca [8 x i8], align 8
53-
// CHECK-NEXT: %x = alloca [1024 x i8], align 16
54-
// CHECK-NEXT: %.offload_baseptrs = alloca [1 x ptr], align 8
55-
// CHECK-NEXT: %.offload_ptrs = alloca [1 x ptr], align 8
56-
// CHECK-NEXT: %.offload_sizes = alloca [1 x i64], align 8
55+
// CHECK-NEXT: %0 = alloca [8 x i8], align 8
56+
// CHECK-NEXT: %1 = alloca [8 x i8], align 8
57+
// CHECK-NEXT: %y = alloca [1024 x i8], align 16
58+
// CHECK-NEXT: %x = alloca [1024 x i8], align 16
59+
// CHECK-NEXT: %.offload_baseptrs = alloca [2 x ptr], align 8
60+
// CHECK-NEXT: %.offload_ptrs = alloca [2 x ptr], align 8
61+
// CHECK-NEXT: %.offload_sizes = alloca [2 x i64], align 8
5762
// CHECK-NEXT: %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
58-
// CHECK: call void @__tgt_init_all_rtls()
59-
// CHECK-NEXT: store ptr %x, ptr %.offload_baseptrs, align 8
63+
// CHECK: store ptr %x, ptr %.offload_baseptrs, align 8
6064
// CHECK-NEXT: store ptr %x, ptr %.offload_ptrs, align 8
6165
// CHECK-NEXT: store i64 1024, ptr %.offload_sizes, align 8
62-
// CHECK-NEXT: call void @__tgt_target_data_begin_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.[[K]], ptr null, ptr null)
66+
// CHECK-NEXT: [[BPTRS_1:%.*]] = getelementptr inbounds nuw i8, ptr %.offload_baseptrs, i64 8
67+
// CHECK-NEXT: store ptr %y, ptr [[BPTRS_1]], align 8
68+
// CHECK-NEXT: [[PTRS_1:%.*]] = getelementptr inbounds nuw i8, ptr %.offload_ptrs, i64 8
69+
// CHECK-NEXT: store ptr %y, ptr [[PTRS_1]], align 8
70+
// CHECK-NEXT: [[SIZES_1:%.*]] = getelementptr inbounds nuw i8, ptr %.offload_sizes, i64 8
71+
// CHECK-NEXT: store i64 1024, ptr [[SIZES_1]], align 8
72+
// CHECK-NEXT: call void @__tgt_target_data_begin_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 2, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.[[K]].begin, ptr null, ptr null)
6373
// CHECK-NEXT: store i32 3, ptr %kernel_args, align 8
6474
// CHECK-NEXT: [[P4:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 4
65-
// CHECK-NEXT: store i32 1, ptr [[P4]], align 4
75+
// CHECK-NEXT: store i32 2, ptr [[P4]], align 4
6676
// CHECK-NEXT: [[P8:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 8
6777
// CHECK-NEXT: store ptr %.offload_baseptrs, ptr [[P8]], align 8
6878
// CHECK-NEXT: [[P16:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 16
6979
// CHECK-NEXT: store ptr %.offload_ptrs, ptr [[P16]], align 8
7080
// CHECK-NEXT: [[P24:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 24
7181
// CHECK-NEXT: store ptr %.offload_sizes, ptr [[P24]], align 8
7282
// CHECK-NEXT: [[P32:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 32
73-
// CHECK-NEXT: store ptr @.offload_maptypes.[[K]], ptr [[P32]], align 8
83+
// CHECK-NEXT: store ptr @.offload_maptypes.[[K]].kernel, ptr [[P32]], align 8
7484
// CHECK-NEXT: [[P40:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 40
7585
// CHECK-NEXT: [[P72:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 72
7686
// CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 8 dereferenceable(32) [[P40]], i8 0, i64 32, i1 false)
@@ -81,9 +91,9 @@ pub fn _kernel_1(x: &mut [f32; 256]) {
8191
// CHECK-NEXT: store i32 1, ptr [[P92]], align 4
8292
// CHECK-NEXT: [[P96:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 96
8393
// CHECK-NEXT: store i32 0, ptr [[P96]], align 8
84-
// CHECK-NEXT: {{%[^ ]+}} = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 256, i32 32, ptr nonnull @.[[K]].region_id, ptr nonnull %kernel_args)
85-
// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.[[K]], ptr null, ptr null)
86-
// CHECK: ret void
94+
// CHECK-NEXT: [[TGT_RET:%.*]] = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 256, i32 32, ptr nonnull @.[[K]].region_id, ptr nonnull %kernel_args)
95+
// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 2, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.[[K]].end, ptr null, ptr null)
96+
// CHECK: ret void
8797
// CHECK-NEXT: }
8898

8999
// CHECK: declare void @__tgt_register_lib(ptr) local_unnamed_addr
@@ -92,6 +102,7 @@ pub fn _kernel_1(x: &mut [f32; 256]) {
92102
// CHECK-LABEL: define internal void @.omp_offloading.descriptor_reg() section ".text.startup" {
93103
// CHECK-NEXT: entry:
94104
// CHECK-NEXT: call void @__tgt_register_lib(ptr nonnull @.omp_offloading.descriptor)
105+
// CHECK-NEXT: call void @__tgt_init_all_rtls()
95106
// CHECK-NEXT: %0 = {{tail }}call i32 @atexit(ptr nonnull @.omp_offloading.descriptor_unreg)
96107
// CHECK-NEXT: ret void
97108
// CHECK-NEXT: }

0 commit comments

Comments
 (0)