Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lmdeploy/pytorch/backends/cuda/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def prepare_inputs_for_generation(

def reset(self):
"""Remove all graphs to prevent hanging on exit."""
super().reset()
self._runner_map.clear()
if get_deepep_state().enabled():
from dlblas.layers.moe.token_dispatcher import DeepEPBuffer
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/backends/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_input_processor(self):

def reset(self):
"""Remove all graphs to prevent hanging on exit."""
pass
self._runner_meta.padding_batch_size = None

def get_meta(self):
"""Get graphrunner meta."""
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/engine/engine_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ async def __no_running_warning():
# stale after the drain point.
forward_inputs = None
next_running = None
self.inputs_maker.clear_for_sleep()
# Acknowledge that no new forward input will be scheduled until
# wakeup resumes this loop.
self._main_sleep_drain_event.set()
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/engine/inputs_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ def __init__(
# long context chunker
self.long_context_chunker = LongContextChunker(config.max_prefill_token_num)

def clear_for_sleep(self):
"""Clear transient scheduling state before engine sleep."""
self.long_context_chunker.clear()

def _init_do_prefill(self, config: InputsMakerConfig):
if config.role == EngineRole.Prefill:
self.do_prefill = self.do_prefill_pnode
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/engine/model_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def warmup(self):
if dp > 1:
num_tokens = inputs.input_ids.numel()
inputs.build_dp_meta([num_tokens] * world_size)
inputs.dp_meta.dp_is_decoding = False
logger.debug('Warmup prefill start.')
self._forward_impl(inputs)
torch.cuda.synchronize()
Expand All @@ -423,6 +424,7 @@ def warmup(self):
if dp > 1:
num_tokens = inputs.input_ids.numel()
inputs.build_dp_meta([num_tokens] * world_size)
inputs.dp_meta.dp_is_decoding = True
logger.debug(f'Warmup decoding num_tokens={num_tokens} start.')
self._forward_impl(inputs)
torch.cuda.synchronize()
Expand Down Expand Up @@ -1146,6 +1148,8 @@ def get_input_processor(self):
def reset_graph_runner(self):
"""Reset graph runner to prevent tp hanging."""
with self.all_context():
self._prev_chunk_output = None
self._prev_chunk_last_logit = None
if hasattr(self.patched_model, 'reset'):
self.patched_model.reset()

Expand Down
19 changes: 19 additions & 0 deletions lmdeploy/pytorch/spec_decode/spec_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ def __build_dp_meta(inputs: ModelInputs):

padding_batch_size = max(dp_meta.dp_batches)
new_dpmeta = DPMeta.build(inputs.input_ids.numel(), dp_meta.dp_batches)
new_dpmeta.dp_is_decoding = dp_meta.dp_is_decoding
return new_dpmeta, padding_batch_size

def _update_dp_model_inputs(inputs: ModelInputs, dp_meta: DPMeta, padding_batch_size: int | None):
Expand Down Expand Up @@ -541,6 +542,20 @@ async def async_model_forward(

def warmup(self, max_batches: int, target_model_config: ModelConfig):
"""warmup."""

def add_warmup_dp_meta(inputs: ModelInputs, is_decoding: bool):
dist_config = self.draft_dist_ctx.dist_config
if dist_config.dp <= 1:
return

num_tokens = inputs.input_ids.numel()
batch_size = inputs.seq_length.numel()
world_size = dist_config.world_size
with self.draft_context():
inputs.build_dp_meta([num_tokens] * world_size)
inputs.dp_meta.dp_batches = [batch_size] * world_size
inputs.dp_meta.dp_is_decoding = is_decoding

target_hidden_size = self.proposer.get_target_hidden_size(target_model_config)

# warmup prefill
Expand All @@ -551,6 +566,7 @@ def warmup(self, max_batches: int, target_model_config: ModelConfig):
target_hidden_size=target_hidden_size,
target_dtype=self.model_config.dtype,
meta=self.make_dummy_meta)
add_warmup_dp_meta(inputs, is_decoding=False)

# warmup prefill
self._forward_impl(inputs)
Expand All @@ -569,6 +585,7 @@ def warmup(self, max_batches: int, target_model_config: ModelConfig):
target_hidden_size=target_hidden_size,
target_dtype=self.model_config.dtype,
meta=self.make_dummy_meta)
add_warmup_dp_meta(inputs, is_decoding=True)
self._forward_impl(inputs)
# decode 1 tokens per sequence
inputs = self.inputs_strategy.make_dummy(batch_size,
Expand All @@ -579,11 +596,13 @@ def warmup(self, max_batches: int, target_model_config: ModelConfig):
target_hidden_size=self.model_config.hidden_size,
target_dtype=self.model_config.dtype,
meta=self.make_dummy_meta)
add_warmup_dp_meta(inputs, is_decoding=True)
self._forward_impl(inputs)

def reset_graph_runner(self):
"""Reset graph runner."""
with self.draft_context():
self._prev_chunk_last.clear()
if self.proposer.model is not None and hasattr(self.proposer.model, 'reset'):
self.proposer.model.reset()

Expand Down
15 changes: 15 additions & 0 deletions tests/pytorch/engine/test_cudagraph_capture_batch_sizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from lmdeploy.messages import PytorchEngineConfig
from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner
from lmdeploy.pytorch.backends.graph_runner import GraphRunnerMeta
from lmdeploy.pytorch.config import CacheConfig
from lmdeploy.pytorch.engine.config_builder import ConfigBuilder

Expand Down Expand Up @@ -53,3 +54,17 @@ def test_graph_runner_defensively_normalizes_capture_batch_sizes():
runner.cache_config = cache_config

assert runner.get_capture_batch_sizes() == [1, 4, 8]


def test_graph_runner_reset_clears_padding_batch_size(monkeypatch):
from lmdeploy.pytorch.backends.cuda import graph_runner as cuda_graph_runner

runner = object.__new__(CUDAGraphRunner)
runner._runner_meta = GraphRunnerMeta(padding_batch_size=1)
runner._runner_map = {'stale': object()}
monkeypatch.setattr(cuda_graph_runner.get_deepep_state(), 'enabled', lambda: False)

runner.reset()

assert runner.get_meta().padding_batch_size is None
assert runner._runner_map == {}
35 changes: 35 additions & 0 deletions tests/pytorch/engine/test_inputs_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,41 @@ async def get_output_async(self):
assert not block_trie.pinned


def test_engine_loop_sleep_drain_clears_long_context_chunker():
seq = _DummySeq(
history_ids=512,
token_ids=2048,
all_multimodals={},
input_multimodals={},
)
seq.status = MessageStatus.RUNNING

maker = InputsMakerAsync.__new__(InputsMakerAsync)
maker.long_context_chunker = LongContextChunker(max_prefill_token_num=512)
maker.long_context_chunker.set_seq(seq)
assert maker.long_context_chunker.enabled()

async def _run_sleep_drain():
loop = EngineLoop.__new__(EngineLoop)
loop.stop_event = asyncio.Event()
loop.has_runable_event = asyncio.Event()
loop._sleep_requested = True
loop._main_sleep_drain_event = asyncio.Event()
loop._sleep_resume_event = asyncio.Event()
loop.scheduler = SimpleNamespace()
loop.inputs_maker = maker

task = asyncio.create_task(loop.main_loop())
await asyncio.wait_for(loop._main_sleep_drain_event.wait(), timeout=1)
loop.stop_event.set()
loop._sleep_resume_event.set()
await asyncio.wait_for(task, timeout=1)

asyncio.run(_run_sleep_drain())

assert not maker.long_context_chunker.enabled()


def test_long_context_chunker_uses_cached_multimodal_size_for_chunk_limit():
image = _DummyMultiModal(start=512, end=5888)
seq = _DummySeq(
Expand Down
100 changes: 100 additions & 0 deletions tests/pytorch/engine/test_model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ def reset_graph_runner(self):
agent = BaseModelAgent.__new__(BaseModelAgent)
agent.patched_model = _PatchedModel()
agent.spec_agent = _SpecAgent()
agent._prev_chunk_output = {'model_metas': object()}
agent._prev_chunk_last_logit = torch.ones(1, 2)

@contextmanager
def _all_context():
Expand All @@ -258,6 +260,8 @@ def _all_context():
'spec_reset',
'exit_all_context',
]
assert agent._prev_chunk_output is None
assert agent._prev_chunk_last_logit is None

def test_spec_agent_reset_graph_runner_uses_draft_context(self):
from lmdeploy.pytorch.spec_decode.spec_agent import SpecModelAgent
Expand All @@ -271,6 +275,7 @@ def reset(self):

agent = SpecModelAgent.__new__(SpecModelAgent)
agent.proposer = type('Proposer', (), {'model': _Model()})()
agent._prev_chunk_last = {'hidden_states': torch.ones(1, 1, 2)}

@contextmanager
def _draft_context():
Expand All @@ -287,10 +292,105 @@ def _draft_context():
'reset',
'exit_draft_context',
]
assert agent._prev_chunk_last == {}


