Skip to content

Commit d113818

Browse files
committed
Add indexer parity test and fix kernel issue
1 parent e887876 commit d113818

6 files changed

Lines changed: 326 additions & 18 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,8 @@ moba_topk: 8
357357
# DeepSeek Sparse Attention (DSA)
358358
# deepseek3.2 introduces indexer in MLA
359359
use_sparse_indexer: False
360+
# Whether to use Pallas kernel for indexer computation
361+
use_kernel_indexer: True
360362
index_head_dim: 128
361363
index_n_heads: 64
362364
index_topk: 2048

src/maxtext/configs/models/deepseek3.2-671b.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ base_num_query_heads: 128
2020
base_num_kv_heads: 128
2121
base_mlp_dim: 18432
2222
base_moe_mlp_dim: 2048
23-
base_num_decoder_layers: 61 #6
23+
base_num_decoder_layers: 6 #61 #6
2424
first_num_dense_layers: 3
2525
mlp_activations: ["silu","linear"]
2626
vocab_size: 129280
2727
enable_dropout: False
2828
logits_via_embedding: False
2929
normalization_layer_epsilon: 1.0e-6
30-
num_experts: 256 #64
30+
num_experts: 64 #256 #64
3131
num_experts_per_tok: 8
3232
shared_experts: 1
3333
routed_scaling_factor: 2.5

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,7 @@ class AttentionIndexer(BaseModel):
532532
"""Configuration for DeepSeek Sparse Attention (DSA): DeepSeek3.2-style MLA with indexer."""
533533

534534
use_sparse_indexer: bool = Field(False, description="Whether to use sparse indexer for MLA.")
535+
use_kernel_indexer: bool = Field(True, description="Whether to use Pallas kernel for indexer computation.")
535536
index_head_dim: NonNegativeInt = Field(128, description="Head dim for indexer query and key.")
536537
index_n_heads: NonNegativeInt = Field(64, description="Number of query heads in indexer.")
537538
index_topk: NonNegativeInt = Field(2048, description="Number of tokens selected by the query token in indexer.")

src/maxtext/layers/attention_mla.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import math
1818
from typing import Any, Optional, Tuple
1919
import copy
20+
import functools
2021

