Skip to content

Commit efbee87

Browse files
khatwanimohitA9isha
authored andcommitted
bypass block sizes for mla
1 parent 16fc894 commit efbee87

3 files changed

Lines changed: 8 additions & 0 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,8 @@ pagedattn_max_pages_per_group: -1 # defaults to number of pages needed to reach
10051005
# Alignment of head_dim to the nearest multiple of this value, set to 0 to disable alignment. On
10061006
# TPUs, the head_dim is padded to the nearest multiple of 128.
10071007
pagedattn_head_dim_alignment: 128
1008+
pagedattn_num_kv_pages_per_block: -1 # -1 = auto-tune via tpu_inference; set explicitly to override for specific TPU types
1009+
pagedattn_num_queries_per_block: -1 # -1 = auto-tune via tpu_inference; set explicitly to override for specific TPU types
10081010

10091011

10101012
# Chunked Prefill Parameters

src/maxtext/configs/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,8 @@ class PagedAttention(BaseModel):
612612
# Alignment of head_dim to the nearest multiple of this value, set to 0 to disable alignment. On
613613
# TPUs, the head_dim is padded to the nearest multiple of 128.
614614
pagedattn_head_dim_alignment: int = Field(128, description="Alignment of head_dim to the nearest multiple.")
615+
pagedattn_num_kv_pages_per_block: int = Field(-1, description="Number of KV pages per compute block; -1 = auto-tune via tpu_inference.")
616+
pagedattn_num_queries_per_block: int = Field(-1, description="Number of queries per compute block; -1 = auto-tune via tpu_inference.")
615617

616618

617619
class MoEGeneral(BaseModel):

src/maxtext/layers/attentions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,8 @@ def forward_serve_vllm(
974974

975975
md = rpa_metadata
976976

977+
num_kv_pages_per_block = self.config.pagedattn_num_kv_pages_per_block
978+
num_queries_per_block = self.config.pagedattn_num_queries_per_block
977979
output, kv_cache = rpa_ops(
978980
self.mesh,
979981
query,
@@ -990,6 +992,8 @@ def forward_serve_vllm(
990992
q_scale,
991993
k_scale,
992994
v_scale,
995+
num_kv_pages_per_block=num_kv_pages_per_block if num_kv_pages_per_block > 0 else None,
996+
num_queries_per_block=num_queries_per_block if num_queries_per_block > 0 else None,
993997
)
994998
return kv_cache, output
995999

0 commit comments

Comments
 (0)