Skip to content

Commit 515228d

Browse files
committed
XXX: CUDA 13 kind of working
1 parent 2623e21 commit 515228d

2 files changed

Lines changed: 27 additions & 7 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/cuda/gemm/kernels/src/gemm_tiled.rs

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use core::mem::MaybeUninit;
12
use cuda_std::address_space;
23
use cuda_std::kernel;
34
use cuda_std::thread;
@@ -38,11 +39,19 @@ pub unsafe fn gemm_tiled(
3839
beta: f32,
3940
) {
4041
const TILE_SIZE: usize = 16;
42+
const TILE_SIZE_2D: usize = TILE_SIZE * TILE_SIZE;
4143

44+
// Shared GPU memory is modelled with `#[address_space(shared)] static mut`. Unlike normal
45+
// `static mut`, it is not initialized, and only exists for the duration of the kernel's
46+
// (multi-)execution. Because it is not initialized, it must be marked with `MaybeUninit`,
47+
// written with `write` (in unsafe blocks because writing a `static mut` is unsafe), and
48+
// subsequently read with `assume_init`.
4249
#[address_space(shared)]
43-
static mut TILE_A: [f32; TILE_SIZE * TILE_SIZE] = [0.; TILE_SIZE * TILE_SIZE];
50+
static mut TILE_A: [MaybeUninit<f32>; TILE_SIZE_2D] =
51+
[const { MaybeUninit::uninit() }; TILE_SIZE_2D];
4452
#[address_space(shared)]
45-
static mut TILE_B: [f32; TILE_SIZE * TILE_SIZE] = [0.; TILE_SIZE * TILE_SIZE];
53+
static mut TILE_B: [MaybeUninit<f32>; TILE_SIZE_2D] =
54+
[const { MaybeUninit::uninit() }; TILE_SIZE_2D];
4655

4756
// Thread indices within the block.
4857
let tx = thread::thread_idx_x() as usize;
@@ -57,20 +66,30 @@ pub unsafe fn gemm_tiled(
5766
for kk in (0..k).step_by(TILE_SIZE) {
5867
// Collaborative loading of tiles into shared memory.
5968
if row < m && (kk + tx) < k {
60-
unsafe { TILE_A[ty * TILE_SIZE + tx] = mat_a[row * k + (kk + tx)] };
69+
unsafe {
70+
TILE_A[ty * TILE_SIZE + tx].write(mat_a[row * k + (kk + tx)]);
71+
}
6172
} else {
62-
unsafe { TILE_A[ty * TILE_SIZE + tx] = 0.0f32 };
73+
unsafe {
74+
TILE_A[ty * TILE_SIZE + tx].write(0.0f32);
75+
}
6376
}
6477
if col < n && (kk + ty) < k {
65-
unsafe { TILE_B[ty * TILE_SIZE + tx] = mat_b[(kk + ty) * n + col] };
78+
unsafe {
79+
TILE_B[ty * TILE_SIZE + tx].write(mat_b[(kk + ty) * n + col]);
80+
}
6681
} else {
67-
unsafe { TILE_B[ty * TILE_SIZE + tx] = 0.0f32 };
82+
unsafe {
83+
TILE_B[ty * TILE_SIZE + tx].write(0.0f32);
84+
}
6885
}
6986
thread::sync_threads();
7087

7188
// Perform the computation on the tile.
7289
for i in 0..TILE_SIZE {
73-
sum += unsafe { TILE_A[ty * TILE_SIZE + i] * TILE_B[i * TILE_SIZE + tx] };
90+
sum += unsafe {
91+
TILE_A[ty * TILE_SIZE + i].assume_init() * TILE_B[i * TILE_SIZE + tx].assume_init()
92+
};
7493
}
7594
thread::sync_threads();
7695
}

0 commit comments

Comments
 (0)