@@ -131,6 +131,11 @@ def get_args():
131131 default = 1.0 ,
132132 help = "sampling temperature" ,
133133 )
134+ parser .add_argument (
135+ "--warmup" ,
136+ action = "store_true" ,
137+ help = "Perform a warmup run before benchmarking/inference."
138+ )
134139
135140 return parser .parse_args ()
136141
@@ -239,39 +244,40 @@ def test(
239244 # ---------------------------------------------------------------------------- #
240245 # Warmup
241246 # ---------------------------------------------------------------------------- #
242- warmup_steps = 1
243-
244- # Choose a length that approximates the real workload.
245- # It should be long enough to trigger the correct kernel paths,
246- # but not so long that warmup becomes unnecessarily expensive.
247- avg_prompt_len = min (64 , max (len (ids ) for ids in input_ids_list ))
248-
249- # Use truncated versions of real prompts for warmup
250- warmup_ids = [
251- ids [:avg_prompt_len ] if len (ids ) >= avg_prompt_len else ids
252- for ids in input_ids_list
253- ]
247+ if args .warmup :
248+ warmup_steps = 1
249+
250+ # Choose a length that approximates the real workload.
251+ # It should be long enough to trigger the correct kernel paths,
252+ # but not so long that warmup becomes unnecessarily expensive.
253+ avg_prompt_len = min (64 , max (len (ids ) for ids in input_ids_list ))
254+
255+ # Use truncated versions of real prompts for warmup
256+ warmup_ids = [
257+ ids [:avg_prompt_len ] if len (ids ) >= avg_prompt_len else ids
258+ for ids in input_ids_list
259+ ]
254260
255- input_ids_infini = infinicore .from_list (warmup_ids )
261+ input_ids_infini = infinicore .from_list (warmup_ids )
256262
257- print ("=================== warmup start ===================" )
263+ print ("=================== warmup start ===================" )
258264
259- for _ in range (warmup_steps ):
260- _ = model .generate (
261- input_ids_infini ,
262- GenerationConfig (
263- max_new_tokens = 2 , # warmup decode kernel
264- temperature = 1 ,
265- top_k = 1 ,
266- top_p = 0.8 ,
267- ),
268- _measure_and_log_time = False ,
269- )
265+ for _ in range (warmup_steps ):
266+ _ = model .generate (
267+ input_ids_infini ,
268+ GenerationConfig (
269+ max_new_tokens = 2 , # warmup decode kernel
270+ temperature = temperature ,
271+ top_k = top_k ,
272+ top_p = top_p ,
273+ ),
274+ _measure_and_log_time = False ,
275+ )
270276
271- print ("=================== warmup done ====================" )
277+ print ("=================== warmup done ====================" )
272278
273- # Reset KV cache
274- model .reset_cache (cache_config )
279+ # Reset KV cache
280+ model .reset_cache (cache_config )
275281
276282 # ---------------------------------------------------------------------------- #
277283 # Generate
0 commit comments