@@ -141,6 +141,11 @@ def get_args():
141141 default = 1.0 ,
142142 help = "sampling temperature" ,
143143 )
144+ parser .add_argument (
145+ "--warmup" ,
146+ action = "store_true" ,
147+ help = "Perform a warmup run before benchmarking/inference."
148+ )
144149
145150 parser .add_argument (
146151 "--attn" ,
@@ -263,39 +268,40 @@ def test(
263268 # ---------------------------------------------------------------------------- #
264269 # Warmup
265270 # ---------------------------------------------------------------------------- #
266- warmup_steps = 1
267-
268- # Choose a length that approximates the real workload.
269- # It should be long enough to trigger the correct kernel paths,
270- # but not so long that warmup becomes unnecessarily expensive.
271- avg_prompt_len = min (64 , max (len (ids ) for ids in input_ids_list ))
272-
273- # Use truncated versions of real prompts for warmup
274- warmup_ids = [
275- ids [:avg_prompt_len ] if len (ids ) >= avg_prompt_len else ids
276- for ids in input_ids_list
277- ]
271+ if args .warmup :
272+ warmup_steps = 1
273+
274+ # Choose a length that approximates the real workload.
275+ # It should be long enough to trigger the correct kernel paths,
276+ # but not so long that warmup becomes unnecessarily expensive.
277+ avg_prompt_len = min (64 , max (len (ids ) for ids in input_ids_list ))
278+
279+ # Use truncated versions of real prompts for warmup
280+ warmup_ids = [
281+ ids [:avg_prompt_len ] if len (ids ) >= avg_prompt_len else ids
282+ for ids in input_ids_list
283+ ]
278284
279- input_ids_infini = infinicore .from_list (warmup_ids )
285+ input_ids_infini = infinicore .from_list (warmup_ids )
280286
281- print ("=================== warmup start ===================" )
287+ print ("=================== warmup start ===================" )
282288
283- for _ in range (warmup_steps ):
284- _ = model .generate (
285- input_ids_infini ,
286- GenerationConfig (
287- max_new_tokens = 2 , # warmup decode kernel
288- temperature = 1 ,
289- top_k = 1 ,
290- top_p = 0.8 ,
291- ),
292- _measure_and_log_time = False ,
293- )
289+ for _ in range (warmup_steps ):
290+ _ = model .generate (
291+ input_ids_infini ,
292+ GenerationConfig (
293+ max_new_tokens = 2 , # warmup decode kernel
294+ temperature = temperature ,
295+ top_k = top_k ,
296+ top_p = top_p ,
297+ ),
298+ _measure_and_log_time = False ,
299+ )
294300
295- print ("=================== warmup done ====================" )
301+ print ("=================== warmup done ====================" )
296302
297- # Reset KV cache
298- model .reset_cache (cache_config )
303+ # Reset KV cache
304+ model .reset_cache (cache_config )
299305
300306 # ---------------------------------------------------------------------------- #
301307 # Generate
0 commit comments