Skip to content

Commit 638d8c5

Browse files
authored
feat: support disaggregated weight update with lmdeploy backend (#1854)
* fix test_update_weight_disaggregated.py * support lmdeploy disaggregated update weight * revert raise_for_status --------- Co-authored-by: irexyc@gmail.com <irexyc>
1 parent 334e80c commit 638d8c5

2 files changed

Lines changed: 289 additions & 16 deletions

File tree

tests/rl/test_update_weight_disaggregated.py

Lines changed: 97 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def _is_sglang_update_weight_sha256_test_enabled():
3939
4040
This test-only switch controls whether the unit test expects SGLang to
4141
compute and return received bucket hashes for sent/received hash
42-
comparison.
43-
42+
comparison.
43+
4444
! Note that upstream SGLang does not provide this SHA256 check
4545
by default.
4646
"""
@@ -82,7 +82,7 @@ def request_update_params(self, state_dict, train_enable_ep=False, finished=Fals
8282
train_enable_ep=train_enable_ep,
8383
finished=finished,
8484
)
85-
85+
8686
def _hook_compare_test_sent_and_received_weight_hash(
8787
self,
8888
result: dict,
@@ -119,7 +119,7 @@ def tearDownClass(cls) -> None:
119119
del os.environ["XTUNER_USE_FA3"]
120120

121121
def setUp(self):
122-
ray.init(num_cpus=80, ignore_reinit_error=True)
122+
ray.init(num_cpus=128, ignore_reinit_error=True)
123123
self.model_path = MODEL_PATH
124124
self.temp_dir = tempfile.TemporaryDirectory()
125125
self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs")
@@ -158,7 +158,7 @@ def init_config(self):
158158
model_path=MODEL_PATH,
159159
model_name=os.path.basename(MODEL_PATH).lower(),
160160
tokenizer_path=MODEL_PATH,
161-
rollout_cross_node_comm=False,
161+
rollout_cross_node_comm=os.environ.get("XTUNER_USE_SGLANG", "0") != "0",
162162
tensor_parallel_size=rollout_tp_size,
163163
expert_parallel_size=rollout_ep_size,
164164
gpus_per_node=int(os.environ.get("GPUS_PER_NODE", "8")), # gpu: 8, npu: 16
@@ -185,7 +185,7 @@ def init_config(self):
185185
),
186186
ignore_idx=-100,
187187
use_kl_loss=False,
188-
kl_loss_coef=0.001,
188+
kl_loss_coef=0.001,
189189
kl_loss_type="low_var_kl",
190190
mode="eager"),
191191
lr_cfg=lr_cfg,
@@ -209,7 +209,6 @@ def _build_train_controller(self, worker_cls=BaseTrainingWorker):
209209
)
210210
ray.get([worker.test_all_reduce.remote() for worker in train_workers])
211211
train_controller = TrainingController(workers=train_workers)
212-
train_controller.set_train_rollout_mode("disaggregated")
213212
return train_controller
214213

215214
def _build_sglang_rollout_controller(self):
@@ -238,7 +237,6 @@ def test_sglang_disaggregated_update_weight_and_generate(self):
238237
futures = [worker.test_all_reduce.remote() for worker in train_workers]
239238
ray.get(futures)
240239
train_controller = TrainingController(workers=train_workers)
241-
train_controller.set_train_rollout_mode("disaggregated")
242240

243241
# init rollout on a separate placement group
244242
rollout_pg = AutoAcceleratorWorkers.build_placement_group(
@@ -255,6 +253,7 @@ def test_sglang_disaggregated_update_weight_and_generate(self):
255253

256254
info_dict = ray.get(rollout_controller.get_rollout_metadata.remote())
257255
train_controller.update_rollout_info(info_dict)
256+
train_controller.set_train_rollout_mode("disaggregated")
258257

259258
train_controller.update_weights()
260259

@@ -273,6 +272,7 @@ def test_sglang_disaggregated_update_weight_after_pause_and_generate(self):
273272

274273
info_dict = ray.get(rollout_controller.get_rollout_metadata.remote())
275274
train_controller.update_rollout_info(info_dict)
275+
train_controller.set_train_rollout_mode("disaggregated")
276276

277277
ray.get(rollout_controller.pause_generation.remote())
278278
time.sleep(float(os.environ.get("XTUNER_UPDATE_WEIGHT_PAUSE_SLEEP", "2")))
@@ -290,6 +290,7 @@ def test_sglang_disaggregated_update_weight_sha256_is_stable(self):
290290

291291
info_dict = ray.get(rollout_controller.get_rollout_metadata.remote())
292292
train_controller.update_rollout_info(info_dict)
293+
train_controller.set_train_rollout_mode("disaggregated")
293294

294295
ray.get([worker.reset_update_weight_sha256.remote() for worker in train_controller.workers])
295296
train_controller.update_weights()
@@ -313,6 +314,94 @@ def test_sglang_disaggregated_update_weight_sha256_is_stable(self):
313314

314315
ray.get(rollout_controller.shutdown.remote(), timeout=60)
315316

317+
def _build_lmdeploy_rollout_controller(self):
318+
rollout_pg = AutoAcceleratorWorkers.build_placement_group(
319+
self.rollout_resources_cfg,
320+
name=f"test_update_weight_rollout_{id(self)}",
321+
)
322+
set_cpu_resource_manager(CPUResourceManager(accelerator_placement_groups=[self.pg, rollout_pg]))
323+
self.rollout_cfg.skip_load_weights = False
324+
return self.rollout_cfg.build(rollout_pg)
325+
326+
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
327+
def test_lmdeploy_disaggregated_update_weight_and_generate(self):
328+
train_controller = self._build_train_controller()
329+
rollout_controller = self._build_lmdeploy_rollout_controller()
330+
331+
sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1)
332+
input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params)
333+
res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state))
334+
335+
info_dict = ray.get(rollout_controller.get_rollout_metadata.remote())
336+
train_controller.update_rollout_info(info_dict)
337+
train_controller.set_train_rollout_mode("disaggregated")
338+
339+
train_controller.update_weights()
340+
341+
res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state))
342+
self.assertEqual(res_update_weight.response, res_baseline.response)
343+
ray.get(rollout_controller.shutdown.remote(), timeout=60)
344+
345+
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
346+
def test_lmdeploy_disaggregated_update_weight_after_pause_and_generate(self):
347+
train_controller = self._build_train_controller()
348+
rollout_controller = self._build_lmdeploy_rollout_controller()
349+
350+
sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1)
351+
input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params)
352+
res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state))
353+
354+
info_dict = ray.get(rollout_controller.get_rollout_metadata.remote())
355+
train_controller.update_rollout_info(info_dict)
356+
train_controller.set_train_rollout_mode("disaggregated")
357+
358+
ray.get(rollout_controller.pause_generation.remote())
359+
time.sleep(float(os.environ.get("XTUNER_UPDATE_WEIGHT_PAUSE_SLEEP", "2")))
360+
train_controller.update_weights()
361+
ray.get(rollout_controller.continue_generation.remote())
362+
363+
res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state))
364+
self.assertEqual(res_update_weight.response, res_baseline.response)
365+
ray.get(rollout_controller.shutdown.remote(), timeout=60)
366+
367+
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
368+
def test_lmdeploy_disaggregated_multi_update_and_generate(self):
369+
"""Drive N consecutive update_weights+generate cycles on a single rollout engine.
370+
371+
LMDeploy's PyTorch backend runs a per-FusedMoE ``update_weights()`` finalize that
372+
REPLACES ``gate_up.weight`` / ``down.weight`` Parameter objects (see
373+
``lmdeploy/pytorch/nn/moe/default.py`` ``LinearWeights.update_weight``). The CUDA-graph
374+
staleness this introduces is handled by ``reset_graph_runner()`` inside the finalize,
375+
but the second-round behaviour of the transpose-contig-transpose layout transform is
376+
untested. This test catches any regression in back-to-back updates without sleep/wakeup
377+
between them. Same method also exercises ascend / NPU where the finalize is a no-op
378+
and graph capture is disabled (eager mode), so it should be trivially safe there.
379+
"""
380+
train_controller = self._build_train_controller()
381+
rollout_controller = self._build_lmdeploy_rollout_controller()
382+
383+
sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1)
384+
input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params)
385+
res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state))
386+
387+
info_dict = ray.get(rollout_controller.get_rollout_metadata.remote())
388+
train_controller.update_rollout_info(info_dict)
389+
train_controller.set_train_rollout_mode("disaggregated")
390+
391+
# Trainer never actually steps, so each broadcast carries the same bytes;
392+
# the rollout response should remain identical to baseline across all rounds.
393+
num_iterations = int(os.environ.get("XTUNER_LMDEPLOY_MULTI_UPDATE_ITERS", "2"))
394+
for i in range(num_iterations):
395+
train_controller.update_weights()
396+
res = ray.get(rollout_controller.generate.remote(rollout_state=input_state))
397+
self.assertEqual(
398+
res.response,
399+
res_baseline.response,
400+
f"iteration {i}: response diverged from baseline after multi-update",
401+
)
402+
403+
ray.get(rollout_controller.shutdown.remote(), timeout=60)
404+
316405

317406
if __name__ == "__main__":
318407
unittest.main()

0 commit comments

Comments
 (0)