diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index 53ea6d1951..29c19c96dd 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -1177,6 +1177,44 @@ def _construct(item, require_clone: bool = True): ipc_tensor = func(*args) return ipc_tensor.clone() if require_clone else ipc_tensor + def _deserialize_weights(serialized_data): + raw = ForkingPickler.loads(pybase64.b64decode(serialized_data)) + if request.load_format == 'flattened_bucket': + metadata: list[FlattenedTensorMetadata] = raw['metadata'] + if not metadata: + return [] + if 'flattened_tensor' in weights: + # Determine if clone is required + require_clone = weights.get('require_clone', True) + if 'event_ipc_handle' in weights and not hasattr(torch.cuda.Event, 'from_ipc_handle'): + # Force clone when IPC event is provided but cannot be used + require_clone = True + self._update_params_ipc_tensor = _construct(weights['flattened_tensor'], + require_clone=require_clone) + elif self._update_params_ipc_tensor is None: + raise ValueError( + 'flattened_tensor is not provided in weights and no cached ipc tensor is available. ' + 'Please provide flattened_tensor on the first update_params call.') + if 'event_ipc_handle' in weights and hasattr(torch.cuda.Event, 'from_ipc_handle'): + self._update_params_ipc_event = torch.cuda.Event.from_ipc_handle( + device=torch.cuda.current_device(), + handle=weights['event_ipc_handle'], + ) + flattened_tensor: torch.Tensor = self._update_params_ipc_tensor + if self._update_params_ipc_event is not None: + self._update_params_ipc_event.wait() + bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=metadata) + return list(bucket.reconstruct_tensors()) + return [(k, _construct(v)) for k, v in raw] + + def _split_main_and_draft(weights): + # TODO, zhouxinyu, support split and update weights for other mtp methods + if not self.spec_agent.is_enabled() or self.spec_agent.method != 'qwen3_5_mtp': + return weights, [] + main = [(name, weight) for name, weight in weights if not name.startswith('mtp.')] + draft = [(name, weight) for name, weight in weights if name.startswith('mtp.')] + return main, draft + with self.all_context(): # After deserialization, weights is a dict with following keys: # - metadata: List[FlattenedTensorMetadata] @@ -1186,52 +1224,33 @@ def _construct(item, require_clone: bool = True): serialized_data = request.serialized_named_tensors if isinstance(serialized_data, list): serialized_data = serialized_data[self.dist_ctx.tp_group.rank] + model = self.patched_model.get_model() - weights = ForkingPickler.loads(pybase64.b64decode(serialized_data)) - if request.load_format == 'flattened_bucket': - metadata: list[FlattenedTensorMetadata] = weights['metadata'] - if metadata: - if 'flattened_tensor' in weights: - # Determine if clone is required - require_clone = weights.get('require_clone', True) - if 'event_ipc_handle' in weights and not hasattr(torch.cuda.Event, 'from_ipc_handle'): - # Force clone when IPC event is provided but cannot be used - require_clone = True - self._update_params_ipc_tensor = _construct(weights['flattened_tensor'], - require_clone=require_clone) - elif self._update_params_ipc_tensor is None: - raise ValueError( - 'flattened_tensor is not provided in weights and no cached ipc tensor is available. ' - 'Please provide flattened_tensor on the first update_params call.') - if 'event_ipc_handle' in weights and hasattr(torch.cuda.Event, 'from_ipc_handle'): - self._update_params_ipc_event = torch.cuda.Event.from_ipc_handle( - device=torch.cuda.current_device(), - handle=weights['event_ipc_handle'], - ) - flattened_tensor: torch.Tensor = self._update_params_ipc_tensor - if self._update_params_ipc_event is not None: - self._update_params_ipc_event.wait() - bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=metadata) - weights = bucket.reconstruct_tensors() - else: - # empty data - weights = [] - else: - weights = [(k, _construct(v)) for k, v in weights] + spec_model = self.spec_agent.get_model() - weights = ModelWeightLoader._rename_weights_iterator(weights, model) - model.load_weights(weights) - if self._update_params_ipc_event is not None: - self._update_params_ipc_event.record() + weights = _deserialize_weights(serialized_data) + main_weights, draft_weights = _split_main_and_draft(weights) + + for m, w, tag in [(model, main_weights, 'main'), (spec_model, draft_weights, 'draft')]: + if m is None or not w: + continue + + w = list(ModelWeightLoader._rename_weights_iterator(w, m)) + logger.info(f'Update_params: {tag}_num_tensors={len(w)}') + m.load_weights(iter(w)) + + if self._update_params_ipc_event is not None: + self._update_params_ipc_event.record() if request.finished: - for _, mod in model.named_modules(): - if not hasattr(mod, 'update_weights'): - continue - mod.update_weights() - torch.cuda.synchronize() - self._update_params_ipc_event = None - self._update_params_ipc_tensor = None + for m in filter(None, [model, spec_model]): + for _, mod in m.named_modules(): + if hasattr(mod, 'update_weights'): + mod.update_weights() + + torch.cuda.synchronize() + self._update_params_ipc_event = None + self._update_params_ipc_tensor = None torch.cuda.empty_cache() @@ -1241,11 +1260,17 @@ async def sleep(self, level: int = 1): self.state.is_sleeping = True if self.dist_config.dp > 1: await self.state.to_sleep.wait() + device = 'cpu' if level == 1 else 'meta' self.cache_engine = None self.state_cache_engine = None self.reset_graph_runner() - device = 'cpu' if level == 1 else 'meta' self.patched_model.get_model().to(device=device, non_blocking=True) + + spec_model = self.spec_agent.get_model() + if spec_model is not None: + self.spec_agent.cache_engine = None + spec_model.to(device=device, non_blocking=True) + torch.cuda.synchronize() # force clean _update_params_ipc tensor and event after all gpu jobs done self._update_params_ipc_tensor = None @@ -1258,11 +1283,16 @@ def wakeup(self, tags: list[str] | None = None): """Wakeup.""" if tags is None: tags = ['weights', 'kv_cache'] + if 'weights' in tags: device = next(self.patched_model.get_model().parameters()).device assert device.type in ['cpu', 'meta'] + spec_model = self.spec_agent.get_model() + if device.type == 'cpu': self.patched_model.get_model().to(torch.cuda.current_device()) + if spec_model is not None: + spec_model.to(torch.cuda.current_device()) else: # user should update weights after wakeup old_empty_init = self.misc_config.empty_init diff --git a/lmdeploy/pytorch/engine/mp_engine/base.py b/lmdeploy/pytorch/engine/mp_engine/base.py index a5c16dd967..c0352787b1 100644 --- a/lmdeploy/pytorch/engine/mp_engine/base.py +++ b/lmdeploy/pytorch/engine/mp_engine/base.py @@ -65,9 +65,9 @@ def update_params(self, request: Any): """Update params.""" return self._collective_rpc('update_params', request) - def get_schedule_metrics(self): + async def get_schedule_metrics(self): """Get schedule metrics.""" - return self._collective_rpc('get_schedule_metrics') + return await self._collective_rpc_async('get_schedule_metrics') def p2p_initialize(self, conn_request: DistServeInitRequest): """Init rdma link.""" diff --git a/lmdeploy/pytorch/spec_decode/base.py b/lmdeploy/pytorch/spec_decode/base.py index 3ecfab5f82..d39c8a7082 100644 --- a/lmdeploy/pytorch/spec_decode/base.py +++ b/lmdeploy/pytorch/spec_decode/base.py @@ -65,3 +65,7 @@ def update_main_model_outputs(self, output: dict[str, torch.Tensor], model_input # replace with aux output['hidden_states'] = output.pop('aux_hidden_states') return hidden_states, output + + def get_model(self): + """Get model.""" + return None diff --git a/lmdeploy/pytorch/spec_decode/spec_agent.py b/lmdeploy/pytorch/spec_decode/spec_agent.py index 4000461e82..e8f2131228 100644 --- a/lmdeploy/pytorch/spec_decode/spec_agent.py +++ b/lmdeploy/pytorch/spec_decode/spec_agent.py @@ -231,6 +231,10 @@ def warmup(self, max_batches: int, target_model_config: ModelConfig): self._forward_impl(inputs) def reset_graph_runner(self): - 'reset graph runner' + """Reset graph runner.""" if self.proposer.model is not None and hasattr(self.proposer.model, 'reset'): self.proposer.model.reset() + + def get_model(self): + """Get model.""" + return self.proposer.model.get_model() diff --git a/lmdeploy/serve/core/async_engine.py b/lmdeploy/serve/core/async_engine.py index 60807bcc77..c5dfcd0364 100644 --- a/lmdeploy/serve/core/async_engine.py +++ b/lmdeploy/serve/core/async_engine.py @@ -201,8 +201,11 @@ def _build_stat_loggers(self): # set stats loggers of metrics processor metrics_processor.stat_loggers = self.stat_loggers - def get_schedule_metrics(self): - return self.engine.get_schedule_metrics() + async def get_schedule_metrics(self): + result = self.engine.get_schedule_metrics() + if asyncio.iscoroutine(result): + return await result + return result async def do_log_stats(self): """Loop through CLI logger and Prometheus logger and output the diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 5e84ee4221..2c552febd0 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -1381,7 +1381,7 @@ async def _force_log(): await asyncio.sleep(log_interval) # periodically update schedule metrics, as they change less frequently than iteration stats - schedule_metrics = async_engine.get_schedule_metrics() + schedule_metrics = await async_engine.get_schedule_metrics() await metrics_processor.update_schedule_stats(schedule_metrics) await async_engine.do_log_stats()