Skip to content

Commit aebda8a

Browse files
committed
bypass block sizes for mla
1 parent 1d2eefa commit aebda8a

3 files changed

Lines changed: 8 additions & 0 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,8 @@ pagedattn_max_pages_per_group: -1 # defaults to number of pages needed to reach
985985
# Alignment of head_dim to the nearest multiple of this value, set to 0 to disable alignment. On
986986
# TPUs, the head_dim is padded to the nearest multiple of 128.
987987
pagedattn_head_dim_alignment: 128
988+
pagedattn_num_kv_pages_per_block: -1 # -1 = auto-tune via tpu_inference; set explicitly to override for specific TPU types
989+
pagedattn_num_queries_per_block: -1 # -1 = auto-tune via tpu_inference; set explicitly to override for specific TPU types
988990

989991

990992
# Chunked Prefill Parameters

src/maxtext/configs/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,8 @@ class PagedAttention(BaseModel):
602602
# Alignment of head_dim to the nearest multiple of this value, set to 0 to disable alignment. On
603603
# TPUs, the head_dim is padded to the nearest multiple of 128.
604604
pagedattn_head_dim_alignment: int = Field(128, description="Alignment of head_dim to the nearest multiple.")
605+
pagedattn_num_kv_pages_per_block: int = Field(-1, description="Number of KV pages per compute block; -1 = auto-tune via tpu_inference.")
606+
pagedattn_num_queries_per_block: int = Field(-1, description="Number of queries per compute block; -1 = auto-tune via tpu_inference.")
605607

606608

607609
class MoEGeneral(BaseModel):

src/maxtext/layers/attentions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,8 @@ def forward_serve_vllm(
974974

975975
md = rpa_metadata
976976

977+
num_kv_pages_per_block = self.config.pagedattn_num_kv_pages_per_block
978+
num_queries_per_block = self.config.pagedattn_num_queries_per_block
977979
output, kv_cache = rpa_ops(
978980
self.mesh,
979981
query,
@@ -990,6 +992,8 @@ def forward_serve_vllm(
990992
q_scale,
991993
k_scale,
992994
v_scale,
995+
num_kv_pages_per_block=num_kv_pages_per_block if num_kv_pages_per_block > 0 else None,
996+
num_queries_per_block=num_queries_per_block if num_queries_per_block > 0 else None,
993997
)
994998
return kv_cache, output
995999

0 commit comments

Comments
 (0)