@@ -39,8 +39,8 @@ def _is_sglang_update_weight_sha256_test_enabled():
3939
4040 This test-only switch controls whether the unit test expects SGLang to
4141 compute and return received bucket hashes for sent/received hash
42- comparison.
43-
42+ comparison.
43+
4444 ! Note that upstream SGLang does not provide this SHA256 check
4545 by default.
4646 """
@@ -82,7 +82,7 @@ def request_update_params(self, state_dict, train_enable_ep=False, finished=Fals
8282 train_enable_ep = train_enable_ep ,
8383 finished = finished ,
8484 )
85-
85+
8686 def _hook_compare_test_sent_and_received_weight_hash (
8787 self ,
8888 result : dict ,
@@ -119,7 +119,7 @@ def tearDownClass(cls) -> None:
119119 del os .environ ["XTUNER_USE_FA3" ]
120120
121121 def setUp (self ):
122- ray .init (num_cpus = 80 , ignore_reinit_error = True )
122+ ray .init (num_cpus = 128 , ignore_reinit_error = True )
123123 self .model_path = MODEL_PATH
124124 self .temp_dir = tempfile .TemporaryDirectory ()
125125 self .worker_log_dir = os .path .join (self .temp_dir .name , "work_dirs" )
@@ -158,7 +158,7 @@ def init_config(self):
158158 model_path = MODEL_PATH ,
159159 model_name = os .path .basename (MODEL_PATH ).lower (),
160160 tokenizer_path = MODEL_PATH ,
161- rollout_cross_node_comm = False ,
161+ rollout_cross_node_comm = os . environ . get ( "XTUNER_USE_SGLANG" , "0" ) != "0" ,
162162 tensor_parallel_size = rollout_tp_size ,
163163 expert_parallel_size = rollout_ep_size ,
164164 gpus_per_node = int (os .environ .get ("GPUS_PER_NODE" , "8" )), # gpu: 8, npu: 16
@@ -185,7 +185,7 @@ def init_config(self):
185185 ),
186186 ignore_idx = - 100 ,
187187 use_kl_loss = False ,
188- kl_loss_coef = 0.001 ,
188+ kl_loss_coef = 0.001 ,
189189 kl_loss_type = "low_var_kl" ,
190190 mode = "eager" ),
191191 lr_cfg = lr_cfg ,
@@ -209,7 +209,6 @@ def _build_train_controller(self, worker_cls=BaseTrainingWorker):
209209 )
210210 ray .get ([worker .test_all_reduce .remote () for worker in train_workers ])
211211 train_controller = TrainingController (workers = train_workers )
212- train_controller .set_train_rollout_mode ("disaggregated" )
213212 return train_controller
214213
215214 def _build_sglang_rollout_controller (self ):
@@ -238,7 +237,6 @@ def test_sglang_disaggregated_update_weight_and_generate(self):
238237 futures = [worker .test_all_reduce .remote () for worker in train_workers ]
239238 ray .get (futures )
240239 train_controller = TrainingController (workers = train_workers )
241- train_controller .set_train_rollout_mode ("disaggregated" )
242240
243241 # init rollout on a separate placement group
244242 rollout_pg = AutoAcceleratorWorkers .build_placement_group (
@@ -255,6 +253,7 @@ def test_sglang_disaggregated_update_weight_and_generate(self):
255253
256254 info_dict = ray .get (rollout_controller .get_rollout_metadata .remote ())
257255 train_controller .update_rollout_info (info_dict )
256+ train_controller .set_train_rollout_mode ("disaggregated" )
258257
259258 train_controller .update_weights ()
260259
@@ -273,6 +272,7 @@ def test_sglang_disaggregated_update_weight_after_pause_and_generate(self):
273272
274273 info_dict = ray .get (rollout_controller .get_rollout_metadata .remote ())
275274 train_controller .update_rollout_info (info_dict )
275+ train_controller .set_train_rollout_mode ("disaggregated" )
276276
277277 ray .get (rollout_controller .pause_generation .remote ())
278278 time .sleep (float (os .environ .get ("XTUNER_UPDATE_WEIGHT_PAUSE_SLEEP" , "2" )))
@@ -290,6 +290,7 @@ def test_sglang_disaggregated_update_weight_sha256_is_stable(self):
290290
291291 info_dict = ray .get (rollout_controller .get_rollout_metadata .remote ())
292292 train_controller .update_rollout_info (info_dict )
293+ train_controller .set_train_rollout_mode ("disaggregated" )
293294
294295 ray .get ([worker .reset_update_weight_sha256 .remote () for worker in train_controller .workers ])
295296 train_controller .update_weights ()
@@ -313,6 +314,94 @@ def test_sglang_disaggregated_update_weight_sha256_is_stable(self):
313314
314315 ray .get (rollout_controller .shutdown .remote (), timeout = 60 )
315316
317+ def _build_lmdeploy_rollout_controller (self ):
318+ rollout_pg = AutoAcceleratorWorkers .build_placement_group (
319+ self .rollout_resources_cfg ,
320+ name = f"test_update_weight_rollout_{ id (self )} " ,
321+ )
322+ set_cpu_resource_manager (CPUResourceManager (accelerator_placement_groups = [self .pg , rollout_pg ]))
323+ self .rollout_cfg .skip_load_weights = False
324+ return self .rollout_cfg .build (rollout_pg )
325+
326+ @unittest .skipIf (os .environ .get ("XTUNER_USE_LMDEPLOY" , "0" ) == "0" , "lmdeploy backend is not enabled" )
327+ def test_lmdeploy_disaggregated_update_weight_and_generate (self ):
328+ train_controller = self ._build_train_controller ()
329+ rollout_controller = self ._build_lmdeploy_rollout_controller ()
330+
331+ sample_params = SampleParams (temperature = 0.0 , max_tokens = 128 , top_k = 1 )
332+ input_state = RolloutState (message = TEST_TEXT_MESSAGES , sample_params = sample_params )
333+ res_baseline = ray .get (rollout_controller .generate .remote (rollout_state = input_state ))
334+
335+ info_dict = ray .get (rollout_controller .get_rollout_metadata .remote ())
336+ train_controller .update_rollout_info (info_dict )
337+ train_controller .set_train_rollout_mode ("disaggregated" )
338+
339+ train_controller .update_weights ()
340+
341+ res_update_weight = ray .get (rollout_controller .generate .remote (rollout_state = input_state ))
342+ self .assertEqual (res_update_weight .response , res_baseline .response )
343+ ray .get (rollout_controller .shutdown .remote (), timeout = 60 )
344+
345+ @unittest .skipIf (os .environ .get ("XTUNER_USE_LMDEPLOY" , "0" ) == "0" , "lmdeploy backend is not enabled" )
346+ def test_lmdeploy_disaggregated_update_weight_after_pause_and_generate (self ):
347+ train_controller = self ._build_train_controller ()
348+ rollout_controller = self ._build_lmdeploy_rollout_controller ()
349+
350+ sample_params = SampleParams (temperature = 0.0 , max_tokens = 128 , top_k = 1 )
351+ input_state = RolloutState (message = TEST_TEXT_MESSAGES , sample_params = sample_params )
352+ res_baseline = ray .get (rollout_controller .generate .remote (rollout_state = input_state ))
353+
354+ info_dict = ray .get (rollout_controller .get_rollout_metadata .remote ())
355+ train_controller .update_rollout_info (info_dict )
356+ train_controller .set_train_rollout_mode ("disaggregated" )
357+
358+ ray .get (rollout_controller .pause_generation .remote ())
359+ time .sleep (float (os .environ .get ("XTUNER_UPDATE_WEIGHT_PAUSE_SLEEP" , "2" )))
360+ train_controller .update_weights ()
361+ ray .get (rollout_controller .continue_generation .remote ())
362+
363+ res_update_weight = ray .get (rollout_controller .generate .remote (rollout_state = input_state ))
364+ self .assertEqual (res_update_weight .response , res_baseline .response )
365+ ray .get (rollout_controller .shutdown .remote (), timeout = 60 )
366+
367+ @unittest .skipIf (os .environ .get ("XTUNER_USE_LMDEPLOY" , "0" ) == "0" , "lmdeploy backend is not enabled" )
368+ def test_lmdeploy_disaggregated_multi_update_and_generate (self ):
369+ """Drive N consecutive update_weights+generate cycles on a single rollout engine.
370+
371+ LMDeploy's PyTorch backend runs a per-FusedMoE ``update_weights()`` finalize that
372+ REPLACES ``gate_up.weight`` / ``down.weight`` Parameter objects (see
373+ ``lmdeploy/pytorch/nn/moe/default.py`` ``LinearWeights.update_weight``). The CUDA-graph
374+ staleness this introduces is handled by ``reset_graph_runner()`` inside the finalize,
375+ but the second-round behaviour of the transpose-contig-transpose layout transform is
376+ untested. This test catches any regression in back-to-back updates without sleep/wakeup
377+ between them. Same method also exercises ascend / NPU where the finalize is a no-op
378+ and graph capture is disabled (eager mode), so it should be trivially safe there.
379+ """
380+ train_controller = self ._build_train_controller ()
381+ rollout_controller = self ._build_lmdeploy_rollout_controller ()
382+
383+ sample_params = SampleParams (temperature = 0.0 , max_tokens = 128 , top_k = 1 )
384+ input_state = RolloutState (message = TEST_TEXT_MESSAGES , sample_params = sample_params )
385+ res_baseline = ray .get (rollout_controller .generate .remote (rollout_state = input_state ))
386+
387+ info_dict = ray .get (rollout_controller .get_rollout_metadata .remote ())
388+ train_controller .update_rollout_info (info_dict )
389+ train_controller .set_train_rollout_mode ("disaggregated" )
390+
391+ # Trainer never actually steps, so each broadcast carries the same bytes;
392+ # the rollout response should remain identical to baseline across all rounds.
393+ num_iterations = int (os .environ .get ("XTUNER_LMDEPLOY_MULTI_UPDATE_ITERS" , "2" ))
394+ for i in range (num_iterations ):
395+ train_controller .update_weights ()
396+ res = ray .get (rollout_controller .generate .remote (rollout_state = input_state ))
397+ self .assertEqual (
398+ res .response ,
399+ res_baseline .response ,
400+ f"iteration { i } : response diverged from baseline after multi-update" ,
401+ )
402+
403+ ray .get (rollout_controller .shutdown .remote (), timeout = 60 )
404+
316405
317406if __name__ == "__main__" :
318407 unittest .main ()
0 commit comments