@@ -328,48 +328,83 @@ def run(
328328 # Warmup
329329 # ---------------------------------------------------------------------------- #
330330 if cfg .warmup :
331- warmup_steps = 1
332-
333- # warmup cache capacity
334- warmup_cache_len = 128
335- warmup_batch = len (test .input_ids_list )
336-
337- test .model .reset_cache (
338- StaticKVCacheConfig (
339- max_batch_size = warmup_batch ,
340- max_cache_len = warmup_cache_len ,
331+ print ("=================== warmup start ===================" )
332+ # -------------------------------------------------------- #
333+ # reset cache before warmup
334+ # support both paged cache and static cache
335+ # -------------------------------------------------------- #
336+ if cache_config is not None :
337+ # Paged KVCache
338+ test .model .reset_cache (cache_config )
339+ else :
340+ # Static KVCache
341+ max_batch_size = max (c ["batch_size" ] for _ , c in cases_dict .items ())
342+ max_cache_len = max (
343+ c ["input_len" ] + c ["output_len" ]
344+ for _ , c in cases_dict .items ()
341345 )
342- )
343346
344- avg_prompt_len = min ( 64 , max ( len ( ids ) for ids in test .input_ids_list ))
345-
346- warmup_ids = [
347- ids [: avg_prompt_len ] if len ( ids ) >= avg_prompt_len else ids
348- for ids in test . input_ids_list
349- ]
347+ test .model . reset_cache (
348+ StaticKVCacheConfig (
349+ max_batch_size = max_batch_size ,
350+ max_cache_len = max_cache_len ,
351+ )
352+ )
350353
351- input_ids_infini = infinicore .from_list (warmup_ids )
354+ warmup_shapes = []
355+ seen = set ()
356+ for _ , case in cases_dict .items ():
357+ key = (case ["batch_size" ], case ["input_len" ])
358+ if key in seen :
359+ continue
360+ seen .add (key )
361+ warmup_shapes .append ((case ["batch_size" ], case ["input_len" ]))
362+
363+ for w_batch , w_input_len in warmup_shapes :
364+ tqdm .write (
365+ f"\033 [93m[warmup] batch={ w_batch } , input_len={ w_input_len } , "
366+ f"will prefill + 3 decode steps\033 [0m"
367+ )
352368
353- print ("=================== warmup start ===================" )
369+ warmup_ids = repeat_prompt (test .input_ids_list [0 ], target_length = w_input_len )
370+ warmup_ids_list = [warmup_ids ] * w_batch
371+ warmup_input = infinicore .from_list (warmup_ids_list )
354372
355- for _ in range (warmup_steps ):
356373 _ = test .model .generate (
357- input_ids_infini ,
374+ warmup_input ,
358375 GenerationConfig (
359- max_new_tokens = 5 , # decode kernel warmup
360- temperature = cfg . temperature ,
376+ max_new_tokens = 3 ,
377+ eos_token_id = [] ,
361378 top_k = cfg .top_k ,
362379 top_p = cfg .top_p ,
380+ temperature = cfg .temperature ,
363381 stop_on_eos = False ,
364382 ),
365383 _measure_and_log_time = False ,
366384 )
367385
368386 print ("=================== warmup done ====================" )
369-
370- # reset cache back to benchmark config
387+ # -------------------------------------------------------- #
388+ # reset cache back to benchmark config
389+ # support both paged cache and static cache
390+ # -------------------------------------------------------- #
371391 if cache_config is not None :
392+ # Paged KVCache
372393 test .model .reset_cache (cache_config )
394+ else :
395+ # Static KVCache
396+ max_batch_size = max (c ["batch_size" ] for _ , c in cases_dict .items ())
397+ max_cache_len = max (
398+ c ["input_len" ] + c ["output_len" ]
399+ for _ , c in cases_dict .items ()
400+ )
401+
402+ test .model .reset_cache (
403+ StaticKVCacheConfig (
404+ max_batch_size = max_batch_size ,
405+ max_cache_len = max_cache_len ,
406+ )
407+ )
373408
374409 # ---------------------------------------------------------------------------- #
375410 # Warmup done
0 commit comments