Skip to content

Commit 8970ec1

Browse files
authored
delete huge objectref manually after ray.get() (#1648)
* delete huge objectref manually after ray.get * rm trim_memory function * delete default params in mapping function * handle extra_info in rb
1 parent 8738624 commit 8970ec1

8 files changed

Lines changed: 201 additions & 52 deletions

File tree

xtuner/v1/ray/base/accelerator.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,25 @@ def device_visible_env_name(self):
196196
else:
197197
raise ValueError(f"Unsupported accelerator type: {self.accelerator}")
198198

199+
def get_logical_local_rank(self) -> int:
200+
"""Resolve the assigned accelerator id to the logical local rank.
201+
202+
Ray reports accelerator ids in the physical numbering space. Torch selects devices from the current visible-
203+
device list, which is indexed logically from zero after applying visibility masks.
204+
"""
205+
accelerator_id = str(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0])
206+
visible_devices = os.environ.get(self.device_visible_env_name)
207+
if visible_devices is None:
208+
return int(accelerator_id)
209+
210+
visible_device_ids = [device_id.strip() for device_id in visible_devices.split(",") if device_id.strip()]
211+
if accelerator_id not in visible_device_ids:
212+
raise ValueError(
213+
f"Assigned accelerator id {accelerator_id} is not present in "
214+
f"{self.device_visible_env_name}={visible_devices}."
215+
)
216+
return visible_device_ids.index(accelerator_id)
217+
199218
def setup_distributed(self, rank: int, master_addr: str, master_port: int, world_size: int):
200219
"""Set up the distributed environment for the worker.
201220
@@ -215,8 +234,7 @@ def setup_distributed(self, rank: int, master_addr: str, master_port: int, world
215234
os.environ["MASTER_PORT"] = str(master_port)
216235
os.environ["RANK"] = str(rank)
217236
os.environ["WORLD_SIZE"] = str(world_size)
218-
os.environ["LOCAL_RANK"] = str(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0])
219-
237+
os.environ["LOCAL_RANK"] = str(self.get_logical_local_rank())
220238
# backend 参数是指定通信后端,不是从环境变量获取
221239
# - 'nccl': NVIDIA GPU 间通信(推荐用于 GPU)
222240
# - 'gloo': CPU 通信或跨平台

xtuner/v1/ray/dataflow/replay_buffer.py

Lines changed: 143 additions & 38 deletions
Large diffs are not rendered by default.

xtuner/v1/ray/environment/single_turn_env.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,14 @@ async def generate( # type: ignore[override]
9494
if self.rollout_controller:
9595
response_future = []
9696
for sample in group_data_items:
97-
sample.data.extra_info["root_id"] = sample.uid.root_id
98-
sample.data.extra_info["action_id"] = sample.uid.action_id
97+
rollout_extra_info = dict(sample.data.extra_info)
98+
rollout_extra_info["root_id"] = sample.uid.root_id
99+
rollout_extra_info["action_id"] = sample.uid.action_id
99100
update_sample_params = sample_params
100101

101102
if "partial_rollout_input_ids" in sample.env.rollout.extra_info:
102103
input_ids_length = len(sample.data.input_ids) if sample.data.input_ids is not None else 0
103104
current_partial_length = len(sample.env.rollout.extra_info["partial_rollout_input_ids"])
104-
rollout_extra_info = copy.deepcopy(sample.data.extra_info)
105105
rollout_extra_info["partial_rollout_input_ids"] = sample.env.rollout.extra_info[
106106
"partial_rollout_input_ids"
107107
]
@@ -113,8 +113,6 @@ async def generate( # type: ignore[override]
113113
self.logger.debug(
114114
f"root_id: {sample.uid.root_id}, action_id {sample.uid.action_id} pass current_partial_length {current_partial_length}, input_ids_length {input_ids_length} to rollout and set max_tokens to {update_sample_params.max_tokens}"
115115
)
116-
else:
117-
rollout_extra_info = sample.data.extra_info
118116

119117
if "routed_experts" in sample.env.rollout.extra_info:
120118
rollout_extra_info["routed_experts"] = sample.env.rollout.extra_info["routed_experts"]
@@ -126,6 +124,8 @@ async def generate( # type: ignore[override]
126124
extra_params=extra_params,
127125
extra_info=rollout_extra_info,
128126
)
127+
del rollout_extra_info
128+
129129
response_future.append(fut)
130130
try:
131131
rollout_responses = await asyncio.wait_for(

xtuner/v1/ray/rollout/worker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,7 @@ async def _handle_non_stream_response(
573573

574574
data = base64.b64decode(routed_experts)
575575
routed_experts = ray.cloudpickle.loads(data)
576+
del data
576577
else:
577578
routed_experts = torch.tensor(routed_experts) # n,layer,expert
578579
routed_experts = ray.put(routed_experts)
@@ -586,13 +587,14 @@ async def _handle_non_stream_response(
586587
routed_experts = ray.cloudpickle.loads(data)
587588
cur_routed_experts = await routed_experts # n,layer,expert
588589
ray.internal.free(routed_experts, local_only=False)
590+
del data
589591
else:
590592
routed_experts = torch.tensor(routed_experts) # n,layer,expert
591593
cur_routed_experts = routed_experts
592594

593595
history_routed_experts = await input_extra_info["routed_experts"] # n, layer, expert
594596
ray.internal.free(input_extra_info["routed_experts"], local_only=False)
595-
del input_extra_info["routed_experts"]
597+
del input_extra_info
596598

597599
assert (history_routed_experts.shape[0] - 1) > 0 and history_routed_experts.shape[
598600
0
@@ -613,6 +615,8 @@ async def _handle_non_stream_response(
613615
f"Experts(exist={history_routed_experts.shape}, init_cur={init_cur_roued_experts}, cur={cur_routed_experts.shape}, concat={concat_routed_experts.shape})"
614616
)
615617
extra_info["routed_experts"] = ray.put(concat_routed_experts)
618+
del history_routed_experts
619+
del cur_routed_experts
616620
else:
617621
assert finish_reason == "abort", (
618622
f"routed_experts is None, but finish_reason is {finish_reason}, expected abort. response: {response}"

xtuner/v1/ray/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import TYPE_CHECKING, Callable, Coroutine, List, Optional, cast
66

77
import ray
8+
from ray import ObjectRef
89

910

1011
if TYPE_CHECKING:
@@ -208,3 +209,13 @@ def create_task(
208209
for callback in done_callbacks:
209210
task.add_done_callback(callback)
210211
return task
212+
213+
214+
def free_object_refs(refs: List[ObjectRef]) -> None:
215+
valid_refs = [ref for ref in refs if isinstance(ref, ObjectRef)]
216+
if not valid_refs:
217+
return
218+
try:
219+
ray._private.internal_api.free(valid_refs, local_only=False)
220+
except Exception:
221+
ray.internal.free(valid_refs, local_only=False)

xtuner/v1/rl/base/controller.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,10 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx:
260260
rollout_idx=rollout_idx,
261261
)
262262
)
263-
log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT)
263+
try:
264+
log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT)
265+
finally:
266+
del packed_data_batches
264267
return log_infos
265268

266269
@ray_method

xtuner/v1/rl/base/worker.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo
3636
from xtuner.v1.ray.base import SingleAcceleratorWorker
3737
from xtuner.v1.ray.config import RolloutConfig
38+
from xtuner.v1.ray.utils import free_object_refs
3839
from xtuner.v1.rl.base.loss import BaseRLLossContext
3940
from xtuner.v1.rl.utils import gather_logprobs
4041
from xtuner.v1.train.trainer import LoadCheckpointConfig
@@ -483,8 +484,12 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo
483484
assert isinstance(pixel_values, list), (
484485
f"pixel_values should be list of tensor, got {type(pixel_values)}"
485486
)
486-
pixel_values = [ray.get(pixel_obf) for pixel_obf in pixel_values]
487-
pixel_values = torch.cat(pixel_values, dim=0)
487+
pixel_value_refs = list(pixel_values)
488+
try:
489+
pixel_values = torch.cat(ray.get(pixel_value_refs), dim=0)
490+
finally:
491+
free_object_refs(pixel_value_refs)
492+
488493
seq_ctx.pixel_values = pixel_values
489494

490495
rollout_routed_experts = seq_ctx.rollout_routed_experts

xtuner/v1/train/rl_trainer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def get_train_seq_ctx(
120120
):
121121
seq_ctx = SequenceContext.from_input_ids((input_ids,), device="cpu")
122122
if multimodal_train_info and len(multimodal_train_info) > 0:
123-
position_ids = multimodal_train_info.get("position_ids") # (1,n) or (3,1,n)
123+
position_ids = multimodal_train_info.pop("position_ids") # (1,n) or (3,1,n)
124124
if position_ids is not None and len(position_ids.shape) == 3:
125125
# qwen3vl 需要特殊处理,其余的不需要额外处理
126126
max_value = position_ids.max(dim=-1).values # (3,1)
@@ -130,8 +130,9 @@ def get_train_seq_ctx(
130130
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
131131
seq_ctx.position_ids = position_ids # type: ignore[assignment]
132132
assert position_ids.size(-1) == input_ids.size(-1)
133-
seq_ctx.pixel_values = multimodal_train_info.get("pixel_values")
134-
seq_ctx.image_grid_thw = multimodal_train_info.get("image_grid_thw")
133+
seq_ctx.pixel_values = multimodal_train_info.pop("pixel_values")
134+
seq_ctx.image_grid_thw = multimodal_train_info.pop("image_grid_thw")
135+
del multimodal_train_info
135136
return seq_ctx
136137

137138

@@ -629,6 +630,8 @@ def fit(self):
629630
# 1. Rollout to generate experience
630631
rollout_info = self._rollout_step(rollout_idx, step_timer_dict)
631632

633+
train_log_info = {}
634+
eval_log_info = {}
632635
if not self._debug_rollout:
633636
# 2. Train on the generated experience
634637
train_log_info = self._train_step(

0 commit comments

Comments
 (0)