Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 136 additions & 29 deletions docs/src/api/kernel_programming.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,81 +71,126 @@ target = ROCArray(zeros(UInt32, bins))
## Wave Matrix Multiply Accumulate (WMMA)

Perform following computation `D = A ⋅ B + C`.
Currently only RDNA 3 is supported and following types:

### RDNA 3 (gfx1100-gfx1199)

Currently RDNA 3 supports the following types:
- `FP16 ⋅ FP16 + FP32 -> FP32`;
- `BFP16 ⋅ BFP16 + FP32 -> FP32`.

All WMMA functionality is in the `AMDGPU.Device.WMMA` submodule.
The tile dimensions are fixed at 16×16×16 (`WMMA.M`, `WMMA.N`, `WMMA.K`).
All WMMA functionality for RDNA 3 is in the `AMDGPU.Device.WMMA` submodule.
The tile dimensions are fixed at 16×16×16 (`WMMA_RDNA3.M`, `WMMA_RDNA3.N`, `WMMA_RDNA3.K`).

### RDNA 4 (gfx1200+)

RDNA 4 introduces a simplified VGPR layout for WMMA operations with the following improvements:
- Cleaner data distribution with no duplication (128-bit vs 256-bit in RDNA 3)
- Each lane handles 8 elements for A and B fragments (vs 16 with duplication in RDNA 3)
- Support for FP8 and BF8 types (requires LLVM 18+ and ROCm 6.0+)
- New intrinsic names with `_gfx12` suffix

All WMMA functionality for RDNA 4 is in the `AMDGPU.Device.WMMA_RDNA4` submodule.
The tile dimensions remain at 16×16×16 (`WMMA_RDNA4.M`, `WMMA_RDNA4.N`, `WMMA_RDNA4.K`).

**Supported types on RDNA 4:**
- `FP16 ⋅ FP16 + FP32 -> FP32`
- `BFP16 ⋅ BFP16 + FP32 -> FP32`
- `FP8 ⋅ FP8 + FP32 -> FP32` (experimental)
- `BF8 ⋅ BF8 + FP32 -> FP32` (experimental)

### Common Features

Both RDNA 3 and RDNA 4 support the following layout types:

### Layout types

Two layout types control how matrices are read from and written to memory:

- `WMMA.ColMajor` — column-major (Julia/Fortran) order: element `(row, col)` is at `ptr[col * stride + row]`.
- `WMMA.RowMajor` — row-major (C) order: element `(row, col)` is at `ptr[row * stride + col]`.
- `WMMA_RDNA3.ColMajor` / `WMMA_RDNA4.ColMajor` — column-major (Julia/Fortran) order: element `(row, col)` is at `ptr[col * stride + row]`.
- `WMMA_RDNA3.RowMajor` / `WMMA_RDNA4.RowMajor` — row-major (C) order: element `(row, col)` is at `ptr[row * stride + col]`.

### API

#### RDNA 3 API

```@docs
AMDGPU.Device.WMMA.Fragment
AMDGPU.Device.WMMA.fill_c
AMDGPU.Device.WMMA.load_a
AMDGPU.Device.WMMA.load_b
AMDGPU.Device.WMMA.load_c
AMDGPU.Device.WMMA.store_d
AMDGPU.Device.WMMA.mma
AMDGPU.Device.WMMA_RDNA3.Fragment
AMDGPU.Device.WMMA_RDNA3.fill_c
AMDGPU.Device.WMMA_RDNA3.load_a
AMDGPU.Device.WMMA_RDNA3.load_b
AMDGPU.Device.WMMA_RDNA3.load_c
AMDGPU.Device.WMMA_RDNA3.store_d
AMDGPU.Device.WMMA_RDNA3.mma
```

#### RDNA 4 API

```@docs
AMDGPU.Device.WMMA_RDNA4.Fragment
AMDGPU.Device.WMMA_RDNA4.fill_c
AMDGPU.Device.WMMA_RDNA4.load_a
AMDGPU.Device.WMMA_RDNA4.load_b
AMDGPU.Device.WMMA_RDNA4.load_c
AMDGPU.Device.WMMA_RDNA4.store_d
AMDGPU.Device.WMMA_RDNA4.mma
```

`load_c` and `store_d` accept pointer types `Float32`, `Float16`, and `BFloat16`.
When `T` is `Float16` or `BFloat16`, values are widened to `Float32` on load and
narrowed back on store, so the `FragmentC_F32` accumulator type is always `Float32`
regardless of the backing buffer type.

