|
28 | 28 | from lmdeploy.pytorch.strategies.base.model_agent import ExtraInputs, ExtraOutputs, StoppingCriteria |
29 | 29 | from lmdeploy.pytorch.utils import get_gpu_memory, monkey_patch_hf_modules_cache, wait_for_async_tasks |
30 | 30 | from lmdeploy.pytorch.weight_loader.model_weight_loader import ModelWeightLoader, load_model_weights |
31 | | -from lmdeploy.serve.openai.protocol import UpdateParamsRequest |
| 31 | +from lmdeploy.serve.openai.protocol import ( |
| 32 | + DestroyWeightsUpdateGroupRequest, |
| 33 | + InitWeightsUpdateGroupRequest, |
| 34 | + UpdateParamsRequest, |
| 35 | + UpdateWeightsFromDistributedRequest, |
| 36 | +) |
32 | 37 | from lmdeploy.tokenizer import Tokenizer |
33 | | -from lmdeploy.utils import FlattenedTensorBucket, FlattenedTensorMetadata, get_logger |
| 38 | +from lmdeploy.utils import FlattenedTensorBucket, FlattenedTensorMetadata, get_logger, init_custom_process_group |
34 | 39 |
|
35 | 40 | from .dp_utils import DistGatherScalar, DPForwardMeta, GatheredDPForwardMeta |
36 | 41 | from .inputs_maker import build_inputs_maker |
@@ -299,6 +304,9 @@ def __init__( |
299 | 304 | self._update_params_ipc_tensor: torch.Tensor | None = None |
300 | 305 | self._update_params_ipc_event: torch.cuda.Event | None = None |
301 | 306 |
|
| 307 | + # disaggregated weight-update process groups, keyed by group_name |
| 308 | + self._model_update_group: dict[str, dist.ProcessGroup] = {} |
| 309 | + |
302 | 310 | # microbatch |
303 | 311 | self.enable_microbatch = self.dist_config.enable_microbatch |
304 | 312 | self.enable_microbatch_prefill_batchsize_threshold = \ |
@@ -1214,6 +1222,126 @@ def _split_main_and_draft(weights): |
1214 | 1222 |
|
1215 | 1223 | torch.cuda.empty_cache() |
1216 | 1224 |
|
| 1225 | + def init_weights_update_group(self, request: InitWeightsUpdateGroupRequest): |
| 1226 | + """Create a NCCL process group with an external trainer for the |
| 1227 | + disaggregated weight-update path. |
| 1228 | +
|
| 1229 | + rank 0 is the trainer; this engine's local TP ranks fill `rank_offset .. rank_offset + tp - 1`. |
| 1230 | + """ |
| 1231 | + with self.all_context(): |
| 1232 | + group_name = request.group_name |
| 1233 | + if not group_name: |
| 1234 | + return False, 'group_name cannot be empty' |
| 1235 | + if group_name in self._model_update_group: |
| 1236 | + return False, f'group {group_name!r} already initialized' |
| 1237 | + |
| 1238 | + local_rank = self.dist_ctx.tp_group.rank |
| 1239 | + rank = request.rank_offset + local_rank |
| 1240 | + init_method = f'tcp://{request.master_address}:{request.master_port}' |
| 1241 | + logger.info(f'init weights update group: master={request.master_address}:{request.master_port}, ' |
| 1242 | + f'rank_offset={request.rank_offset}, rank={rank}, world_size={request.world_size}, ' |
| 1243 | + f'group_name={group_name}, backend={request.backend}') |
| 1244 | + try: |
| 1245 | + pg = init_custom_process_group( |
| 1246 | + backend=request.backend, |
| 1247 | + init_method=init_method, |
| 1248 | + world_size=request.world_size, |
| 1249 | + rank=rank, |
| 1250 | + group_name=group_name, |
| 1251 | + ) |
| 1252 | + self._model_update_group[group_name] = pg |
| 1253 | + return True, 'Succeeded to initialize weights update group.' |
| 1254 | + except Exception as e: |
| 1255 | + msg = f'Failed to initialize weights update group: {e}' |
| 1256 | + logger.exception(msg) |
| 1257 | + return False, msg |
| 1258 | + |
| 1259 | + @torch.inference_mode() |
| 1260 | + def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedRequest): |
| 1261 | + """Receive a bucket of weights through the previously initialized NCCL |
| 1262 | + group and load them into the running model.""" |
| 1263 | + with self.all_context(): |
| 1264 | + group_name = request.group_name |
| 1265 | + pg = self._model_update_group.get(group_name) |
| 1266 | + if pg is None: |
| 1267 | + return False, (f'group {group_name!r} not initialized. ' |
| 1268 | + 'Call init_weights_update_group first.') |
| 1269 | + |
| 1270 | + device = torch.cuda.current_device() |
| 1271 | + try: |
| 1272 | + if request.names: |
| 1273 | + named_tensors = [] |
| 1274 | + for name, dtype_str, shape in zip(request.names, request.dtypes, request.shapes): |
| 1275 | + target_dtype = getattr(torch, dtype_str) if isinstance(dtype_str, str) else dtype_str |
| 1276 | + named_tensors.append((name, torch.empty(shape, dtype=target_dtype, device=device))) |
| 1277 | + |
| 1278 | + if request.load_format == 'flattened_bucket': |
| 1279 | + bucket = FlattenedTensorBucket(named_tensors=named_tensors) |
| 1280 | + flattened_tensor = bucket.get_flattened_tensor() |
| 1281 | + dist.broadcast(flattened_tensor, src=0, group=pg) |
| 1282 | + weights = list(bucket.reconstruct_tensors()) |
| 1283 | + else: |
| 1284 | + handles = [] |
| 1285 | + for _, tensor in named_tensors: |
| 1286 | + handles.append(dist.broadcast(tensor, src=0, group=pg, async_op=True)) |
| 1287 | + for handle in handles: |
| 1288 | + handle.wait() |
| 1289 | + weights = named_tensors |
| 1290 | + else: |
| 1291 | + weights = [] |
| 1292 | + |
| 1293 | + model = self.patched_model.get_model() if self.patched_model is not None else None |
| 1294 | + spec_model = self.spec_agent.get_model() |
| 1295 | + # Same draft-split rule as update_params (currently only qwen3_5_mtp). |
| 1296 | + if self.spec_agent.is_enabled() and self.spec_agent.method == 'qwen3_5_mtp': |
| 1297 | + main_weights = [(n, w) for n, w in weights if not n.startswith('mtp.')] |
| 1298 | + draft_weights = [(n, w) for n, w in weights if n.startswith('mtp.')] |
| 1299 | + else: |
| 1300 | + main_weights, draft_weights = weights, [] |
| 1301 | + |
| 1302 | + for m, w, tag in [(model, main_weights, 'main'), (spec_model, draft_weights, 'draft')]: |
| 1303 | + if m is None or not w: |
| 1304 | + continue |
| 1305 | + renamed = list(ModelWeightLoader._rename_weights_iterator(w, m)) |
| 1306 | + logger.info(f'update_weights_from_distributed: {tag}_num_tensors={len(renamed)}') |
| 1307 | + m.load_weights(iter(renamed)) |
| 1308 | + |
| 1309 | + if request.finished: |
| 1310 | + for m in filter(None, [model, spec_model]): |
| 1311 | + for _, mod in m.named_modules(): |
| 1312 | + if hasattr(mod, 'update_weights'): |
| 1313 | + mod.update_weights() |
| 1314 | + torch.cuda.synchronize() |
| 1315 | + # FusedMoE.update_weights() above replaces the gate_up / down |
| 1316 | + # Parameter objects (LinearWeights.update_weight registers a new |
| 1317 | + # nn.Parameter), so any CUDA graph captured before the update |
| 1318 | + # still references the freed old pointers. Drop the captured |
| 1319 | + # graphs so the next forward re-captures with the new params. |
| 1320 | + self.reset_graph_runner() |
| 1321 | + |
| 1322 | + torch.cuda.empty_cache() |
| 1323 | + return True, 'Succeeded to update parameter online.' |
| 1324 | + except Exception as e: |
| 1325 | + msg = (f'Failed to update parameter online: {e}. The model weights are partially updated; ' |
| 1326 | + 'please discard them and reload.') |
| 1327 | + logger.exception(msg) |
| 1328 | + return False, msg |
| 1329 | + |
| 1330 | + def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupRequest): |
| 1331 | + """Destroy a previously initialized weights-update process group.""" |
| 1332 | + group_name = request.group_name |
| 1333 | + pg = self._model_update_group.get(group_name) |
| 1334 | + if pg is None: |
| 1335 | + return False, f'group {group_name!r} not initialized' |
| 1336 | + try: |
| 1337 | + dist.destroy_process_group(pg) |
| 1338 | + self._model_update_group.pop(group_name) |
| 1339 | + return True, f'Succeeded to destroy group {group_name!r}.' |
| 1340 | + except Exception as e: |
| 1341 | + msg = f'Failed to destroy weights update group {group_name!r}: {e}' |
| 1342 | + logger.exception(msg) |
| 1343 | + return False, msg |
| 1344 | + |
1217 | 1345 | @torch.inference_mode() |
1218 | 1346 | async def sleep(self, level: int = 1): |
1219 | 1347 | """Sleep.""" |
|
0 commit comments