Skip to content

Commit 18600ad

Browse files
irexycRunningLeon
andauthored
support disaggregated weight update (#4638)
* support disaggregated weight update * fix routed experts length * update * fix * fix race after wakeup before warmup * Revert "fix race after wakeup before warmup" This reverts commit 94adf31. * Revert "update" This reverts commit f40b1a7. * Revert "fix routed experts length" This reverts commit f29dfde. * use async func * fix EngineHealthMonitor * revert meaningless changes * fix ci * fix comments * fix health probe --------- Co-authored-by: RunningLeon <mnsheng@yeah.net>
1 parent 75f5ddc commit 18600ad

10 files changed

Lines changed: 361 additions & 4 deletions

File tree

lmdeploy/pytorch/engine/engine.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def __init__(
196196
# infer sleeping from empty_init: empty_init still builds runtime
197197
# resources and has its own weight-update workflow.
198198
self._sleeping_tags = set()
199+
self._weights_update_lock: asyncio.Lock | None = None
199200
self._multimodal_session_trim_count = max(0, _envs.multimodal_session_trim_count)
200201
self._multimodal_session_end_count = 0
201202

@@ -499,6 +500,29 @@ def update_params(self, request: Any):
499500
"""Update params."""
500501
self.executor.update_params(request)
501502

503+
def _get_weights_update_lock(self):
504+
"""Get the disaggregated weights-update lock."""
505+
if self._weights_update_lock is None:
506+
self._weights_update_lock = asyncio.Lock()
507+
return self._weights_update_lock
508+
509+
async def _run_weights_update(self, func, request: Any):
510+
"""Run one serialized disaggregated weights-update operation."""
511+
async with self._get_weights_update_lock():
512+
return await asyncio.to_thread(func, request)
513+
514+
async def init_weights_update_group(self, request: Any):
515+
"""Init disaggregated weights-update process group."""
516+
return await self._run_weights_update(self.executor.init_weights_update_group, request)
517+
518+
async def update_weights_from_distributed(self, request: Any):
519+
"""Receive weights through the disaggregated process group."""
520+
return await self._run_weights_update(self.executor.update_weights_from_distributed, request)
521+
522+
async def destroy_weights_update_group(self, request: Any):
523+
"""Tear down a previously initialized weights-update process group."""
524+
return await self._run_weights_update(self.executor.destroy_weights_update_group, request)
525+
502526
def _block_new_inputs(self):
503527
"""Block new inference work from engine instances."""
504528
logger.info('PyTorch engine is blocking new inference requests.')

lmdeploy/pytorch/engine/executor/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,18 @@ def update_params(self, request: Any):
106106
"""Update params."""
107107
raise NotImplementedError('Not Implemented.')
108108

109+
def init_weights_update_group(self, request: Any):
110+
"""Init disaggregated weights-update process group."""
111+
raise NotImplementedError('Not Implemented.')
112+
113+
def update_weights_from_distributed(self, request: Any):
114+
"""Receive weights through the disaggregated process group."""
115+
raise NotImplementedError('Not Implemented.')
116+
117+
def destroy_weights_update_group(self, request: Any):
118+
"""Tear down a previously initialized weights-update process group."""
119+
raise NotImplementedError('Not Implemented.')
120+
109121
def get_input_processor(self):
110122
"""Get input processor."""
111123
raise NotImplementedError('Not Implemented.')

lmdeploy/pytorch/engine/executor/base_worker.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,18 @@ def update_params(self, request: Any):
118118
"""Update params."""
119119
self.model_agent.update_params(request)
120120

121+
def init_weights_update_group(self, request: Any):
122+
"""Init disaggregated weights-update process group."""
123+
return self.model_agent.init_weights_update_group(request)
124+
125+
def update_weights_from_distributed(self, request: Any):
126+
"""Receive weights through the disaggregated process group."""
127+
return self.model_agent.update_weights_from_distributed(request)
128+
129+
def destroy_weights_update_group(self, request: Any):
130+
"""Tear down a previously initialized weights-update process group."""
131+
return self.model_agent.destroy_weights_update_group(request)
132+
121133
def warmup(self):
122134
"""warmup."""
123135
self.model_agent.warmup()

lmdeploy/pytorch/engine/executor/ray_executor.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,29 @@ def update_params(self, request: Any):
366366
"""Update params."""
367367
self.collective_rpc('update_params', (request, ))
368368

369+
def _reduce_worker_status(self, results: list[tuple[bool, str]], op_name: str) -> tuple[bool, str]:
370+
"""Reduce worker status results."""
371+
successes, messages = zip(*results)
372+
if all(successes):
373+
return True, messages[0]
374+
message = ' | '.join(f'rank{idx}: {message}' for idx, message in enumerate(messages))
375+
return False, f'{op_name}: {message}'
376+
377+
def init_weights_update_group(self, request: Any):
378+
"""Init disaggregated weights-update process group."""
379+
results = self.collective_rpc('init_weights_update_group', (request, ))
380+
return self._reduce_worker_status(results, 'init_weights_update_group')
381+
382+
def update_weights_from_distributed(self, request: Any):
383+
"""Receive weights through the disaggregated process group."""
384+
results = self.collective_rpc('update_weights_from_distributed', (request, ))
385+
return self._reduce_worker_status(results, 'update_weights_from_distributed')
386+
387+
def destroy_weights_update_group(self, request: Any):
388+
"""Tear down a previously initialized weights-update process group."""
389+
results = self.collective_rpc('destroy_weights_update_group', (request, ))
390+
return self._reduce_worker_status(results, 'destroy_weights_update_group')
391+
369392
def warmup(self):
370393
"""Build cache engine."""
371394
self.collective_rpc('warmup')
@@ -499,8 +522,11 @@ async def forward_async(self, inputs):
499522

500523
if self._prev_out is not None:
501524
try:
502-
ray.get(self._prev_out)
503-
except SystemExit:
525+
# Await (instead of blocking ray.get) so the engine event loop is yielded while the
526+
# previous forward runs. Blocking here stalls the whole loop, starving co-located
527+
# async tasks such as the health probe.
528+
await asyncio.gather(*self._prev_out)
529+
except (SystemExit, ray.exceptions.RayActorError):
504530
logger.error('Ray worker exited.')
505531
raise
506532
finally:

lmdeploy/pytorch/engine/model_agent/agent.py

Lines changed: 130 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,14 @@
2828
from lmdeploy.pytorch.strategies.base.model_agent import ExtraInputs, ExtraOutputs, StoppingCriteria
2929
from lmdeploy.pytorch.utils import get_gpu_memory, monkey_patch_hf_modules_cache, wait_for_async_tasks
3030
from lmdeploy.pytorch.weight_loader.model_weight_loader import ModelWeightLoader, load_model_weights
31-
from lmdeploy.serve.openai.protocol import UpdateParamsRequest
31+
from lmdeploy.serve.openai.protocol import (
32+
DestroyWeightsUpdateGroupRequest,
33+
InitWeightsUpdateGroupRequest,
34+
UpdateParamsRequest,
35+
UpdateWeightsFromDistributedRequest,
36+
)
3237
from lmdeploy.tokenizer import Tokenizer
33-
from lmdeploy.utils import FlattenedTensorBucket, FlattenedTensorMetadata, get_logger
38+
from lmdeploy.utils import FlattenedTensorBucket, FlattenedTensorMetadata, get_logger, init_custom_process_group
3439

3540
from .dp_utils import DistGatherScalar, DPForwardMeta, GatheredDPForwardMeta
3641
from .inputs_maker import build_inputs_maker
@@ -299,6 +304,9 @@ def __init__(
299304
self._update_params_ipc_tensor: torch.Tensor | None = None
300305
self._update_params_ipc_event: torch.cuda.Event | None = None
301306

307+
# disaggregated weight-update process groups, keyed by group_name
308+
self._model_update_group: dict[str, dist.ProcessGroup] = {}
309+
302310
# microbatch
303311
self.enable_microbatch = self.dist_config.enable_microbatch
304312
self.enable_microbatch_prefill_batchsize_threshold = \
@@ -1214,6 +1222,126 @@ def _split_main_and_draft(weights):
12141222

12151223
torch.cuda.empty_cache()
12161224

1225+
def init_weights_update_group(self, request: InitWeightsUpdateGroupRequest):
1226+
"""Create a NCCL process group with an external trainer for the
1227+
disaggregated weight-update path.
1228+
1229+
rank 0 is the trainer; this engine's local TP ranks fill `rank_offset .. rank_offset + tp - 1`.
1230+
"""
1231+
with self.all_context():
1232+
group_name = request.group_name
1233+
if not group_name:
1234+
return False, 'group_name cannot be empty'
1235+
if group_name in self._model_update_group:
1236+
return False, f'group {group_name!r} already initialized'
1237+
1238+
local_rank = self.dist_ctx.tp_group.rank
1239+
rank = request.rank_offset + local_rank
1240+
init_method = f'tcp://{request.master_address}:{request.master_port}'
1241+
logger.info(f'init weights update group: master={request.master_address}:{request.master_port}, '
1242+
f'rank_offset={request.rank_offset}, rank={rank}, world_size={request.world_size}, '
1243+
f'group_name={group_name}, backend={request.backend}')
1244+
try:
1245+
pg = init_custom_process_group(
1246+
backend=request.backend,
1247+
init_method=init_method,
1248+
world_size=request.world_size,
1249+
rank=rank,
1250+
group_name=group_name,
1251+
)
1252+
self._model_update_group[group_name] = pg
1253+
return True, 'Succeeded to initialize weights update group.'
1254+
except Exception as e:
1255+
msg = f'Failed to initialize weights update group: {e}'
1256+
logger.exception(msg)
1257+
return False, msg
1258+
1259+
@torch.inference_mode()
1260+
def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedRequest):
1261+
"""Receive a bucket of weights through the previously initialized NCCL
1262+
group and load them into the running model."""
1263+
with self.all_context():
1264+
group_name = request.group_name
1265+
pg = self._model_update_group.get(group_name)
1266+
if pg is None:
1267+
return False, (f'group {group_name!r} not initialized. '
1268+
'Call init_weights_update_group first.')
1269+
1270+
device = torch.cuda.current_device()
1271+
try:
1272+
if request.names:
1273+
named_tensors = []
1274+
for name, dtype_str, shape in zip(request.names, request.dtypes, request.shapes):
1275+
target_dtype = getattr(torch, dtype_str) if isinstance(dtype_str, str) else dtype_str
1276+
named_tensors.append((name, torch.empty(shape, dtype=target_dtype, device=device)))
1277+
1278+
if request.load_format == 'flattened_bucket':
1279+
bucket = FlattenedTensorBucket(named_tensors=named_tensors)
1280+
flattened_tensor = bucket.get_flattened_tensor()
1281+
dist.broadcast(flattened_tensor, src=0, group=pg)
1282+
weights = list(bucket.reconstruct_tensors())
1283+
else:
1284+
handles = []
1285+
for _, tensor in named_tensors:
1286+
handles.append(dist.broadcast(tensor, src=0, group=pg, async_op=True))
1287+
for handle in handles:
1288+
handle.wait()
1289+
weights = named_tensors
1290+
else:
1291+
weights = []
1292+
1293+
model = self.patched_model.get_model() if self.patched_model is not None else None
1294+
spec_model = self.spec_agent.get_model()
1295+
# Same draft-split rule as update_params (currently only qwen3_5_mtp).
1296+
if self.spec_agent.is_enabled() and self.spec_agent.method == 'qwen3_5_mtp':
1297+
main_weights = [(n, w) for n, w in weights if not n.startswith('mtp.')]
1298+
draft_weights = [(n, w) for n, w in weights if n.startswith('mtp.')]
1299+
else:
1300+
main_weights, draft_weights = weights, []
1301+
1302+
for m, w, tag in [(model, main_weights, 'main'), (spec_model, draft_weights, 'draft')]:
1303+
if m is None or not w:
1304+
continue
1305+
renamed = list(ModelWeightLoader._rename_weights_iterator(w, m))
1306+
logger.info(f'update_weights_from_distributed: {tag}_num_tensors={len(renamed)}')
1307+
m.load_weights(iter(renamed))
1308+
1309+
if request.finished:
1310+
for m in filter(None, [model, spec_model]):
1311+
for _, mod in m.named_modules():
1312+
if hasattr(mod, 'update_weights'):
1313+
mod.update_weights()
1314+
torch.cuda.synchronize()
1315+
# FusedMoE.update_weights() above replaces the gate_up / down
1316+
# Parameter objects (LinearWeights.update_weight registers a new
1317+
# nn.Parameter), so any CUDA graph captured before the update
1318+
# still references the freed old pointers. Drop the captured
1319+
# graphs so the next forward re-captures with the new params.
1320+
self.reset_graph_runner()
1321+
1322+
torch.cuda.empty_cache()
1323+
return True, 'Succeeded to update parameter online.'
1324+
except Exception as e:
1325+
msg = (f'Failed to update parameter online: {e}. The model weights are partially updated; '
1326+
'please discard them and reload.')
1327+
logger.exception(msg)
1328+
return False, msg
1329+
1330+
def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupRequest):
1331+
"""Destroy a previously initialized weights-update process group."""
1332+
group_name = request.group_name
1333+
pg = self._model_update_group.get(group_name)
1334+
if pg is None:
1335+
return False, f'group {group_name!r} not initialized'
1336+
try:
1337+
dist.destroy_process_group(pg)
1338+
self._model_update_group.pop(group_name)
1339+
return True, f'Succeeded to destroy group {group_name!r}.'
1340+
except Exception as e:
1341+
msg = f'Failed to destroy weights update group {group_name!r}: {e}'
1342+
logger.exception(msg)
1343+
return False, msg
1344+
12171345
@torch.inference_mode()
12181346
async def sleep(self, level: int = 1):
12191347
"""Sleep."""

lmdeploy/pytorch/engine/mp_engine/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,18 @@ def update_params(self, request: Any):
6767
"""Update params."""
6868
return self._collective_rpc('update_params', request)
6969

70+
async def init_weights_update_group(self, request: Any):
71+
"""Init disaggregated weights-update process group."""
72+
return await self._collective_rpc_async('init_weights_update_group', request)
73+
74+
async def update_weights_from_distributed(self, request: Any):
75+
"""Receive weights through the disaggregated process group."""
76+
return await self._collective_rpc_async('update_weights_from_distributed', request)
77+
78+
async def destroy_weights_update_group(self, request: Any):
79+
"""Tear down a previously initialized weights-update process group."""
80+
return await self._collective_rpc_async('destroy_weights_update_group', request)
81+
7082
async def get_schedule_metrics(self):
7183
"""Get schedule metrics."""
7284
return await self._collective_rpc_async('get_schedule_metrics')

lmdeploy/pytorch/engine/mp_engine/base_worker.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,18 @@ def update_params(self, request: Any):
116116
"""Update params."""
117117
return self.engine.update_params(request)
118118

119+
async def init_weights_update_group(self, request: Any):
120+
"""Init disaggregated weights-update process group."""
121+
return await self.engine.init_weights_update_group(request)
122+
123+
async def update_weights_from_distributed(self, request: Any):
124+
"""Receive weights through the disaggregated process group."""
125+
return await self.engine.update_weights_from_distributed(request)
126+
127+
async def destroy_weights_update_group(self, request: Any):
128+
"""Tear down a previously initialized weights-update process group."""
129+
return await self.engine.destroy_weights_update_group(request)
130+
119131
def close(self) -> None:
120132
"""Close engine worker."""
121133
self.engine.close()

lmdeploy/serve/openai/api_server.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,15 @@
6363
CompletionResponseStreamChoice,
6464
CompletionStreamResponse,
6565
DeltaMessage,
66+
DestroyWeightsUpdateGroupRequest,
6667
EmbeddingsRequest,
6768
EncodeRequest,
6869
EncodeResponse,
6970
ErrorResponse,
7071
GenerateReqInput,
7172
GenerateReqMetaOutput,
7273
GenerateReqOutput,
74+
InitWeightsUpdateGroupRequest,
7375
LogProbs,
7476
ModelCard,
7577
ModelList,
@@ -78,6 +80,7 @@
7880
PoolingResponse,
7981
TopLogprob,
8082
UpdateParamsRequest,
83+
UpdateWeightsFromDistributedRequest,
8184
UsageInfo,
8285
)
8386
from lmdeploy.serve.openai.responses import create_responses_router
@@ -1205,6 +1208,52 @@ def update_params(request: UpdateParamsRequest, raw_request: Request = None):
12051208
return JSONResponse(content=None)
12061209

