From 3c61d262cef6eef6b362635b483abcd9b75dcb67 Mon Sep 17 00:00:00 2001 From: waynehacking8 Date: Wed, 17 Jun 2026 06:14:32 +0800 Subject: [PATCH] [Bugfix] Fix double-counted max_q_seqlen in decode delta kv_seqlens create_model_inputs_delta / create_model_inputs_delta_valid_only build kv_seqlens as [seq.num_all_ids + max_q_seqlen]. num_all_ids can be one decode step stale here -- EngineLoop prefetches the next inputs before _finish_forward_output() advances the sequence -- so the +max_q_seqlen recovers this forward's kv length. But the reductions then added max_q_seqlen a SECOND time and used batch_size = len(self.running_seqs), which counts scheduler-dropped invalid seqs: sum_kv_seqlen = sum(kv_seqlens) + batch_size * max_q_seqlen max_kv_seqlen = max(kv_seqlens) + max_q_seqlen so max_kv_seqlen / sum_kv_seqlen were over-counted (max by max_q_seqlen, scaling with spec/MTP num_decode_tokens), over-allocating the attention grid + kv-cache resources. Reduce over kv_seqlens directly; the +max_q_seqlen is already applied once in the comprehension. Fixes #4024 Co-authored-by: Claude --- lmdeploy/pytorch/engine/inputs_maker.py | 16 +++++++++--- tests/pytorch/engine/test_inputs_maker.py | 32 +++++++++++++++++++++++ 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/lmdeploy/pytorch/engine/inputs_maker.py b/lmdeploy/pytorch/engine/inputs_maker.py index 3f6babc35f..488680f5b0 100644 --- a/lmdeploy/pytorch/engine/inputs_maker.py +++ b/lmdeploy/pytorch/engine/inputs_maker.py @@ -603,9 +603,15 @@ def create_model_inputs_delta(self): else: num_ignored_history = torch.zeros(len(valid_seqs), dtype=torch.long) + # num_all_ids can be one decode step stale here: EngineLoop prefetches + # the next inputs before _finish_forward_output() advances the sequence, + # so +max_q_seqlen recovers this forward's kv length. The bug was adding + # max_q_seqlen AGAIN in the reductions, plus using batch_size (which + # counts scheduler-dropped invalid seqs) instead of reducing over the + # valid seqs only (#4024). kv_seqlens = [seq.num_all_ids + max_q_seqlen for seq in valid_seqs] - sum_kv_seqlen = sum(kv_seqlens) + batch_size * max_q_seqlen - max_kv_seqlen = max(kv_seqlens) + max_q_seqlen + sum_kv_seqlen = sum(kv_seqlens) + max_kv_seqlen = max(kv_seqlens) output = ModelInputsDelta( indices=None, @@ -650,13 +656,15 @@ def create_model_inputs_delta_valid_only(self): num_decode_tokens = self.engine_strategy.get_num_decode_tokens() max_q_seqlen = num_decode_tokens + # Keep +max_q_seqlen (num_all_ids may be one decode step stale), but do + # not add it a second time in the reductions or use batch_size (#4024). kv_seqlens = [seq.num_all_ids + max_q_seqlen for seq in valid_seqs] if len(kv_seqlens) == 0: sum_kv_seqlen = 0 max_kv_seqlen = 0 else: - sum_kv_seqlen = sum(kv_seqlens) + batch_size * max_q_seqlen - max_kv_seqlen = max(kv_seqlens) + max_q_seqlen + sum_kv_seqlen = sum(kv_seqlens) + max_kv_seqlen = max(kv_seqlens) output = ModelInputsDelta( indices=None, diff --git a/tests/pytorch/engine/test_inputs_maker.py b/tests/pytorch/engine/test_inputs_maker.py index 5923722877..1f9771d6fb 100644 --- a/tests/pytorch/engine/test_inputs_maker.py +++ b/tests/pytorch/engine/test_inputs_maker.py @@ -3,6 +3,8 @@ from dataclasses import dataclass from types import SimpleNamespace +import pytest + from lmdeploy.pytorch.disagg.config import EngineRole from lmdeploy.pytorch.engine.engine_loop import EngineLoop from lmdeploy.pytorch.engine.inputs_maker import ( @@ -333,3 +335,33 @@ def test_state_prefix_cache_save_offsets_are_compact(): assert src_offsets == (5, 6) assert dst_offsets == (21, 22) + + +@pytest.mark.parametrize('max_q_seqlen', [1, 4]) # standard decode, then spec/MTP +def test_create_model_inputs_delta_valid_only_matches_one_decode_advance(max_q_seqlen): + # Regression for #4024. The delta is built from the (stale) scheduler seqs + # at the current state, then applied after the model-agent's StepInputs has + # advanced one decode step. So delta.max/sum_kv_seqlen must equal the base + # kv (num_all_ids of the valid seqs) advanced by EXACTLY one decode step -- + # the invariant the engine uses in ModelInputs.step (model_inputs.py) and + # get_model_inputs_next_decoding (strategies/ar/model_inputs.py): + # max_kv_seqlen += max_q_seqlen + # sum_kv_seqlen += num_valid_seqs * max_q_seqlen + # Parametrizing max_q_seqlen proves the offset is one max_q_seqlen, not the + # old double (num_all_ids + 2 * max_q_seqlen) nor zero (num_all_ids alone). + num_all_ids = [100, 250] # valid seqs' kv at the (stale) build state + maker = InputsMakerAsync.__new__(InputsMakerAsync) + maker.engine_strategy = SimpleNamespace(get_num_decode_tokens=lambda: max_q_seqlen) + maker.running_seqs = [ + SimpleNamespace(status=MessageStatus.RUNNING, num_all_ids=num_all_ids[0]), + SimpleNamespace(status=MessageStatus.RUNNING, num_all_ids=num_all_ids[1]), + SimpleNamespace(status=MessageStatus.STOPPED, num_all_ids=70), # dropped + ] + + output, valid_seqs, invalid_seqs = maker.create_model_inputs_delta_valid_only() + + assert [seq.num_all_ids for seq in valid_seqs] == num_all_ids + assert len(invalid_seqs) == 1 + # base kv at the (stale) build state + one canonical decode advance + assert output.max_kv_seqlen == max(num_all_ids) + max_q_seqlen + assert output.sum_kv_seqlen == sum(num_all_ids) + len(valid_seqs) * max_q_seqlen