**Note:** For RDNA 4, the same behavior applies, but the underlying LLVM intrinsics
use the new `_gfx12` suffix and have a simplified VGPR layout.

### Example

Below is a matrix multiplication kernel using WMMA with column-major inputs.
Pass `WMMA.RowMajor` instead to load from row-major (C-style) buffers.
Pass `WMMA_RDNA3.RowMajor` instead to load from row-major (C-style) buffers.

```@example wmma-matmul
!!! note "Hardware Requirements"
WMMA instructions require RDNA 3 (gfx11) or newer GPUs. This code will only execute
successfully on compatible hardware with appropriate ROCm/LLVM support.

```@example
using AMDGPU
using AMDGPU.Device: WMMA
using AMDGPU.Device: WMMA_RDNA3

function wmma_kernel!(C, A::AbstractArray{T}, B, M::Int32, N::Int32, K::Int32, layout) where T
tile_row = (workgroupIdx().x - Int32(1)) * Int32(WMMA.M)
tile_col = (workgroupIdx().y - Int32(1)) * Int32(WMMA.N)
tile_row = (workgroupIdx().x - Int32(1)) * Int32(WMMA_RDNA3.M)
tile_col = (workgroupIdx().y - Int32(1)) * Int32(WMMA_RDNA3.N)

C_ptr = pointer(C)
A_ptr = pointer(A)
B_ptr = pointer(B)

c_frag = WMMA.fill_c(Float32, 0f0)
c_frag = WMMA_RDNA3.fill_c(Float32, 0f0)
k = Int32(0)
while k < K
a_ptr, a_stride = _a_tile(A_ptr, layout, tile_row, k, M, K, T)
b_ptr, b_stride = _b_tile(B_ptr, layout, tile_col, k, N, K, T)

a_frag = WMMA.load_a(a_ptr, a_stride, layout)
b_frag = WMMA.load_b(b_ptr, b_stride, layout)
c_frag = WMMA.mma(a_frag, b_frag, c_frag)
a_frag = WMMA_RDNA3.load_a(a_ptr, a_stride, layout)
b_frag = WMMA_RDNA3.load_b(b_ptr, b_stride, layout)
c_frag = WMMA_RDNA3.mma(a_frag, b_frag, c_frag)

k += Int32(WMMA.K)
k += Int32(WMMA_RDNA3.K)
end

c_ptr = C_ptr + (tile_col * M + tile_row) * Int32(sizeof(Float32))
WMMA.store_d(c_ptr, c_frag, M, WMMA.ColMajor)
WMMA_RDNA3.store_d(c_ptr, c_frag, M, WMMA_RDNA3.ColMajor)
return
end

# Tile pointer + stride helpers — dispatched on layout, DCE'd by the compiler.
_a_tile(ptr, ::Type{WMMA.ColMajor}, tile_row, k, M, K, ::Type{T}) where T =
_a_tile(ptr, ::Type{WMMA_RDNA3.ColMajor}, tile_row, k, M, K, ::Type{T}) where T =
ptr + (k * M + tile_row) * Int32(sizeof(T)), M
_a_tile(ptr, ::Type{WMMA.RowMajor}, tile_row, k, M, K, ::Type{T}) where T =
_a_tile(ptr, ::Type{WMMA_RDNA3.RowMajor}, tile_row, k, M, K, ::Type{T}) where T =
ptr + (tile_row * K + k) * Int32(sizeof(T)), K

_b_tile(ptr, ::Type{WMMA.ColMajor}, tile_col, k, N, K, ::Type{T}) where T =
_b_tile(ptr, ::Type{WMMA_RDNA3.ColMajor}, tile_col, k, N, K, ::Type{T}) where T =
ptr + (tile_col * K + k) * Int32(sizeof(T)), K
_b_tile(ptr, ::Type{WMMA.RowMajor}, tile_col, k, N, K, ::Type{T}) where T =
_b_tile(ptr, ::Type{WMMA_RDNA3.RowMajor}, tile_col, k, N, K, ::Type{T}) where T =
ptr + (k * N + tile_col) * Int32(sizeof(T)), N

M, N, K = 32, 32, 32
Expand All @@ -154,9 +199,71 @@ B_host = Float16.(rand(K, N))
A, B = ROCArray(A_host), ROCArray(B_host)
C = ROCArray(zeros(Float32, M, N))

tiles_m, tiles_n = M ÷ WMMA.M, N ÷ WMMA.N
tiles_m, tiles_n = M ÷ WMMA_RDNA3.M, N ÷ WMMA_RDNA3.N
@roc gridsize=(tiles_m, tiles_n) groupsize=32 wmma_kernel!(
C, A, B, Int32(M), Int32(N), Int32(K), WMMA.ColMajor)
C, A, B, Int32(M), Int32(N), Int32(K), WMMA_RDNA3.ColMajor)