12071210

1211+
def _check_pytorch_backend_for_disagg_weight_update():
1212+
"""Disaggregated weight-update endpoints are PyTorch-backend only for
1213+
now."""
1214+
backend = getattr(VariableInterface.async_engine, 'backend', None)
1215+
if backend != 'pytorch':
1216+
return create_error_response(
1217+
HTTPStatus.NOT_IMPLEMENTED,
1218+
f'Disaggregated weight-update endpoints require backend="pytorch", got {backend!r}.')
1219+
return None
1220+
1221+
1222+
@router.post('/init_weights_update_group', dependencies=[Depends(validate_json_request)])
1223+
async def init_weights_update_group(request: InitWeightsUpdateGroupRequest, raw_request: Request = None):
1224+
"""Initialize the torch.distributed process group used by an external
1225+
trainer to broadcast weights into this rollout engine."""
1226+
err = _check_pytorch_backend_for_disagg_weight_update()
1227+
if err is not None:
1228+
return err
1229+
success, message = await VariableInterface.async_engine.engine.init_weights_update_group(request)
1230+
content = {'success': success, 'message': message}
1231+
return JSONResponse(content=content, status_code=200 if success else HTTPStatus.BAD_REQUEST)
1232+
1233+
1234+
@router.post('/update_weights_from_distributed', dependencies=[Depends(validate_json_request)])
1235+
async def update_weights_from_distributed(request: UpdateWeightsFromDistributedRequest, raw_request: Request = None):
1236+
"""Receive a bucket of weights through a previously initialized weights-
1237+
update group and load them into the running model."""
1238+
err = _check_pytorch_backend_for_disagg_weight_update()
1239+
if err is not None:
1240+
return err
1241+
success, message = await VariableInterface.async_engine.engine.update_weights_from_distributed(request)
1242+
content = {'success': success, 'message': message}
1243+
return JSONResponse(content=content, status_code=200 if success else HTTPStatus.BAD_REQUEST)
1244+
1245+
1246+
@router.post('/destroy_weights_update_group', dependencies=[Depends(validate_json_request)])
1247+
async def destroy_weights_update_group(request: DestroyWeightsUpdateGroupRequest, raw_request: Request = None):
1248+
"""Tear down a previously initialized weights-update group."""
1249+
err = _check_pytorch_backend_for_disagg_weight_update()
1250+
if err is not None:
1251+
return err
1252+
success, message = await VariableInterface.async_engine.engine.destroy_weights_update_group(request)
1253+
content = {'success': success, 'message': message}
1254+
return JSONResponse(content=content, status_code=200 if success else HTTPStatus.BAD_REQUEST)
1255+
1256+
12081257
@router.post('/sleep', dependencies=[Depends(validate_json_request)])
12091258
async def sleep(raw_request: Request = None):
12101259
level = raw_request.query_params.get('level', '1')

0 commit comments

Comments
 (0)