diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index e82eefbdcb..f49448c961 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -320,6 +320,7 @@ def prepare_inputs_for_generation( def reset(self): """Remove all graphs to prevent hanging on exit.""" + super().reset() self._runner_map.clear() if get_deepep_state().enabled(): from dlblas.layers.moe.token_dispatcher import DeepEPBuffer diff --git a/lmdeploy/pytorch/backends/graph_runner.py b/lmdeploy/pytorch/backends/graph_runner.py index c977e6bc63..b8f4c2ebb6 100644 --- a/lmdeploy/pytorch/backends/graph_runner.py +++ b/lmdeploy/pytorch/backends/graph_runner.py @@ -95,7 +95,7 @@ def get_input_processor(self): def reset(self): """Remove all graphs to prevent hanging on exit.""" - pass + self._runner_meta.padding_batch_size = None def get_meta(self): """Get graphrunner meta.""" diff --git a/lmdeploy/pytorch/engine/engine_loop.py b/lmdeploy/pytorch/engine/engine_loop.py index 234ee67d14..599d4cd1b1 100644 --- a/lmdeploy/pytorch/engine/engine_loop.py +++ b/lmdeploy/pytorch/engine/engine_loop.py @@ -478,6 +478,7 @@ async def __no_running_warning(): # stale after the drain point. forward_inputs = None next_running = None + self.inputs_maker.clear_for_sleep() # Acknowledge that no new forward input will be scheduled until # wakeup resumes this loop. self._main_sleep_drain_event.set() diff --git a/lmdeploy/pytorch/engine/inputs_maker.py b/lmdeploy/pytorch/engine/inputs_maker.py index ab58826069..24dbea32ed 100644 --- a/lmdeploy/pytorch/engine/inputs_maker.py +++ b/lmdeploy/pytorch/engine/inputs_maker.py @@ -285,6 +285,10 @@ def __init__( # long context chunker self.long_context_chunker = LongContextChunker(config.max_prefill_token_num) + def clear_for_sleep(self): + """Clear transient scheduling state before engine sleep.""" + self.long_context_chunker.clear() + def _init_do_prefill(self, config: InputsMakerConfig): if config.role == EngineRole.Prefill: self.do_prefill = self.do_prefill_pnode diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index f9c28e4159..300b7d8790 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -403,6 +403,7 @@ def warmup(self): if dp > 1: num_tokens = inputs.input_ids.numel() inputs.build_dp_meta([num_tokens] * world_size) + inputs.dp_meta.dp_is_decoding = False logger.debug('Warmup prefill start.') self._forward_impl(inputs) torch.cuda.synchronize() @@ -423,6 +424,7 @@ def warmup(self): if dp > 1: num_tokens = inputs.input_ids.numel() inputs.build_dp_meta([num_tokens] * world_size) + inputs.dp_meta.dp_is_decoding = True logger.debug(f'Warmup decoding num_tokens={num_tokens} start.') self._forward_impl(inputs) torch.cuda.synchronize() @@ -1146,6 +1148,8 @@ def get_input_processor(self): def reset_graph_runner(self): """Reset graph runner to prevent tp hanging.""" with self.all_context(): + self._prev_chunk_output = None + self._prev_chunk_last_logit = None if hasattr(self.patched_model, 'reset'): self.patched_model.reset() diff --git a/lmdeploy/pytorch/spec_decode/spec_agent.py b/lmdeploy/pytorch/spec_decode/spec_agent.py index 1894328169..579fc0ad65 100644 --- a/lmdeploy/pytorch/spec_decode/spec_agent.py +++ b/lmdeploy/pytorch/spec_decode/spec_agent.py @@ -472,6 +472,7 @@ def __build_dp_meta(inputs: ModelInputs): padding_batch_size = max(dp_meta.dp_batches) new_dpmeta = DPMeta.build(inputs.input_ids.numel(), dp_meta.dp_batches) + new_dpmeta.dp_is_decoding = dp_meta.dp_is_decoding return new_dpmeta, padding_batch_size def _update_dp_model_inputs(inputs: ModelInputs, dp_meta: DPMeta, padding_batch_size: int | None): @@ -541,6 +542,20 @@ async def async_model_forward( def warmup(self, max_batches: int, target_model_config: ModelConfig): """warmup.""" + + def add_warmup_dp_meta(inputs: ModelInputs, is_decoding: bool): + dist_config = self.draft_dist_ctx.dist_config + if dist_config.dp <= 1: + return + + num_tokens = inputs.input_ids.numel() + batch_size = inputs.seq_length.numel() + world_size = dist_config.world_size + with self.draft_context(): + inputs.build_dp_meta([num_tokens] * world_size) + inputs.dp_meta.dp_batches = [batch_size] * world_size + inputs.dp_meta.dp_is_decoding = is_decoding + target_hidden_size = self.proposer.get_target_hidden_size(target_model_config) # warmup prefill @@ -551,6 +566,7 @@ def warmup(self, max_batches: int, target_model_config: ModelConfig): target_hidden_size=target_hidden_size, target_dtype=self.model_config.dtype, meta=self.make_dummy_meta) + add_warmup_dp_meta(inputs, is_decoding=False) # warmup prefill self._forward_impl(inputs) @@ -569,6 +585,7 @@ def warmup(self, max_batches: int, target_model_config: ModelConfig): target_hidden_size=target_hidden_size, target_dtype=self.model_config.dtype, meta=self.make_dummy_meta) + add_warmup_dp_meta(inputs, is_decoding=True) self._forward_impl(inputs) # decode 1 tokens per sequence inputs = self.inputs_strategy.make_dummy(batch_size, @@ -579,11 +596,13 @@ def warmup(self, max_batches: int, target_model_config: ModelConfig): target_hidden_size=self.model_config.hidden_size, target_dtype=self.model_config.dtype, meta=self.make_dummy_meta) + add_warmup_dp_meta(inputs, is_decoding=True) self._forward_impl(inputs) def reset_graph_runner(self): """Reset graph runner.""" with self.draft_context(): + self._prev_chunk_last.clear() if self.proposer.model is not None and hasattr(self.proposer.model, 'reset'): self.proposer.model.reset() diff --git a/tests/pytorch/engine/test_cudagraph_capture_batch_sizes.py b/tests/pytorch/engine/test_cudagraph_capture_batch_sizes.py index 01cb783cd5..5e9628ce4c 100644 --- a/tests/pytorch/engine/test_cudagraph_capture_batch_sizes.py +++ b/tests/pytorch/engine/test_cudagraph_capture_batch_sizes.py @@ -3,6 +3,7 @@ from lmdeploy.messages import PytorchEngineConfig from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner +from lmdeploy.pytorch.backends.graph_runner import GraphRunnerMeta from lmdeploy.pytorch.config import CacheConfig from lmdeploy.pytorch.engine.config_builder import ConfigBuilder @@ -53,3 +54,17 @@ def test_graph_runner_defensively_normalizes_capture_batch_sizes(): runner.cache_config = cache_config assert runner.get_capture_batch_sizes() == [1, 4, 8] + + +def test_graph_runner_reset_clears_padding_batch_size(monkeypatch): + from lmdeploy.pytorch.backends.cuda import graph_runner as cuda_graph_runner + + runner = object.__new__(CUDAGraphRunner) + runner._runner_meta = GraphRunnerMeta(padding_batch_size=1) + runner._runner_map = {'stale': object()} + monkeypatch.setattr(cuda_graph_runner.get_deepep_state(), 'enabled', lambda: False) + + runner.reset() + + assert runner.get_meta().padding_batch_size is None + assert runner._runner_map == {} diff --git a/tests/pytorch/engine/test_inputs_maker.py b/tests/pytorch/engine/test_inputs_maker.py index 667be525b3..ec9c5e8a5a 100644 --- a/tests/pytorch/engine/test_inputs_maker.py +++ b/tests/pytorch/engine/test_inputs_maker.py @@ -168,6 +168,41 @@ async def get_output_async(self): assert not block_trie.pinned +def test_engine_loop_sleep_drain_clears_long_context_chunker(): + seq = _DummySeq( + history_ids=512, + token_ids=2048, + all_multimodals={}, + input_multimodals={}, + ) + seq.status = MessageStatus.RUNNING + + maker = InputsMakerAsync.__new__(InputsMakerAsync) + maker.long_context_chunker = LongContextChunker(max_prefill_token_num=512) + maker.long_context_chunker.set_seq(seq) + assert maker.long_context_chunker.enabled() + + async def _run_sleep_drain(): + loop = EngineLoop.__new__(EngineLoop) + loop.stop_event = asyncio.Event() + loop.has_runable_event = asyncio.Event() + loop._sleep_requested = True + loop._main_sleep_drain_event = asyncio.Event() + loop._sleep_resume_event = asyncio.Event() + loop.scheduler = SimpleNamespace() + loop.inputs_maker = maker + + task = asyncio.create_task(loop.main_loop()) + await asyncio.wait_for(loop._main_sleep_drain_event.wait(), timeout=1) + loop.stop_event.set() + loop._sleep_resume_event.set() + await asyncio.wait_for(task, timeout=1) + + asyncio.run(_run_sleep_drain()) + + assert not maker.long_context_chunker.enabled() + + def test_long_context_chunker_uses_cached_multimodal_size_for_chunk_limit(): image = _DummyMultiModal(start=512, end=5888) seq = _DummySeq( diff --git a/tests/pytorch/engine/test_model_agent.py b/tests/pytorch/engine/test_model_agent.py index 4b92d925dc..ffd2511414 100644 --- a/tests/pytorch/engine/test_model_agent.py +++ b/tests/pytorch/engine/test_model_agent.py @@ -241,6 +241,8 @@ def reset_graph_runner(self): agent = BaseModelAgent.__new__(BaseModelAgent) agent.patched_model = _PatchedModel() agent.spec_agent = _SpecAgent() + agent._prev_chunk_output = {'model_metas': object()} + agent._prev_chunk_last_logit = torch.ones(1, 2) @contextmanager def _all_context(): @@ -258,6 +260,8 @@ def _all_context(): 'spec_reset', 'exit_all_context', ] + assert agent._prev_chunk_output is None + assert agent._prev_chunk_last_logit is None def test_spec_agent_reset_graph_runner_uses_draft_context(self): from lmdeploy.pytorch.spec_decode.spec_agent import SpecModelAgent @@ -271,6 +275,7 @@ def reset(self): agent = SpecModelAgent.__new__(SpecModelAgent) agent.proposer = type('Proposer', (), {'model': _Model()})() + agent._prev_chunk_last = {'hidden_states': torch.ones(1, 1, 2)} @contextmanager def _draft_context(): @@ -287,10 +292,105 @@ def _draft_context(): 'reset', 'exit_draft_context', ] + assert agent._prev_chunk_last == {} class TestModelAgentWakeup: + def test_sleep_clears_middle_chunk_carryover_state(self, event_loop, monkeypatch): + from lmdeploy.pytorch.engine.model_agent.agent import BaseModelAgent, SleepWakeupState + from lmdeploy.pytorch.spec_decode.spec_agent import SpecModelAgent + + events = [] + + class _Moveable: + + def __init__(self, name): + self.name = name + + def to(self, *args, **kwargs): + events.append((self.name, 'to', args, kwargs)) + return self + + class _PatchedModel: + + def __init__(self): + self.model = _Moveable('main_model') + + def reset(self): + events.append('main_reset') + + def get_model(self): + return self.model + + class _SpecGraphRunner: + + def __init__(self): + self.model = _Moveable('spec_model') + + def reset(self): + events.append('spec_reset') + + def get_model(self): + return self.model + + spec_agent = SpecModelAgent.__new__(SpecModelAgent) + spec_agent.proposer = type('Proposer', (), {'model': _SpecGraphRunner()})() + spec_agent._prev_chunk_last = {'hidden_states': torch.ones(1, 1, 2)} + spec_agent.cache_engine = object() + + @contextmanager + def _draft_context(): + events.append('enter_draft_context') + yield + events.append('exit_draft_context') + + spec_agent.draft_context = _draft_context + + model_agent = BaseModelAgent.__new__(BaseModelAgent) + model_agent.state = SleepWakeupState() + model_agent.dist_config = SimpleNamespace(dp=1) + model_agent.cache_engine = object() + model_agent.state_cache_engine = object() + model_agent.patched_model = _PatchedModel() + model_agent.spec_agent = spec_agent + model_agent._prev_chunk_output = {'model_metas': object()} + model_agent._prev_chunk_last_logit = torch.ones(1, 2) + model_agent._pre_in_que = asyncio.Queue() + model_agent._in_que = asyncio.Queue() + model_agent._out_que = asyncio.Queue() + model_agent._pre_in_que.put_nowait('stale_middle_chunk_input') + model_agent._in_que.put_nowait('stale_middle_chunk_cuda_input') + model_agent._out_que.put_nowait('stale_middle_chunk_output') + model_agent._update_params_ipc_tensor = object() + model_agent._update_params_ipc_event = object() + + @contextmanager + def _all_context(): + events.append('enter_all_context') + yield + events.append('exit_all_context') + + model_agent.all_context = _all_context + monkeypatch.setattr(torch.cuda, 'synchronize', lambda: events.append('cuda_synchronize')) + monkeypatch.setattr(torch.cuda, 'empty_cache', lambda: events.append('cuda_empty_cache')) + + event_loop.run_until_complete(model_agent.sleep(level=1)) + + assert model_agent._prev_chunk_output is None + assert model_agent._prev_chunk_last_logit is None + assert spec_agent._prev_chunk_last == {} + assert model_agent.cache_engine is None + assert model_agent.state_cache_engine is None + assert spec_agent.cache_engine is None + assert model_agent._pre_in_que.empty() + assert model_agent._in_que.empty() + assert model_agent._out_que.empty() + assert model_agent._update_params_ipc_tensor is None + assert model_agent._update_params_ipc_event is None + assert 'main_reset' in events + assert 'spec_reset' in events + def test_dp_kv_cache_wakeup_warms_before_releasing_forward_task(self): from lmdeploy.pytorch.engine.model_agent.agent import BaseModelAgent, SleepWakeupState diff --git a/tests/pytorch/spec_decode/test_spec_agent.py b/tests/pytorch/spec_decode/test_spec_agent.py index d121bed2a3..72ad8b81a0 100644 --- a/tests/pytorch/spec_decode/test_spec_agent.py +++ b/tests/pytorch/spec_decode/test_spec_agent.py @@ -1,4 +1,5 @@ import asyncio +from types import SimpleNamespace import torch @@ -44,12 +45,15 @@ class Meta: def __init__(self): self.meta = self.Meta() self.update_inputs_calls = 0 + self.update_inputs_dp_is_decoding = [] def get_meta(self): return self.meta def update_inputs(self, inputs): self.update_inputs_calls += 1 + if inputs.dp_meta is not None: + self.update_inputs_dp_is_decoding.append(inputs.dp_meta.dp_is_decoding) return inputs @@ -258,6 +262,150 @@ def _forward_impl(_inputs): assert agent.proposer.model.update_inputs_calls == agent.num_spec_tokens - 1 +def test_async_model_forward_preserves_dp_global_decoding_in_draft_loop(monkeypatch): + """Rebuilt draft-loop DPMeta must keep DP-global decode state.""" + import lmdeploy.pytorch.spec_decode.spec_agent as spec_agent_mod + from lmdeploy.pytorch.model_inputs import DPMeta + from lmdeploy.pytorch.spec_decode.spec_agent import SpecModelAgent + + monkeypatch.setattr(spec_agent_mod.DPMeta, 'build', staticmethod(lambda seqlen, num_tokens: DPMeta())) + inputs, extra_inputs = _make_non_last_chunk_inputs(dp_meta=DPMeta(dp_batches=[2, 2], dp_is_decoding=True)) + + agent = object.__new__(SpecModelAgent) + agent.num_spec_tokens = 3 + agent.rank = 0 + agent.proposer = _DummyProposer() + forward_calls = 0 + + def _forward_impl(_inputs): + nonlocal forward_calls + forward_calls += 1 + return {'call': forward_calls} + + agent._forward_impl = _forward_impl + + asyncio.run(agent._async_model_forward(inputs, extra_inputs, sampling_inputs=None)) + + assert agent.proposer.model.update_inputs_dp_is_decoding == [True, True] + + +def test_spec_model_agent_warmup_adds_dp_meta_for_draft_capture(monkeypatch): + """Draft warmup must mark decode graph captures as DP-global decode.""" + import lmdeploy.pytorch.spec_decode.spec_agent as spec_agent_mod + from lmdeploy.pytorch.config import DistConfig + from lmdeploy.pytorch.distributed import DistContext, DistGroup + from lmdeploy.pytorch.model_inputs import DPMeta, ModelInputs + from lmdeploy.pytorch.spec_decode.spec_agent import SpecModelAgent + + class DummyInputsStrategy: + + def make_dummy(self, + batch_size: int, + is_decoding: bool, + device: str = 'cpu', + vocab_size: int = 1, + max_q_seqlen: int = 1, + target_hidden_size: int = None, + target_dtype: torch.dtype = torch.float32, + meta=None): + input_ids = torch.zeros((1, batch_size * max_q_seqlen), dtype=torch.long) + seq_length = torch.full((batch_size, ), max_q_seqlen, dtype=torch.long) + inputs = ModelInputs(input_ids=input_ids, + seq_length=seq_length, + history_lengths=torch.zeros(batch_size, dtype=torch.long), + block_offsets=torch.zeros((batch_size, 1), dtype=torch.long), + is_decoding=is_decoding, + num_ignored_history=torch.zeros(batch_size, dtype=torch.long), + max_q_seqlen=max_q_seqlen, + max_kv_seqlen=max_q_seqlen, + sum_kv_seqlen=batch_size * max_q_seqlen) + if target_hidden_size is not None: + inputs.target_hidden_states = torch.zeros((1, batch_size * max_q_seqlen, target_hidden_size), + dtype=target_dtype) + return inputs + + class DummyDraftModel: + + def get_capture_batch_sizes(self): + return [2] + + class DummyProposer: + + def __init__(self): + self.model = DummyDraftModel() + + def get_target_hidden_size(self, target_model_config): + return 4 + + build_calls = [] + + def fake_dp_meta_build(seqlen, num_tokens): + build_calls.append((seqlen, list(num_tokens))) + return DPMeta(tp_sizes=[seqlen], moe_tp_sizes=[seqlen]) + + monkeypatch.setattr(spec_agent_mod.DPMeta, 'build', staticmethod(fake_dp_meta_build)) + + dist_config = DistConfig(dp=2, ep=2) + draft_dist_ctx = DistContext(rank=0, + dp_rank=0, + dist_config=dist_config, + attn_tp_group=DistGroup(rank=0), + mlp_tp_group=DistGroup(rank=0), + moe_tp_group=DistGroup(rank=0), + tp_group=DistGroup(rank=0)) + agent = object.__new__(SpecModelAgent) + agent.draft_dist_ctx = draft_dist_ctx + agent.inputs_strategy = DummyInputsStrategy() + agent.proposer = DummyProposer() + agent.model_config = SimpleNamespace(vocab_size=11, dtype=torch.float32, hidden_size=8) + agent.num_spec_tokens = 3 + agent.make_dummy_meta = None + + forwarded = [] + + def forward_impl(inputs): + forwarded.append({ + 'num_tokens': inputs.input_ids.numel(), + 'batch_size': inputs.seq_length.numel(), + 'is_decoding': inputs.is_decoding, + 'dp_batches': inputs.dp_meta.dp_batches, + 'dp_is_decoding': inputs.dp_meta.dp_is_decoding, + 'global_is_decoding': inputs.global_is_decoding(), + }) + + agent._forward_impl = forward_impl + + agent.warmup(max_batches=4, target_model_config=SimpleNamespace()) + + assert build_calls == [(4, [4, 4]), (8, [8, 8]), (2, [2, 2])] + assert forwarded == [ + { + 'num_tokens': 4, + 'batch_size': 4, + 'is_decoding': False, + 'dp_batches': [4, 4], + 'dp_is_decoding': False, + 'global_is_decoding': False, + }, + { + 'num_tokens': 8, + 'batch_size': 2, + 'is_decoding': True, + 'dp_batches': [2, 2], + 'dp_is_decoding': True, + 'global_is_decoding': True, + }, + { + 'num_tokens': 2, + 'batch_size': 2, + 'is_decoding': True, + 'dp_batches': [2, 2], + 'dp_is_decoding': True, + 'global_is_decoding': True, + }, + ] + + def test_slice_sampling_inputs_decode(): """Test _slice_sampling_inputs with decoding (num_tokens_per_batch > 1).""" from lmdeploy.pytorch.engine.logits_process import SamplingInputs