@@ -336,32 +336,43 @@ def run(
336336 warmup_steps = 1
337337
338338 # warmup cache capacity
339- warmup_cache_len = 128
340- warmup_batch = len (test .input_ids_list )
341-
342- test .model .reset_cache (
343- StaticKVCacheConfig (
339+ warmup_case = next (iter (cases_dict .values ()))
340+ warmup_batch = warmup_case ["batch_size" ]
341+ warmup_input_len = warmup_case ["input_len" ]
342+ warmup_decode_len = 5
343+
344+ if enable_paged_attn :
345+ warmup_num_blocks = (
346+ (warmup_input_len + warmup_decode_len + paged_kv_block_size - 1 )
347+ // paged_kv_block_size
348+ ) * warmup_batch
349+ warmup_cache_config = PagedKVCacheConfig (
350+ warmup_num_blocks , paged_kv_block_size
351+ )
352+ else :
353+ warmup_cache_config = StaticKVCacheConfig (
344354 max_batch_size = warmup_batch ,
345- max_cache_len = warmup_cache_len ,
355+ max_cache_len = warmup_input_len + warmup_decode_len ,
346356 )
347- )
348357
349- avg_prompt_len = min ( 64 , max ( len ( ids ) for ids in test .input_ids_list ) )
358+ test .model . reset_cache ( warmup_cache_config )
350359
351- warmup_ids = [
352- ids [:avg_prompt_len ] if len (ids ) >= avg_prompt_len else ids
353- for ids in test .input_ids_list
354- ]
360+ warmup_prompt_ids = repeat_prompt (test .input_ids_list [0 ], warmup_input_len )
361+ warmup_ids = [warmup_prompt_ids ] * warmup_batch
355362
356363 input_ids_infini = infinicore .from_list (warmup_ids , dtype = infinicore .int64 )
357364
365+ print (
366+ f"\033 [93m[warmup] batch={ warmup_batch } , input_len={ warmup_input_len } , "
367+ f"will prefill + { warmup_decode_len } decode steps\033 [0m"
368+ )
358369 print ("=================== warmup start ===================" )
359370
360371 for _ in range (warmup_steps ):
361372 _ = test .model .generate (
362373 input_ids_infini ,
363374 GenerationConfig (
364- max_new_tokens = 5 , # decode kernel warmup
375+ max_new_tokens = warmup_decode_len , # decode kernel warmup
365376 temperature = cfg .temperature ,
366377 top_k = cfg .top_k ,
367378 top_p = cfg .top_p ,
0 commit comments