Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def __init__(
# infer sleeping from empty_init: empty_init still builds runtime
# resources and has its own weight-update workflow.
self._sleeping_tags = set()
self._weights_update_lock: asyncio.Lock | None = None
self._multimodal_session_trim_count = max(0, _envs.multimodal_session_trim_count)
self._multimodal_session_end_count = 0

Expand Down Expand Up @@ -496,6 +497,29 @@ def update_params(self, request: Any):
"""Update params."""
self.executor.update_params(request)

def _get_weights_update_lock(self):
"""Get the disaggregated weights-update lock."""
if self._weights_update_lock is None:
self._weights_update_lock = asyncio.Lock()
return self._weights_update_lock

async def _run_weights_update(self, func, request: Any):
"""Run one serialized disaggregated weights-update operation."""
async with self._get_weights_update_lock():
return await asyncio.to_thread(func, request)

async def init_weights_update_group(self, request: Any):
"""Init disaggregated weights-update process group."""
return await self._run_weights_update(self.executor.init_weights_update_group, request)

async def update_weights_from_distributed(self, request: Any):
"""Receive weights through the disaggregated process group."""
return await self._run_weights_update(self.executor.update_weights_from_distributed, request)

async def destroy_weights_update_group(self, request: Any):
"""Tear down a previously initialized weights-update process group."""
return await self._run_weights_update(self.executor.destroy_weights_update_group, request)

def _block_new_inputs(self):
"""Block new inference work from engine instances."""
logger.info('PyTorch engine is blocking new inference requests.')
Expand Down
12 changes: 12 additions & 0 deletions lmdeploy/pytorch/engine/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,18 @@ def update_params(self, request: Any):
"""Update params."""
raise NotImplementedError('Not Implemented.')

def init_weights_update_group(self, request: Any):
"""Init disaggregated weights-update process group."""
raise NotImplementedError('Not Implemented.')

def update_weights_from_distributed(self, request: Any):
"""Receive weights through the disaggregated process group."""
raise NotImplementedError('Not Implemented.')

def destroy_weights_update_group(self, request: Any):
"""Tear down a previously initialized weights-update process group."""
raise NotImplementedError('Not Implemented.')

def get_input_processor(self):
"""Get input processor."""
raise NotImplementedError('Not Implemented.')
Expand Down
12 changes: 12 additions & 0 deletions lmdeploy/pytorch/engine/executor/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,18 @@ def update_params(self, request: Any):
"""Update params."""
self.model_agent.update_params(request)

def init_weights_update_group(self, request: Any):
"""Init disaggregated weights-update process group."""
return self.model_agent.init_weights_update_group(request)

def update_weights_from_distributed(self, request: Any):
"""Receive weights through the disaggregated process group."""
return self.model_agent.update_weights_from_distributed(request)

def destroy_weights_update_group(self, request: Any):
"""Tear down a previously initialized weights-update process group."""
return self.model_agent.destroy_weights_update_group(request)

def warmup(self):
"""warmup."""
self.model_agent.warmup()
Expand Down
30 changes: 28 additions & 2 deletions lmdeploy/pytorch/engine/executor/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,29 @@ def update_params(self, request: Any):
"""Update params."""
self.collective_rpc('update_params', (request, ))

def _reduce_worker_status(self, results: list[tuple[bool, str]], op_name: str) -> tuple[bool, str]:
"""Reduce worker status results."""
successes, messages = zip(*results)
if all(successes):
return True, messages[0]
message = ' | '.join(f'rank{idx}: {message}' for idx, message in enumerate(messages))
return False, f'{op_name}: {message}'

def init_weights_update_group(self, request: Any):
Comment thread
irexyc marked this conversation as resolved.
"""Init disaggregated weights-update process group."""
results = self.collective_rpc('init_weights_update_group', (request, ))
return self._reduce_worker_status(results, 'init_weights_update_group')

def update_weights_from_distributed(self, request: Any):
"""Receive weights through the disaggregated process group."""
results = self.collective_rpc('update_weights_from_distributed', (request, ))
return self._reduce_worker_status(results, 'update_weights_from_distributed')

def destroy_weights_update_group(self, request: Any):
"""Tear down a previously initialized weights-update process group."""
results = self.collective_rpc('destroy_weights_update_group', (request, ))
return self._reduce_worker_status(results, 'destroy_weights_update_group')

def warmup(self):
"""Build cache engine."""
self.collective_rpc('warmup')
Expand Down Expand Up @@ -499,8 +522,11 @@ async def forward_async(self, inputs):

if self._prev_out is not None:
try:
ray.get(self._prev_out)
except SystemExit:
# Await (instead of blocking ray.get) so the engine event loop is yielded while the
# previous forward runs. Blocking here stalls the whole loop, starving co-located
# async tasks such as the health probe.
await asyncio.gather(*self._prev_out)
except (SystemExit, ray.exceptions.RayActorError):
logger.error('Ray worker exited.')
raise
finally:
Expand Down
132 changes: 130 additions & 2 deletions lmdeploy/pytorch/engine/model_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@
from lmdeploy.pytorch.strategies.base.model_agent import ExtraInputs, ExtraOutputs, StoppingCriteria
from lmdeploy.pytorch.utils import get_gpu_memory, monkey_patch_hf_modules_cache, wait_for_async_tasks
from lmdeploy.pytorch.weight_loader.model_weight_loader import ModelWeightLoader, load_model_weights
from lmdeploy.serve.openai.protocol import UpdateParamsRequest
from lmdeploy.serve.openai.protocol import (
DestroyWeightsUpdateGroupRequest,
InitWeightsUpdateGroupRequest,
UpdateParamsRequest,
UpdateWeightsFromDistributedRequest,
)
from lmdeploy.tokenizer import Tokenizer
from lmdeploy.utils import FlattenedTensorBucket, FlattenedTensorMetadata, get_logger
from lmdeploy.utils import FlattenedTensorBucket, FlattenedTensorMetadata, get_logger, init_custom_process_group

from .dp_utils import DistGatherScalar, DPForwardMeta, GatheredDPForwardMeta
from .inputs_maker import build_inputs_maker
Expand Down Expand Up @@ -284,6 +289,9 @@ def __init__(
self._update_params_ipc_tensor: torch.Tensor | None = None
self._update_params_ipc_event: torch.cuda.Event | None = None

# disaggregated weight-update process groups, keyed by group_name
self._model_update_group: dict[str, dist.ProcessGroup] = {}

# microbatch
self.enable_microbatch = self.dist_config.enable_microbatch
self.enable_microbatch_prefill_batchsize_threshold = \
Expand Down Expand Up @@ -1194,6 +1202,126 @@ def _split_main_and_draft(weights):

torch.cuda.empty_cache()

def init_weights_update_group(self, request: InitWeightsUpdateGroupRequest):
"""Create a NCCL process group with an external trainer for the
disaggregated weight-update path.

rank 0 is the trainer; this engine's local TP ranks fill `rank_offset .. rank_offset + tp - 1`.
"""
with self.all_context():
group_name = request.group_name
if not group_name:
return False, 'group_name cannot be empty'
if group_name in self._model_update_group:
return False, f'group {group_name!r} already initialized'

local_rank = self.dist_ctx.tp_group.rank
Comment thread
grimoire marked this conversation as resolved.
rank = request.rank_offset + local_rank
init_method = f'tcp://{request.master_address}:{request.master_port}'
logger.info(f'init weights update group: master={request.master_address}:{request.master_port}, '
f'rank_offset={request.rank_offset}, rank={rank}, world_size={request.world_size}, '
f'group_name={group_name}, backend={request.backend}')
try:
pg = init_custom_process_group(
backend=request.backend,
init_method=init_method,
world_size=request.world_size,
rank=rank,
group_name=group_name,
)
self._model_update_group[group_name] = pg
return True, 'Succeeded to initialize weights update group.'
except Exception as e:
msg = f'Failed to initialize weights update group: {e}'
logger.exception(msg)
return False, msg

@torch.inference_mode()
def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedRequest):
"""Receive a bucket of weights through the previously initialized NCCL
group and load them into the running model."""
with self.all_context():
group_name = request.group_name
pg = self._model_update_group.get(group_name)
if pg is None:
return False, (f'group {group_name!r} not initialized. '
'Call init_weights_update_group first.')

device = torch.cuda.current_device()
try:
if request.names:
named_tensors = []
for name, dtype_str, shape in zip(request.names, request.dtypes, request.shapes):
target_dtype = getattr(torch, dtype_str) if isinstance(dtype_str, str) else dtype_str
named_tensors.append((name, torch.empty(shape, dtype=target_dtype, device=device)))
Comment on lines +1252 to +1256

if request.load_format == 'flattened_bucket':
bucket = FlattenedTensorBucket(named_tensors=named_tensors)
flattened_tensor = bucket.get_flattened_tensor()
dist.broadcast(flattened_tensor, src=0, group=pg)
weights = list(bucket.reconstruct_tensors())
else:
handles = []
for _, tensor in named_tensors:
handles.append(dist.broadcast(tensor, src=0, group=pg, async_op=True))
for handle in handles:
handle.wait()
weights = named_tensors
else:
weights = []

model = self.patched_model.get_model() if self.patched_model is not None else None
spec_model = self.spec_agent.get_model()
# Same draft-split rule as update_params (currently only qwen3_5_mtp).
if self.spec_agent.is_enabled() and self.spec_agent.method == 'qwen3_5_mtp':
main_weights = [(n, w) for n, w in weights if not n.startswith('mtp.')]
draft_weights = [(n, w) for n, w in weights if n.startswith('mtp.')]
else:
main_weights, draft_weights = weights, []

for m, w, tag in [(model, main_weights, 'main'), (spec_model, draft_weights, 'draft')]:
if m is None or not w:
continue
renamed = list(ModelWeightLoader._rename_weights_iterator(w, m))
logger.info(f'update_weights_from_distributed: {tag}_num_tensors={len(renamed)}')
m.load_weights(iter(renamed))

if request.finished:
for m in filter(None, [model, spec_model]):
for _, mod in m.named_modules():
if hasattr(mod, 'update_weights'):
mod.update_weights()
torch.cuda.synchronize()
# FusedMoE.update_weights() above replaces the gate_up / down
# Parameter objects (LinearWeights.update_weight registers a new
# nn.Parameter), so any CUDA graph captured before the update
# still references the freed old pointers. Drop the captured
# graphs so the next forward re-captures with the new params.
self.reset_graph_runner()

torch.cuda.empty_cache()
return True, 'Succeeded to update parameter online.'
except Exception as e:
msg = (f'Failed to update parameter online: {e}. The model weights are partially updated; '
'please discard them and reload.')
logger.exception(msg)
return False, msg

def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupRequest):
"""Destroy a previously initialized weights-update process group."""
group_name = request.group_name
pg = self._model_update_group.get(group_name)
if pg is None:
return False, f'group {group_name!r} not initialized'
try:
dist.destroy_process_group(pg)
self._model_update_group.pop(group_name)
return True, f'Succeeded to destroy group {group_name!r}.'
except Exception as e:
msg = f'Failed to destroy weights update group {group_name!r}: {e}'
logger.exception(msg)
return False, msg
Comment thread
irexyc marked this conversation as resolved.