@assert maximum(abs.(Float32.(C) .- (Float32.(A) * Float32.(B)))) < 0.1
```

### RDNA 4 Example

Here's the same example adapted for RDNA 4:

!!! note "Hardware Requirements"
WMMA instructions for RDNA 4 require gfx1200+ GPUs. This code will only execute
successfully on compatible hardware with ROCm 6.0+ and LLVM 18+.

```julia
using AMDGPU
using AMDGPU.Device: WMMA_RDNA4

function wmma_rdna4_kernel!(C, A::AbstractArray{T}, B, M::Int32, N::Int32, K::Int32, layout) where T
tile_row = (workgroupIdx().x - Int32(1)) * Int32(WMMA_RDNA4.M)
tile_col = (workgroupIdx().y - Int32(1)) * Int32(WMMA_RDNA4.N)

C_ptr = pointer(C)
A_ptr = pointer(A)
B_ptr = pointer(B)

c_frag = WMMA_RDNA4.fill_c(Float32, 0f0)
k = Int32(0)
while k < K
a_ptr, a_stride = _a_tile(A_ptr, layout, tile_row, k, M, K, T)
b_ptr, b_stride = _b_tile(B_ptr, layout, tile_col, k, N, K, T)

a_frag = WMMA_RDNA4.load_a(a_ptr, a_stride, layout)
b_frag = WMMA_RDNA4.load_b(b_ptr, b_stride, layout)
c_frag = WMMA_RDNA4.mma(a_frag, b_frag, c_frag)

k += Int32(WMMA_RDNA4.K)
end

c_ptr = C_ptr + (tile_col * M + tile_row) * Int32(sizeof(Float32))
WMMA_RDNA4.store_d(c_ptr, c_frag, M, WMMA_RDNA4.ColMajor)
return
end

# Tile pointer + stride helpers — dispatched on layout, DCE'd by the compiler.
_a_tile(ptr, ::Type{WMMA_RDNA4.ColMajor}, tile_row, k, M, K, ::Type{T}) where T =
ptr + (k * M + tile_row) * Int32(sizeof(T)), M
_a_tile(ptr, ::Type{WMMA_RDNA4.RowMajor}, tile_row, k, M, K, ::Type{T}) where T =
ptr + (tile_row * K + k) * Int32(sizeof(T)), K

_b_tile(ptr, ::Type{WMMA_RDNA4.ColMajor}, tile_col, k, N, K, ::Type{T}) where T =
ptr + (tile_col * K + k) * Int32(sizeof(T)), K
_b_tile(ptr, ::Type{WMMA_RDNA4.RowMajor}, tile_col, k, N, K, ::Type{T}) where T =
ptr + (k * N + tile_col) * Int32(sizeof(T)), N

M, N, K = 32, 32, 32
A_host = Float16.(rand(M, K))
B_host = Float16.(rand(K, N))
A, B = ROCArray(A_host), ROCArray(B_host)
C = ROCArray(zeros(Float32, M, N))

tiles_m, tiles_n = M ÷ WMMA_RDNA4.M, N ÷ WMMA_RDNA4.N
@roc groupsize=32 gridsize=(tiles_m, tiles_n) wmma_rdna4_kernel!(
C, A, B, Int32(M), Int32(N), Int32(K), WMMA_RDNA4.ColMajor)

@assert maximum(abs.(Float32.(C) .- (Float32.(A) * Float32.(B)))) < 0.1
```
Expand Down
3 changes: 2 additions & 1 deletion src/device/gcn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ include(joinpath("gcn", "execution_control.jl"))
include(joinpath("gcn", "hostcall.jl"))
include(joinpath("gcn", "output.jl"))
include(joinpath("gcn", "memory_dynamic.jl"))
include(joinpath("gcn", "wmma.jl"))
include(joinpath("gcn", "wmma_rdna3.jl"))
include(joinpath("gcn", "wmma_rdna4.jl"))
4 changes: 2 additions & 2 deletions src/device/gcn/wmma.jl → src/device/gcn/wmma_rdna3.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# WMMA (Wavefront Matrix Multiply-Accumulate) intrinsics for RDNA 3 (GFX11)
# https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma_32.ll

export WMMA
export WMMA_RDNA3

module WMMA
module WMMA_RDNA3

export Fragment, M, N, K
export ColMajor, RowMajor
Expand Down
Loading