@@ -223,3 +223,132 @@ s(x) = e^{m(x_{1})-m(x)}s_{1}(x_1) + e^{m(x_2)-m(x)}s_{1}(x_2)\\
223223Softmax(x) = \frac{l(x)}{s(x)}
224224\end{aligned}
225225$$
226+
227+
228+ Figure :numref:` ch-deploy/flashattn ` shows a brief overview of
229+ FlashAttention with two blocks. Following decomposition, Softmax
230+ calculations can be executed block by block. Therefore, ** K, Q** and
231+ ** V** are initially divided into blocks. Subsequently, compute the
232+ Softmax values together with the respective $s(x)$ and $m(x)$.
233+ Ultimately, aggregate ** O** blocks, the outcomes of the block-wise
234+ Softmax values with the multiplication of corresponding ** V** block
235+ vectors. To enhance the efficiency of these steps, it's necessary to
236+ load all the required matrix blocks from the HBM to the on-chip SRAM for
237+ the current step's computation. All the calculations take place on-chip,
238+ that is, within the SRAM. To ensure that all required blocks are
239+ sufficiently proper to fit within the on-chip SRAM, which has a capacity
240+ of 20MB, careful consideration must be given to setting the size of
241+ these blocks. For ** K, Q** and ** V** $\in\mathbb{R}^{N \times d}$, the
242+ block size is set to $\lfloor \frac{M}{4d} \rfloor$ where M is the SRAM
243+ size and the output block size is set to be
244+ $min(\lfloor \frac{M}{4d} \rfloor, d)$ [ @dao2022flashattention ] .
245+ Post-computation of each block, the resulting output block along with
246+ the corresponding $s(x)$ and $m(x)$ are transferred back to the HBM.
247+ These blocks are sufficiently small for reads/writes to avoid causing
248+ significant latency; in addition, all related computations are
249+ implemented in one CUDA kernel using ** kernel fusion** . This avoids
250+ repeatedly reading and writing from and to HBM.
251+
252+ <figure id =" fig:ch-deploy/memory " >
253+ <div class =" center " >
254+ <img src="../img/ch08/Memory hierarchy.png"
255+ style="width:80.0%" />
256+ </div >
257+ <figcaption >Memory Hierarchy Overview</figcaption >
258+ </figure >
259+
260+ <figure id =" fig:ch-deploy/flashattn " >
261+ <div class =" center " >
262+ <img src =" ../img/ch08/flashattn.png " style =" width :80.0% " />
263+ </div >
264+ <figcaption >FlashAttention Overview with Two Blocks</figcaption >
265+ </figure >
266+
267+ ** Recomputation** :
268+
269+ Standard attention requires $O(N^2)$ memory to store intermediate
270+ matrices ** S** and ** P** for gradient computation w.r.t. ** Q, K, V** in
271+ the backward pass. For FlashAttention, ** S** and ** P** can be recomputed
272+ with the HBM-stored $s(x)$, $m(x)$ and ** O** in SRAM easily. Therefore,
273+ only $O(N)$ memory is required. Furthermore, FlashAttention has fewer
274+ HBM accesses than Standard Attention which results in faster runtime
275+ [ @dao2022flashattention ] .
276+
277+ The standard FlashAttention implementation doesn't eliminate the
278+ redundant computation of zero elements within the attention mechanism.
279+ To address this, a mask is incorporated in FlashAttention to focus
280+ computation exclusively on non-zero elements. Termed as Block-Sparse
281+ FlashAttention, this approach is also discussed in
282+ [ @dao2022flashattention ] . By using sparsity, Block-Sparse FlashAttention
283+ effectively reduces the larger component of the I/O complexity, leading
284+ to a direct improvement in performance.
285+
286+ However, FlashAttention has not been fully optimized. Dao noted that its
287+ inefficiency stems from suboptimal work distribution among various
288+ thread blocks and warps on the GPU. This leads to either low occupancy
289+ or unnecessary shared memory reads and writes. Thus, Dao proposed
290+ ** FlashAttention-2** [ @dao2023flashattention2 ] which has better
291+ parallelism and work partitioning.
292+
293+ FlashAttention-2 includes several tweaks to reduce the non-matmul
294+ operations.
295+
296+ 1 . Remain output ** O** blocks un-scaled until the very end of the loop.
297+
298+ 2 . Instead of saving both $s(x)$ and $m(x)$ in HBM, save
299+ $logsumexp_ {i} = m_ {i} + log(s_ {i})$ which is enough for backward
300+ pass.
301+
302+ 3 . For blocks where column indices are greater than row indices, which
303+ occupy about half of the blocks in large sequences, computation is
304+ skipped. It leads to a 1.7-1.8X speedup compared to those without
305+ this skip.
306+
307+ 4 . Only use the row-wise $logsumexp$ instead of both the row-wise max
308+ $m(x)$ and row-wise sum $s(x)$ of exponentials in the softmax.
309+
310+ For parallelism, In the original version of FlashAttention, parallel
311+ processing was done over the batch size and number of heads, with one
312+ thread block processing one attention head. There are as many thread
313+ blocks as the product of the batch size and the number of heads. This
314+ works well on an A100 GPU, which has 108 Streaming Multiprocessors
315+ (SMs), as long as the number of thread blocks is large enough to engage
316+ most of the SMs, like 80 or more.
317+
318+ However, for long sequences, this isn't as efficient because of the
319+ smaller number of thread blocks. FlashAttention-2 introduces additional
320+ parallelization over the sequence length dimension, which significantly
321+ speeds up the process in these cases by improving GPU occupancy, i.e.
322+ the fraction of GPU resources being used.
323+
324+ In the forward pass, the method schedules different parts of the
325+ sequence length on different thread blocks that operate independently.
326+ The backward pass also incorporates parallelization over the sequence
327+ length. To update the gradients of the query matrix ** dQ** , it uses
328+ atomic additions to synchronize updates between different thread blocks.
329+
330+ Within each thread block, work partitioning for each wrap is also of
331+ importance. Usually, 4 to 8 warps are allocated to each thread block. To
332+ handle this condition, FlashAttention-2 introduces significant
333+ improvements in both the forward and backward passes of the algorithm.
334+ In the forward pass, unlike FlashAttention which splits ** K** and ** V**
335+ across 4 warps (the \" split-K\" scheme) leading to inefficient shared
336+ memory operations, FlashAttention-2 splits ** Q** across the warps while
337+ keeping ** K** and ** V** accessible to all. This change eliminates the
338+ need for inter-warp communication and reduces shared memory
339+ reads/writes, resulting in a faster runtime. Each warp directly
340+ multiplies its slice of ** Q** with ** K** and then with ** V** ,
341+ simplifying the computation of the output slice. In the backward pass,
342+ FlashAttention-2 continues to avoid the \" split-K\" scheme, aligning the
343+ warps in a way that minimizes shared memory operations. Despite
344+ requiring some synchronization due to complex dependencies among inputs
345+ and gradients, this approach still leads to a speedup by reducing the
346+ shared memory reads/writes.
347+
348+ FlashAttention has gained significant attention in the industry for its
349+ remarkable performance, offering accelerated attention computations in
350+ both forward and backward passes while also reducing memory I/O
351+ complexity. An enhanced version, FlashAttention-2, achieves a notable 2X
352+ speedup over the standard FlashAttention [ @dao2022flashattention ] .
353+ Moreover, continuous optimization efforts are being made, promising an
354+ even more potent version of FlashAttention in the future.
0 commit comments