@torch.inference_mode()
async def sleep(self, level: int = 1):
"""Sleep."""
Expand Down
12 changes: 12 additions & 0 deletions lmdeploy/pytorch/engine/mp_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ def update_params(self, request: Any):
"""Update params."""
return self._collective_rpc('update_params', request)

async def init_weights_update_group(self, request: Any):
"""Init disaggregated weights-update process group."""
return await self._collective_rpc_async('init_weights_update_group', request)

async def update_weights_from_distributed(self, request: Any):
"""Receive weights through the disaggregated process group."""
return await self._collective_rpc_async('update_weights_from_distributed', request)

async def destroy_weights_update_group(self, request: Any):
"""Tear down a previously initialized weights-update process group."""
return await self._collective_rpc_async('destroy_weights_update_group', request)

async def get_schedule_metrics(self):
"""Get schedule metrics."""
return await self._collective_rpc_async('get_schedule_metrics')
Expand Down
12 changes: 12 additions & 0 deletions lmdeploy/pytorch/engine/mp_engine/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,18 @@ def update_params(self, request: Any):
"""Update params."""
return self.engine.update_params(request)

async def init_weights_update_group(self, request: Any):
"""Init disaggregated weights-update process group."""
return await self.engine.init_weights_update_group(request)

async def update_weights_from_distributed(self, request: Any):
"""Receive weights through the disaggregated process group."""
return await self.engine.update_weights_from_distributed(request)

