diff --git a/tests/rl/test_update_weight_disaggregated.py b/tests/rl/test_update_weight_disaggregated.py index 7c1db67629..8a4460ec98 100644 --- a/tests/rl/test_update_weight_disaggregated.py +++ b/tests/rl/test_update_weight_disaggregated.py @@ -39,8 +39,8 @@ def _is_sglang_update_weight_sha256_test_enabled(): This test-only switch controls whether the unit test expects SGLang to compute and return received bucket hashes for sent/received hash - comparison. - + comparison. + ! Note that upstream SGLang does not provide this SHA256 check by default. """ @@ -82,7 +82,7 @@ def request_update_params(self, state_dict, train_enable_ep=False, finished=Fals train_enable_ep=train_enable_ep, finished=finished, ) - + def _hook_compare_test_sent_and_received_weight_hash( self, result: dict, @@ -119,7 +119,7 @@ def tearDownClass(cls) -> None: del os.environ["XTUNER_USE_FA3"] def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) + ray.init(num_cpus=128, ignore_reinit_error=True) self.model_path = MODEL_PATH self.temp_dir = tempfile.TemporaryDirectory() self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") @@ -158,7 +158,7 @@ def init_config(self): model_path=MODEL_PATH, model_name=os.path.basename(MODEL_PATH).lower(), tokenizer_path=MODEL_PATH, - rollout_cross_node_comm=False, + rollout_cross_node_comm=os.environ.get("XTUNER_USE_SGLANG", "0") != "0", tensor_parallel_size=rollout_tp_size, expert_parallel_size=rollout_ep_size, gpus_per_node=int(os.environ.get("GPUS_PER_NODE", "8")), # gpu: 8, npu: 16 @@ -185,7 +185,7 @@ def init_config(self): ), ignore_idx=-100, use_kl_loss=False, - kl_loss_coef=0.001, + kl_loss_coef=0.001, kl_loss_type="low_var_kl", mode="eager"), lr_cfg=lr_cfg, @@ -209,7 +209,6 @@ def _build_train_controller(self, worker_cls=BaseTrainingWorker): ) ray.get([worker.test_all_reduce.remote() for worker in train_workers]) train_controller = TrainingController(workers=train_workers) - train_controller.set_train_rollout_mode("disaggregated") return train_controller def _build_sglang_rollout_controller(self): @@ -238,7 +237,6 @@ def test_sglang_disaggregated_update_weight_and_generate(self): futures = [worker.test_all_reduce.remote() for worker in train_workers] ray.get(futures) train_controller = TrainingController(workers=train_workers) - train_controller.set_train_rollout_mode("disaggregated") # init rollout on a separate placement group rollout_pg = AutoAcceleratorWorkers.build_placement_group( @@ -255,6 +253,7 @@ def test_sglang_disaggregated_update_weight_and_generate(self): info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) train_controller.update_rollout_info(info_dict) + train_controller.set_train_rollout_mode("disaggregated") train_controller.update_weights() @@ -273,6 +272,7 @@ def test_sglang_disaggregated_update_weight_after_pause_and_generate(self): info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) train_controller.update_rollout_info(info_dict) + train_controller.set_train_rollout_mode("disaggregated") ray.get(rollout_controller.pause_generation.remote()) time.sleep(float(os.environ.get("XTUNER_UPDATE_WEIGHT_PAUSE_SLEEP", "2"))) @@ -290,6 +290,7 @@ def test_sglang_disaggregated_update_weight_sha256_is_stable(self): info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) train_controller.update_rollout_info(info_dict) + train_controller.set_train_rollout_mode("disaggregated") ray.get([worker.reset_update_weight_sha256.remote() for worker in train_controller.workers]) train_controller.update_weights() @@ -313,6 +314,94 @@ def test_sglang_disaggregated_update_weight_sha256_is_stable(self): ray.get(rollout_controller.shutdown.remote(), timeout=60) + def _build_lmdeploy_rollout_controller(self): + rollout_pg = AutoAcceleratorWorkers.build_placement_group( + self.rollout_resources_cfg, + name=f"test_update_weight_rollout_{id(self)}", + ) + set_cpu_resource_manager(CPUResourceManager(accelerator_placement_groups=[self.pg, rollout_pg])) + self.rollout_cfg.skip_load_weights = False + return self.rollout_cfg.build(rollout_pg) + + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_lmdeploy_disaggregated_update_weight_and_generate(self): + train_controller = self._build_train_controller() + rollout_controller = self._build_lmdeploy_rollout_controller() + + sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1) + input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params) + res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) + + info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) + train_controller.update_rollout_info(info_dict) + train_controller.set_train_rollout_mode("disaggregated") + + train_controller.update_weights() + + res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) + self.assertEqual(res_update_weight.response, res_baseline.response) + ray.get(rollout_controller.shutdown.remote(), timeout=60) + + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_lmdeploy_disaggregated_update_weight_after_pause_and_generate(self): + train_controller = self._build_train_controller() + rollout_controller = self._build_lmdeploy_rollout_controller() + + sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1) + input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params) + res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) + + info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) + train_controller.update_rollout_info(info_dict) + train_controller.set_train_rollout_mode("disaggregated") + + ray.get(rollout_controller.pause_generation.remote()) + time.sleep(float(os.environ.get("XTUNER_UPDATE_WEIGHT_PAUSE_SLEEP", "2"))) + train_controller.update_weights() + ray.get(rollout_controller.continue_generation.remote()) + + res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) + self.assertEqual(res_update_weight.response, res_baseline.response) + ray.get(rollout_controller.shutdown.remote(), timeout=60) + + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_lmdeploy_disaggregated_multi_update_and_generate(self): + """Drive N consecutive update_weights+generate cycles on a single rollout engine. + + LMDeploy's PyTorch backend runs a per-FusedMoE ``update_weights()`` finalize that + REPLACES ``gate_up.weight`` / ``down.weight`` Parameter objects (see + ``lmdeploy/pytorch/nn/moe/default.py`` ``LinearWeights.update_weight``). The CUDA-graph + staleness this introduces is handled by ``reset_graph_runner()`` inside the finalize, + but the second-round behaviour of the transpose-contig-transpose layout transform is + untested. This test catches any regression in back-to-back updates without sleep/wakeup + between them. Same method also exercises ascend / NPU where the finalize is a no-op + and graph capture is disabled (eager mode), so it should be trivially safe there. + """ + train_controller = self._build_train_controller() + rollout_controller = self._build_lmdeploy_rollout_controller() + + sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1) + input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params) + res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) + + info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) + train_controller.update_rollout_info(info_dict) + train_controller.set_train_rollout_mode("disaggregated") + + # Trainer never actually steps, so each broadcast carries the same bytes; + # the rollout response should remain identical to baseline across all rounds. + num_iterations = int(os.environ.get("XTUNER_LMDEPLOY_MULTI_UPDATE_ITERS", "2")) + for i in range(num_iterations): + train_controller.update_weights() + res = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) + self.assertEqual( + res.response, + res_baseline.response, + f"iteration {i}: response diverged from baseline after multi-update", + ) + + ray.get(rollout_controller.shutdown.remote(), timeout=60) + if __name__ == "__main__": unittest.main() diff --git a/xtuner/v1/rl/trainer/update_weighter.py b/xtuner/v1/rl/trainer/update_weighter.py index da3f15d5b7..0125f03b42 100644 --- a/xtuner/v1/rl/trainer/update_weighter.py +++ b/xtuner/v1/rl/trainer/update_weighter.py @@ -73,8 +73,13 @@ def _init_update_weighter(self): self._sglang_disagg_executor: ThreadPoolExecutor | None = None self._train_update_sync_group: dist.ProcessGroup | None = None self._sglang_disagg_update_lock = Lock() + self._lmdeploy_disagg_group: dist.ProcessGroup | None = None + self._lmdeploy_disagg_group_name: str | None = None + self._lmdeploy_disagg_engine_urls: list[str] = [] + self._lmdeploy_disagg_executor: ThreadPoolExecutor | None = None + self._lmdeploy_disagg_update_lock = Lock() self.use_fake_weight_update = ( - False # 仅在 lmdeploy 后端的 disaggregated 模式下使用,表示是否使用 fake 接口进行权重更新 + False # 仅在 lmdeploy turbomind 后端的 disaggregated 模式下使用,表示是否使用 fake 接口进行权重更新 ) def _hook_compare_test_sent_and_received_weight_hash( @@ -150,9 +155,12 @@ def set_train_rollout_mode(self, train_rollout_mode: str): if backend == "vllm": raise NotImplementedError("Disaggregated train-rollout mode is not supported for vLLM backend.") - elif backend == "pytorch" or backend == "turbomind": + elif backend == "pytorch": + self.use_fake_weight_update = False + + elif backend == "turbomind": self.logger.warning( - "Disaggregated train-rollout mode for lmdeploy backend is not fully supported yet. " + "Disaggregated train-rollout mode for lmdeploy turbomind backend is not yet supported. " "A fake no-op interface will be used temporarily.", ) self.use_fake_weight_update = True # 后续 fake 接口可根据这个标志跳过实际同步 @@ -172,6 +180,7 @@ def set_train_rollout_mode(self, train_rollout_mode: str): if self.is_train_rollout_colocated: self._reset_sglang_disagg_group() + self._reset_lmdeploy_disagg_group() def _reset_sglang_disagg_group(self): if self._sglang_disagg_executor is not None: @@ -186,6 +195,19 @@ def _reset_sglang_disagg_group(self): self._sglang_disagg_engine_urls = [] self._sglang_disagg_executor = None + def _reset_lmdeploy_disagg_group(self): + if self._lmdeploy_disagg_executor is not None: + self._lmdeploy_disagg_executor.shutdown(wait=False, cancel_futures=True) + try: + if self._lmdeploy_disagg_group is not None: + dist.destroy_process_group(self._lmdeploy_disagg_group) + except Exception: + pass + self._lmdeploy_disagg_group = None + self._lmdeploy_disagg_group_name = None + self._lmdeploy_disagg_engine_urls = [] + self._lmdeploy_disagg_executor = None + def _get_train_update_sync_group(self) -> dist.ProcessGroup: if self._train_update_sync_group is None: ranks = list(range(dist.get_world_size())) @@ -329,11 +351,9 @@ def _update_weights_hf_generator(self, submodule=None, final_update=False): ) train_enable_ep = model.fsdp_config is not None and model.fsdp_config.ep_size > 1 - if train_enable_ep and not self.is_train_rollout_colocated: - raise NotImplementedError("Disaggregated update_weights with train expert parallelism is not supported.") if train_enable_ep: - if self.rollout_cfg_info["ep"] > 1: + if self.is_train_rollout_colocated and self.rollout_cfg_info["ep"] > 1: rollout_device_mesh = self._ensure_rollout_device_mesh() fused_gen = self._rl_get_fused_ep_hf_param( model, @@ -342,6 +362,9 @@ def _update_weights_hf_generator(self, submodule=None, final_update=False): bucket_size=bucket_size, ) else: + # Disaggregated update uses one external trainer+rollout process group. + # Broadcast the same full expert bucket to every rollout rank and let + # the backend loader apply its local TP/EP slicing. fused_gen = self._rl_get_fused_ep_hf_param( model, target_ep_rank=0, @@ -591,7 +614,7 @@ def _build_lmdeploy_flattened_tensor_data(self, state_dict: dict, flattened_tens flattened_tensor_data["event_ipc_handle"] = self._update_params_ipc_event.ipc_handle() return flattened_tensor_data - def _get_sglang_disagg_engine_info(self) -> RolloutEngineInfo: + def _get_disagg_engine_info(self) -> RolloutEngineInfo: engine_info: RolloutEngineInfo = [] seen_urls: set[str] = set() rank_to_engine_size: dict[int, int] = {} @@ -622,7 +645,7 @@ def _get_sglang_disagg_engine_info(self) -> RolloutEngineInfo: def _ensure_sglang_disagg_group(self): if self._sglang_disagg_group is not None: return - engine_info = self._get_sglang_disagg_engine_info() + engine_info = self._get_disagg_engine_info() if not engine_info: self.logger.error("No active rollout engine url, cannot init sglang weight update group") return @@ -687,6 +710,73 @@ def _ensure_sglang_disagg_group(self): self._sglang_disagg_group_name = group_name self._sglang_disagg_engine_urls = [url for _, url, _ in engine_info] + def _ensure_lmdeploy_disagg_group(self): + if self._lmdeploy_disagg_group is not None: + return + engine_info = self._get_disagg_engine_info() + if not engine_info: + self.logger.error("No active rollout engine url, cannot init lmdeploy weight update group") + return + + os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False" + backend = "nccl" + + master_address = None + master_port = None + try: + import ray + + master_address = ray.util.get_node_ip_address() + except Exception: + master_address = socket.gethostbyname(socket.gethostname()) + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("", 0)) + master_port = int(sock.getsockname()[1]) + + group_name = f"xtuner_lmdeploy_weight_update_{self.rank}" + world_size = sum(engine_size for _, _, engine_size in engine_info) + 1 + + self._lmdeploy_disagg_executor = ThreadPoolExecutor(max_workers=max(1, len(engine_info))) + init_futures = [] + rank_offset = 1 + for _, url, engine_size in engine_info: + payload = { + "master_address": master_address, + "master_port": master_port, + "rank_offset": rank_offset, + "world_size": world_size, + "group_name": group_name, + "backend": backend, + } + init_futures.append( + self._lmdeploy_disagg_executor.submit( + requests.post, + f"{url}/init_weights_update_group", + json=payload, + ) + ) + rank_offset += engine_size + + self._lmdeploy_disagg_group = self._init_external_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=0, + group_name=group_name, + ) + + for init_future in init_futures: + response = init_future.result() + response.raise_for_status() + result = response.json() + assert result.get("success", True), ( + f"LMDeploy init_weights_update_group failed: {result.get('message', result)}" + ) + + self._lmdeploy_disagg_group_name = group_name + self._lmdeploy_disagg_engine_urls = [url for _, url, _ in engine_info] + def _request_update_params_sglang_disaggregated(self, state_dict): if not state_dict: return @@ -751,6 +841,96 @@ def _request_update_params_sglang_disaggregated(self, state_dict): ) dist.barrier(group=train_sync_group) + def _request_update_params_lmdeploy_disaggregated(self, state_dict, finished: bool = False): + if not state_dict and not finished: + return + + train_sync_group = self._get_train_update_sync_group() + head_rank = 0 + if dist.get_rank() != head_rank: + dist.barrier(group=train_sync_group) + return + + self._ensure_lmdeploy_disagg_group() + if self._lmdeploy_disagg_group is None: + dist.barrier(group=train_sync_group) + return + + assert self._lmdeploy_disagg_executor is not None + assert self._lmdeploy_disagg_group_name is not None + with self._lmdeploy_disagg_update_lock: + try: + from lmdeploy.utils import FlattenedTensorBucket + except Exception as e: + raise RuntimeError( + "Disaggregated update_weights for lmdeploy backend requires lmdeploy builds that provide " + "`lmdeploy.utils.FlattenedTensorBucket`." + ) from e + + if state_dict: + names = list(state_dict.keys()) + tensors = [ + tensor.detach().to(device=DEVICE, non_blocking=True).contiguous() for tensor in state_dict.values() + ] + payload = { + "names": names, + "dtypes": [str(tensor.dtype).replace("torch.", "") for tensor in tensors], + "shapes": [list(tensor.shape) for tensor in tensors], + "group_name": self._lmdeploy_disagg_group_name, + "load_format": "flattened_bucket", + "finished": finished, + } + update_futures = [ + self._lmdeploy_disagg_executor.submit( + requests.post, + f"{url}/update_weights_from_distributed", + json=payload, + ) + for url in self._lmdeploy_disagg_engine_urls + ] + flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=list(zip(names, tensors))) + flattened_tensor = flattened_tensor_bucket.get_flattened_tensor() + dist.broadcast(flattened_tensor, src=0, group=self._lmdeploy_disagg_group) + DEVICE_MODULE.synchronize() + for update_future in update_futures: + response = update_future.result() + response.raise_for_status() + result = response.json() + self._hook_compare_test_sent_and_received_weight_hash( + result, + names=names, + ) + assert result.get("success", True), ( + f"LMDeploy update_weights_from_distributed failed: {result.get('message', result)}" + ) + else: + # finalize-only request: no tensors to broadcast, just trigger the + # rollout side's mod.update_weights() finalization hooks. + payload = { + "names": [], + "dtypes": [], + "shapes": [], + "group_name": self._lmdeploy_disagg_group_name, + "load_format": "flattened_bucket", + "finished": True, + } + update_futures = [ + self._lmdeploy_disagg_executor.submit( + requests.post, + f"{url}/update_weights_from_distributed", + json=payload, + ) + for url in self._lmdeploy_disagg_engine_urls + ] + for update_future in update_futures: + response = update_future.result() + response.raise_for_status() + result = response.json() + assert result.get("success", True), ( + f"LMDeploy update_weights_from_distributed (finalize) failed: {result.get('message', result)}" + ) + dist.barrier(group=train_sync_group) + @ray_method def request_update_params(self, state_dict, train_enable_ep=False, finished=False): """Send a request to update the parameters on the rollout workers. @@ -771,6 +951,10 @@ def request_update_params(self, state_dict, train_enable_ep=False, finished=Fals self._request_update_params_sglang_disaggregated(state_dict) return + if self.rollout_cfg_info["backend"] == "pytorch" and not self.is_train_rollout_colocated: + self._request_update_params_lmdeploy_disaggregated(state_dict, finished=finished) + return + cpu_mesh = self._ensure_rollout_device_mesh()["engine_parallel"] cpu_group = cpu_mesh.get_group() head_rank = cpu_mesh.mesh[0].item()