File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -260,6 +260,43 @@ def test(
260260
261261 model .reset_cache (cache_config )
262262
263+ # ---------------------------------------------------------------------------- #
264+ # Warmup
265+ # ---------------------------------------------------------------------------- #
266+ warmup_steps = 1
267+
268+ # Choose a length that approximates the real workload.
269+ # It should be long enough to trigger the correct kernel paths,
270+ # but not so long that warmup becomes unnecessarily expensive.
271+ avg_prompt_len = min (64 , max (len (ids ) for ids in input_ids_list ))
272+
273+ # Use truncated versions of real prompts for warmup
274+ warmup_ids = [
275+ ids [:avg_prompt_len ] if len (ids ) >= avg_prompt_len else ids
276+ for ids in input_ids_list
277+ ]
278+
279+ input_ids_infini = infinicore .from_list (warmup_ids )
280+
281+ print ("=================== warmup start ===================" )
282+
283+ for _ in range (warmup_steps ):
284+ _ = model .generate (
285+ input_ids_infini ,
286+ GenerationConfig (
287+ max_new_tokens = 2 , # warmup decode kernel
288+ temperature = 1 ,
289+ top_k = 1 ,
290+ top_p = 0.8 ,
291+ ),
292+ _measure_and_log_time = False ,
293+ )
294+
295+ print ("=================== warmup done ====================" )
296+
297+ # Reset KV cache
298+ model .reset_cache (cache_config )
299+
263300 # ---------------------------------------------------------------------------- #
264301 # Generate
265302 # ---------------------------------------------------------------------------- #
You can’t perform that action at this time.
0 commit comments