2122
import jax
2223
from jax.ad_checkpoint import checkpoint_name
@@ -306,7 +307,7 @@ def backward_computation(q: jnp.ndarray, k: jnp.ndarray, w: jnp.ndarray, d_score
306307

307308
# Block sizes
308309
bT = 32
309-
bS = 512
310+
bS = 256
310311

311312
# Padding
312313
pad_d = (128 - (D % 128)) % 128
@@ -445,13 +446,15 @@ def __init__(
445446
self,
446447
config: Any,
447448
rotary_embedding,
449+
mesh: Optional[Mesh] = None,
448450
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"),
449451
quant: Optional[Quant] = None,
450452
model_mode: str = MODEL_MODE_TRAIN,
451453
rngs: Optional[nnx.Rngs] = None,
452454
):
453455
self.config = config
454456
self.rotary_embedding = rotary_embedding
457+
self.mesh = mesh
455458
self.quant = quant
456459
self.kernel_init = kernel_init
457460
self.model_mode = model_mode
@@ -661,7 +664,7 @@ def _computation_impl(self, q: jnp.ndarray, k: jnp.ndarray, w: jnp.ndarray, mask
661664

662665
# Block sizes
663666
bT = 32
664-
bS = 512
667+
bS = 256
665668

666669
# Pad D to multiple of 128 (TPU vector alignment)
667670
# TPU vector registers are 8x128 (for f32). The last dimension should be 128-aligned.
@@ -713,7 +716,7 @@ def _computation_impl(self, q: jnp.ndarray, k: jnp.ndarray, w: jnp.ndarray, mask
713716
else:
714717
# Dummy mask to satisfy Pallas signature
715718
# Create a small dummy mask
716-
dummy_mask = jnp.zeros((1, 1), dtype=jnp.float32)
719+
dummy_mask = jnp.zeros((B, 1, 1), dtype=jnp.float32)
717720
mask_spec = pl.BlockSpec(memory_space=None)
718721

719722
# Outputs
@@ -738,15 +741,40 @@ def _computation_impl(self, q: jnp.ndarray, k: jnp.ndarray, w: jnp.ndarray, mask
738741
# If has_mask is False, we pass the dummy mask to the kernel
739742
mask_arg = mask if has_mask else dummy_mask
740743

741-
score = pl.pallas_call(
742-
kernel_fn,
743-
out_shape=out_shape,
744-
grid=grid,
745-
in_specs=[q_spec, k_spec, w_spec, mask_spec],
746-
out_specs=o_score_spec,
747-
scratch_shapes=scratch_shapes,
748-
compiler_params=pltpu.CompilerParams(dimension_semantics=("parallel", "parallel"))
749-
)(q, k, w, mask_arg)
744+
# Wrap in shard_map to avoid partitioning error on TPU
745+
# Map B to the first axis of the mesh (usually data/fsdp)
746+
from jax.sharding import PartitionSpec as P
747+
748+
# Use jax.shard_map if available (JAX 0.4.31+), otherwise fallback to experimental
749+
shard_map = getattr(jax, "shard_map", None)
750+
if shard_map is None:
751+
from jax.experimental.shard_map import shard_map
752+
kwargs = {}
753+
else:
754+
kwargs = {"check_vma": False}
755+
756+
# Infer sharding axis from mesh_axes if possible, otherwise assume the first one
757+
batch_axis = self.config.mesh_axes[1] if len(self.config.mesh_axes) > 1 else self.config.mesh_axes[0]
758+
759+
@functools.partial(
760+
shard_map,
761+
mesh=self.mesh,
762+
in_specs=(P(batch_axis, None, None, None), P(batch_axis, None, None), P(batch_axis, None, None), P(batch_axis, None, None)),
763+
out_specs=P(batch_axis, None, None),
764+
**kwargs
765+
)
766+
def sharded_pallas_call(q_s, k_s, w_s, m_s):
767+
return pl.pallas_call(
768+
kernel_fn,
769+
out_shape=jax.ShapeDtypeStruct((q_s.shape[0], T_padded, S_padded), dtype=jnp.float32),
770+
grid=(q_s.shape[0], T_padded // bT),
771+
in_specs=[q_spec, k_spec, w_spec, mask_spec],
772+
out_specs=o_score_spec,
773+
scratch_shapes=scratch_shapes,
774+
compiler_params=pltpu.CompilerParams(dimension_semantics=("parallel", "parallel"))
775+
)(q_s, k_s, w_s, m_s)
776+
777+
score = sharded_pallas_call(q, k, w, mask_arg)
750778

751779
# Slice back to original dimensions
752780
score = score[:, :T, :S]
@@ -852,14 +880,16 @@ def __call__(
852880
k = self.apply_partial_rope(k, inputs_positions=inputs_positions)
853881
k = k.squeeze(2) # [b, s, 1, d] -> [b, s, d]
854882

855-
if True:
883+
if self.config.use_kernel_indexer:
856884
# early return
857-
print("use kernel implementation")
858885
weights = self.weights_proj(inputs_q)
859886
weights = weights * (self.n_heads**-0.5) * self.softmax_scale
860-
return self.computation(q, k, weights, attention_mask, self.config.index_topk)
887+
indexer_score, topk_indices, _ = self.computation(q, k, weights, attention_mask, self.config.index_topk)
888+
indexer_mask = self.generate_mask(topk_indices, seqlen)
889+
if attention_mask is not None:
890+
indexer_mask += attention_mask
891+
return indexer_mask, topk_indices, indexer_score
861892

862-
print("use JAX implementation")
863893
# Compute Index Scores
864894
# QK product: relu(q @ k.T), [b, t, s, h]
865895
# Similar to MQA, each key is shared by h query head
@@ -1201,6 +1231,7 @@ def __init__(
12011231
config,
12021232
rngs=rngs,
12031233
rotary_embedding=indexer_rope,
1234+
mesh=mesh,
12041235
kernel_init=kernel_init,
12051236
quant=quant,
12061237
model_mode=model_mode,
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Kernel Optimization Plan: MLA Indexer Computation
2+
3+
## 1. Current Kernel Analysis
4+
The current implementation of the MLA Indexer computation involves three Pallas kernels:
5+
1. **Forward Kernel (`Indexer.kernel`)**: Computes attention scores using a shared Key (MQA-style) and weighted head aggregation.
6+
2. **Backward Kernel 1 (`backward_qw_kernel`)**: Computes gradients for Query (`d_q`) and Head Weights (`d_w`).
7+
3. **Backward Kernel 2 (`backward_k_kernel`)**: Computes gradients for Key (`d_k`).
8+
9+
**Identified Issues:**
10+
- **Serialized Execution**: All kernels currently use a "start DMA -> wait DMA -> compute" pattern within their inner loops. This prevents overlap of memory transfer and computation, significantly reducing performance on TPU where HBM bandwidth is often the bottleneck.
11+
- **Single Buffering**: Scratch buffers in VMEM are single-buffered, making it impossible to prefetch the next block while processing the current one.
12+
- **Block Sizing**: `bS=256` and `bT=32` are hardcoded. While reasonable, they should be validated against the specific head dimensions and VMEM capacity.
13+
14+
## 2. Optimization Strategy
15+
The primary optimization is to implement **Manual Software Pipelining (Double Buffering)** for all three kernels.
16+
17+
**Key Transformations:**
18+
1. **Double Buffering**: Allocate scratch buffers of size `(2, ...)` in VMEM for all inputs that are iterated over (e.g., `K` blocks in forward pass).
19+
2. **Pipelined Loop Structure**:
20+
- **Prologue**: Initiate the load for the first block (buffer 0).
21+
- **Body**:
22+
- Wait for buffer `i % 2`.
23+
- Initiate load for block `i+1` into buffer `(i+1) % 2` (if not last iteration).
24+
- Compute using buffer `i % 2`.
25+
- **Epilogue**: (Handled naturally by the loop condition).
26+
3. **Async Copies**: Use `pltpu.make_async_copy` with explicit semaphores to manage synchronization.
27+
28+
## 3. Memory Layout and Tiling
29+
30+
### Forward Kernel (`Indexer.kernel`)
31+
- **Grid**: `(B, T // bT)`
32+
- **Loop**: Over `S // bS` blocks.
33+
- **Stationary Data**: `q_block` (bT, H, D), `w_block` (bT, H) - Loaded once per program, stay in VMEM.
34+
- **Streaming Data**: `k_block` (bS, D), `mask_block` (bT, bS).
35+
- **Scratch Buffers**:
36+
- `k_scratch`: `(2, bS, D_padded)` in VMEM.
37+
- `mask_scratch`: `(2, bT, bS)` in VMEM.
38+
- `score_scratch`: `(bT, bS)` in VMEM (Accumulator, no need to double buffer if we write out once).
39+
40+
### Backward Kernel 1 (`backward_qw_kernel`)
41+
- **Grid**: `(B, T // bT)`
42+
- **Loop**: Over `S // bS` blocks.
43+
- **Stationary Data**: `q_block`, `w_block` (loaded once). `d_q_acc`, `d_w_acc` (accumulators in VMEM).
44+
- **Streaming Data**: `k_block`, `d_score_block`.
45+
- **Scratch Buffers**:
46+
- `k_scratch`: `(2, bS, D_padded)`
47+
- `d_score_scratch`: `(2, bT, bS)`
48+
49+
### Backward Kernel 2 (`backward_k_kernel`)
50+
- **Grid**: `(B, S // bS)`
51+
- **Loop**: Over `T // bT` blocks.
52+
- **Stationary Data**: `k_block` (loaded once). `d_k_acc` (accumulator).
53+
- **Streaming Data**: `q_block`, `w_block`, `d_score_block`.
54+
- **Scratch Buffers**:
55+
- `q_scratch`: `(2, bT, H, D_padded)`
56+
- `w_scratch`: `(2, bT, H_padded)`
57+
- `d_score_scratch`: `(2, bT, bS)`
58+
59+
## 4. TPU-Specific Optimizations
60+
- **Vector Alignment**: Ensure `D` and `H` are padded to multiples of 128 (already partially handled, will reinforce).
61+
- **Semaphores**: Use `pltpu.SemaphoreType.DMA` for async copy tracking.
62+
- **Predication**: Use `pl.when` to handle the conditional prefetch for the next iteration.
63+
64+
## 5. Implementation Details
65+
66+
### Pipeline Logic (Template)
67+
```python
68+
# Example for Forward Kernel Loop
69+
def body(i, _):
70+
curr_buff = i % 2
71+
next_buff = (i + 1) % 2
72+
73+
# 1. Wait for current block
74+
# (In first iteration, this waits for the copy started in prologue)
75+
# (In subsequent, it waits for copy started in previous body)
76+
# We need a semaphore per buffer to track "ready to read"
77+
78+
# Actually, simpler pattern:
79+
# Start 0.
80+
# Loop i:
81+
# Wait i%2.
82+
# Start (i+1)%2 if not last.
83+
# Compute i%2.
84+
```
85+
86+
### Block Sizes
87+
- `bT = 32`: Good balance for register pressure and T-dimension parallelism.
88+
- `bS = 128`: Reduced from 256 to ensure double buffering fits comfortably in VMEM with larger head dimensions.
89+
- Check: `2 * 128 * 256 * 4 bytes` = ~256KB. Very small. We can keep `bS=256` or even `512`.
90+
- Let's stick to `bS=256` (approx 512KB for double buffer).
91+
92+
## 6. Expected Performance Impact
93+
- **Latency**: Significant reduction due to hiding HBM latency.
94+
- **Throughput**: Higher utilization of MXU (Matrix Units) as they won't stall waiting for data.
95+
- **Speedup**: Estimated 1.5x - 2.0x improvement for memory-bound regimes.
96+
97+
## 7. Documentation Requirements
98+
- Annotate all scratch buffer shapes with `(2, ...)` to indicate double buffering.
99+
- Clearly comment the "Produce / Consume" pattern in the pipeline.
100+
- Document the memory hierarchy (HBM -> VMEM -> Registers).

0 commit comments

Comments
 (0)