@@ -214,6 +214,9 @@ def __init__(
214214 self .adapter_manager = adapter_manager
215215 self .config = config
216216 self .spec_decoding = config .spec_decoding
217+ self .cache_config = scheduler .cache_config
218+ self .kernel_blocks_per_kv = self .cache_config .block_size // self .cache_config .kernel_block_size
219+ self .kernel_block_arange = torch .arange (self .kernel_blocks_per_kv , dtype = self .torch_int_dtype )
217220
218221 # strategies
219222 self .engine_strategy = engine_strategy
@@ -322,6 +325,29 @@ def _set_adapter_ids(self, model_inputs: ModelInputs, messages: 'SeqList'):
322325 local_adapter_ids = model_inputs .seq_length .new_tensor (local_adapter_ids )
323326 model_inputs .local_adapter_ids = local_adapter_ids
324327
328+ def _map_to_kernel_block_offsets (self , block_offsets : torch .Tensor ):
329+ """Converts manager block_offsets to kernel block_offsets.
330+
331+ Example:
332+
333+ # block_manager block size: 32 tokens,
334+ # Kernel block size: 16 tokens
335+ # kernel_blocks_per_kv = 2
336+ >>> block_manager block offsets = [0, 1, 3]
337+ >>> Result kernel block offsets = [0, 1, 2, 3, 6, 7]
338+
339+ # Each block_manager block id maps to 2 kernel block id:
340+ # block_manager block id 0 -> kernel block id [0, 1]
341+ # block_manager block id 1 -> kernel block id [2, 3]
342+ # block_manager block id 3 -> kernel block id [6, 7]
343+ """
344+ if self .kernel_blocks_per_kv == 1 :
345+ return block_offsets
346+ batch_size = block_offsets .shape [0 ]
347+ block_offsets = (block_offsets [:, :, None ] * self .kernel_blocks_per_kv +
348+ self .kernel_block_arange [None , None , :]).reshape (batch_size , - 1 )
349+ return block_offsets
350+
325351 @torch .inference_mode ()
326352 @record_function ('create_model_inputs' )
327353 def create_model_inputs (self , messages : 'SeqList' , is_prefill : bool ):
@@ -355,6 +381,7 @@ def create_model_inputs(self, messages: 'SeqList', is_prefill: bool):
355381 # block offsets
356382 block_offsets = self .scheduler .get_block_tables (messages )
357383 block_offsets = _tensorlize_block_offsets (block_offsets , dtype = self .torch_int_dtype )
384+ block_offsets = self ._map_to_kernel_block_offsets (block_offsets )
358385
359386 # num_ignored_history
360387 num_ignored_history = torch .tensor ([msg .num_ignored_history for msg in messages ])
@@ -410,6 +437,7 @@ def create_model_inputs_long_context(self,
410437 # block offsets
411438 block_offsets = self .scheduler .get_block_tables ([seq ])
412439 block_offsets = torch .as_tensor (block_offsets [0 ], dtype = self .torch_int_dtype )[None ]
440+ block_offsets = self ._map_to_kernel_block_offsets (block_offsets )
413441
414442 # num_ignored_history
415443 num_ignored_history = torch .tensor ([seq .num_ignored_history ])
@@ -482,6 +510,7 @@ def create_model_inputs_delta(self):
482510 # block offsets
483511 block_offsets = self .scheduler .get_block_tables (valid_seqs )
484512 block_offsets = _tensorlize_block_offsets (block_offsets , dtype = self .torch_int_dtype )
513+ block_offsets = self ._map_to_kernel_block_offsets (block_offsets )
485514
486515 # sliding window
487516 if self .scheduler .cache_config .window_size > 0 :
0 commit comments