async def destroy_weights_update_group(self, request: Any):
"""Tear down a previously initialized weights-update process group."""
return await self.engine.destroy_weights_update_group(request)

def close(self) -> None:
"""Close engine worker."""
self.engine.close()
Expand Down
49 changes: 49 additions & 0 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,15 @@
CompletionResponseStreamChoice,
CompletionStreamResponse,
DeltaMessage,
DestroyWeightsUpdateGroupRequest,
EmbeddingsRequest,
EncodeRequest,
EncodeResponse,
ErrorResponse,
GenerateReqInput,
GenerateReqMetaOutput,
GenerateReqOutput,
InitWeightsUpdateGroupRequest,
LogProbs,
ModelCard,
ModelList,
Expand All @@ -78,6 +80,7 @@
PoolingResponse,
TopLogprob,
UpdateParamsRequest,
UpdateWeightsFromDistributedRequest,
UsageInfo,
)
from lmdeploy.serve.utils.request_cleanup import with_request_cleanup
Expand Down Expand Up @@ -1201,6 +1204,52 @@ def update_params(request: UpdateParamsRequest, raw_request: Request = None):
return JSONResponse(content=None)


def _check_pytorch_backend_for_disagg_weight_update():
"""Disaggregated weight-update endpoints are PyTorch-backend only for
now."""
backend = getattr(VariableInterface.async_engine, 'backend', None)
if backend != 'pytorch':
return create_error_response(
HTTPStatus.NOT_IMPLEMENTED,
f'Disaggregated weight-update endpoints require backend="pytorch", got {backend!r}.')
return None


