@@ -325,49 +325,51 @@ def run(
325325 # Warmup
326326 # ---------------------------------------------------------------------------- #
327327 if cfg .warmup :
328- warmup_steps = 1
328+ print ( "=================== warmup start ===================" )
329329
330- # warmup cache capacity
331- warmup_cache_len = 128
332- warmup_batch = len (test .input_ids_list )
330+ import time
331+ # time.sleep(5) # ← 制造 2 秒空白,msys Timeline 里清晰可见
333332
334- test .model .reset_cache (
335- StaticKVCacheConfig (
336- max_batch_size = warmup_batch ,
337- max_cache_len = warmup_cache_len ,
338- )
339- )
340-
341- avg_prompt_len = min (64 , max (len (ids ) for ids in test .input_ids_list ))
333+ # 使用实际的 cache_config 触发 mate JIT
334+ if cache_config is not None :
335+ test .model .reset_cache (cache_config )
342336
343- warmup_ids = [
344- ids [:avg_prompt_len ] if len (ids ) >= avg_prompt_len else ids
345- for ids in test .input_ids_list
346- ]
337+ # 收集所有不同的 (batch_size, input_len) shape
338+ warmup_shapes = []
339+ seen = set ()
340+ for _ , case in cases_dict .items ():
341+ key = (case ["batch_size" ], case ["input_len" ])
342+ if key in seen :
343+ continue
344+ seen .add (key )
345+ warmup_shapes .append ((case ["batch_size" ], case ["input_len" ]))
346+
347+ for w_batch , w_input_len in warmup_shapes :
348+ tqdm .write (
349+ f"\033 [93m[warmup] batch={ w_batch } , input_len={ w_input_len } , "
350+ f"will prefill + 3 decode steps\033 [0m"
351+ )
347352
348- input_ids_infini = infinicore .from_list (warmup_ids )
353+ warmup_ids = repeat_prompt (test .input_ids_list [0 ], target_length = w_input_len )
354+ warmup_ids_list = [warmup_ids ] * w_batch
355+ warmup_input = infinicore .from_list (warmup_ids_list )
349356
350- print ("=================== warmup start ===================" )
351-
352- for _ in range (warmup_steps ):
353357 _ = test .model .generate (
354- input_ids_infini ,
358+ warmup_input ,
355359 GenerationConfig (
356- max_new_tokens = 5 , # decode kernel warmup
357- temperature = cfg . temperature ,
360+ max_new_tokens = 3 ,
361+ eos_token_id = [] ,
358362 top_k = cfg .top_k ,
359363 top_p = cfg .top_p ,
364+ temperature = cfg .temperature ,
360365 stop_on_eos = False ,
361366 ),
362367 _measure_and_log_time = False ,
363368 )
364-
365- print ("=================== warmup done ====================" )
366-
367- # reset cache back to benchmark config
368369 if cache_config is not None :
369370 test .model .reset_cache (cache_config )
370-
371+ # time.sleep(5) # ← 制造 2 秒空白,msys Timeline 里清晰可见
372+ print ("=================== warmup done ====================" )
371373 # ---------------------------------------------------------------------------- #
372374 # Warmup done
373375 # ---------------------------------------------------------------------------- #
0 commit comments