From 749ce1fc3fc654f1f05e0e275adf30978ff30962 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Wed, 1 Apr 2026 19:55:21 +0800 Subject: [PATCH 1/4] delete huge objectref manually after ray.get --- xtuner/v1/ray/dataflow/replay_buffer.py | 109 +++++++++++++++---- xtuner/v1/ray/environment/single_turn_env.py | 10 +- xtuner/v1/ray/rollout/worker.py | 6 +- xtuner/v1/ray/utils.py | 11 ++ xtuner/v1/rl/base/controller.py | 5 +- xtuner/v1/rl/base/worker.py | 9 +- xtuner/v1/train/rl_trainer.py | 9 +- xtuner/v1/utils/misc.py | 20 ++++ 8 files changed, 147 insertions(+), 32 deletions(-) diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py index 7068406ef..172109840 100644 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ b/xtuner/v1/ray/dataflow/replay_buffer.py @@ -26,6 +26,7 @@ is_valid_for_replaybuffer, ) from xtuner.v1.datasets.config import DataloaderConfig +from xtuner.v1.ray.utils import free_object_refs from xtuner.v1.utils import get_logger from xtuner.v1.utils.device import get_device @@ -118,23 +119,38 @@ def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> Re return replay_meta -def mapping_replaymeta_to_dataitem(replay_meta: ReplayMeta) -> List[RLDataFlowItem]: +def mapping_replaymeta_to_dataitem(replay_meta: ReplayMeta, *, consume_refs: bool = False) -> List[RLDataFlowItem]: env_str = replay_meta.env root_id = replay_meta.root_id action_id = replay_meta.action_id - data_ref = ray.get(replay_meta.action_ref) - group_data_item = [] - for obs_id, obs_ref in zip(replay_meta.observation_ids, replay_meta.observation_refs): - env_data = ray.get(obs_ref) - # NOTE: This mapping function used by both dump and get. ObjectRefs are kept during dump (for training continuity) - # but released during get (via del replaymeta) when no longer needed. So we do not free them manually here. - # ray._private.internal_api.free(obs_ref) + action_ref = replay_meta.action_ref + observation_refs = list(replay_meta.observation_refs) + + data_value = ray.get(action_ref) if action_ref is not None else None + + env_values = [ray.get(obs_ref) for obs_ref in observation_refs] + + if consume_refs: + refs_to_free: List[ObjectRef] = [] + if isinstance(action_ref, ObjectRef): + refs_to_free.append(action_ref) + refs_to_free.extend([ref for ref in observation_refs if isinstance(ref, ObjectRef)]) + free_object_refs(refs_to_free) + replay_meta.action_ref = None + replay_meta.observation_refs.clear() + + group_data_item = [] + for obs_id, env_data in zip(replay_meta.observation_ids, env_values): item = RLDataFlowItem( uid=RLUIDItem( - env=env_str, root_id=root_id, action_id=action_id, observation_id=obs_id, version=replay_meta.version + env=env_str, + root_id=root_id, + action_id=action_id, + observation_id=obs_id, + version=replay_meta.version, ), - data=data_ref, + data=data_value, env=env_data, extra_info=RLExtraDataItem(), ) @@ -323,6 +339,25 @@ def __init__(self, replay_buffer_cfg): self.sample_from_aborted_count = 0 self.sample_from_expired_count = 0 + def _update_replay_meta_state(self, replay_meta: ReplayMeta, new_state: RolloutState): + for observation_id in replay_meta.observation_ids: + old_state = self._observations2states.get(observation_id) + if old_state and observation_id in self._states.get(old_state, []): + self._states[old_state].remove(observation_id) + self._observations2states[observation_id] = new_state + if observation_id not in self._states[new_state]: + self._states[new_state].append(observation_id) + replay_meta.state = new_state + + def _strip_rollout_payload_for_rerun(self, replay_meta: ReplayMeta, new_state: RolloutState): + """Keep prompt refs only and drop rollout outputs that will not be + reused.""" + old_obs_refs = [ref for ref in replay_meta.observation_refs if ref is not None] + if old_obs_refs: + ray.internal.free(old_obs_refs, local_only=False) + replay_meta.observation_refs = [ray.put(RLEnvDataItem()) for _ in replay_meta.observation_ids] + self._update_replay_meta_state(replay_meta, new_state) + def add(self, grouped_dataitem: List[RLDataFlowItem]): """Adds a group of data items to the storage. @@ -392,7 +427,7 @@ def get(self, global_batch_size: int) -> Tuple[List[List[RLDataFlowItem]], List[ self.logger.warning("Get action_id None from completed_actions and skip this iteration.") continue replay_meta = self._actions.pop(action_id) - group_samples = mapping_replaymeta_to_dataitem(replay_meta) + group_samples = mapping_replaymeta_to_dataitem(replay_meta, consume_refs=True) # 将这条数据彻底清除,不用再记录root_id对应的action_ids了 self._clear_meta_for_root(replay_meta) multimodal_train_info = None @@ -426,6 +461,9 @@ def sample(self, sample_from_expired_states) -> List[RLDataFlowItem]: return [] def clear(self): + for replay_meta in list(self._actions.values()): + self._release_replay_meta_refs(replay_meta) + attrs_to_clear = [ "_aborted_actions", "_completed_actions", @@ -487,9 +525,10 @@ def convert_to_ray_objref(self, data_item: RLDataFlowItem): data_item.data.multimodal_train_info["pixel_values"] = pixel_values_ref # type: ignore[index] # convert rollout.extra_info.router_experts to ray.ObjectRef if "routed_experts" in data_item.env.rollout.extra_info: - routed_experts_ref = ray.put(data_item.env.rollout.extra_info["routed_experts"]) - del data_item.env.rollout.extra_info["routed_experts"] - data_item.env.rollout.extra_info["routed_experts"] = routed_experts_ref + if not isinstance(data_item.env.rollout.extra_info["routed_experts"], ray.ObjectRef): + routed_experts_ref = ray.put(data_item.env.rollout.extra_info["routed_experts"]) + del data_item.env.rollout.extra_info["routed_experts"] + data_item.env.rollout.extra_info["routed_experts"] = routed_experts_ref def has_objectref(self, item: RLDataFlowItem) -> bool: def check(obj): @@ -519,13 +558,15 @@ def dump(self, file_path: Path): file_path (str): The path to the file where the state will be saved. """ - all_data_items = [mapping_replaymeta_to_dataitem(replay_meta) for replay_meta in self._actions.values()] - - for data_items in all_data_items: + all_data_items = [] + for replay_meta in self._actions.values(): + # dump 仅用于序列化快照,这里可直接消费 refs,避免长时间占用 object store + data_items = mapping_replaymeta_to_dataitem(replay_meta, consume_refs=True) for item in data_items: self.resolve_ray_objects(item) res = self.has_objectref(item) assert not res, "ReplayBufferStorage.dump found unresolved ray.ObjectRef in RLDataFlowItem" + all_data_items.append(data_items) state = { "_completed_actions": self._completed_actions, @@ -603,7 +644,7 @@ def _sample_from_expired_storage(self) -> List[RLDataFlowItem]: assert len(self._expired_actions) > 0 action_id = self._expired_actions.pop() replay_meta = self._actions.pop(action_id) - group_samples = mapping_replaymeta_to_dataitem(replay_meta) + group_samples = mapping_replaymeta_to_dataitem(replay_meta, consume_refs=True) # 把这条数据上次的记录全部删掉,重新开始rollout,root2actions也要清除 self._clear_meta_for_root(replay_meta) @@ -638,7 +679,7 @@ def _sample_from_aborted_storage(self) -> List[RLDataFlowItem]: # 通过self.aborted_samples_count判断过这里不会返回None replay_meta = self._actions.pop(action_id) # type: ignore[arg-type] replay_meta_version = replay_meta.version - group_samples = mapping_replaymeta_to_dataitem(replay_meta) + group_samples = mapping_replaymeta_to_dataitem(replay_meta, consume_refs=True) # 把这条数据上次rollout产生的输出的记录都删掉,上次的数据已经记录在了RLEnvDataItem中了 self._clear_meta_for_actions(replay_meta) @@ -699,6 +740,10 @@ def _check_completed_samples_expired(self): for version in expired_versions: bucket = self._completed_actions.pop(version) + for action_id in bucket: + replay_meta = self._actions.get(action_id) + if replay_meta is not None: + self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.EXPIRED) self._expired_actions.extend(bucket) self.logger.info( f"Moved {len(bucket)} completed samples with version {version} to expired samples due to exceeding tail_batch_candidate_steps." @@ -709,12 +754,25 @@ def _check_completed_samples_aborted(self): return for version, bucket in self._completed_actions.items(): + for action_id in bucket: + replay_meta = self._actions.get(action_id) + if replay_meta is not None: + self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.ABORTED) self._aborted_actions[0].extend(bucket) self.logger.info( f"Moved {len(bucket)} completed samples with version {version} to aborted samples due to partial rollout disabled." ) self._completed_actions.clear() + def _release_replay_meta_refs(self, replay_meta: ReplayMeta): + refs_to_free: List[ObjectRef] = [] + if isinstance(replay_meta.action_ref, ObjectRef): + refs_to_free.append(replay_meta.action_ref) + refs_to_free.extend([ref for ref in replay_meta.observation_refs if isinstance(ref, ObjectRef)]) + free_object_refs(refs_to_free) + replay_meta.action_ref = None + replay_meta.observation_refs.clear() + def _clear_meta_for_actions(self, replay_meta: ReplayMeta): """Completely removes an action and all its associated data from the storage. @@ -723,12 +781,15 @@ def _clear_meta_for_actions(self, replay_meta: ReplayMeta): """ action_id = replay_meta.action_id + self._release_replay_meta_refs(replay_meta) + for observation_id in replay_meta.observation_ids: self._observations.pop(observation_id, None) state = self._observations2states.pop(observation_id, None) if state and observation_id in self._states.get(state, []): self._states[state].remove(observation_id) + self._actions.pop(action_id, None) self._action2observations.pop(action_id, None) del replay_meta @@ -747,13 +808,18 @@ def _clear_meta_for_root(self, replay_meta: ReplayMeta): and clear all related actions. """ root_id = replay_meta.root_id + current_action_id = replay_meta.action_id + + self._clear_meta_for_actions(replay_meta) + if root_id in self._root2actions: for action_id in self._root2actions[root_id]: + if action_id == current_action_id: + continue new_replay_meta = self._actions.pop(action_id, None) if new_replay_meta: self._clear_meta_for_actions(new_replay_meta) del self._root2actions[root_id] - del replay_meta def _check_rollout_state_and_insert(self, replay_meta: ReplayMeta): """Checks the rollout state of a ReplayMeta object and inserts its @@ -776,10 +842,13 @@ def _check_rollout_state_and_insert(self, replay_meta: ReplayMeta): if self.tail_batch_candidate_steps > 0 and replay_meta.version >= self.tail_batch_candidate_steps: # 过期的数据需要重置状态 self._expired_actions.append(action_id) + self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.EXPIRED) self.logger.debug( f"Add expired sample with action_id: {action_id} to _expired_actions because version: {replay_meta.version} >= tail_batch_candidate_steps: {self.tail_batch_candidate_steps}." ) else: + if not self.enable_partial_rollout: + self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.ABORTED) self._aborted_actions[replay_meta.version].append(action_id) self.logger.debug( f"Add aborted sample with action_id: {action_id} version: {replay_meta.version} to _aborted_actions." diff --git a/xtuner/v1/ray/environment/single_turn_env.py b/xtuner/v1/ray/environment/single_turn_env.py index 5a3211bdb..8967b567d 100644 --- a/xtuner/v1/ray/environment/single_turn_env.py +++ b/xtuner/v1/ray/environment/single_turn_env.py @@ -94,14 +94,14 @@ async def generate( # type: ignore[override] if self.rollout_controller: response_future = [] for sample in group_data_items: - sample.data.extra_info["root_id"] = sample.uid.root_id - sample.data.extra_info["action_id"] = sample.uid.action_id + rollout_extra_info = dict(sample.data.extra_info) + rollout_extra_info["root_id"] = sample.uid.root_id + rollout_extra_info["action_id"] = sample.uid.action_id update_sample_params = sample_params if "partial_rollout_input_ids" in sample.env.rollout.extra_info: input_ids_length = len(sample.data.input_ids) if sample.data.input_ids is not None else 0 current_partial_length = len(sample.env.rollout.extra_info["partial_rollout_input_ids"]) - rollout_extra_info = copy.deepcopy(sample.data.extra_info) rollout_extra_info["partial_rollout_input_ids"] = sample.env.rollout.extra_info[ "partial_rollout_input_ids" ] @@ -113,8 +113,6 @@ async def generate( # type: ignore[override] self.logger.debug( f"root_id: {sample.uid.root_id}, action_id {sample.uid.action_id} pass current_partial_length {current_partial_length}, input_ids_length {input_ids_length} to rollout and set max_tokens to {update_sample_params.max_tokens}" ) - else: - rollout_extra_info = sample.data.extra_info if "routed_experts" in sample.env.rollout.extra_info: rollout_extra_info["routed_experts"] = sample.env.rollout.extra_info["routed_experts"] @@ -126,6 +124,8 @@ async def generate( # type: ignore[override] extra_params=extra_params, extra_info=rollout_extra_info, ) + del rollout_extra_info + response_future.append(fut) try: rollout_responses = await asyncio.wait_for( diff --git a/xtuner/v1/ray/rollout/worker.py b/xtuner/v1/ray/rollout/worker.py index 42936d29e..acad50204 100644 --- a/xtuner/v1/ray/rollout/worker.py +++ b/xtuner/v1/ray/rollout/worker.py @@ -573,6 +573,7 @@ async def _handle_non_stream_response( data = base64.b64decode(routed_experts) routed_experts = ray.cloudpickle.loads(data) + del data else: routed_experts = torch.tensor(routed_experts) # n,layer,expert routed_experts = ray.put(routed_experts) @@ -586,13 +587,14 @@ async def _handle_non_stream_response( routed_experts = ray.cloudpickle.loads(data) cur_routed_experts = await routed_experts # n,layer,expert ray.internal.free(routed_experts, local_only=False) + del data else: routed_experts = torch.tensor(routed_experts) # n,layer,expert cur_routed_experts = routed_experts history_routed_experts = await input_extra_info["routed_experts"] # n, layer, expert ray.internal.free(input_extra_info["routed_experts"], local_only=False) - del input_extra_info["routed_experts"] + del input_extra_info assert (history_routed_experts.shape[0] - 1) > 0 and history_routed_experts.shape[ 0 @@ -613,6 +615,8 @@ async def _handle_non_stream_response( f"Experts(exist={history_routed_experts.shape}, init_cur={init_cur_roued_experts}, cur={cur_routed_experts.shape}, concat={concat_routed_experts.shape})" ) extra_info["routed_experts"] = ray.put(concat_routed_experts) + del history_routed_experts + del cur_routed_experts else: assert finish_reason == "abort", ( f"routed_experts is None, but finish_reason is {finish_reason}, expected abort. response: {response}" diff --git a/xtuner/v1/ray/utils.py b/xtuner/v1/ray/utils.py index 4eea3b228..985429907 100644 --- a/xtuner/v1/ray/utils.py +++ b/xtuner/v1/ray/utils.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable, Coroutine, List, Optional, cast import ray +from ray import ObjectRef if TYPE_CHECKING: @@ -208,3 +209,13 @@ def create_task( for callback in done_callbacks: task.add_done_callback(callback) return task + + +def free_object_refs(refs: List[ObjectRef]) -> None: + valid_refs = [ref for ref in refs if isinstance(ref, ObjectRef)] + if not valid_refs: + return + try: + ray._private.internal_api.free(valid_refs, local_only=False) + except Exception: + ray.internal.free(valid_refs, local_only=False) diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/base/controller.py index b500b53e4..d5f10a15d 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/base/controller.py @@ -260,7 +260,10 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: rollout_idx=rollout_idx, ) ) - log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT) + try: + log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT) + finally: + del packed_data_batches return log_infos @ray_method diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 08a96620c..087871dff 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -35,6 +35,7 @@ from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo from xtuner.v1.ray.base import SingleAcceleratorWorker from xtuner.v1.ray.config import RolloutConfig +from xtuner.v1.ray.utils import free_object_refs from xtuner.v1.rl.base.loss import BaseRLLossContext from xtuner.v1.rl.utils import gather_logprobs from xtuner.v1.train.trainer import LoadCheckpointConfig @@ -483,8 +484,12 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo assert isinstance(pixel_values, list), ( f"pixel_values should be list of tensor, got {type(pixel_values)}" ) - pixel_values = [ray.get(pixel_obf) for pixel_obf in pixel_values] - pixel_values = torch.cat(pixel_values, dim=0) + pixel_value_refs = list(pixel_values) + try: + pixel_values = torch.cat(ray.get(pixel_value_refs), dim=0) + finally: + free_object_refs(pixel_value_refs) + seq_ctx.pixel_values = pixel_values rollout_routed_experts = seq_ctx.rollout_routed_experts diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 398e0572d..39639a063 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -118,7 +118,7 @@ def get_train_seq_ctx( ): seq_ctx = SequenceContext.from_input_ids((input_ids,), device="cpu") if multimodal_train_info and len(multimodal_train_info) > 0: - position_ids = multimodal_train_info.get("position_ids") # (1,n) or (3,1,n) + position_ids = multimodal_train_info.pop("position_ids") # (1,n) or (3,1,n) if position_ids is not None and len(position_ids.shape) == 3: # qwen3vl 需要特殊处理,其余的不需要额外处理 max_value = position_ids.max(dim=-1).values # (3,1) @@ -128,8 +128,9 @@ def get_train_seq_ctx( position_ids = torch.cat([position_ids, response_position_ids], dim=-1) seq_ctx.position_ids = position_ids # type: ignore[assignment] assert position_ids.size(-1) == input_ids.size(-1) - seq_ctx.pixel_values = multimodal_train_info.get("pixel_values") - seq_ctx.image_grid_thw = multimodal_train_info.get("image_grid_thw") + seq_ctx.pixel_values = multimodal_train_info.pop("pixel_values") + seq_ctx.image_grid_thw = multimodal_train_info.pop("image_grid_thw") + del multimodal_train_info return seq_ctx @@ -623,6 +624,8 @@ def fit(self): # 1. Rollout to generate experience rollout_info = self._rollout_step(rollout_idx, step_timer_dict) + train_log_info = {} + eval_log_info = {} if not self._debug_rollout: # 2. Train on the generated experience train_log_info = self._train_step( diff --git a/xtuner/v1/utils/misc.py b/xtuner/v1/utils/misc.py index a588e2f1a..80b1896a5 100644 --- a/xtuner/v1/utils/misc.py +++ b/xtuner/v1/utils/misc.py @@ -1,3 +1,4 @@ +import logging import os import sys import threading @@ -21,6 +22,8 @@ HF_PATCH_MODULES_CACHE_PREFIX = "modules_pid_" +_TRIM_MEMORY_WARNED = False + logger = get_logger() XTUNER_DETERMINISTIC = os.getenv("XTUNER_DETERMINISTIC") == "true" @@ -214,3 +217,20 @@ def clean_param_name(name: str) -> str: if "_orig_mod." in name: name = name.replace("_orig_mod.", "") return name + + +def trim_memory(logger: logging.Logger | None = None): + """Try to return free heap pages to OS.""" + global _TRIM_MEMORY_WARNED + if logger is None: + logger = get_logger() + try: + import ctypes + + libc = ctypes.CDLL("libc.so.6") + return libc.malloc_trim(0) + except Exception as e: + if not _TRIM_MEMORY_WARNED: + logger.warning(f" >>>>>>>>> [trim_memory] Failed to trim memory: {e} <<<<<<<<") + _TRIM_MEMORY_WARNED = True + return False From 236793b5931549a42f88e4dde869bcef40873273 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Tue, 7 Apr 2026 17:41:31 +0800 Subject: [PATCH 2/4] rm trim_memory function --- xtuner/v1/utils/misc.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/xtuner/v1/utils/misc.py b/xtuner/v1/utils/misc.py index 80b1896a5..a588e2f1a 100644 --- a/xtuner/v1/utils/misc.py +++ b/xtuner/v1/utils/misc.py @@ -1,4 +1,3 @@ -import logging import os import sys import threading @@ -22,8 +21,6 @@ HF_PATCH_MODULES_CACHE_PREFIX = "modules_pid_" -_TRIM_MEMORY_WARNED = False - logger = get_logger() XTUNER_DETERMINISTIC = os.getenv("XTUNER_DETERMINISTIC") == "true" @@ -217,20 +214,3 @@ def clean_param_name(name: str) -> str: if "_orig_mod." in name: name = name.replace("_orig_mod.", "") return name - - -def trim_memory(logger: logging.Logger | None = None): - """Try to return free heap pages to OS.""" - global _TRIM_MEMORY_WARNED - if logger is None: - logger = get_logger() - try: - import ctypes - - libc = ctypes.CDLL("libc.so.6") - return libc.malloc_trim(0) - except Exception as e: - if not _TRIM_MEMORY_WARNED: - logger.warning(f" >>>>>>>>> [trim_memory] Failed to trim memory: {e} <<<<<<<<") - _TRIM_MEMORY_WARNED = True - return False From eb3bc6b3700fd3ebf5f8f7520370f119a09b40b5 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Tue, 7 Apr 2026 17:52:22 +0800 Subject: [PATCH 3/4] delete default params in mapping function --- xtuner/v1/ray/dataflow/replay_buffer.py | 31 ++++++++++++++----------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py index 172109840..956cd04d3 100644 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ b/xtuner/v1/ray/dataflow/replay_buffer.py @@ -119,7 +119,7 @@ def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> Re return replay_meta -def mapping_replaymeta_to_dataitem(replay_meta: ReplayMeta, *, consume_refs: bool = False) -> List[RLDataFlowItem]: +def mapping_replaymeta_to_dataitem(replay_meta: ReplayMeta) -> List[RLDataFlowItem]: env_str = replay_meta.env root_id = replay_meta.root_id action_id = replay_meta.action_id @@ -131,14 +131,13 @@ def mapping_replaymeta_to_dataitem(replay_meta: ReplayMeta, *, consume_refs: boo env_values = [ray.get(obs_ref) for obs_ref in observation_refs] - if consume_refs: - refs_to_free: List[ObjectRef] = [] - if isinstance(action_ref, ObjectRef): - refs_to_free.append(action_ref) - refs_to_free.extend([ref for ref in observation_refs if isinstance(ref, ObjectRef)]) - free_object_refs(refs_to_free) - replay_meta.action_ref = None - replay_meta.observation_refs.clear() + refs_to_free: List[ObjectRef] = [] + if isinstance(action_ref, ObjectRef): + refs_to_free.append(action_ref) + refs_to_free.extend([ref for ref in observation_refs if isinstance(ref, ObjectRef)]) + free_object_refs(refs_to_free) + replay_meta.action_ref = None + replay_meta.observation_refs.clear() group_data_item = [] for obs_id, env_data in zip(replay_meta.observation_ids, env_values): @@ -427,7 +426,7 @@ def get(self, global_batch_size: int) -> Tuple[List[List[RLDataFlowItem]], List[ self.logger.warning("Get action_id None from completed_actions and skip this iteration.") continue replay_meta = self._actions.pop(action_id) - group_samples = mapping_replaymeta_to_dataitem(replay_meta, consume_refs=True) + group_samples = mapping_replaymeta_to_dataitem(replay_meta) # 将这条数据彻底清除,不用再记录root_id对应的action_ids了 self._clear_meta_for_root(replay_meta) multimodal_train_info = None @@ -490,6 +489,7 @@ def resolve_ray_objects(self, data_item: RLDataFlowItem): """ # Resolve data.multimodal_train_info + free_refs_list = [] if hasattr(data_item.data, "multimodal_train_info"): multimodal_info = data_item.data.multimodal_train_info if multimodal_info and "pixel_values" in multimodal_info: @@ -497,6 +497,7 @@ def resolve_ray_objects(self, data_item: RLDataFlowItem): if isinstance(pixel_values_ref, ObjectRef): multimodal_info["pixel_values"] = ray.get(pixel_values_ref) data_item.data.multimodal_train_info = multimodal_info + free_refs_list.append(pixel_values_ref) # Resolve rollout.extra_info.router_experts if "routed_experts" in data_item.env.rollout.extra_info: if isinstance(data_item.env.rollout.extra_info["routed_experts"], ObjectRef): @@ -504,7 +505,9 @@ def resolve_ray_objects(self, data_item: RLDataFlowItem): ray.internal.free(data_item.env.rollout.extra_info["routed_experts"], local_only=False) del data_item.env.rollout.extra_info["routed_experts"] data_item.env.rollout.extra_info["routed_experts"] = routed_experts - self.logger.info("Resolved routed_experts ObjectRef in rollout.extra_info") + free_refs_list.append(pixel_values_ref) + + free_object_refs(free_refs_list) def convert_to_ray_objref(self, data_item: RLDataFlowItem): """Converts large tensors in RLDataFlowItem to ray.ObjectRefs. @@ -561,7 +564,7 @@ def dump(self, file_path: Path): all_data_items = [] for replay_meta in self._actions.values(): # dump 仅用于序列化快照,这里可直接消费 refs,避免长时间占用 object store - data_items = mapping_replaymeta_to_dataitem(replay_meta, consume_refs=True) + data_items = mapping_replaymeta_to_dataitem(replay_meta) for item in data_items: self.resolve_ray_objects(item) res = self.has_objectref(item) @@ -644,7 +647,7 @@ def _sample_from_expired_storage(self) -> List[RLDataFlowItem]: assert len(self._expired_actions) > 0 action_id = self._expired_actions.pop() replay_meta = self._actions.pop(action_id) - group_samples = mapping_replaymeta_to_dataitem(replay_meta, consume_refs=True) + group_samples = mapping_replaymeta_to_dataitem(replay_meta) # 把这条数据上次的记录全部删掉,重新开始rollout,root2actions也要清除 self._clear_meta_for_root(replay_meta) @@ -679,7 +682,7 @@ def _sample_from_aborted_storage(self) -> List[RLDataFlowItem]: # 通过self.aborted_samples_count判断过这里不会返回None replay_meta = self._actions.pop(action_id) # type: ignore[arg-type] replay_meta_version = replay_meta.version - group_samples = mapping_replaymeta_to_dataitem(replay_meta, consume_refs=True) + group_samples = mapping_replaymeta_to_dataitem(replay_meta) # 把这条数据上次rollout产生的输出的记录都删掉,上次的数据已经记录在了RLEnvDataItem中了 self._clear_meta_for_actions(replay_meta) From faab53053c849583d7f7c9cc862f72a8078ad1f5 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Tue, 7 Apr 2026 21:18:15 +0800 Subject: [PATCH 4/4] handle extra_info in rb --- xtuner/v1/ray/base/accelerator.py | 22 ++++- xtuner/v1/ray/dataflow/replay_buffer.py | 103 ++++++++++++++++-------- 2 files changed, 88 insertions(+), 37 deletions(-) diff --git a/xtuner/v1/ray/base/accelerator.py b/xtuner/v1/ray/base/accelerator.py index df41feebf..cf0550952 100644 --- a/xtuner/v1/ray/base/accelerator.py +++ b/xtuner/v1/ray/base/accelerator.py @@ -196,6 +196,25 @@ def device_visible_env_name(self): else: raise ValueError(f"Unsupported accelerator type: {self.accelerator}") + def get_logical_local_rank(self) -> int: + """Resolve the assigned accelerator id to the logical local rank. + + Ray reports accelerator ids in the physical numbering space. Torch selects devices from the current visible- + device list, which is indexed logically from zero after applying visibility masks. + """ + accelerator_id = str(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0]) + visible_devices = os.environ.get(self.device_visible_env_name) + if visible_devices is None: + return int(accelerator_id) + + visible_device_ids = [device_id.strip() for device_id in visible_devices.split(",") if device_id.strip()] + if accelerator_id not in visible_device_ids: + raise ValueError( + f"Assigned accelerator id {accelerator_id} is not present in " + f"{self.device_visible_env_name}={visible_devices}." + ) + return visible_device_ids.index(accelerator_id) + def setup_distributed(self, rank: int, master_addr: str, master_port: int, world_size: int): """Set up the distributed environment for the worker. @@ -215,8 +234,7 @@ def setup_distributed(self, rank: int, master_addr: str, master_port: int, world os.environ["MASTER_PORT"] = str(master_port) os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) - os.environ["LOCAL_RANK"] = str(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0]) - + os.environ["LOCAL_RANK"] = str(self.get_logical_local_rank()) # backend 参数是指定通信后端,不是从环境变量获取 # - 'nccl': NVIDIA GPU 间通信(推荐用于 GPU) # - 'gloo': CPU 通信或跨平台 diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py index 956cd04d3..4d2a92c11 100644 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ b/xtuner/v1/ray/dataflow/replay_buffer.py @@ -59,6 +59,7 @@ class ReplayMeta: observation_ids: List[int] = field(default_factory=list) observation_refs: List[ObjectRef] = field(default_factory=list) observation_versions: List[int] = field(default_factory=list) # 目前发数据为按组下发,暂时用不到 + observation_extra_infos: List[RLExtraDataItem] = field(default_factory=list) state: RolloutState = RolloutState.INIT version: int = 0 # version for partial rollout extra_info: Dict[str, Any] = field(default_factory=dict) @@ -95,10 +96,14 @@ def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> Re group_version = grouped_dataitem[0].uid.version observation_ids = [] observation_refs = [] + observation_versions = [] + observation_extra_infos = [] for item in grouped_dataitem: observation_ids.append(item.uid.observation_id) observation_refs.append(ray.put(item.env)) + observation_versions.append(item.uid.version) + observation_extra_infos.append(item.extra_info.model_copy(deep=True)) group_state = determine_group_state(grouped_dataitem) logger.debug( @@ -112,6 +117,8 @@ def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> Re action_ref=ray.put(data), observation_ids=observation_ids, observation_refs=observation_refs, + observation_versions=observation_versions, + observation_extra_infos=observation_extra_infos, state=group_state, version=group_version, extra_info={}, @@ -119,7 +126,7 @@ def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> Re return replay_meta -def mapping_replaymeta_to_dataitem(replay_meta: ReplayMeta) -> List[RLDataFlowItem]: +def mapping_replaymeta_to_dataitem(replay_meta: ReplayMeta, consume_refs: bool = True) -> List[RLDataFlowItem]: env_str = replay_meta.env root_id = replay_meta.root_id action_id = replay_meta.action_id @@ -131,27 +138,40 @@ def mapping_replaymeta_to_dataitem(replay_meta: ReplayMeta) -> List[RLDataFlowIt env_values = [ray.get(obs_ref) for obs_ref in observation_refs] - refs_to_free: List[ObjectRef] = [] - if isinstance(action_ref, ObjectRef): - refs_to_free.append(action_ref) - refs_to_free.extend([ref for ref in observation_refs if isinstance(ref, ObjectRef)]) - free_object_refs(refs_to_free) - replay_meta.action_ref = None - replay_meta.observation_refs.clear() + if consume_refs: + refs_to_free: List[ObjectRef] = [] + if isinstance(action_ref, ObjectRef): + refs_to_free.append(action_ref) + refs_to_free.extend([ref for ref in observation_refs if isinstance(ref, ObjectRef)]) + free_object_refs(refs_to_free) + replay_meta.action_ref = None + replay_meta.observation_refs.clear() group_data_item = [] - for obs_id, env_data in zip(replay_meta.observation_ids, env_values): + observation_versions = replay_meta.observation_versions or [replay_meta.version] * len(replay_meta.observation_ids) + observation_extra_infos = replay_meta.observation_extra_infos or [ + RLExtraDataItem() for _ in replay_meta.observation_ids + ] + for idx, (obs_id, env_data) in enumerate(zip(replay_meta.observation_ids, env_values)): + observation_version = observation_versions[idx] if idx < len(observation_versions) else replay_meta.version + extra_info = ( + observation_extra_infos[idx].model_copy(deep=True) + if idx < len(observation_extra_infos) + else RLExtraDataItem() + ) + if env_data.rollout.state == RolloutState.INIT and replay_meta.state != RolloutState.INIT: + env_data.rollout.state = replay_meta.state item = RLDataFlowItem( uid=RLUIDItem( env=env_str, root_id=root_id, action_id=action_id, observation_id=obs_id, - version=replay_meta.version, + version=observation_version, ), data=data_value, env=env_data, - extra_info=RLExtraDataItem(), + extra_info=extra_info, ) group_data_item.append(item) return group_data_item @@ -473,6 +493,7 @@ def clear(self): "_observations2states", "_states", "_action2observations", + "_multimodal_train_infos", ] for attr in attrs_to_clear: getattr(self, attr).clear() @@ -487,28 +508,33 @@ def resolve_ray_objects(self, data_item: RLDataFlowItem): Returns: RLDataFlowItem: The data item with ray.ObjectRefs resolved. """ - - # Resolve data.multimodal_train_info - free_refs_list = [] - if hasattr(data_item.data, "multimodal_train_info"): - multimodal_info = data_item.data.multimodal_train_info - if multimodal_info and "pixel_values" in multimodal_info: - pixel_values_ref = multimodal_info["pixel_values"] - if isinstance(pixel_values_ref, ObjectRef): - multimodal_info["pixel_values"] = ray.get(pixel_values_ref) - data_item.data.multimodal_train_info = multimodal_info - free_refs_list.append(pixel_values_ref) - # Resolve rollout.extra_info.router_experts - if "routed_experts" in data_item.env.rollout.extra_info: - if isinstance(data_item.env.rollout.extra_info["routed_experts"], ObjectRef): - routed_experts = ray.get(data_item.env.rollout.extra_info["routed_experts"]) - ray.internal.free(data_item.env.rollout.extra_info["routed_experts"], local_only=False) - del data_item.env.rollout.extra_info["routed_experts"] - data_item.env.rollout.extra_info["routed_experts"] = routed_experts - free_refs_list.append(pixel_values_ref) - + free_refs_list: List[ObjectRef] = [] + self._resolve_nested_objectrefs(data_item, free_refs_list) free_object_refs(free_refs_list) + def _resolve_nested_objectrefs(self, obj: Any, refs_to_free: List[ObjectRef]): + if isinstance(obj, ObjectRef): + value = ray.get(obj) + refs_to_free.append(obj) + return self._resolve_nested_objectrefs(value, refs_to_free) + if isinstance(obj, BaseModel): + for field_name in type(obj).model_fields: + setattr(obj, field_name, self._resolve_nested_objectrefs(getattr(obj, field_name), refs_to_free)) + return obj + if isinstance(obj, list): + for idx, value in enumerate(obj): + obj[idx] = self._resolve_nested_objectrefs(value, refs_to_free) + return obj + if isinstance(obj, tuple): + return tuple(self._resolve_nested_objectrefs(value, refs_to_free) for value in obj) + if isinstance(obj, set): + return {self._resolve_nested_objectrefs(value, refs_to_free) for value in obj} + if isinstance(obj, dict): + for key, value in list(obj.items()): + obj[key] = self._resolve_nested_objectrefs(value, refs_to_free) + return obj + return obj + def convert_to_ray_objref(self, data_item: RLDataFlowItem): """Converts large tensors in RLDataFlowItem to ray.ObjectRefs. @@ -538,7 +564,7 @@ def check(obj): if isinstance(obj, ray.ObjectRef): return True if isinstance(obj, BaseModel): - return any(check(getattr(obj, f)) for f in obj.model_fields) + return any(check(getattr(obj, f)) for f in type(obj).model_fields) if isinstance(obj, (list, tuple, set)): return any(check(x) for x in obj) if isinstance(obj, dict): @@ -564,7 +590,7 @@ def dump(self, file_path: Path): all_data_items = [] for replay_meta in self._actions.values(): # dump 仅用于序列化快照,这里可直接消费 refs,避免长时间占用 object store - data_items = mapping_replaymeta_to_dataitem(replay_meta) + data_items = mapping_replaymeta_to_dataitem(replay_meta, consume_refs=False) for item in data_items: self.resolve_ray_objects(item) res = self.has_objectref(item) @@ -580,6 +606,9 @@ def dump(self, file_path: Path): "_observations2states": self._observations2states, "_states": dict(self._states), "_action2observations": dict(self._action2observations), + "_multimodal_train_infos": self._multimodal_train_infos, + "sample_from_aborted_count": self.sample_from_aborted_count, + "sample_from_expired_count": self.sample_from_expired_count, } torch.save(state, file_path) @@ -597,13 +626,16 @@ def resume(self, file_path: Path): state = torch.load(file_path, map_location="cpu", weights_only=False) - self._completed_actions = state["_completed_actions"] - self._aborted_actions = state["_aborted_actions"] + self._completed_actions = defaultdict(list, state["_completed_actions"]) + self._aborted_actions = defaultdict(list, state["_aborted_actions"]) self._expired_actions = state["_expired_actions"] self._root2actions = defaultdict(list, state["_root2actions"]) self._observations2states = state["_observations2states"] self._states = defaultdict(list, state["_states"]) self._action2observations = defaultdict(list, state["_action2observations"]) + self._multimodal_train_infos = state.get("_multimodal_train_infos", {}) + self.sample_from_aborted_count = state.get("sample_from_aborted_count", 0) + self.sample_from_expired_count = state.get("sample_from_expired_count", 0) dump_actions = state["_actions"] # 重建 _actions 和 _observations: 与replaymeta相关 @@ -992,6 +1024,7 @@ def save(self, file_path: Path | str): """ if isinstance(file_path, str): file_path = Path(file_path) + file_path.mkdir(parents=True, exist_ok=True) # save dataloader dataloader_path = file_path / "dataloader"