Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 97 additions & 8 deletions tests/rl/test_update_weight_disaggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def _is_sglang_update_weight_sha256_test_enabled():

This test-only switch controls whether the unit test expects SGLang to
compute and return received bucket hashes for sent/received hash
comparison.
comparison.

! Note that upstream SGLang does not provide this SHA256 check
by default.
"""
Expand Down Expand Up @@ -82,7 +82,7 @@ def request_update_params(self, state_dict, train_enable_ep=False, finished=Fals
train_enable_ep=train_enable_ep,
finished=finished,
)

def _hook_compare_test_sent_and_received_weight_hash(
self,
result: dict,
Expand Down Expand Up @@ -119,7 +119,7 @@ def tearDownClass(cls) -> None:
del os.environ["XTUNER_USE_FA3"]

def setUp(self):
ray.init(num_cpus=80, ignore_reinit_error=True)
ray.init(num_cpus=128, ignore_reinit_error=True)
self.model_path = MODEL_PATH
self.temp_dir = tempfile.TemporaryDirectory()
self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs")
Expand Down Expand Up @@ -158,7 +158,7 @@ def init_config(self):
model_path=MODEL_PATH,
model_name=os.path.basename(MODEL_PATH).lower(),
tokenizer_path=MODEL_PATH,
rollout_cross_node_comm=False,
rollout_cross_node_comm=os.environ.get("XTUNER_USE_SGLANG", "0") != "0",
tensor_parallel_size=rollout_tp_size,
expert_parallel_size=rollout_ep_size,
gpus_per_node=int(os.environ.get("GPUS_PER_NODE", "8")), # gpu: 8, npu: 16
Expand All @@ -185,7 +185,7 @@ def init_config(self):
),
ignore_idx=-100,
use_kl_loss=False,
kl_loss_coef=0.001,
kl_loss_coef=0.001,
kl_loss_type="low_var_kl",
mode="eager"),
lr_cfg=lr_cfg,
Expand All @@ -209,7 +209,6 @@ def _build_train_controller(self, worker_cls=BaseTrainingWorker):
)
ray.get([worker.test_all_reduce.remote() for worker in train_workers])
train_controller = TrainingController(workers=train_workers)
train_controller.set_train_rollout_mode("disaggregated")
return train_controller

def _build_sglang_rollout_controller(self):
Expand Down Expand Up @@ -238,7 +237,6 @@ def test_sglang_disaggregated_update_weight_and_generate(self):
futures = [worker.test_all_reduce.remote() for worker in train_workers]
ray.get(futures)
train_controller = TrainingController(workers=train_workers)
train_controller.set_train_rollout_mode("disaggregated")

# init rollout on a separate placement group
rollout_pg = AutoAcceleratorWorkers.build_placement_group(
Expand All @@ -255,6 +253,7 @@ def test_sglang_disaggregated_update_weight_and_generate(self):

info_dict = ray.get(rollout_controller.get_rollout_metadata.remote())
train_controller.update_rollout_info(info_dict)
train_controller.set_train_rollout_mode("disaggregated")

train_controller.update_weights()

Expand All @@ -273,6 +272,7 @@ def test_sglang_disaggregated_update_weight_after_pause_and_generate(self):

info_dict = ray.get(rollout_controller.get_rollout_metadata.remote())
train_controller.update_rollout_info(info_dict)
train_controller.set_train_rollout_mode("disaggregated")

ray.get(rollout_controller.pause_generation.remote())
time.sleep(float(os.environ.get("XTUNER_UPDATE_WEIGHT_PAUSE_SLEEP", "2")))
Expand All @@ -290,6 +290,7 @@ def test_sglang_disaggregated_update_weight_sha256_is_stable(self):

info_dict = ray.get(rollout_controller.get_rollout_metadata.remote())
train_controller.update_rollout_info(info_dict)
train_controller.set_train_rollout_mode("disaggregated")

ray.get([worker.reset_update_weight_sha256.remote() for worker in train_controller.workers])
train_controller.update_weights()
Expand All @@ -313,6 +314,94 @@ def test_sglang_disaggregated_update_weight_sha256_is_stable(self):

ray.get(rollout_controller.shutdown.remote(), timeout=60)

def _build_lmdeploy_rollout_controller(self):
rollout_pg = AutoAcceleratorWorkers.build_placement_group(
self.rollout_resources_cfg,
name=f"test_update_weight_rollout_{id(self)}",
)
set_cpu_resource_manager(CPUResourceManager(accelerator_placement_groups=[self.pg, rollout_pg]))
self.rollout_cfg.skip_load_weights = False
return self.rollout_cfg.build(rollout_pg)

@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
def test_lmdeploy_disaggregated_update_weight_and_generate(self):
train_controller = self._build_train_controller()
rollout_controller = self._build_lmdeploy_rollout_controller()

sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1)
input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params)
res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state))

info_dict = ray.get(rollout_controller.get_rollout_metadata.remote())
train_controller.update_rollout_info(info_dict)
train_controller.set_train_rollout_mode("disaggregated")

train_controller.update_weights()

res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state))
self.assertEqual(res_update_weight.response, res_baseline.response)
ray.get(rollout_controller.shutdown.remote(), timeout=60)

@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
def test_lmdeploy_disaggregated_update_weight_after_pause_and_generate(self):
train_controller = self._build_train_controller()
rollout_controller = self._build_lmdeploy_rollout_controller()

sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1)
input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params)
res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state))

info_dict = ray.get(rollout_controller.get_rollout_metadata.remote())
train_controller.update_rollout_info(info_dict)
train_controller.set_train_rollout_mode("disaggregated")

ray.get(rollout_controller.pause_generation.remote())
time.sleep(float(os.environ.get("XTUNER_UPDATE_WEIGHT_PAUSE_SLEEP", "2")))
train_controller.update_weights()
ray.get(rollout_controller.continue_generation.remote())

res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state))
self.assertEqual(res_update_weight.response, res_baseline.response)
ray.get(rollout_controller.shutdown.remote(), timeout=60)

@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
def test_lmdeploy_disaggregated_multi_update_and_generate(self):
"""Drive N consecutive update_weights+generate cycles on a single rollout engine.

