From 6f553a8601576c01e907546b589582a544a94275 Mon Sep 17 00:00:00 2001 From: zxy Date: Tue, 24 Mar 2026 12:12:41 +0800 Subject: [PATCH 1/7] support update params for draft model --- lmdeploy/pytorch/engine/model_agent/agent.py | 75 +++++++++++++++----- 1 file changed, 56 insertions(+), 19 deletions(-) diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index 12c04c0b80..edbe1bb7ba 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -1160,6 +1160,12 @@ def reset_graph_runner(self): self.spec_agent.reset_graph_runner() + def _get_spec_model(self): + """Return the spec-decode draft model, or None if not enabled.""" + if self.spec_agent.is_enabled() and self.spec_agent.proposer.model is not None: + return self.spec_agent.proposer.model.get_model() + return None + @torch.inference_mode() def update_params(self, request: UpdateParamsRequest): """Update params.""" @@ -1172,32 +1178,48 @@ def _construct(item): # clone() seems necessary otherwise the producer can not release the memory return func(*args).clone() + def _deserialize_weights(serialized_data): + raw = ForkingPickler.loads(pybase64.b64decode(serialized_data)) + if request.load_format == 'flattened_bucket': + metadata: list[FlattenedTensorMetadata] = raw['metadata'] + if not metadata: + return [] + flattened_tensor: torch.Tensor = _construct(raw['flattened_tensor']) + bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=metadata) + return list(bucket.reconstruct_tensors()) + return [(k, _construct(v)) for k, v in raw] + + def _split_main_and_draft(weights): + if not self.spec_agent.is_enabled(): + return weights, [] + main = [(name, weight) for name, weight in weights if not name.startswith('mtp.')] + draft = [(name, weight) for name, weight in weights if name.startswith('mtp.')] + return main, draft + with self.all_context(): serialized_data = request.serialized_named_tensors if isinstance(serialized_data, list): serialized_data = serialized_data[self.dist_ctx.tp_group.rank] + model = self.patched_model.get_model() - weights = ForkingPickler.loads(pybase64.b64decode(serialized_data)) - if request.load_format == 'flattened_bucket': - metadata: list[FlattenedTensorMetadata] = weights['metadata'] - if metadata: - flattened_tensor: torch.Tensor = _construct(weights['flattened_tensor']) - bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=metadata) - weights = bucket.reconstruct_tensors() - else: - # empty data - weights = [] - else: - weights = [(k, _construct(v)) for k, v in weights] + spec_model = self._get_spec_model() + + weights = _deserialize_weights(serialized_data) + main_weights, draft_weights = _split_main_and_draft(weights) + + for m, w, tag in [(model, main_weights, 'main'), (spec_model, draft_weights, 'draft')]: + if m is None or not w: + continue - weights = ModelWeightLoader._rename_weights_iterator(weights, model) - model.load_weights(weights) + w = list(ModelWeightLoader._rename_weights_iterator(w, m)) + logger.info(f'Update_params: {tag}_num_tensors={len(w)}') + m.load_weights(iter(w)) if request.finished: - for _, mod in model.named_modules(): - if not hasattr(mod, 'update_weights'): - continue - mod.update_weights() + for m in filter(None, [model, spec_model]): + for _, mod in m.named_modules(): + if hasattr(mod, 'update_weights'): + mod.update_weights() torch.cuda.empty_cache() @@ -1207,10 +1229,16 @@ async def sleep(self, level: int = 1): self.state.is_sleeping = True if self.dist_config.dp > 1: await self.state.to_sleep.wait() + device = 'cpu' if level == 1 else 'meta' self.cache_engine = None self.reset_graph_runner() - device = 'cpu' if level == 1 else 'meta' self.patched_model.get_model().to(device=device, non_blocking=True) + + spec_model = self._get_spec_model() + if spec_model is not None: + self.spec_agent.cache_engine = None + spec_model.to(device=device, non_blocking=True) + torch.cuda.synchronize() torch.cuda.empty_cache() self.state.to_sleep.clear() @@ -1220,9 +1248,11 @@ def wakeup(self, tags: list[str] | None = None): """Wakeup.""" if tags is None: tags = ['weights', 'kv_cache'] + if 'weights' in tags: device = next(self.patched_model.get_model().parameters()).device assert device.type in ['cpu', 'meta'] + if device.type == 'cpu': self.patched_model.get_model().to(torch.cuda.current_device()) else: @@ -1233,6 +1263,13 @@ def wakeup(self, tags: list[str] | None = None): self.build_graph_runner() self.misc_config.empty_init = old_empty_init + spec_model = self._get_spec_model() + if spec_model is not None: + spec_device = next(spec_model.parameters()).device + assert spec_device.type in ['cpu', 'meta'] + if spec_device.type == 'cpu': + spec_model.to(torch.cuda.current_device()) + if 'kv_cache' in tags: self.build_cache_engine() # wake up signal From e9449977b5d334cf68bbf828a2b271f40c965b47 Mon Sep 17 00:00:00 2001 From: zxy Date: Tue, 24 Mar 2026 12:15:06 +0800 Subject: [PATCH 2/7] fix zmq rpc race condition --- lmdeploy/pytorch/engine/mp_engine/base.py | 4 ++-- lmdeploy/serve/core/async_engine.py | 7 +++++-- lmdeploy/serve/openai/api_server.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/lmdeploy/pytorch/engine/mp_engine/base.py b/lmdeploy/pytorch/engine/mp_engine/base.py index a5c16dd967..c0352787b1 100644 --- a/lmdeploy/pytorch/engine/mp_engine/base.py +++ b/lmdeploy/pytorch/engine/mp_engine/base.py @@ -65,9 +65,9 @@ def update_params(self, request: Any): """Update params.""" return self._collective_rpc('update_params', request) - def get_schedule_metrics(self): + async def get_schedule_metrics(self): """Get schedule metrics.""" - return self._collective_rpc('get_schedule_metrics') + return await self._collective_rpc_async('get_schedule_metrics') def p2p_initialize(self, conn_request: DistServeInitRequest): """Init rdma link.""" diff --git a/lmdeploy/serve/core/async_engine.py b/lmdeploy/serve/core/async_engine.py index 293262180b..284bcea916 100644 --- a/lmdeploy/serve/core/async_engine.py +++ b/lmdeploy/serve/core/async_engine.py @@ -201,8 +201,11 @@ def _build_stat_loggers(self): # set stats loggers of metrics processor metrics_processor.stat_loggers = self.stat_loggers - def get_schedule_metrics(self): - return self.engine.get_schedule_metrics() + async def get_schedule_metrics(self): + result = self.engine.get_schedule_metrics() + if asyncio.iscoroutine(result): + return await result + return result async def do_log_stats(self): """Loop through CLI logger and Prometheus logger and output the diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index f64586fff7..b8bcf31ae9 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -1379,7 +1379,7 @@ async def _force_log(): await asyncio.sleep(log_interval) # periodically update schedule metrics, as they change less frequently than iteration stats - schedule_metrics = async_engine.get_schedule_metrics() + schedule_metrics = await async_engine.get_schedule_metrics() await metrics_processor.update_schedule_stats(schedule_metrics) await async_engine.do_log_stats() From 7a48257457f5448851b350d7596d2043194cf108 Mon Sep 17 00:00:00 2001 From: zxy Date: Tue, 24 Mar 2026 20:37:41 +0800 Subject: [PATCH 3/7] only for qwen35 mtp, move get_model function --- lmdeploy/pytorch/engine/model_agent/agent.py | 14 ++++---------- lmdeploy/pytorch/spec_decode/spec_agent.py | 8 +++++++- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index edbe1bb7ba..e268d4307a 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -1160,12 +1160,6 @@ def reset_graph_runner(self): self.spec_agent.reset_graph_runner() - def _get_spec_model(self): - """Return the spec-decode draft model, or None if not enabled.""" - if self.spec_agent.is_enabled() and self.spec_agent.proposer.model is not None: - return self.spec_agent.proposer.model.get_model() - return None - @torch.inference_mode() def update_params(self, request: UpdateParamsRequest): """Update params.""" @@ -1190,7 +1184,7 @@ def _deserialize_weights(serialized_data): return [(k, _construct(v)) for k, v in raw] def _split_main_and_draft(weights): - if not self.spec_agent.is_enabled(): + if not self.spec_agent.is_enabled() or self.spec_agent.method != 'qwen3_5_mtp': return weights, [] main = [(name, weight) for name, weight in weights if not name.startswith('mtp.')] draft = [(name, weight) for name, weight in weights if name.startswith('mtp.')] @@ -1202,7 +1196,7 @@ def _split_main_and_draft(weights): serialized_data = serialized_data[self.dist_ctx.tp_group.rank] model = self.patched_model.get_model() - spec_model = self._get_spec_model() + spec_model = self.spec_agent.get_model() weights = _deserialize_weights(serialized_data) main_weights, draft_weights = _split_main_and_draft(weights) @@ -1234,7 +1228,7 @@ async def sleep(self, level: int = 1): self.reset_graph_runner() self.patched_model.get_model().to(device=device, non_blocking=True) - spec_model = self._get_spec_model() + spec_model = self.spec_agent.get_model() if spec_model is not None: self.spec_agent.cache_engine = None spec_model.to(device=device, non_blocking=True) @@ -1263,7 +1257,7 @@ def wakeup(self, tags: list[str] | None = None): self.build_graph_runner() self.misc_config.empty_init = old_empty_init - spec_model = self._get_spec_model() + spec_model = self.spec_agent.get_model() if spec_model is not None: spec_device = next(spec_model.parameters()).device assert spec_device.type in ['cpu', 'meta'] diff --git a/lmdeploy/pytorch/spec_decode/spec_agent.py b/lmdeploy/pytorch/spec_decode/spec_agent.py index 4000461e82..aa16455e8a 100644 --- a/lmdeploy/pytorch/spec_decode/spec_agent.py +++ b/lmdeploy/pytorch/spec_decode/spec_agent.py @@ -231,6 +231,12 @@ def warmup(self, max_batches: int, target_model_config: ModelConfig): self._forward_impl(inputs) def reset_graph_runner(self): - 'reset graph runner' + """Reset graph runner.""" if self.proposer.model is not None and hasattr(self.proposer.model, 'reset'): self.proposer.model.reset() + + def get_model(self): + """Get model.""" + if self.is_enable() and self.proposer.model is not None: + return self.proposer.model + return None From b745d7fd1b5670c78759e3403f58ab24e7185c15 Mon Sep 17 00:00:00 2001 From: zxy Date: Wed, 25 Mar 2026 11:43:53 +0800 Subject: [PATCH 4/7] release state cache --- lmdeploy/pytorch/engine/model_agent/agent.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index e268d4307a..2a74e35a2d 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -1225,6 +1225,7 @@ async def sleep(self, level: int = 1): await self.state.to_sleep.wait() device = 'cpu' if level == 1 else 'meta' self.cache_engine = None + self.state_cache_engine = None self.reset_graph_runner() self.patched_model.get_model().to(device=device, non_blocking=True) @@ -1276,4 +1277,5 @@ def release(self): self.reset_graph_runner() self.patched_model = None self.cache_engine = None + self.state_cache_engine = None torch.cuda.empty_cache() From 36090b8b19b66b24663ce2c731a8159ed3cac131 Mon Sep 17 00:00:00 2001 From: zxy Date: Wed, 25 Mar 2026 12:06:09 +0800 Subject: [PATCH 5/7] add TODO --- lmdeploy/pytorch/engine/model_agent/agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index 2a74e35a2d..b3caca69fd 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -1184,6 +1184,7 @@ def _deserialize_weights(serialized_data): return [(k, _construct(v)) for k, v in raw] def _split_main_and_draft(weights): + # TODO, zhouxinyu, support split and update weights for other mtp methods if not self.spec_agent.is_enabled() or self.spec_agent.method != 'qwen3_5_mtp': return weights, [] main = [(name, weight) for name, weight in weights if not name.startswith('mtp.')] From 7847174a0c30cc9c09f75f4404f3f9c078fe4669 Mon Sep 17 00:00:00 2001 From: zxy Date: Fri, 3 Apr 2026 12:33:19 +0800 Subject: [PATCH 6/7] fix bug --- lmdeploy/pytorch/spec_decode/spec_agent.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lmdeploy/pytorch/spec_decode/spec_agent.py b/lmdeploy/pytorch/spec_decode/spec_agent.py index aa16455e8a..e8f2131228 100644 --- a/lmdeploy/pytorch/spec_decode/spec_agent.py +++ b/lmdeploy/pytorch/spec_decode/spec_agent.py @@ -237,6 +237,4 @@ def reset_graph_runner(self): def get_model(self): """Get model.""" - if self.is_enable() and self.proposer.model is not None: - return self.proposer.model - return None + return self.proposer.model.get_model() From 5febdff6c3c94b90e7f9e5fb56faad335d499bca Mon Sep 17 00:00:00 2001 From: zxy Date: Fri, 3 Apr 2026 12:40:36 +0800 Subject: [PATCH 7/7] cleanup --- lmdeploy/pytorch/engine/model_agent/agent.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index d938f054d5..29c19c96dd 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -1207,17 +1207,6 @@ def _deserialize_weights(serialized_data): return list(bucket.reconstruct_tensors()) return [(k, _construct(v)) for k, v in raw] - def _deserialize_weights(serialized_data): - raw = ForkingPickler.loads(pybase64.b64decode(serialized_data)) - if request.load_format == 'flattened_bucket': - metadata: list[FlattenedTensorMetadata] = raw['metadata'] - if not metadata: - return [] - flattened_tensor: torch.Tensor = _construct(raw['flattened_tensor']) - bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=metadata) - return list(bucket.reconstruct_tensors()) - return [(k, _construct(v)) for k, v in raw] - def _split_main_and_draft(weights): # TODO, zhouxinyu, support split and update weights for other mtp methods if not self.spec_agent.is_enabled() or self.spec_agent.method != 'qwen3_5_mtp':