@@ -104,7 +104,7 @@ def _pick_kernel_block_size(cache_block_size: int) -> int:
104104
105105
106106def _build_block_tables (
107- raw_block_tables : list [ list [ int ]] ,
107+ ctx : PagedAttentionContext ,
108108 cache_block_size : int ,
109109) -> tuple [mx .array , int ]:
110110 """Build kernel-compatible block tables, translating if necessary.
@@ -117,14 +117,23 @@ def _build_block_tables(
117117 Returns:
118118 (block_tables, kernel_block_size)
119119 """
120+ cached = ctx .block_tables_cache .get (cache_block_size )
121+ if cached is not None :
122+ return cached
123+
124+ raw_block_tables = ctx .block_tables
120125 if not raw_block_tables :
121- return mx .zeros ((0 , 0 ), dtype = mx .int32 ), cache_block_size
126+ result = (mx .zeros ((0 , 0 ), dtype = mx .int32 ), cache_block_size )
127+ ctx .block_tables_cache [cache_block_size ] = result
128+ return result
122129
123130 if cache_block_size in _KERNEL_BLOCK_SIZES :
124131 # Fast path — no translation needed.
125132 max_blocks = max (len (bt ) for bt in raw_block_tables )
126133 padded = [bt + [0 ] * (max_blocks - len (bt )) for bt in raw_block_tables ]
127- return mx .array (padded , dtype = mx .int32 ), cache_block_size
134+ result = (mx .array (padded , dtype = mx .int32 ), cache_block_size )
135+ ctx .block_tables_cache [cache_block_size ] = result
136+ return result
128137
129138 # Hybrid path — translate large block_size to a kernel-compatible one.
130139 # Vectorized: each vLLM block b → [b*ratio, b*ratio+1, …, b*ratio+ratio-1].
@@ -139,7 +148,9 @@ def _build_block_tables(
139148 expanded = (bt_arr [:, :, None ] * ratio + offsets [None , None , :]).reshape (
140149 bt_arr .shape [0 ], - 1
141150 )
142- return expanded , kernel_bs
151+ result = (expanded , kernel_bs )
152+ ctx .block_tables_cache [cache_block_size ] = result
153+ return result
143154
144155
145156# === Q/K/V preparation (YOCO, K-eq-V, v_norm variants) ===
@@ -424,20 +435,24 @@ def sdpa_forward(
424435 k_3d = mx .contiguous (keys [0 ].transpose (1 , 0 , 2 ).astype (kv_cache .dtype ))
425436 v_3d = mx .contiguous (values [0 ].transpose (1 , 0 , 2 ).astype (kv_cache .dtype ))
426437
427- slot_mapping = mx .array (ctx .slot_mapping , dtype = mx .int64 )
428- seq_lens = mx .array (ctx .context_lens , dtype = mx .int32 )
429- cu_seqlens_q = mx .array (ctx .cu_seqlens , dtype = mx .int32 )
430- max_seq_len = max (ctx .context_lens )
438+ slot_mapping = ctx .slot_mapping_mx
439+ if slot_mapping is None :
440+ slot_mapping = mx .array (ctx .slot_mapping , dtype = mx .int64 )
441+ seq_lens = ctx .context_lens_mx
442+ if seq_lens is None :
443+ seq_lens = mx .array (ctx .context_lens , dtype = mx .int32 )
444+ cu_seqlens_q = ctx .cu_seqlens_mx
445+ if cu_seqlens_q is None :
446+ cu_seqlens_q = mx .array (ctx .cu_seqlens , dtype = mx .int32 )
447+ max_seq_len = ctx .max_context_len or max (ctx .context_lens )
431448
432449 # --- Block tables (with hybrid block-size translation) ---
433450 # vLLM may inflate block_size (e.g. 544) to align attention pages with
434451 # mamba pages in hybrid models. The Metal kernel only supports small
435452 # block sizes (8, 16, 32). _build_block_tables handles the translation:
436453 # it expands each vLLM block into multiple kernel blocks and returns the
437454 # kernel-compatible block_size. The cache is reshaped to match (zero-copy).
438- block_tables , kernel_block_size = _build_block_tables (
439- ctx .block_tables , kv_cache .block_size
440- )
455+ block_tables , kernel_block_size = _build_block_tables (ctx , kv_cache .block_size )
441456
442457 if shared_kv is not None :
443458 # YOCO shared layer: the reference layer already scattered the
0 commit comments