You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The current implementation of the MLA Indexer computation involves three Pallas kernels:
5
+
1.**Forward Kernel (`Indexer.kernel`)**: Computes attention scores using a shared Key (MQA-style) and weighted head aggregation.
6
+
2.**Backward Kernel 1 (`backward_qw_kernel`)**: Computes gradients for Query (`d_q`) and Head Weights (`d_w`).
7
+
3.**Backward Kernel 2 (`backward_k_kernel`)**: Computes gradients for Key (`d_k`).
8
+
9
+
**Identified Issues:**
10
+
-**Serialized Execution**: All kernels currently use a "start DMA -> wait DMA -> compute" pattern within their inner loops. This prevents overlap of memory transfer and computation, significantly reducing performance on TPU where HBM bandwidth is often the bottleneck.
11
+
-**Single Buffering**: Scratch buffers in VMEM are single-buffered, making it impossible to prefetch the next block while processing the current one.
12
+
-**Block Sizing**: `bS=256` and `bT=32` are hardcoded. While reasonable, they should be validated against the specific head dimensions and VMEM capacity.
13
+
14
+
## 2. Optimization Strategy
15
+
The primary optimization is to implement **Manual Software Pipelining (Double Buffering)** for all three kernels.
16
+
17
+
**Key Transformations:**
18
+
1.**Double Buffering**: Allocate scratch buffers of size `(2, ...)` in VMEM for all inputs that are iterated over (e.g., `K` blocks in forward pass).
19
+
2.**Pipelined Loop Structure**:
20
+
-**Prologue**: Initiate the load for the first block (buffer 0).
21
+
-**Body**:
22
+
- Wait for buffer `i % 2`.
23
+
- Initiate load for block `i+1` into buffer `(i+1) % 2` (if not last iteration).
24
+
- Compute using buffer `i % 2`.
25
+
-**Epilogue**: (Handled naturally by the loop condition).
26
+
3.**Async Copies**: Use `pltpu.make_async_copy` with explicit semaphores to manage synchronization.
27
+
28
+
## 3. Memory Layout and Tiling
29
+
30
+
### Forward Kernel (`Indexer.kernel`)
31
+
-**Grid**: `(B, T // bT)`
32
+
-**Loop**: Over `S // bS` blocks.
33
+
-**Stationary Data**: `q_block` (bT, H, D), `w_block` (bT, H) - Loaded once per program, stay in VMEM.
0 commit comments