Skip to content

Commit 0f3284c

Browse files
committed
update dp ep mpt
1 parent 50bc81f commit 0f3284c

2 files changed

Lines changed: 26 additions & 3 deletions

File tree

lmdeploy/pytorch/spec_decode/spec_agent.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,10 @@ def _prepare_inputs_from_main(self, model_inputs: ModelInputs, extra_inputs: Ext
205205
history_lengths = model_inputs.history_lengths.clone()
206206

207207
if not model_inputs.is_chunk:
208-
if not model_inputs.is_dummy and not model_inputs.is_decoding:
208+
local_is_decoding = model_inputs.is_decoding
209+
if model_inputs.dp_meta is not None:
210+
local_is_decoding = model_inputs.dp_meta.is_decoding
211+
if not model_inputs.is_dummy and not local_is_decoding:
209212
# Non-chunk prefill starts an independent stream. Dummy DP
210213
# placeholders and interleaved decode must not clear a pending
211214
# long-chunk carry.

tests/pytorch/spec_decode/test_spec_agent.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from lmdeploy.pytorch.model_inputs import ModelInputs
5+
from lmdeploy.pytorch.model_inputs import DPMeta, ModelInputs
66
from lmdeploy.pytorch.spec_decode.spec_agent import SpecModelAgent, _expand_sampling_inputs
77
from lmdeploy.pytorch.strategies.ar_spec.model_agent import ARSpecExtraInputs
88

@@ -312,7 +312,13 @@ def test_slice_sampling_inputs_prefill():
312312
assert result is sampling_inputs
313313

314314

315-
def _model_inputs(input_ids, *, is_decoding=False, is_chunk=False, is_first_chunk=False, is_last_chunk=False):
315+
def _model_inputs(input_ids,
316+
*,
317+
is_decoding=False,
318+
is_chunk=False,
319+
is_first_chunk=False,
320+
is_last_chunk=False,
321+
dp_meta=None):
316322
input_ids = torch.tensor([input_ids])
317323
seq_length = torch.tensor([input_ids.size(1)])
318324
history_lengths = torch.tensor([0])
@@ -330,6 +336,7 @@ def _model_inputs(input_ids, *, is_decoding=False, is_chunk=False, is_first_chun
330336
is_chunk=is_chunk,
331337
is_first_chunk=is_first_chunk,
332338
is_last_chunk=is_last_chunk,
339+
dp_meta=dp_meta,
333340
)
334341

335342

@@ -370,3 +377,16 @@ def test_prepare_inputs_from_main_clears_chunk_carry_on_non_chunk_prefill():
370377
agent._prepare_inputs_from_main(prefill, _extra([[1, 10], [2, 20], [3, 30]]))
371378

372379
assert agent._prev_chunk_last == {}
380+
381+
382+
def test_prepare_inputs_from_main_keeps_chunk_carry_for_dp_local_decode_global_prefill():
383+
agent = SpecModelAgent.__new__(SpecModelAgent)
384+
saved = torch.ones(1, 1, 2)
385+
agent._prev_chunk_last = {'hidden_states': saved.clone()}
386+
agent.proposer = _DummyProposer()
387+
388+
dp_meta = DPMeta(dp_batches=[1, 1], is_decoding=True, dp_is_decoding=False)
389+
inputs = _model_inputs([90, 91, 92], is_decoding=False, dp_meta=dp_meta)
390+
agent._prepare_inputs_from_main(inputs, _extra([[9, 90], [8, 80], [7, 70]]))
391+
392+
assert torch.equal(agent._prev_chunk_last['hidden_states'], saved)

0 commit comments

Comments
 (0)