Skip to content

Commit 41e3f67

Browse files
committed
feat: Add DeepSeek-V4 Compressed Attention
1 parent f47c3f4 commit 41e3f67

6 files changed

Lines changed: 1636 additions & 2 deletions

File tree

src/maxtext/common/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ class AttentionType(enum.Enum):
120120
LOCAL_SLIDING = "local_sliding"
121121
CHUNK = "chunk"
122122
MLA = "mla"
123+
COMPRESSED = "compressed"
123124
FULL = "full"
124125

125126

src/maxtext/configs/base.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,12 @@ qk_nope_head_dim: 128
411411
qk_rope_head_dim: 64
412412
v_head_dim: 128
413413

414+
# Compressed Attention parameters
415+
o_lora_rank: 0
416+
o_groups: 0
417+
compress_ratios: []
418+
compressed_rope_max_timescale: 160_000 # Timescale for Compressed Sparse/Heavy Attention
419+
414420
# QK-Clip (Muon Clip) Configuration
415421
use_qk_clip: false # Enable QK-Clip (supported in MLA with DotProduct or Tokamax Splash)
416422
qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper)

src/maxtext/configs/types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,17 @@ class MlaAttention(BaseModel):
636636
v_head_dim: NonNegativeInt = Field(128, description="Dimension of V heads in MLA.")
637637

638638

639+
class CompressedAttention(BaseModel):
640+
"""Configuration for Compressed Attention."""
641+
642+
o_lora_rank: NonNegativeInt = Field(0, description="Output LoRA rank for Compressed Attention.")
643+
o_groups: NonNegativeInt = Field(0, description="Output groups for Compressed Attention.")
644+
compress_ratios: list[int] = Field(default_factory=list, description="Per-layer compression ratios (0, 4, 128, etc).")
645+
compressed_rope_max_timescale: int = Field(
646+
160000, description="If positive, used for Compressed Sparse/Heavy Attention."
647+
)
648+
649+
639650
class AttentionIndexer(BaseModel):
640651
"""Configuration for DeepSeek Sparse Attention (DSA): DeepSeek3.2-style MLA with indexer."""
641652

@@ -2254,6 +2265,7 @@ class MaxTextConfig(
22542265
# Attention Mechanisms
22552266
Attention,
22562267
MlaAttention,
2268+
CompressedAttention,
22572269
MoBa,
22582270
AttentionIndexer,
22592271
Llama4Attention,

0 commit comments

Comments
 (0)