Skip to content

Commit a5cb8a7

Browse files
committed
update dp ep mpt
1 parent c4ffb58 commit a5cb8a7

2 files changed

Lines changed: 24 additions & 3 deletions

File tree

lmdeploy/pytorch/spec_decode/spec_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,8 @@ def _prepare_inputs_from_main(self, model_inputs: ModelInputs, extra_inputs: Ext
204204
history_lengths = model_inputs.history_lengths.clone()
205205

206206
if not model_inputs.is_chunk:
207-
if not model_inputs.is_dummy and not model_inputs.is_decoding:
207+
local_is_decoding = model_inputs.is_decoding
208+
if not model_inputs.is_dummy and not local_is_decoding:
208209
# Non-chunk prefill starts an independent stream. Dummy DP
209210
# placeholders and interleaved decode must not clear a pending
210211
# long-chunk carry.

tests/pytorch/spec_decode/test_spec_agent.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from lmdeploy.pytorch.model_inputs import ModelInputs
5+
from lmdeploy.pytorch.model_inputs import DPMeta, ModelInputs
66
from lmdeploy.pytorch.spec_decode.spec_agent import SpecModelAgent, _expand_sampling_inputs
77
from lmdeploy.pytorch.strategies.ar_spec.model_agent import ARSpecExtraInputs
88

@@ -316,7 +316,13 @@ def test_slice_sampling_inputs_prefill():
316316
assert result is sampling_inputs
317317

318318

319-
def _model_inputs(input_ids, *, is_decoding=False, is_chunk=False, is_first_chunk=False, is_last_chunk=False):
319+
def _model_inputs(input_ids,
320+
*,
321+
is_decoding=False,
322+
is_chunk=False,
323+
is_first_chunk=False,
324+
is_last_chunk=False,
325+
dp_meta=None):
320326
input_ids = torch.tensor([input_ids])
321327
seq_length = torch.tensor([input_ids.size(1)])
322328
history_lengths = torch.tensor([0])
@@ -334,6 +340,7 @@ def _model_inputs(input_ids, *, is_decoding=False, is_chunk=False, is_first_chun
334340
is_chunk=is_chunk,
335341
is_first_chunk=is_first_chunk,
336342
is_last_chunk=is_last_chunk,
343+
dp_meta=dp_meta,
337344
)
338345

339346

@@ -374,3 +381,16 @@ def test_prepare_inputs_from_main_clears_chunk_carry_on_non_chunk_prefill():
374381
agent._prepare_inputs_from_main(prefill, _extra([[1, 10], [2, 20], [3, 30]]))
375382

376383
assert agent._prev_chunk_last == {}
384+
385+
386+
def test_prepare_inputs_from_main_keeps_chunk_carry_for_dp_local_decode_global_prefill():
387+
agent = SpecModelAgent.__new__(SpecModelAgent)
388+
saved = torch.ones(1, 1, 2)
389+
agent._prev_chunk_last = {'hidden_states': saved.clone()}
390+
agent.proposer = _DummyProposer()
391+
392+
dp_meta = DPMeta(dp_batches=[1, 1], dp_is_decoding=False)
393+
inputs = _model_inputs([90, 91, 92], is_decoding=True, dp_meta=dp_meta)
394+
agent._prepare_inputs_from_main(inputs, _extra([[9, 90], [8, 80], [7, 70]]))
395+
396+
assert torch.equal(agent._prev_chunk_last['hidden_states'], saved)

0 commit comments

Comments
 (0)