1+ use core:: mem:: MaybeUninit ;
12use cuda_std:: address_space;
23use cuda_std:: kernel;
34use cuda_std:: thread;
@@ -38,11 +39,19 @@ pub unsafe fn gemm_tiled(
3839 beta : f32 ,
3940) {
4041 const TILE_SIZE : usize = 16 ;
42+ const TILE_SIZE_2D : usize = TILE_SIZE * TILE_SIZE ;
4143
44+ // Shared GPU memory is modelled with `#[address_space(shared)] static mut`. Unlike normal
45+ // `static mut`, it is not initialized, and only exists for the duration of the kernel's
46+ // (multi-)execution. Because it is not initialized, it must be marked with `MaybeUninit`,
47+ // written with `write` (in unsafe blocks because writing a `static mut` is unsafe), and
48+ // subsequently read with `assume_init`.
4249 #[ address_space( shared) ]
43- static mut TILE_A : [ f32 ; TILE_SIZE * TILE_SIZE ] = [ 0. ; TILE_SIZE * TILE_SIZE ] ;
50+ static mut TILE_A : [ MaybeUninit < f32 > ; TILE_SIZE_2D ] =
51+ [ const { MaybeUninit :: uninit ( ) } ; TILE_SIZE_2D ] ;
4452 #[ address_space( shared) ]
45- static mut TILE_B : [ f32 ; TILE_SIZE * TILE_SIZE ] = [ 0. ; TILE_SIZE * TILE_SIZE ] ;
53+ static mut TILE_B : [ MaybeUninit < f32 > ; TILE_SIZE_2D ] =
54+ [ const { MaybeUninit :: uninit ( ) } ; TILE_SIZE_2D ] ;
4655
4756 // Thread indices within the block.
4857 let tx = thread:: thread_idx_x ( ) as usize ;
@@ -57,20 +66,30 @@ pub unsafe fn gemm_tiled(
5766 for kk in ( 0 ..k) . step_by ( TILE_SIZE ) {
5867 // Collaborative loading of tiles into shared memory.
5968 if row < m && ( kk + tx) < k {
60- unsafe { TILE_A [ ty * TILE_SIZE + tx] = mat_a[ row * k + ( kk + tx) ] } ;
69+ unsafe {
70+ TILE_A [ ty * TILE_SIZE + tx] . write ( mat_a[ row * k + ( kk + tx) ] ) ;
71+ }
6172 } else {
62- unsafe { TILE_A [ ty * TILE_SIZE + tx] = 0.0f32 } ;
73+ unsafe {
74+ TILE_A [ ty * TILE_SIZE + tx] . write ( 0.0f32 ) ;
75+ }
6376 }
6477 if col < n && ( kk + ty) < k {
65- unsafe { TILE_B [ ty * TILE_SIZE + tx] = mat_b[ ( kk + ty) * n + col] } ;
78+ unsafe {
79+ TILE_B [ ty * TILE_SIZE + tx] . write ( mat_b[ ( kk + ty) * n + col] ) ;
80+ }
6681 } else {
67- unsafe { TILE_B [ ty * TILE_SIZE + tx] = 0.0f32 } ;
82+ unsafe {
83+ TILE_B [ ty * TILE_SIZE + tx] . write ( 0.0f32 ) ;
84+ }
6885 }
6986 thread:: sync_threads ( ) ;
7087
7188 // Perform the computation on the tile.
7289 for i in 0 ..TILE_SIZE {
73- sum += unsafe { TILE_A [ ty * TILE_SIZE + i] * TILE_B [ i * TILE_SIZE + tx] } ;
90+ sum += unsafe {
91+ TILE_A [ ty * TILE_SIZE + i] . assume_init ( ) * TILE_B [ i * TILE_SIZE + tx] . assume_init ( )
92+ } ;
7493 }
7594 thread:: sync_threads ( ) ;
7695 }
0 commit comments