Skip to content

Commit f2c390f

Browse files
committed
issue/224 - feat: add --warmup flag and disable warmup by default
1 parent f71b115 commit f2c390f

File tree

1 file changed

+34
-28
lines changed

1 file changed

+34
-28
lines changed

examples/jiuge.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ def get_args():
141141
default=1.0,
142142
help="sampling temperature",
143143
)
144+
parser.add_argument(
145+
"--warmup",
146+
action="store_true",
147+
help="Perform a warmup run before benchmarking/inference."
148+
)
144149

145150
parser.add_argument(
146151
"--attn",
@@ -263,39 +268,40 @@ def test(
263268
# ---------------------------------------------------------------------------- #
264269
# Warmup
265270
# ---------------------------------------------------------------------------- #
266-
warmup_steps = 1
267-
268-
# Choose a length that approximates the real workload.
269-
# It should be long enough to trigger the correct kernel paths,
270-
# but not so long that warmup becomes unnecessarily expensive.
271-
avg_prompt_len = min(64, max(len(ids) for ids in input_ids_list))
272-
273-
# Use truncated versions of real prompts for warmup
274-
warmup_ids = [
275-
ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids
276-
for ids in input_ids_list
277-
]
271+
if args.warmup:
272+
warmup_steps = 1
273+
274+
# Choose a length that approximates the real workload.
275+
# It should be long enough to trigger the correct kernel paths,
276+
# but not so long that warmup becomes unnecessarily expensive.
277+
avg_prompt_len = min(64, max(len(ids) for ids in input_ids_list))
278+
279+
# Use truncated versions of real prompts for warmup
280+
warmup_ids = [
281+
ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids
282+
for ids in input_ids_list
283+
]
278284

279-
input_ids_infini = infinicore.from_list(warmup_ids)
285+
input_ids_infini = infinicore.from_list(warmup_ids)
280286

281-
print("=================== warmup start ===================")
287+
print("=================== warmup start ===================")
282288

283-
for _ in range(warmup_steps):
284-
_ = model.generate(
285-
input_ids_infini,
286-
GenerationConfig(
287-
max_new_tokens=2, # warmup decode kernel
288-
temperature=1,
289-
top_k=1,
290-
top_p=0.8,
291-
),
292-
_measure_and_log_time=False,
293-
)
289+
for _ in range(warmup_steps):
290+
_ = model.generate(
291+
input_ids_infini,
292+
GenerationConfig(
293+
max_new_tokens=2, # warmup decode kernel
294+
temperature=temperature,
295+
top_k=top_k,
296+
top_p=top_p,
297+
),
298+
_measure_and_log_time=False,
299+
)
294300

295-
print("=================== warmup done ====================")
301+
print("=================== warmup done ====================")
296302

297-
# Reset KV cache
298-
model.reset_cache(cache_config)
303+
# Reset KV cache
304+
model.reset_cache(cache_config)
299305

300306
# ---------------------------------------------------------------------------- #
301307
# Generate

0 commit comments

Comments
 (0)