Skip to content

Commit 1bd757c

Browse files
committed
add fmha
1 parent 9178f78 commit 1bd757c

1 file changed

Lines changed: 208 additions & 0 deletions

File tree

examples/fmha.jl

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Batch matrix multiplication example - Julia port of cuTile Python's AttentionFMHA.py sample
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
using CUDA
6+
import cuTile as ct
7+
8+
import NNlib
9+
10+
const INV_LOG_2 = Float32(1 / log(2))
11+
const ConstInt = ct.Constant{Int}
12+
const ConstBool = ct.Constant{Bool}
13+
14+
# TODO: "latency"
15+
16+
# cuTile kernel for Fused Multi-Head Attention
17+
# Q: d x
18+
function fmha_kernel(
19+
Q::ct.TileArray{T,4}, K::ct.TileArray{T,4}, V::ct.TileArray{T,4}, Out::ct.TileArray{T,4},
20+
qk_scale::AbstractFloat,
21+
input_pos::Integer,
22+
TILE_D::ConstInt,
23+
H::ConstInt, # number of heads?
24+
TILE_M::ConstInt,
25+
TILE_N::ConstInt,
26+
QUERY_GROUP_SIZE::ConstInt,
27+
CAUSAL::ConstBool,
28+
EVEN_K::ConstBool
29+
) where T
30+
bid_x = ct.bid(1)
31+
bid_y = ct.bid(2)
32+
batch_idx = cld(bid_y, H[])
33+
head_idx = mod1(bid_y, H[])
34+
off_kv_h = cld(head_idx, QUERY_GROUP_SIZE[])
35+
36+
qk_scale = Float32(qk_scale) * Float32(INV_LOG_2)
37+
38+
# Offsets for query tile (M-dimension)
39+
offs_m = bid_x * TILE_M[] .+ ct.arange((TILE_M[],), Int32) .+ input_pos
40+
41+
# local offsets for key/value tile (N-dimension)
42+
offs_n_tile = ct.reshape(ct.arange((TILE_N[],), Int32), (1, TILE_N[]))
43+
44+
# online softmax accumulators in Float32 for stability
45+
m_i = ct.full((1, TILE_M[]), -Inf32, Float32)
46+
l_i = ct.zeros((1, TILE_M[]), Float32)
47+
acc = ct.zeros((TILE_D[], TILE_M[]), Float32)
48+
49+
# query tile for this batch, head, and M-chunk
50+
q = ct.load(Q, (1, bid_x, head_idx, batch_idx), (TILE_D[], TILE_M[], 1, 1))
51+
q = ct.reshape(q, (TILE_D[], TILE_M[]))
52+
53+
m_end = input_pos + (bid_x + 1) * TILE_M[]
54+
k_seqlen = K.sizes[2]
55+
if CAUSAL[]
56+
# when kv pos could exceed q pos
57+
mask_start = cld(input_pos + bid_x * TILE_M[], TILE_N[])
58+
# when kv pos could exceed k_seqlen
59+
mask_start = min(mask_start, cld(k_seqlen, TILE_N[]))
60+
Tc = cld(min(m_end, k_seqlen), TILE_N[])
61+
else
62+
Tc = cld(k_seqlen, TILE_N[])
63+
mask_start = cld(k_seqlen, TILE_N[])
64+
end
65+
66+
# loop over K, V blocks (N-dimension chunks)
67+
j = Int32(1)
68+
while j <= Tc
69+
k = ct.load(K, (1, j, off_kv_h, batch_idx), (TILE_D[], TILE_N[], 1, 1))
70+
k = ct.reshape(k, (TILE_D[], TILE_N[]))
71+
k = ct.transpose(k)
72+
73+
qk = ct.zeros((TILE_N[], TILE_M[]), Float32)
74+
qk = ct.muladd(k, q, qk)
75+
76+
if (CAUSAL[] || !EVEN_K[]) && j >= mask_start
77+
offs_n = j * TILE_N[] + offs_n_tile
78+
mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
79+
if !EVEN_K[]
80+
mask = mask .& (offs_n .< k_seqlen)
81+
end
82+
if CAUSAL[]
83+
mask = mask .& (offs_m .>= offs_n)
84+
end
85+
mask = ct.where(mask, -Inf32, Float32)
86+
qk = qk .+ mask
87+
end
88+
89+
# moving qk_scale multiplication after reduce_max
90+
m_ij = max.(m_i, (ct.reduce_max(qk, 1) * qk_scale))
91+
qk = qk * qk_scale .- m_ij
92+
93+
# attention weights [TILE_N, TILE_M]
94+
p = exp2.(qk) # might need to expose "flush_to_zero"
95+
l_ij = ct.reduce_sum(p, 1)
96+
alpha = exp2.(m_i .- m_ij) # flush to zero?
97+
98+
l_i = l_i .* alpha .+ l_ij
99+
acc = acc .* alpha
100+
101+
v = ct.load(V, (1, j, off_kv_h, batch_idx), (TILE_D[], TILE_N[], 1, 1))
102+
v = ct.reshape(v, (TILE_D[], TILE_N[]))
103+
p = ct.astype(p, eltype(q))
104+
acc = ct.muladd(v, p, acc) # [TILE_D, TILE_M]
105+
m_i = m_ij
106+
107+
j += Int32(1)
108+
end
109+
110+
acc = acc ./ l_i # flush to zero? rounding mode?
111+
acc = ct.reshape(acc, (TILE_D[], TILE_M[], 1, 1))
112+
ct.store(Out, (1, bid_x, head_idx, batch_idx), acc)
113+
114+
return
115+
end
116+
117+
function cutile_fmha(Q::AbstractArray{T,4}, K::AbstractArray{T,4}, V::AbstractArray{T,4};
118+
qk_scale::Union{AbstractFloat,Nothing} = nothing,
119+
input_pos::Integer = 0,
120+
tile_m::Integer = 128,
121+
tile_n::Integer = 128,
122+
query_group_size::Integer = 1,
123+
causal::Bool = false,
124+
) where T
125+
if size(Q, 4) != size(K, 4) || size(Q, 4) != size(V, 4)
126+
throw(ArgumentError("Batch dimensions must match for Q, K, V."))
127+
end
128+
if size(Q, 3) % query_group_size != 0
129+
throw(ArgumentError("Number of query heads must be divisible by query_group_size."))
130+
end
131+
if size(K, 3) * query_group_size != size(Q, 3)
132+
throw(ArgumentError("K_heads * query_group_size must equal Q_heads."))
133+
end
134+
if size(Q, 1) != size(K, 1)
135+
throw(ArgumentError("D_k (first dim of Q and K) must match."))
136+
end
137+
if size(K, 2) != size(V, 2)
138+
throw(ArgumentError("SeqLen_KV (dim 2 of K and V) must match."))
139+
end
140+
141+
D_k, SeqLen_Q, Heads, Batch = size(Q)
142+
D_v, SeqLen_KV, KV_heads, _ = size(V)
143+
even_k = (SeqLen_KV % tile_n) == 0
144+
145+
isnothing(qk_scale) && (qk_scale = 1 / sqrt(D_k))
146+
147+
Out = CUDA.zeros(T, D_v, SeqLen_Q, Heads, Batch)
148+
149+
grid_x = cld(SeqLen_Q, tile_m)
150+
grid_y = Heads * Batch
151+
grid = (grid_x, grid_y, 1)
152+
153+
ct.launch(fmha_kernel, grid,
154+
Q, K, V, Out,
155+
qk_scale, input_pos,
156+
ct.Constant(D_k),
157+
ct.Constant(Heads),
158+
ct.Constant(tile_m),
159+
ct.Constant(tile_n),
160+
ct.Constant(query_group_size),
161+
ct.Constant(causal),
162+
ct.Constant(even_k))
163+
164+
return Out
165+
end
166+
167+
function nnlib_fmha(Q::AbstractArray{T,4}, K::AbstractArray{T,4}, V::AbstractArray{T,4};
168+
query_group_size::Integer = 1,
169+
causal::Bool = false,
170+
) where T
171+
mask = causal ? NNlib.make_causal_mask(Q; dims=2) : nothing
172+
if query_group_size > 1
173+
K, V = repeat.((K, V), inner=(1, 1, query_group_size))
174+
end
175+
Out, _ = NNlib.dot_product_attention(Q, K, V; mask)
176+
return Out
177+
end
178+
179+
180+
function test_fmha(::Type{T},
181+
D_k, SeqLen_Q, Heads, Batch,
182+
D_v, SeqLen_KV, KV_heads,
183+
causal, tile_m, tile_n,
184+
) where T
185+
query_group_size = Heads ÷ KV_heads
186+
187+
Q = CUDA.randn(T, D_k, SeqLen_Q, Heads, Batch)
188+
K = CUDA.randn(T, D_k, SeqLen_KV, KV_heads, Batch)
189+
V = CUDA.randn(T, D_v, SeqLen_KV, KV_heads, Batch)
190+
191+
out_cutile = cutile_fmha(Q, K, V;
192+
causal=causal,
193+
tile_m=tile_m, tile_n=tile_n,
194+
query_group_size=query_group_size)
195+
196+
Q_cpu = Array(Q)
197+
K_cpu = Array(K)
198+
V_cpu = Array(V)
199+
expected = nnlib_fmha(Q_cpu, K_cpu, V_cpu; query_group_size, causal)
200+
result = Array(out_cutile)
201+
202+
if isapprox(result, expected, rtol=1e-2, atol=1e-2)
203+
println(" passed")
204+
else
205+
max_diff = maximum(abs.(result - expected))
206+
println(" FAILED (max diff: $max_diff)")
207+
end
208+
end

0 commit comments

Comments
 (0)