Skip to content

Commit a247d40

Browse files
committed
issue/224 - feat: add --warmup flag and disable warmup by default
1 parent d304ba0 commit a247d40

File tree

1 file changed

+34
-28
lines changed

1 file changed

+34
-28
lines changed

examples/jiuge.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@ def get_args():
131131
default=1.0,
132132
help="sampling temperature",
133133
)
134+
parser.add_argument(
135+
"--warmup",
136+
action="store_true",
137+
help="Perform a warmup run before benchmarking/inference."
138+
)
134139

135140
return parser.parse_args()
136141

@@ -239,39 +244,40 @@ def test(
239244
# ---------------------------------------------------------------------------- #
240245
# Warmup
241246
# ---------------------------------------------------------------------------- #
242-
warmup_steps = 1
243-
244-
# Choose a length that approximates the real workload.
245-
# It should be long enough to trigger the correct kernel paths,
246-
# but not so long that warmup becomes unnecessarily expensive.
247-
avg_prompt_len = min(64, max(len(ids) for ids in input_ids_list))
248-
249-
# Use truncated versions of real prompts for warmup
250-
warmup_ids = [
251-
ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids
252-
for ids in input_ids_list
253-
]
247+
if args.warmup:
248+
warmup_steps = 1
249+
250+
# Choose a length that approximates the real workload.
251+
# It should be long enough to trigger the correct kernel paths,
252+
# but not so long that warmup becomes unnecessarily expensive.
253+
avg_prompt_len = min(64, max(len(ids) for ids in input_ids_list))
254+
255+
# Use truncated versions of real prompts for warmup
256+
warmup_ids = [
257+
ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids
258+
for ids in input_ids_list
259+
]
254260

255-
input_ids_infini = infinicore.from_list(warmup_ids)
261+
input_ids_infini = infinicore.from_list(warmup_ids)
256262

257-
print("=================== warmup start ===================")
263+
print("=================== warmup start ===================")
258264

259-
for _ in range(warmup_steps):
260-
_ = model.generate(
261-
input_ids_infini,
262-
GenerationConfig(
263-
max_new_tokens=2, # warmup decode kernel
264-
temperature=1,
265-
top_k=1,
266-
top_p=0.8,
267-
),
268-
_measure_and_log_time=False,
269-
)
265+
for _ in range(warmup_steps):
266+
_ = model.generate(
267+
input_ids_infini,
268+
GenerationConfig(
269+
max_new_tokens=2, # warmup decode kernel
270+
temperature=temperature,
271+
top_k=top_k,
272+
top_p=top_p,
273+
),
274+
_measure_and_log_time=False,
275+
)
270276

271-
print("=================== warmup done ====================")
277+
print("=================== warmup done ====================")
272278

273-
# Reset KV cache
274-
model.reset_cache(cache_config)
279+
# Reset KV cache
280+
model.reset_cache(cache_config)
275281

276282
# ---------------------------------------------------------------------------- #
277283
# Generate

0 commit comments

Comments
 (0)