@router.post('/init_weights_update_group', dependencies=[Depends(validate_json_request)])
async def init_weights_update_group(request: InitWeightsUpdateGroupRequest, raw_request: Request = None):
"""Initialize the torch.distributed process group used by an external
trainer to broadcast weights into this rollout engine."""
err = _check_pytorch_backend_for_disagg_weight_update()
if err is not None:
return err
success, message = await VariableInterface.async_engine.engine.init_weights_update_group(request)
content = {'success': success, 'message': message}
return JSONResponse(content=content, status_code=200 if success else HTTPStatus.BAD_REQUEST)


@router.post('/update_weights_from_distributed', dependencies=[Depends(validate_json_request)])
async def update_weights_from_distributed(request: UpdateWeightsFromDistributedRequest, raw_request: Request = None):
"""Receive a bucket of weights through a previously initialized weights-
update group and load them into the running model."""
err = _check_pytorch_backend_for_disagg_weight_update()
if err is not None:
return err
success, message = await VariableInterface.async_engine.engine.update_weights_from_distributed(request)
content = {'success': success, 'message': message}
return JSONResponse(content=content, status_code=200 if success else HTTPStatus.BAD_REQUEST)


@router.post('/destroy_weights_update_group', dependencies=[Depends(validate_json_request)])
async def destroy_weights_update_group(request: DestroyWeightsUpdateGroupRequest, raw_request: Request = None):
"""Tear down a previously initialized weights-update group."""
err = _check_pytorch_backend_for_disagg_weight_update()
if err is not None:
return err
success, message = await VariableInterface.async_engine.engine.destroy_weights_update_group(request)
content = {'success': success, 'message': message}
return JSONResponse(content=content, status_code=200 if success else HTTPStatus.BAD_REQUEST)


@router.post('/sleep', dependencies=[Depends(validate_json_request)])
async def sleep(raw_request: Request = None):
level = raw_request.query_params.get('level', '1')
Expand Down
Loading
Loading