File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -1005,6 +1005,8 @@ pagedattn_max_pages_per_group: -1 # defaults to number of pages needed to reach
10051005# Alignment of head_dim to the nearest multiple of this value, set to 0 to disable alignment. On
10061006# TPUs, the head_dim is padded to the nearest multiple of 128.
10071007pagedattn_head_dim_alignment : 128
1008+ pagedattn_num_kv_pages_per_block : -1 # -1 = auto-tune via tpu_inference; set explicitly to override for specific TPU types
1009+ pagedattn_num_queries_per_block : -1 # -1 = auto-tune via tpu_inference; set explicitly to override for specific TPU types
10081010
10091011
10101012# Chunked Prefill Parameters
Original file line number Diff line number Diff line change @@ -612,6 +612,8 @@ class PagedAttention(BaseModel):
612612 # Alignment of head_dim to the nearest multiple of this value, set to 0 to disable alignment. On
613613 # TPUs, the head_dim is padded to the nearest multiple of 128.
614614 pagedattn_head_dim_alignment : int = Field (128 , description = "Alignment of head_dim to the nearest multiple." )
615+ pagedattn_num_kv_pages_per_block : int = Field (- 1 , description = "Number of KV pages per compute block; -1 = auto-tune via tpu_inference." )
616+ pagedattn_num_queries_per_block : int = Field (- 1 , description = "Number of queries per compute block; -1 = auto-tune via tpu_inference." )
615617
616618
617619class MoEGeneral (BaseModel ):
Original file line number Diff line number Diff line change @@ -974,6 +974,8 @@ def forward_serve_vllm(
974974
975975 md = rpa_metadata
976976
977+ num_kv_pages_per_block = self .config .pagedattn_num_kv_pages_per_block
978+ num_queries_per_block = self .config .pagedattn_num_queries_per_block
977979 output , kv_cache = rpa_ops (
978980 self .mesh ,
979981 query ,
@@ -990,6 +992,8 @@ def forward_serve_vllm(
990992 q_scale ,
991993 k_scale ,
992994 v_scale ,
995+ num_kv_pages_per_block = num_kv_pages_per_block if num_kv_pages_per_block > 0 else None ,
996+ num_queries_per_block = num_queries_per_block if num_queries_per_block > 0 else None ,
993997 )
994998 return kv_cache , output
995999
You can’t perform that action at this time.
0 commit comments