Skip to content

Commit af88166

Browse files
authored
Merge pull request #430 from InfiniTensor/issue/429
issue/429 - feat: adjust warmup
2 parents 1b372bd + c41cf39 commit af88166

1 file changed

Lines changed: 24 additions & 13 deletions

File tree

examples/bench.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -336,32 +336,43 @@ def run(
336336
warmup_steps = 1
337337

338338
# warmup cache capacity
339-
warmup_cache_len = 128
340-
warmup_batch = len(test.input_ids_list)
341-
342-
test.model.reset_cache(
343-
StaticKVCacheConfig(
339+
warmup_case = next(iter(cases_dict.values()))
340+
warmup_batch = warmup_case["batch_size"]
341+
warmup_input_len = warmup_case["input_len"]
342+
warmup_decode_len = 5
343+
344+
if enable_paged_attn:
345+
warmup_num_blocks = (
346+
(warmup_input_len + warmup_decode_len + paged_kv_block_size - 1)
347+
// paged_kv_block_size
348+
) * warmup_batch
349+
warmup_cache_config = PagedKVCacheConfig(
350+
warmup_num_blocks, paged_kv_block_size
351+
)
352+
else:
353+
warmup_cache_config = StaticKVCacheConfig(
344354
max_batch_size=warmup_batch,
345-
max_cache_len=warmup_cache_len,
355+
max_cache_len=warmup_input_len + warmup_decode_len,
346356
)
347-
)
348357

349-
avg_prompt_len = min(64, max(len(ids) for ids in test.input_ids_list))
358+
test.model.reset_cache(warmup_cache_config)
350359

351-
warmup_ids = [
352-
ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids
353-
for ids in test.input_ids_list
354-
]
360+
warmup_prompt_ids = repeat_prompt(test.input_ids_list[0], warmup_input_len)
361+
warmup_ids = [warmup_prompt_ids] * warmup_batch
355362

356363
input_ids_infini = infinicore.from_list(warmup_ids, dtype=infinicore.int64)
357364

365+
print(
366+
f"\033[93m[warmup] batch={warmup_batch}, input_len={warmup_input_len}, "
367+
f"will prefill + {warmup_decode_len} decode steps\033[0m"
368+
)
358369
print("=================== warmup start ===================")
359370

360371
for _ in range(warmup_steps):
361372
_ = test.model.generate(
362373
input_ids_infini,
363374
GenerationConfig(
364-
max_new_tokens=5, # decode kernel warmup
375+
max_new_tokens=warmup_decode_len, # decode kernel warmup
365376
temperature=cfg.temperature,
366377
top_k=cfg.top_k,
367378
top_p=cfg.top_p,

0 commit comments

Comments
 (0)