22
33import torch
44
5- from lmdeploy .pytorch .model_inputs import ModelInputs
5+ from lmdeploy .pytorch .model_inputs import DPMeta , ModelInputs
66from lmdeploy .pytorch .spec_decode .spec_agent import SpecModelAgent , _expand_sampling_inputs
77from lmdeploy .pytorch .strategies .ar_spec .model_agent import ARSpecExtraInputs
88
@@ -316,7 +316,13 @@ def test_slice_sampling_inputs_prefill():
316316 assert result is sampling_inputs
317317
318318
319- def _model_inputs (input_ids , * , is_decoding = False , is_chunk = False , is_first_chunk = False , is_last_chunk = False ):
319+ def _model_inputs (input_ids ,
320+ * ,
321+ is_decoding = False ,
322+ is_chunk = False ,
323+ is_first_chunk = False ,
324+ is_last_chunk = False ,
325+ dp_meta = None ):
320326 input_ids = torch .tensor ([input_ids ])
321327 seq_length = torch .tensor ([input_ids .size (1 )])
322328 history_lengths = torch .tensor ([0 ])
@@ -334,6 +340,7 @@ def _model_inputs(input_ids, *, is_decoding=False, is_chunk=False, is_first_chun
334340 is_chunk = is_chunk ,
335341 is_first_chunk = is_first_chunk ,
336342 is_last_chunk = is_last_chunk ,
343+ dp_meta = dp_meta ,
337344 )
338345
339346
@@ -374,3 +381,16 @@ def test_prepare_inputs_from_main_clears_chunk_carry_on_non_chunk_prefill():
374381 agent ._prepare_inputs_from_main (prefill , _extra ([[1 , 10 ], [2 , 20 ], [3 , 30 ]]))
375382
376383 assert agent ._prev_chunk_last == {}
384+
385+
386+ def test_prepare_inputs_from_main_keeps_chunk_carry_for_dp_local_decode_global_prefill ():
387+ agent = SpecModelAgent .__new__ (SpecModelAgent )
388+ saved = torch .ones (1 , 1 , 2 )
389+ agent ._prev_chunk_last = {'hidden_states' : saved .clone ()}
390+ agent .proposer = _DummyProposer ()
391+
392+ dp_meta = DPMeta (dp_batches = [1 , 1 ], dp_is_decoding = False )
393+ inputs = _model_inputs ([90 , 91 , 92 ], is_decoding = True , dp_meta = dp_meta )
394+ agent ._prepare_inputs_from_main (inputs , _extra ([[9 , 90 ], [8 , 80 ], [7 , 70 ]]))
395+
396+ assert torch .equal (agent ._prev_chunk_last ['hidden_states' ], saved )
0 commit comments