22
33import torch
44
5- from lmdeploy .pytorch .model_inputs import ModelInputs
5+ from lmdeploy .pytorch .model_inputs import DPMeta , ModelInputs
66from lmdeploy .pytorch .spec_decode .spec_agent import SpecModelAgent , _expand_sampling_inputs
77from lmdeploy .pytorch .strategies .ar_spec .model_agent import ARSpecExtraInputs
88
@@ -312,7 +312,13 @@ def test_slice_sampling_inputs_prefill():
312312 assert result is sampling_inputs
313313
314314
315- def _model_inputs (input_ids , * , is_decoding = False , is_chunk = False , is_first_chunk = False , is_last_chunk = False ):
315+ def _model_inputs (input_ids ,
316+ * ,
317+ is_decoding = False ,
318+ is_chunk = False ,
319+ is_first_chunk = False ,
320+ is_last_chunk = False ,
321+ dp_meta = None ):
316322 input_ids = torch .tensor ([input_ids ])
317323 seq_length = torch .tensor ([input_ids .size (1 )])
318324 history_lengths = torch .tensor ([0 ])
@@ -330,6 +336,7 @@ def _model_inputs(input_ids, *, is_decoding=False, is_chunk=False, is_first_chun
330336 is_chunk = is_chunk ,
331337 is_first_chunk = is_first_chunk ,
332338 is_last_chunk = is_last_chunk ,
339+ dp_meta = dp_meta ,
333340 )
334341
335342
@@ -370,3 +377,16 @@ def test_prepare_inputs_from_main_clears_chunk_carry_on_non_chunk_prefill():
370377 agent ._prepare_inputs_from_main (prefill , _extra ([[1 , 10 ], [2 , 20 ], [3 , 30 ]]))
371378
372379 assert agent ._prev_chunk_last == {}
380+
381+
382+ def test_prepare_inputs_from_main_keeps_chunk_carry_for_dp_local_decode_global_prefill ():
383+ agent = SpecModelAgent .__new__ (SpecModelAgent )
384+ saved = torch .ones (1 , 1 , 2 )
385+ agent ._prev_chunk_last = {'hidden_states' : saved .clone ()}
386+ agent .proposer = _DummyProposer ()
387+
388+ dp_meta = DPMeta (dp_batches = [1 , 1 ], is_decoding = True , dp_is_decoding = False )
389+ inputs = _model_inputs ([90 , 91 , 92 ], is_decoding = False , dp_meta = dp_meta )
390+ agent ._prepare_inputs_from_main (inputs , _extra ([[9 , 90 ], [8 , 80 ], [7 , 70 ]]))
391+
392+ assert torch .equal (agent ._prev_chunk_last ['hidden_states' ], saved )
0 commit comments