diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 9271c577f5..4c52c4df0e 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -195,6 +195,7 @@ def __init__( # infer sleeping from empty_init: empty_init still builds runtime # resources and has its own weight-update workflow. self._sleeping_tags = set() + self._weights_update_lock: asyncio.Lock | None = None self._multimodal_session_trim_count = max(0, _envs.multimodal_session_trim_count) self._multimodal_session_end_count = 0 @@ -496,6 +497,29 @@ def update_params(self, request: Any): """Update params.""" self.executor.update_params(request) + def _get_weights_update_lock(self): + """Get the disaggregated weights-update lock.""" + if self._weights_update_lock is None: + self._weights_update_lock = asyncio.Lock() + return self._weights_update_lock + + async def _run_weights_update(self, func, request: Any): + """Run one serialized disaggregated weights-update operation.""" + async with self._get_weights_update_lock(): + return await asyncio.to_thread(func, request) + + async def init_weights_update_group(self, request: Any): + """Init disaggregated weights-update process group.""" + return await self._run_weights_update(self.executor.init_weights_update_group, request) + + async def update_weights_from_distributed(self, request: Any): + """Receive weights through the disaggregated process group.""" + return await self._run_weights_update(self.executor.update_weights_from_distributed, request) + + async def destroy_weights_update_group(self, request: Any): + """Tear down a previously initialized weights-update process group.""" + return await self._run_weights_update(self.executor.destroy_weights_update_group, request) + def _block_new_inputs(self): """Block new inference work from engine instances.""" logger.info('PyTorch engine is blocking new inference requests.') diff --git a/lmdeploy/pytorch/engine/executor/base.py b/lmdeploy/pytorch/engine/executor/base.py index db4ad9008b..c1eea95551 100644 --- a/lmdeploy/pytorch/engine/executor/base.py +++ b/lmdeploy/pytorch/engine/executor/base.py @@ -99,6 +99,18 @@ def update_params(self, request: Any): """Update params.""" raise NotImplementedError('Not Implemented.') + def init_weights_update_group(self, request: Any): + """Init disaggregated weights-update process group.""" + raise NotImplementedError('Not Implemented.') + + def update_weights_from_distributed(self, request: Any): + """Receive weights through the disaggregated process group.""" + raise NotImplementedError('Not Implemented.') + + def destroy_weights_update_group(self, request: Any): + """Tear down a previously initialized weights-update process group.""" + raise NotImplementedError('Not Implemented.') + def get_input_processor(self): """Get input processor.""" raise NotImplementedError('Not Implemented.') diff --git a/lmdeploy/pytorch/engine/executor/base_worker.py b/lmdeploy/pytorch/engine/executor/base_worker.py index 038af61428..09eaf06e09 100644 --- a/lmdeploy/pytorch/engine/executor/base_worker.py +++ b/lmdeploy/pytorch/engine/executor/base_worker.py @@ -118,6 +118,18 @@ def update_params(self, request: Any): """Update params.""" self.model_agent.update_params(request) + def init_weights_update_group(self, request: Any): + """Init disaggregated weights-update process group.""" + return self.model_agent.init_weights_update_group(request) + + def update_weights_from_distributed(self, request: Any): + """Receive weights through the disaggregated process group.""" + return self.model_agent.update_weights_from_distributed(request) + + def destroy_weights_update_group(self, request: Any): + """Tear down a previously initialized weights-update process group.""" + return self.model_agent.destroy_weights_update_group(request) + def warmup(self): """warmup.""" self.model_agent.warmup() diff --git a/lmdeploy/pytorch/engine/executor/ray_executor.py b/lmdeploy/pytorch/engine/executor/ray_executor.py index 5f6442b514..ea7cba386e 100644 --- a/lmdeploy/pytorch/engine/executor/ray_executor.py +++ b/lmdeploy/pytorch/engine/executor/ray_executor.py @@ -366,6 +366,29 @@ def update_params(self, request: Any): """Update params.""" self.collective_rpc('update_params', (request, )) + def _reduce_worker_status(self, results: list[tuple[bool, str]], op_name: str) -> tuple[bool, str]: + """Reduce worker status results.""" + successes, messages = zip(*results) + if all(successes): + return True, messages[0] + message = ' | '.join(f'rank{idx}: {message}' for idx, message in enumerate(messages)) + return False, f'{op_name}: {message}' + + def init_weights_update_group(self, request: Any): + """Init disaggregated weights-update process group.""" + results = self.collective_rpc('init_weights_update_group', (request, )) + return self._reduce_worker_status(results, 'init_weights_update_group') + + def update_weights_from_distributed(self, request: Any): + """Receive weights through the disaggregated process group.""" + results = self.collective_rpc('update_weights_from_distributed', (request, )) + return self._reduce_worker_status(results, 'update_weights_from_distributed') + + def destroy_weights_update_group(self, request: Any): + """Tear down a previously initialized weights-update process group.""" + results = self.collective_rpc('destroy_weights_update_group', (request, )) + return self._reduce_worker_status(results, 'destroy_weights_update_group') + def warmup(self): """Build cache engine.""" self.collective_rpc('warmup') @@ -499,8 +522,11 @@ async def forward_async(self, inputs): if self._prev_out is not None: try: - ray.get(self._prev_out) - except SystemExit: + # Await (instead of blocking ray.get) so the engine event loop is yielded while the + # previous forward runs. Blocking here stalls the whole loop, starving co-located + # async tasks such as the health probe. + await asyncio.gather(*self._prev_out) + except (SystemExit, ray.exceptions.RayActorError): logger.error('Ray worker exited.') raise finally: diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index e6e0cfbde9..ecf73faf97 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -28,9 +28,14 @@ from lmdeploy.pytorch.strategies.base.model_agent import ExtraInputs, ExtraOutputs, StoppingCriteria from lmdeploy.pytorch.utils import get_gpu_memory, monkey_patch_hf_modules_cache, wait_for_async_tasks from lmdeploy.pytorch.weight_loader.model_weight_loader import ModelWeightLoader, load_model_weights -from lmdeploy.serve.openai.protocol import UpdateParamsRequest +from lmdeploy.serve.openai.protocol import ( + DestroyWeightsUpdateGroupRequest, + InitWeightsUpdateGroupRequest, + UpdateParamsRequest, + UpdateWeightsFromDistributedRequest, +) from lmdeploy.tokenizer import Tokenizer -from lmdeploy.utils import FlattenedTensorBucket, FlattenedTensorMetadata, get_logger +from lmdeploy.utils import FlattenedTensorBucket, FlattenedTensorMetadata, get_logger, init_custom_process_group from .dp_utils import DistGatherScalar, DPForwardMeta, GatheredDPForwardMeta from .inputs_maker import build_inputs_maker @@ -284,6 +289,9 @@ def __init__( self._update_params_ipc_tensor: torch.Tensor | None = None self._update_params_ipc_event: torch.cuda.Event | None = None + # disaggregated weight-update process groups, keyed by group_name + self._model_update_group: dict[str, dist.ProcessGroup] = {} + # microbatch self.enable_microbatch = self.dist_config.enable_microbatch self.enable_microbatch_prefill_batchsize_threshold = \ @@ -1194,6 +1202,126 @@ def _split_main_and_draft(weights): torch.cuda.empty_cache() + def init_weights_update_group(self, request: InitWeightsUpdateGroupRequest): + """Create a NCCL process group with an external trainer for the + disaggregated weight-update path. + + rank 0 is the trainer; this engine's local TP ranks fill `rank_offset .. rank_offset + tp - 1`. + """ + with self.all_context(): + group_name = request.group_name + if not group_name: + return False, 'group_name cannot be empty' + if group_name in self._model_update_group: + return False, f'group {group_name!r} already initialized' + + local_rank = self.dist_ctx.tp_group.rank + rank = request.rank_offset + local_rank + init_method = f'tcp://{request.master_address}:{request.master_port}' + logger.info(f'init weights update group: master={request.master_address}:{request.master_port}, ' + f'rank_offset={request.rank_offset}, rank={rank}, world_size={request.world_size}, ' + f'group_name={group_name}, backend={request.backend}') + try: + pg = init_custom_process_group( + backend=request.backend, + init_method=init_method, + world_size=request.world_size, + rank=rank, + group_name=group_name, + ) + self._model_update_group[group_name] = pg + return True, 'Succeeded to initialize weights update group.' + except Exception as e: + msg = f'Failed to initialize weights update group: {e}' + logger.exception(msg) + return False, msg + + @torch.inference_mode() + def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedRequest): + """Receive a bucket of weights through the previously initialized NCCL + group and load them into the running model.""" + with self.all_context(): + group_name = request.group_name + pg = self._model_update_group.get(group_name) + if pg is None: + return False, (f'group {group_name!r} not initialized. ' + 'Call init_weights_update_group first.') + + device = torch.cuda.current_device() + try: + if request.names: + named_tensors = [] + for name, dtype_str, shape in zip(request.names, request.dtypes, request.shapes): + target_dtype = getattr(torch, dtype_str) if isinstance(dtype_str, str) else dtype_str + named_tensors.append((name, torch.empty(shape, dtype=target_dtype, device=device))) + + if request.load_format == 'flattened_bucket': + bucket = FlattenedTensorBucket(named_tensors=named_tensors) + flattened_tensor = bucket.get_flattened_tensor() + dist.broadcast(flattened_tensor, src=0, group=pg) + weights = list(bucket.reconstruct_tensors()) + else: + handles = [] + for _, tensor in named_tensors: + handles.append(dist.broadcast(tensor, src=0, group=pg, async_op=True)) + for handle in handles: + handle.wait() + weights = named_tensors + else: + weights = [] + + model = self.patched_model.get_model() if self.patched_model is not None else None + spec_model = self.spec_agent.get_model() + # Same draft-split rule as update_params (currently only qwen3_5_mtp). + if self.spec_agent.is_enabled() and self.spec_agent.method == 'qwen3_5_mtp': + main_weights = [(n, w) for n, w in weights if not n.startswith('mtp.')] + draft_weights = [(n, w) for n, w in weights if n.startswith('mtp.')] + else: + main_weights, draft_weights = weights, [] + + for m, w, tag in [(model, main_weights, 'main'), (spec_model, draft_weights, 'draft')]: + if m is None or not w: + continue + renamed = list(ModelWeightLoader._rename_weights_iterator(w, m)) + logger.info(f'update_weights_from_distributed: {tag}_num_tensors={len(renamed)}') + m.load_weights(iter(renamed)) + + if request.finished: + for m in filter(None, [model, spec_model]): + for _, mod in m.named_modules(): + if hasattr(mod, 'update_weights'): + mod.update_weights() + torch.cuda.synchronize() + # FusedMoE.update_weights() above replaces the gate_up / down + # Parameter objects (LinearWeights.update_weight registers a new + # nn.Parameter), so any CUDA graph captured before the update + # still references the freed old pointers. Drop the captured + # graphs so the next forward re-captures with the new params. + self.reset_graph_runner() + + torch.cuda.empty_cache() + return True, 'Succeeded to update parameter online.' + except Exception as e: + msg = (f'Failed to update parameter online: {e}. The model weights are partially updated; ' + 'please discard them and reload.') + logger.exception(msg) + return False, msg + + def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupRequest): + """Destroy a previously initialized weights-update process group.""" + group_name = request.group_name + pg = self._model_update_group.get(group_name) + if pg is None: + return False, f'group {group_name!r} not initialized' + try: + dist.destroy_process_group(pg) + self._model_update_group.pop(group_name) + return True, f'Succeeded to destroy group {group_name!r}.' + except Exception as e: + msg = f'Failed to destroy weights update group {group_name!r}: {e}' + logger.exception(msg) + return False, msg + @torch.inference_mode() async def sleep(self, level: int = 1): """Sleep.""" diff --git a/lmdeploy/pytorch/engine/mp_engine/base.py b/lmdeploy/pytorch/engine/mp_engine/base.py index 5e70a78b10..8b834234d2 100644 --- a/lmdeploy/pytorch/engine/mp_engine/base.py +++ b/lmdeploy/pytorch/engine/mp_engine/base.py @@ -67,6 +67,18 @@ def update_params(self, request: Any): """Update params.""" return self._collective_rpc('update_params', request) + async def init_weights_update_group(self, request: Any): + """Init disaggregated weights-update process group.""" + return await self._collective_rpc_async('init_weights_update_group', request) + + async def update_weights_from_distributed(self, request: Any): + """Receive weights through the disaggregated process group.""" + return await self._collective_rpc_async('update_weights_from_distributed', request) + + async def destroy_weights_update_group(self, request: Any): + """Tear down a previously initialized weights-update process group.""" + return await self._collective_rpc_async('destroy_weights_update_group', request) + async def get_schedule_metrics(self): """Get schedule metrics.""" return await self._collective_rpc_async('get_schedule_metrics') diff --git a/lmdeploy/pytorch/engine/mp_engine/base_worker.py b/lmdeploy/pytorch/engine/mp_engine/base_worker.py index c42e45c70d..d80700d9a6 100644 --- a/lmdeploy/pytorch/engine/mp_engine/base_worker.py +++ b/lmdeploy/pytorch/engine/mp_engine/base_worker.py @@ -116,6 +116,18 @@ def update_params(self, request: Any): """Update params.""" return self.engine.update_params(request) + async def init_weights_update_group(self, request: Any): + """Init disaggregated weights-update process group.""" + return await self.engine.init_weights_update_group(request) + + async def update_weights_from_distributed(self, request: Any): + """Receive weights through the disaggregated process group.""" + return await self.engine.update_weights_from_distributed(request) + + async def destroy_weights_update_group(self, request: Any): + """Tear down a previously initialized weights-update process group.""" + return await self.engine.destroy_weights_update_group(request) + def close(self) -> None: """Close engine worker.""" self.engine.close() diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index a772f1ba37..b361c835f6 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -63,6 +63,7 @@ CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage, + DestroyWeightsUpdateGroupRequest, EmbeddingsRequest, EncodeRequest, EncodeResponse, @@ -70,6 +71,7 @@ GenerateReqInput, GenerateReqMetaOutput, GenerateReqOutput, + InitWeightsUpdateGroupRequest, LogProbs, ModelCard, ModelList, @@ -78,6 +80,7 @@ PoolingResponse, TopLogprob, UpdateParamsRequest, + UpdateWeightsFromDistributedRequest, UsageInfo, ) from lmdeploy.serve.utils.request_cleanup import with_request_cleanup @@ -1201,6 +1204,52 @@ def update_params(request: UpdateParamsRequest, raw_request: Request = None): return JSONResponse(content=None) +def _check_pytorch_backend_for_disagg_weight_update(): + """Disaggregated weight-update endpoints are PyTorch-backend only for + now.""" + backend = getattr(VariableInterface.async_engine, 'backend', None) + if backend != 'pytorch': + return create_error_response( + HTTPStatus.NOT_IMPLEMENTED, + f'Disaggregated weight-update endpoints require backend="pytorch", got {backend!r}.') + return None + + +@router.post('/init_weights_update_group', dependencies=[Depends(validate_json_request)]) +async def init_weights_update_group(request: InitWeightsUpdateGroupRequest, raw_request: Request = None): + """Initialize the torch.distributed process group used by an external + trainer to broadcast weights into this rollout engine.""" + err = _check_pytorch_backend_for_disagg_weight_update() + if err is not None: + return err + success, message = await VariableInterface.async_engine.engine.init_weights_update_group(request) + content = {'success': success, 'message': message} + return JSONResponse(content=content, status_code=200 if success else HTTPStatus.BAD_REQUEST) + + +@router.post('/update_weights_from_distributed', dependencies=[Depends(validate_json_request)]) +async def update_weights_from_distributed(request: UpdateWeightsFromDistributedRequest, raw_request: Request = None): + """Receive a bucket of weights through a previously initialized weights- + update group and load them into the running model.""" + err = _check_pytorch_backend_for_disagg_weight_update() + if err is not None: + return err + success, message = await VariableInterface.async_engine.engine.update_weights_from_distributed(request) + content = {'success': success, 'message': message} + return JSONResponse(content=content, status_code=200 if success else HTTPStatus.BAD_REQUEST) + + +@router.post('/destroy_weights_update_group', dependencies=[Depends(validate_json_request)]) +async def destroy_weights_update_group(request: DestroyWeightsUpdateGroupRequest, raw_request: Request = None): + """Tear down a previously initialized weights-update group.""" + err = _check_pytorch_backend_for_disagg_weight_update() + if err is not None: + return err + success, message = await VariableInterface.async_engine.engine.destroy_weights_update_group(request) + content = {'success': success, 'message': message} + return JSONResponse(content=content, status_code=200 if success else HTTPStatus.BAD_REQUEST) + + @router.post('/sleep', dependencies=[Depends(validate_json_request)]) async def sleep(raw_request: Request = None): level = raw_request.query_params.get('level', '1') diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index a4c19d457f..978266c529 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -476,6 +476,32 @@ class UpdateParamsRequest(BaseModel): finished: bool = False +class InitWeightsUpdateGroupRequest(BaseModel): + """Initialize a torch.distributed process group used to broadcast weights + from an external trainer into the rollout engine.""" + master_address: str + master_port: int + rank_offset: int + world_size: int + group_name: str + backend: str = 'nccl' + + +class UpdateWeightsFromDistributedRequest(BaseModel): + """Receive weights through a previously initialized distributed group and + load them into the running model.""" + names: list[str] + dtypes: list[str] + shapes: list[list[int]] + group_name: str + load_format: str | None = None # 'flattened_bucket' or None + finished: bool = False # trigger mod.update_weights() finalization when True + + +class DestroyWeightsUpdateGroupRequest(BaseModel): + """Tear down a previously initialized weights-update process group.""" + group_name: str + # /generate input class GenerateReqInput(BaseModel): diff --git a/lmdeploy/utils.py b/lmdeploy/utils.py index 11783a5898..c16e2497f5 100644 --- a/lmdeploy/utils.py +++ b/lmdeploy/utils.py @@ -623,3 +623,59 @@ def reconstruct_tensors(self) -> list[tuple[str, torch.Tensor]]: reconstructed[i] = (meta.name, tensor) return reconstructed + + +# Copied from Xtuner to allow creating a NCCL process group that is +# NOT a subgroup of the current default world. +# https://github.com/InternLM/xtuner/blob/main/xtuner/v1/rl/trainer/update_weighter.py#L491 +def init_custom_process_group(backend=None, + init_method=None, + timeout=None, + world_size: int = -1, + rank: int = -1, + store=None, + group_name: str | None = None, + pg_options=None): + from packaging.version import parse as parse_version + from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + _new_process_group_helper, + _world, + default_pg_timeout, + rendezvous, + ) + + assert (store is None) or (init_method is None), 'Cannot specify both init_method and store.' + if store is not None: + assert world_size > 0, 'world_size must be positive if using store' + assert rank >= 0, 'rank must be non-negative if using store' + elif init_method is None: + init_method = 'env://' + + backend = Backend(backend) if backend else Backend('undefined') + if timeout is None: + timeout = default_pg_timeout + + if store is None: + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + if group_name is not None: + store = PrefixStore(group_name, store) + + # PyTorch >= 2.6 renamed pg_options -> backend_options. + pg_options_param_name = 'backend_options' if parse_version(torch.__version__) >= parse_version('2.6') else \ + 'pg_options' + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + ) + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + return pg