class TestModelAgentWakeup:

def test_sleep_clears_middle_chunk_carryover_state(self, event_loop, monkeypatch):
from lmdeploy.pytorch.engine.model_agent.agent import BaseModelAgent, SleepWakeupState
from lmdeploy.pytorch.spec_decode.spec_agent import SpecModelAgent

events = []

class _Moveable:

def __init__(self, name):
self.name = name

def to(self, *args, **kwargs):
events.append((self.name, 'to', args, kwargs))
return self

class _PatchedModel:

def __init__(self):
self.model = _Moveable('main_model')

def reset(self):
events.append('main_reset')

def get_model(self):
return self.model

class _SpecGraphRunner:

def __init__(self):
self.model = _Moveable('spec_model')

def reset(self):
events.append('spec_reset')

def get_model(self):
return self.model

spec_agent = SpecModelAgent.__new__(SpecModelAgent)
spec_agent.proposer = type('Proposer', (), {'model': _SpecGraphRunner()})()
spec_agent._prev_chunk_last = {'hidden_states': torch.ones(1, 1, 2)}
spec_agent.cache_engine = object()

@contextmanager
def _draft_context():
events.append('enter_draft_context')
yield
events.append('exit_draft_context')

spec_agent.draft_context = _draft_context

model_agent = BaseModelAgent.__new__(BaseModelAgent)
model_agent.state = SleepWakeupState()
model_agent.dist_config = SimpleNamespace(dp=1)
model_agent.cache_engine = object()
model_agent.state_cache_engine = object()
model_agent.patched_model = _PatchedModel()
model_agent.spec_agent = spec_agent
model_agent._prev_chunk_output = {'model_metas': object()}
model_agent._prev_chunk_last_logit = torch.ones(1, 2)
model_agent._pre_in_que = asyncio.Queue()
model_agent._in_que = asyncio.Queue()
model_agent._out_que = asyncio.Queue()
model_agent._pre_in_que.put_nowait('stale_middle_chunk_input')
model_agent._in_que.put_nowait('stale_middle_chunk_cuda_input')
model_agent._out_que.put_nowait('stale_middle_chunk_output')
model_agent._update_params_ipc_tensor = object()
model_agent._update_params_ipc_event = object()

@contextmanager
def _all_context():
events.append('enter_all_context')
yield
events.append('exit_all_context')

model_agent.all_context = _all_context
monkeypatch.setattr(torch.cuda, 'synchronize', lambda: events.append('cuda_synchronize'))
monkeypatch.setattr(torch.cuda, 'empty_cache', lambda: events.append('cuda_empty_cache'))

event_loop.run_until_complete(model_agent.sleep(level=1))

assert model_agent._prev_chunk_output is None
assert model_agent._prev_chunk_last_logit is None
assert spec_agent._prev_chunk_last == {}
assert model_agent.cache_engine is None
assert model_agent.state_cache_engine is None
assert spec_agent.cache_engine is None
assert model_agent._pre_in_que.empty()
assert model_agent._in_que.empty()
assert model_agent._out_que.empty()
assert model_agent._update_params_ipc_tensor is None
assert model_agent._update_params_ipc_event is None
assert 'main_reset' in events
assert 'spec_reset' in events

def test_dp_kv_cache_wakeup_warms_before_releasing_forward_task(self):
from lmdeploy.pytorch.engine.model_agent.agent import BaseModelAgent, SleepWakeupState

Expand Down
Loading
Loading