Skip to content

Commit ae773ac

Browse files
committed
debug
1 parent d941cad commit ae773ac

1 file changed

Lines changed: 129 additions & 0 deletions

File tree

chapter_model_deployment/Advanced_Efficient_Techniques.md

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,132 @@ s(x) = e^{m(x_{1})-m(x)}s_{1}(x_1) + e^{m(x_2)-m(x)}s_{1}(x_2)\\
223223
Softmax(x) = \frac{l(x)}{s(x)}
224224
\end{aligned}
225225
$$
226+
227+
228+
Figure :numref:`ch-deploy/flashattn` shows a brief overview of
229+
FlashAttention with two blocks. Following decomposition, Softmax
230+
calculations can be executed block by block. Therefore, **K, Q** and
231+
**V** are initially divided into blocks. Subsequently, compute the
232+
Softmax values together with the respective $s(x)$ and $m(x)$.
233+
Ultimately, aggregate **O** blocks, the outcomes of the block-wise
234+
Softmax values with the multiplication of corresponding **V** block
235+
vectors. To enhance the efficiency of these steps, it's necessary to
236+
load all the required matrix blocks from the HBM to the on-chip SRAM for
237+
the current step's computation. All the calculations take place on-chip,
238+
that is, within the SRAM. To ensure that all required blocks are
239+
sufficiently proper to fit within the on-chip SRAM, which has a capacity
240+
of 20MB, careful consideration must be given to setting the size of
241+
these blocks. For **K, Q** and **V** $\in\mathbb{R}^{N \times d}$, the
242+
block size is set to $\lfloor \frac{M}{4d} \rfloor$ where M is the SRAM
243+
size and the output block size is set to be
244+
$min(\lfloor \frac{M}{4d} \rfloor, d)$ [@dao2022flashattention].
245+
Post-computation of each block, the resulting output block along with
246+
the corresponding $s(x)$ and $m(x)$ are transferred back to the HBM.
247+
These blocks are sufficiently small for reads/writes to avoid causing
248+
significant latency; in addition, all related computations are
249+
implemented in one CUDA kernel using **kernel fusion**. This avoids
250+
repeatedly reading and writing from and to HBM.
251+
252+
<figure id="fig:ch-deploy/memory">
253+
<div class="center">
254+
<img src="../img/ch08/Memory hierarchy.png"
255+
style="width:80.0%" />
256+
</div>
257+
<figcaption>Memory Hierarchy Overview</figcaption>
258+
</figure>
259+
260+
<figure id="fig:ch-deploy/flashattn">
261+
<div class="center">
262+
<img src="../img/ch08/flashattn.png" style="width:80.0%" />
263+
</div>
264+
<figcaption>FlashAttention Overview with Two Blocks</figcaption>
265+
</figure>
266+
267+
**Recomputation**:
268+
269+
Standard attention requires $O(N^2)$ memory to store intermediate
270+
matrices **S** and **P** for gradient computation w.r.t. **Q, K, V** in
271+
the backward pass. For FlashAttention, **S** and **P** can be recomputed
272+
with the HBM-stored $s(x)$, $m(x)$ and **O** in SRAM easily. Therefore,
273+
only $O(N)$ memory is required. Furthermore, FlashAttention has fewer
274+
HBM accesses than Standard Attention which results in faster runtime
275+
[@dao2022flashattention].
276+
277+
The standard FlashAttention implementation doesn't eliminate the
278+
redundant computation of zero elements within the attention mechanism.
279+
To address this, a mask is incorporated in FlashAttention to focus
280+
computation exclusively on non-zero elements. Termed as Block-Sparse
281+
FlashAttention, this approach is also discussed in
282+
[@dao2022flashattention]. By using sparsity, Block-Sparse FlashAttention
283+
effectively reduces the larger component of the I/O complexity, leading
284+
to a direct improvement in performance.
285+
286+
However, FlashAttention has not been fully optimized. Dao noted that its
287+
inefficiency stems from suboptimal work distribution among various
288+
thread blocks and warps on the GPU. This leads to either low occupancy
289+
or unnecessary shared memory reads and writes. Thus, Dao proposed
290+
**FlashAttention-2** [@dao2023flashattention2] which has better
291+
parallelism and work partitioning.
292+
293+
FlashAttention-2 includes several tweaks to reduce the non-matmul
294+
operations.
295+
296+
1. Remain output **O** blocks un-scaled until the very end of the loop.
297+
298+
2. Instead of saving both $s(x)$ and $m(x)$ in HBM, save
299+
$logsumexp_{i} = m_{i} + log(s_{i})$ which is enough for backward
300+
pass.
301+
302+
3. For blocks where column indices are greater than row indices, which
303+
occupy about half of the blocks in large sequences, computation is
304+
skipped. It leads to a 1.7-1.8X speedup compared to those without
305+
this skip.
306+
307+
4. Only use the row-wise $logsumexp$ instead of both the row-wise max
308+
$m(x)$ and row-wise sum $s(x)$ of exponentials in the softmax.
309+
310+
For parallelism, In the original version of FlashAttention, parallel
311+
processing was done over the batch size and number of heads, with one
312+
thread block processing one attention head. There are as many thread
313+
blocks as the product of the batch size and the number of heads. This
314+
works well on an A100 GPU, which has 108 Streaming Multiprocessors
315+
(SMs), as long as the number of thread blocks is large enough to engage
316+
most of the SMs, like 80 or more.
317+
318+
However, for long sequences, this isn't as efficient because of the
319+
smaller number of thread blocks. FlashAttention-2 introduces additional
320+
parallelization over the sequence length dimension, which significantly
321+
speeds up the process in these cases by improving GPU occupancy, i.e.
322+
the fraction of GPU resources being used.
323+
324+
In the forward pass, the method schedules different parts of the
325+
sequence length on different thread blocks that operate independently.
326+
The backward pass also incorporates parallelization over the sequence
327+
length. To update the gradients of the query matrix **dQ**, it uses
328+
atomic additions to synchronize updates between different thread blocks.
329+
330+
Within each thread block, work partitioning for each wrap is also of
331+
importance. Usually, 4 to 8 warps are allocated to each thread block. To
332+
handle this condition, FlashAttention-2 introduces significant
333+
improvements in both the forward and backward passes of the algorithm.
334+
In the forward pass, unlike FlashAttention which splits **K** and **V**
335+
across 4 warps (the \"split-K\" scheme) leading to inefficient shared
336+
memory operations, FlashAttention-2 splits **Q** across the warps while
337+
keeping **K** and **V** accessible to all. This change eliminates the
338+
need for inter-warp communication and reduces shared memory
339+
reads/writes, resulting in a faster runtime. Each warp directly
340+
multiplies its slice of **Q** with **K** and then with **V**,
341+
simplifying the computation of the output slice. In the backward pass,
342+
FlashAttention-2 continues to avoid the \"split-K\" scheme, aligning the
343+
warps in a way that minimizes shared memory operations. Despite
344+
requiring some synchronization due to complex dependencies among inputs
345+
and gradients, this approach still leads to a speedup by reducing the
346+
shared memory reads/writes.
347+
348+
FlashAttention has gained significant attention in the industry for its
349+
remarkable performance, offering accelerated attention computations in
350+
both forward and backward passes while also reducing memory I/O
351+
complexity. An enhanced version, FlashAttention-2, achieves a notable 2X
352+
speedup over the standard FlashAttention [@dao2022flashattention].
353+
Moreover, continuous optimization efforts are being made, promising an
354+
even more potent version of FlashAttention in the future.

0 commit comments

Comments
 (0)