Skip to content

Commit df2a1b4

Browse files
waynehacking8claude
andcommitted
[Bugfix] Fix double-counted max_q_seqlen in decode delta kv_seqlens
create_model_inputs_delta / create_model_inputs_delta_valid_only build kv_seqlens as [seq.num_all_ids + max_q_seqlen]. num_all_ids can be one decode step stale here -- EngineLoop prefetches the next inputs before _finish_forward_output() advances the sequence -- so the +max_q_seqlen recovers this forward's kv length. But the reductions then added max_q_seqlen a SECOND time and used batch_size = len(self.running_seqs), which counts scheduler-dropped invalid seqs: sum_kv_seqlen = sum(kv_seqlens) + batch_size * max_q_seqlen max_kv_seqlen = max(kv_seqlens) + max_q_seqlen so max_kv_seqlen / sum_kv_seqlen were over-counted (max by max_q_seqlen, scaling with spec/MTP num_decode_tokens), over-allocating the attention grid + kv-cache resources. Reduce over kv_seqlens directly; the +max_q_seqlen is already applied once in the comprehension. Fixes #4024 Co-authored-by: Claude <noreply@anthropic.com>
1 parent 18600ad commit df2a1b4

2 files changed

Lines changed: 35 additions & 4 deletions

File tree

lmdeploy/pytorch/engine/inputs_maker.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -603,9 +603,15 @@ def create_model_inputs_delta(self):
603603
else:
604604
num_ignored_history = torch.zeros(len(valid_seqs), dtype=torch.long)
605605

606+
# num_all_ids can be one decode step stale here: EngineLoop prefetches
607+
# the next inputs before _finish_forward_output() advances the sequence,
608+
# so +max_q_seqlen recovers this forward's kv length. The bug was adding
609+
# max_q_seqlen AGAIN in the reductions, plus using batch_size (which
610+
# counts scheduler-dropped invalid seqs) instead of reducing over the
611+
# valid seqs only (#4024).
606612
kv_seqlens = [seq.num_all_ids + max_q_seqlen for seq in valid_seqs]
607-
sum_kv_seqlen = sum(kv_seqlens) + batch_size * max_q_seqlen
608-
max_kv_seqlen = max(kv_seqlens) + max_q_seqlen
613+
sum_kv_seqlen = sum(kv_seqlens)
614+
max_kv_seqlen = max(kv_seqlens)
609615

610616
output = ModelInputsDelta(
611617
indices=None,
@@ -650,13 +656,15 @@ def create_model_inputs_delta_valid_only(self):
650656

651657
num_decode_tokens = self.engine_strategy.get_num_decode_tokens()
652658
max_q_seqlen = num_decode_tokens
659+
# Keep +max_q_seqlen (num_all_ids may be one decode step stale), but do
660+
# not add it a second time in the reductions or use batch_size (#4024).
653661
kv_seqlens = [seq.num_all_ids + max_q_seqlen for seq in valid_seqs]
654662
if len(kv_seqlens) == 0:
655663
sum_kv_seqlen = 0
656664
max_kv_seqlen = 0
657665
else:
658-
sum_kv_seqlen = sum(kv_seqlens) + batch_size * max_q_seqlen
659-
max_kv_seqlen = max(kv_seqlens) + max_q_seqlen
666+
sum_kv_seqlen = sum(kv_seqlens)
667+
max_kv_seqlen = max(kv_seqlens)
660668

661669
output = ModelInputsDelta(
662670
indices=None,

tests/pytorch/engine/test_inputs_maker.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,26 @@ def test_state_prefix_cache_save_offsets_are_compact():
333333

334334
assert src_offsets == (5, 6)
335335
assert dst_offsets == (21, 22)
336+
337+
338+
def test_create_model_inputs_delta_valid_only_kv_seqlen_no_double_count():
339+
# Regression for #4024: kv_seqlens = num_all_ids + max_q_seqlen is correct
340+
# (num_all_ids can be one decode step stale due to EngineLoop prefetch), but
341+
# the old code added max_q_seqlen a SECOND time in the reductions and used
342+
# batch_size (which counts scheduler-dropped invalid seqs), over-inflating
343+
# max_kv_seqlen / sum_kv_seqlen.
344+
maker = InputsMakerAsync.__new__(InputsMakerAsync)
345+
maker.engine_strategy = SimpleNamespace(get_num_decode_tokens=lambda: 4)
346+
maker.running_seqs = [
347+
SimpleNamespace(status=MessageStatus.RUNNING, num_all_ids=100),
348+
SimpleNamespace(status=MessageStatus.RUNNING, num_all_ids=250),
349+
SimpleNamespace(status=MessageStatus.STOPPED, num_all_ids=70), # dropped
350+
]
351+
352+
output, valid_seqs, invalid_seqs = maker.create_model_inputs_delta_valid_only()
353+
354+
assert [seq.num_all_ids for seq in valid_seqs] == [100, 250]
355+
assert len(invalid_seqs) == 1
356+
# kv_seqlens = [104, 254]; reduce over the valid seqs, add max_q_seqlen once
357+
assert output.max_kv_seqlen == 254 # old (buggy): 254 + 4 = 258
358+
assert output.sum_kv_seqlen == 358 # old (buggy): 358 + 3 * 4 = 370

0 commit comments

Comments
 (0)