Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions lmdeploy/pytorch/engine/inputs_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,15 @@ def create_model_inputs_delta(self):
else:
num_ignored_history = torch.zeros(len(valid_seqs), dtype=torch.long)

# num_all_ids can be one decode step stale here: EngineLoop prefetches
# the next inputs before _finish_forward_output() advances the sequence,
# so +max_q_seqlen recovers this forward's kv length. The bug was adding
# max_q_seqlen AGAIN in the reductions, plus using batch_size (which
# counts scheduler-dropped invalid seqs) instead of reducing over the
# valid seqs only (#4024).
kv_seqlens = [seq.num_all_ids + max_q_seqlen for seq in valid_seqs]
sum_kv_seqlen = sum(kv_seqlens) + batch_size * max_q_seqlen
max_kv_seqlen = max(kv_seqlens) + max_q_seqlen
sum_kv_seqlen = sum(kv_seqlens)
max_kv_seqlen = max(kv_seqlens)

output = ModelInputsDelta(
indices=None,
Expand Down Expand Up @@ -650,13 +656,15 @@ def create_model_inputs_delta_valid_only(self):

num_decode_tokens = self.engine_strategy.get_num_decode_tokens()
max_q_seqlen = num_decode_tokens
# Keep +max_q_seqlen (num_all_ids may be one decode step stale), but do
# not add it a second time in the reductions or use batch_size (#4024).
kv_seqlens = [seq.num_all_ids + max_q_seqlen for seq in valid_seqs]
if len(kv_seqlens) == 0:
sum_kv_seqlen = 0
max_kv_seqlen = 0
else:
sum_kv_seqlen = sum(kv_seqlens) + batch_size * max_q_seqlen
max_kv_seqlen = max(kv_seqlens) + max_q_seqlen
sum_kv_seqlen = sum(kv_seqlens)
max_kv_seqlen = max(kv_seqlens)

output = ModelInputsDelta(
indices=None,
Expand Down
32 changes: 32 additions & 0 deletions tests/pytorch/engine/test_inputs_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass
from types import SimpleNamespace

import pytest

from lmdeploy.pytorch.disagg.config import EngineRole
from lmdeploy.pytorch.engine.engine_loop import EngineLoop
from lmdeploy.pytorch.engine.inputs_maker import (
Expand Down Expand Up @@ -333,3 +335,33 @@ def test_state_prefix_cache_save_offsets_are_compact():

assert src_offsets == (5, 6)
assert dst_offsets == (21, 22)


@pytest.mark.parametrize('max_q_seqlen', [1, 4]) # standard decode, then spec/MTP
def test_create_model_inputs_delta_valid_only_matches_one_decode_advance(max_q_seqlen):
# Regression for #4024. The delta is built from the (stale) scheduler seqs
# at the current state, then applied after the model-agent's StepInputs has
# advanced one decode step. So delta.max/sum_kv_seqlen must equal the base
# kv (num_all_ids of the valid seqs) advanced by EXACTLY one decode step --
# the invariant the engine uses in ModelInputs.step (model_inputs.py) and
# get_model_inputs_next_decoding (strategies/ar/model_inputs.py):
# max_kv_seqlen += max_q_seqlen
# sum_kv_seqlen += num_valid_seqs * max_q_seqlen
# Parametrizing max_q_seqlen proves the offset is one max_q_seqlen, not the
# old double (num_all_ids + 2 * max_q_seqlen) nor zero (num_all_ids alone).
num_all_ids = [100, 250] # valid seqs' kv at the (stale) build state
maker = InputsMakerAsync.__new__(InputsMakerAsync)
maker.engine_strategy = SimpleNamespace(get_num_decode_tokens=lambda: max_q_seqlen)
maker.running_seqs = [
SimpleNamespace(status=MessageStatus.RUNNING, num_all_ids=num_all_ids[0]),
SimpleNamespace(status=MessageStatus.RUNNING, num_all_ids=num_all_ids[1]),
SimpleNamespace(status=MessageStatus.STOPPED, num_all_ids=70), # dropped
]

output, valid_seqs, invalid_seqs = maker.create_model_inputs_delta_valid_only()

assert [seq.num_all_ids for seq in valid_seqs] == num_all_ids
assert len(invalid_seqs) == 1
# base kv at the (stale) build state + one canonical decode advance
assert output.max_kv_seqlen == max(num_all_ids) + max_q_seqlen
assert output.sum_kv_seqlen == sum(num_all_ids) + len(valid_seqs) * max_q_seqlen
Loading