@@ -101,6 +101,11 @@ def update_block_id_groups(self, new_block_id_groups: tuple[list[int], ...] | No
101101 for group_state , new_blocks in zip (self .group_states , new_block_id_groups ):
102102 group_state .block_ids .extend (new_blocks )
103103
104+ def advance_stored_idx (self , num_offloadable_tokens : int ) -> None :
105+ for group_config , group_state in zip (self .config .kv_group_configs , self .group_states ):
106+ num_blocks = num_offloadable_tokens // group_config .offloaded_block_size
107+ group_state .next_stored_block_idx = num_blocks
108+
104109
105110class OffloadingConnectorScheduler :
106111 """Implementation of Scheduler side methods"""
@@ -333,16 +338,14 @@ def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, num_
333338 if self ._blocks_being_loaded is not None :
334339 self ._blocks_being_loaded .update (req_blocks_being_loaded )
335340
336- def _get_reqs_to_store (self , scheduler_output : SchedulerOutput ):
337- # Below assertion will be removed once this function supports HMA
338- assert len (self .config .kv_group_configs ) == 1
339- group_config = self .config .kv_group_configs [0 ]
340-
341+ def _get_reqs_to_store (self , scheduler_output : SchedulerOutput ) -> dict [ReqId , TransferSpec ]:
342+ block_size_factor = self .config .block_size_factor
341343 reqs_to_store : dict [ReqId , TransferSpec ] = {}
342344 # iterate over both new and cached requests
343345 for req_id , new_block_id_groups , preempted in yield_req_data (scheduler_output ):
344346 req_status = self ._req_status [req_id ]
345347 req_status .update_offload_keys ()
348+ req = req_status .req
346349
347350 if preempted :
348351 for group_state in req_status .group_states :
@@ -351,64 +354,95 @@ def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
351354 if new_block_id_groups :
352355 req_status .update_block_id_groups (new_block_id_groups )
353356
354- # Below assertion will be removed once this function supports HMA
355- assert len (req_status .group_states ) == 1
356- group_state = req_status .group_states [0 ]
357-
358- block_ids = group_state .block_ids
359-
360- req = req_status .req
361- new_tokens = scheduler_output .num_scheduled_tokens [req_id ]
362- expected_tokens = req .num_computed_tokens + new_tokens
357+ num_scheduled_tokens = scheduler_output .num_scheduled_tokens [req_id ]
358+ num_tokens_after_batch = req .num_computed_tokens + num_scheduled_tokens
363359 # with async scheduling, some tokens may be missing
364- total_tokens = min (expected_tokens , req .num_tokens )
365- num_blocks = total_tokens // group_config .offloaded_block_size
366- start_block_idx = group_state .next_stored_block_idx
367- num_new_blocks = num_blocks - start_block_idx
368-
369- if num_new_blocks <= 0 :
360+ num_offloadable_tokens = min (num_tokens_after_batch , req .num_tokens )
361+
362+ # Filter out blocks skipped due to sliding window attention / SSM
363+ new_offload_keys : list [OffloadKey ] = []
364+ for group_config , group_state in zip (self .config .kv_group_configs , req_status .group_states ):
365+ num_blocks = num_offloadable_tokens // group_config .offloaded_block_size
366+ start_block_idx = group_state .next_stored_block_idx
367+ if num_blocks <= start_block_idx :
368+ continue
369+ offload_keys = group_state .offload_keys [start_block_idx :num_blocks ]
370+ # For each block to offload, take the last corresponding GPU block.
371+ # e.g. if block size factor is 3 and GPU block IDs are
372+ # 1 5 6 7 2 4 9 3 8 then we'll take blocks 6 4 8.
373+ # We will use these GPU blocks to determine if the block needs
374+ # offloading, or (if the GPU block ID is 0) this block should
375+ # be skipped due to sliding window attention / SSM.
376+ # We know that if a block is skipped, then all the previous blocks
377+ # are skipped as well. This is why we take the last of each block.
378+ offload_block_ids = group_state .block_ids [
379+ start_block_idx * block_size_factor + block_size_factor - 1 : num_blocks
380+ * block_size_factor : block_size_factor
381+ ]
382+ assert len (offload_keys ) == len (offload_block_ids )
383+
384+ for offload_key , block_id in zip (offload_keys , offload_block_ids ):
385+ if block_id != 0 :
386+ new_offload_keys .append (offload_key )
387+
388+ if not new_offload_keys :
389+ req_status .advance_stored_idx (num_offloadable_tokens )
370390 continue
371391
372- num_gpu_blocks = num_blocks * self .config .block_size_factor
373- assert len (req .block_hashes ) >= num_gpu_blocks
374-
375- new_offload_keys = group_state .offload_keys [start_block_idx :num_blocks ]
376392 store_output = self .manager .prepare_store (new_offload_keys , req_status .req_context )
377393 if store_output is None :
378- logger .warning ("Request %s: cannot store %s blocks" , req_id , num_new_blocks )
394+ logger .warning ("Request %s: cannot store blocks" , req_id )
379395 continue
380396
381- group_state .next_stored_block_idx = num_blocks
382-
383397 if not store_output .keys_to_store :
398+ req_status .advance_stored_idx (num_offloadable_tokens )
384399 continue
385- keys_to_store = set (store_output .keys_to_store )
386400
387- self .manager .touch (group_state .offload_keys [:num_blocks ])
401+ for group_state in req_status .group_states :
402+ self .manager .touch (group_state .offload_keys )
388403
389- dst_spec = store_output .store_spec
404+ keys_to_store = set (store_output .keys_to_store )
405+
406+ group_sizes : list [int ] = []
407+ block_indices : list [int ] = []
390408 src_block_ids : list [int ] = []
391- for idx , key in enumerate (new_offload_keys ):
392- if key not in keys_to_store :
393- continue
394- offloaded_block_idx = start_block_idx + idx
395- gpu_block_idx = offloaded_block_idx * self .config .block_size_factor
396- for i in range (self .config .block_size_factor ):
397- src_block_ids .append (block_ids [gpu_block_idx + i ])
398- src_spec = GPULoadStoreSpec (
399- src_block_ids ,
400- group_sizes = (len (src_block_ids ),),
401- block_indices = (0 ,),
402- )
409+ for group_config , group_state in zip (self .config .kv_group_configs , req_status .group_states ):
410+ num_blocks = num_offloadable_tokens // group_config .offloaded_block_size
411+ start_block_idx = group_state .next_stored_block_idx
412+ block_ids = group_state .block_ids
413+ num_group_blocks = 0
414+ start_gpu_block_idx : int | None = None
415+ for idx , offload_key in enumerate (group_state .offload_keys [start_block_idx :num_blocks ]):
416+ if offload_key not in keys_to_store :
417+ continue
418+
419+ offloaded_block_idx = start_block_idx + idx
420+ gpu_block_idx = offloaded_block_idx * block_size_factor
421+ num_group_blocks += block_size_factor
422+ for i in range (block_size_factor ):
423+ block_id = block_ids [gpu_block_idx + i ]
424+ if block_id == 0 :
425+ # skipped blocks cannot appear after non-skipped blocks
426+ assert start_gpu_block_idx is None
427+ continue
428+ elif start_gpu_block_idx is None :
429+ start_gpu_block_idx = gpu_block_idx + i
430+ src_block_ids .append (block_id )
431+ group_sizes .append (num_group_blocks )
432+ block_indices .append (start_gpu_block_idx or 0 )
433+ group_state .next_stored_block_idx = num_blocks
434+
435+ src_spec = GPULoadStoreSpec (src_block_ids , group_sizes = group_sizes , block_indices = block_indices )
436+ dst_spec = store_output .store_spec
403437
404438 reqs_to_store [req_id ] = (src_spec , dst_spec )
405439 self ._reqs_being_stored [req_id ] |= keys_to_store
406440
407441 logger .debug (
408- "Request %s offloading %s blocks starting from block #%d " ,
442+ "Request %s offloading %s blocks upto %d tokens " ,
409443 req_id ,
410444 len (keys_to_store ),
411- start_block_idx ,
445+ num_offloadable_tokens ,
412446 )
413447
414448 return reqs_to_store
0 commit comments