LMDeploy's PyTorch backend runs a per-FusedMoE ``update_weights()`` finalize that
REPLACES ``gate_up.weight`` / ``down.weight`` Parameter objects (see
``lmdeploy/pytorch/nn/moe/default.py`` ``LinearWeights.update_weight``). The CUDA-graph
staleness this introduces is handled by ``reset_graph_runner()`` inside the finalize,
but the second-round behaviour of the transpose-contig-transpose layout transform is
untested. This test catches any regression in back-to-back updates without sleep/wakeup
between them. Same method also exercises ascend / NPU where the finalize is a no-op
and graph capture is disabled (eager mode), so it should be trivially safe there.
"""
train_controller = self._build_train_controller()
rollout_controller = self._build_lmdeploy_rollout_controller()

sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1)
input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params)
res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state))

info_dict = ray.get(rollout_controller.get_rollout_metadata.remote())
train_controller.update_rollout_info(info_dict)
train_controller.set_train_rollout_mode("disaggregated")

# Trainer never actually steps, so each broadcast carries the same bytes;
# the rollout response should remain identical to baseline across all rounds.
num_iterations = int(os.environ.get("XTUNER_LMDEPLOY_MULTI_UPDATE_ITERS", "2"))
for i in range(num_iterations):
train_controller.update_weights()
res = ray.get(rollout_controller.generate.remote(rollout_state=input_state))
self.assertEqual(
res.response,
res_baseline.response,
f"iteration {i}: response diverged from baseline after multi-update",
)

ray.get(rollout_controller.shutdown.remote(), timeout=60)


if __name__ == "__main__":
unittest.main()
Loading
Loading