File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -985,6 +985,8 @@ pagedattn_max_pages_per_group: -1 # defaults to number of pages needed to reach
985985# Alignment of head_dim to the nearest multiple of this value, set to 0 to disable alignment. On
986986# TPUs, the head_dim is padded to the nearest multiple of 128.
987987pagedattn_head_dim_alignment : 128
988+ pagedattn_num_kv_pages_per_block : -1 # -1 = auto-tune via tpu_inference; set explicitly to override for specific TPU types
989+ pagedattn_num_queries_per_block : -1 # -1 = auto-tune via tpu_inference; set explicitly to override for specific TPU types
988990
989991
990992# Chunked Prefill Parameters
Original file line number Diff line number Diff line change @@ -602,6 +602,8 @@ class PagedAttention(BaseModel):
602602 # Alignment of head_dim to the nearest multiple of this value, set to 0 to disable alignment. On
603603 # TPUs, the head_dim is padded to the nearest multiple of 128.
604604 pagedattn_head_dim_alignment : int = Field (128 , description = "Alignment of head_dim to the nearest multiple." )
605+ pagedattn_num_kv_pages_per_block : int = Field (- 1 , description = "Number of KV pages per compute block; -1 = auto-tune via tpu_inference." )
606+ pagedattn_num_queries_per_block : int = Field (- 1 , description = "Number of queries per compute block; -1 = auto-tune via tpu_inference." )
605607
606608
607609class MoEGeneral (BaseModel ):
Original file line number Diff line number Diff line change @@ -974,6 +974,8 @@ def forward_serve_vllm(
974974
975975 md = rpa_metadata
976976
977+ num_kv_pages_per_block = self .config .pagedattn_num_kv_pages_per_block
978+ num_queries_per_block = self .config .pagedattn_num_queries_per_block
977979 output , kv_cache = rpa_ops (
978980 self .mesh ,
979981 query ,
@@ -990,6 +992,8 @@ def forward_serve_vllm(
990992 q_scale ,
991993 k_scale ,
992994 v_scale ,
995+ num_kv_pages_per_block = num_kv_pages_per_block if num_kv_pages_per_block > 0 else None ,
996+ num_queries_per_block = num_queries_per_block if num_queries_per_block > 0 else None ,
993997 )
994998 return kv_cache , output
995999
You can’t perform that action at this time.
0 commit comments