@@ -325,48 +325,83 @@ def run(
325325 # Warmup
326326 # ---------------------------------------------------------------------------- #
327327 if cfg .warmup :
328- warmup_steps = 1
329-
330- # warmup cache capacity
331- warmup_cache_len = 128
332- warmup_batch = len (test .input_ids_list )
333-
334- test .model .reset_cache (
335- StaticKVCacheConfig (
336- max_batch_size = warmup_batch ,
337- max_cache_len = warmup_cache_len ,
328+ print ("=================== warmup start ===================" )
329+ # -------------------------------------------------------- #
330+ # reset cache before warmup
331+ # support both paged cache and static cache
332+ # -------------------------------------------------------- #
333+ if cache_config is not None :
334+ # Paged KVCache
335+ test .model .reset_cache (cache_config )
336+ else :
337+ # Static KVCache
338+ max_batch_size = max (c ["batch_size" ] for _ , c in cases_dict .items ())
339+ max_cache_len = max (
340+ c ["input_len" ] + c ["output_len" ]
341+ for _ , c in cases_dict .items ()
338342 )
339- )
340343
341- avg_prompt_len = min ( 64 , max ( len ( ids ) for ids in test .input_ids_list ))
342-
343- warmup_ids = [
344- ids [: avg_prompt_len ] if len ( ids ) >= avg_prompt_len else ids
345- for ids in test . input_ids_list
346- ]
344+ test .model . reset_cache (
345+ StaticKVCacheConfig (
346+ max_batch_size = max_batch_size ,
347+ max_cache_len = max_cache_len ,
348+ )
349+ )
347350
348- input_ids_infini = infinicore .from_list (warmup_ids )
351+ warmup_shapes = []
352+ seen = set ()
353+ for _ , case in cases_dict .items ():
354+ key = (case ["batch_size" ], case ["input_len" ])
355+ if key in seen :
356+ continue
357+ seen .add (key )
358+ warmup_shapes .append ((case ["batch_size" ], case ["input_len" ]))
359+
360+ for w_batch , w_input_len in warmup_shapes :
361+ tqdm .write (
362+ f"\033 [93m[warmup] batch={ w_batch } , input_len={ w_input_len } , "
363+ f"will prefill + 3 decode steps\033 [0m"
364+ )
349365
350- print ("=================== warmup start ===================" )
366+ warmup_ids = repeat_prompt (test .input_ids_list [0 ], target_length = w_input_len )
367+ warmup_ids_list = [warmup_ids ] * w_batch
368+ warmup_input = infinicore .from_list (warmup_ids_list )
351369
352- for _ in range (warmup_steps ):
353370 _ = test .model .generate (
354- input_ids_infini ,
371+ warmup_input ,
355372 GenerationConfig (
356- max_new_tokens = 5 , # decode kernel warmup
357- temperature = cfg . temperature ,
373+ max_new_tokens = 3 ,
374+ eos_token_id = [] ,
358375 top_k = cfg .top_k ,
359376 top_p = cfg .top_p ,
377+ temperature = cfg .temperature ,
360378 stop_on_eos = False ,
361379 ),
362380 _measure_and_log_time = False ,
363381 )
364382
365383 print ("=================== warmup done ====================" )
366-
367- # reset cache back to benchmark config
384+ # -------------------------------------------------------- #
385+ # reset cache back to benchmark config
386+ # support both paged cache and static cache
387+ # -------------------------------------------------------- #
368388 if cache_config is not None :
389+ # Paged KVCache
369390 test .model .reset_cache (cache_config )
391+ else :
392+ # Static KVCache
393+ max_batch_size = max (c ["batch_size" ] for _ , c in cases_dict .items ())
394+ max_cache_len = max (
395+ c ["input_len" ] + c ["output_len" ]
396+ for _ , c in cases_dict .items ()
397+ )
398+
399+ test .model .reset_cache (
400+ StaticKVCacheConfig (
401+ max_batch_size = max_batch_size ,
402+ max_cache_len = max_cache_len ,
403+ )
404+ )
370405
371406 # ---------------------------------------------------------------------------- #
372407 # Warmup done
0 commit comments