@@ -462,12 +462,12 @@ def update_cache_config(self, cache_config):
462462 main_process_metrics .free_gpu_block_num .set (self .num_gpu_blocks )
463463 main_process_metrics .available_gpu_resource .set (1.0 )
464464
465- def can_allocate_gpu_blocks (self , num_blocks : int ):
465+ def can_allocate_gpu_blocks (self , num_blocks : int , try_free_gpu_blocks : bool = True ):
466466 """
467467 Check if num_blocks gpu blocks can be allocated.
468468 """
469469 if len (self .gpu_free_block_list ) < num_blocks :
470- if self .cache_config .enable_prefix_caching :
470+ if self .cache_config .enable_prefix_caching and try_free_gpu_blocks :
471471 self .free_block_ids (num_blocks )
472472 if len (self .gpu_free_block_list ) < num_blocks :
473473 return False
@@ -814,7 +814,7 @@ def request_match_blocks(self, task: Request, block_size, *args):
814814 # 2. prepare cpu cache: allocate gpu cache for matched cpu blocks, wait for data transfer to complete
815815 gpu_recv_block_ids = []
816816 match_cpu_blocks_num = len (match_cpu_block_ids )
817- if self .can_allocate_gpu_blocks (num_blocks = match_cpu_blocks_num ):
817+ if self .can_allocate_gpu_blocks (num_blocks = match_cpu_blocks_num , try_free_gpu_blocks = False ):
818818 if match_cpu_blocks_num > 0 :
819819 logger .debug (
820820 f"request_match_blocks: req_id { req_id } , allocate { match_cpu_blocks_num } block to receive cpu cache"
@@ -845,7 +845,7 @@ def request_match_blocks(self, task: Request, block_size, *args):
845845 match_storage_block_ids = []
846846
847847 if self .kvcache_storage_backend and no_match_token_num >= block_size :
848- if not self .can_allocate_gpu_blocks (num_blocks = no_match_block_num ):
848+ if not self .can_allocate_gpu_blocks (num_blocks = no_match_block_num , try_free_gpu_blocks = False ):
849849 raise Exception (
850850 "request_match_blocks: Not enough GPU memory to allocate cache for matched Storage Cache"
851851 )
@@ -881,7 +881,7 @@ def request_match_blocks(self, task: Request, block_size, *args):
881881 read_storage_task = ReadStorageTask (
882882 task_id = req_id ,
883883 keys = no_match_block_keys ,
884- token_ids = input_token_ids ,
884+ token_ids = input_token_ids if self . kvcache_storage_backend == "attention_store" else None ,
885885 gpu_block_ids = gpu_recv_storage_block_ids ,
886886 start_read_block_idx = match_token_num // block_size ,
887887 )
@@ -1141,8 +1141,11 @@ def write_cache_to_storage(self, request: Request):
11411141 token_ids = request .prompt_token_ids
11421142 if isinstance (token_ids , np .ndarray ):
11431143 token_ids = token_ids .tolist ()
1144+
11441145 if self .config .cache_config .enable_output_caching :
1145- token_ids += request .output_token_ids
1146+ input_token_ids = token_ids + request .output_token_ids
1147+ else :
1148+ input_token_ids = token_ids
11461149
11471150 req_id = request .request_id
11481151 keys = []
@@ -1159,7 +1162,7 @@ def write_cache_to_storage(self, request: Request):
11591162 write_storage_task = WriteStorageTask (
11601163 task_id = req_id ,
11611164 keys = keys ,
1162- token_ids = token_ids ,
1165+ token_ids = input_token_ids if self . kvcache_storage_backend == "attention_store" else None ,
11631166 gpu_block_ids = gpu_block_ids ,
11641167 )
11651168 logger .debug (f"issue write storage task: { write_storage_task } " )
@@ -1193,16 +1196,18 @@ def write_cache_to_storage_decode(self, request: Request):
11931196 token_ids = list (token_ids )
11941197
11951198 if self .config .cache_config .enable_output_caching :
1196- token_ids = token_ids + request .output_token_ids
1199+ input_token_ids = token_ids + request .output_token_ids
1200+ else :
1201+ input_token_ids = token_ids
11971202
11981203 # 2. Calculate cache keys using chained hash (consistent with P instance)
11991204 keys = []
12001205 prefix_block_key = [] # Initial is empty list
12011206 block_size = self .config .cache_config .block_size
12021207 mm_idx = 0 # Multimodal index for tracking position in mm_inputs
12031208
1204- for i in range (0 , len (token_ids ), block_size ):
1205- block_token_ids = token_ids [i : i + block_size ]
1209+ for i in range (0 , len (input_token_ids ), block_size ):
1210+ block_token_ids = input_token_ids [i : i + block_size ]
12061211 if len (block_token_ids ) < block_size :
12071212 break # Do not cache incomplete block
12081213
@@ -1236,7 +1241,7 @@ def write_cache_to_storage_decode(self, request: Request):
12361241 write_storage_task = WriteStorageTask (
12371242 task_id = req_id ,
12381243 keys = keys ,
1239- token_ids = token_ids ,
1244+ token_ids = input_token_ids if self . kvcache_storage_backend == "attention_store" else None ,
12401245 gpu_block_ids = gpu_block_ids ,
12411246 )
12421247
@@ -2166,7 +2171,7 @@ def recv_data_transfer_result(self):
21662171 event_type = data [0 ]
21672172
21682173 if event_type .value == CacheStatus .STORAGE2GPU .value :
2169- logger .info (f"recv_data_transfer_result: { data } " )
2174+ logger .debug (f"recv_data_transfer_result: { data } " )
21702175 task_id , hash_keys , block_ids = data [1 :]
21712176 if task_id not in self .storage_prefetch_block_ids :
21722177 self .storage_prefetch_block_ids [task_id ] = []
@@ -2177,7 +2182,7 @@ def recv_data_transfer_result(self):
21772182 if task_id in self .task_prefetch_event :
21782183 self .task_prefetch_event [task_id ].set ()
21792184 elif event_type .value == CacheStatus .GPU2STORAGE .value :
2180- logger .info (f"recv_data_transfer_result: { data } " )
2185+ logger .debug (f"recv_data_transfer_result: { data } " )
21812186 task_id , hash_keys , block_ids = data [1 :]
21822187 if task_id in self .task_write_back_event :
21832188 self .task_write_back_event [task_id ].set ()
0 commit comments