Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions xtuner/v1/ray/base/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,25 @@ def device_visible_env_name(self):
else:
raise ValueError(f"Unsupported accelerator type: {self.accelerator}")

def get_logical_local_rank(self) -> int:
"""Resolve the assigned accelerator id to the logical local rank.

Ray reports accelerator ids in the physical numbering space. Torch selects devices from the current visible-
device list, which is indexed logically from zero after applying visibility masks.
"""
accelerator_id = str(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0])
visible_devices = os.environ.get(self.device_visible_env_name)
if visible_devices is None:
return int(accelerator_id)

visible_device_ids = [device_id.strip() for device_id in visible_devices.split(",") if device_id.strip()]
if accelerator_id not in visible_device_ids:
raise ValueError(
f"Assigned accelerator id {accelerator_id} is not present in "
f"{self.device_visible_env_name}={visible_devices}."
)
return visible_device_ids.index(accelerator_id)

def setup_distributed(self, rank: int, master_addr: str, master_port: int, world_size: int):
"""Set up the distributed environment for the worker.

Expand All @@ -215,8 +234,7 @@ def setup_distributed(self, rank: int, master_addr: str, master_port: int, world
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_RANK"] = str(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0])

os.environ["LOCAL_RANK"] = str(self.get_logical_local_rank())
# backend 参数是指定通信后端,不是从环境变量获取
# - 'nccl': NVIDIA GPU 间通信(推荐用于 GPU)
# - 'gloo': CPU 通信或跨平台
Expand Down
181 changes: 143 additions & 38 deletions xtuner/v1/ray/dataflow/replay_buffer.py

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions xtuner/v1/ray/environment/single_turn_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,14 @@ async def generate( # type: ignore[override]
if self.rollout_controller:
response_future = []
for sample in group_data_items:
sample.data.extra_info["root_id"] = sample.uid.root_id
sample.data.extra_info["action_id"] = sample.uid.action_id
rollout_extra_info = dict(sample.data.extra_info)
rollout_extra_info["root_id"] = sample.uid.root_id
rollout_extra_info["action_id"] = sample.uid.action_id
update_sample_params = sample_params

if "partial_rollout_input_ids" in sample.env.rollout.extra_info:
input_ids_length = len(sample.data.input_ids) if sample.data.input_ids is not None else 0
current_partial_length = len(sample.env.rollout.extra_info["partial_rollout_input_ids"])
rollout_extra_info = copy.deepcopy(sample.data.extra_info)
rollout_extra_info["partial_rollout_input_ids"] = sample.env.rollout.extra_info[
"partial_rollout_input_ids"
]
Expand All @@ -113,8 +113,6 @@ async def generate( # type: ignore[override]
self.logger.debug(
f"root_id: {sample.uid.root_id}, action_id {sample.uid.action_id} pass current_partial_length {current_partial_length}, input_ids_length {input_ids_length} to rollout and set max_tokens to {update_sample_params.max_tokens}"
)
else:
rollout_extra_info = sample.data.extra_info

if "routed_experts" in sample.env.rollout.extra_info:
rollout_extra_info["routed_experts"] = sample.env.rollout.extra_info["routed_experts"]
Expand All @@ -126,6 +124,8 @@ async def generate( # type: ignore[override]
extra_params=extra_params,
extra_info=rollout_extra_info,
)
del rollout_extra_info

response_future.append(fut)
try:
rollout_responses = await asyncio.wait_for(
Expand Down
6 changes: 5 additions & 1 deletion xtuner/v1/ray/rollout/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@ async def _handle_non_stream_response(

data = base64.b64decode(routed_experts)
routed_experts = ray.cloudpickle.loads(data)
del data
else:
routed_experts = torch.tensor(routed_experts) # n,layer,expert
routed_experts = ray.put(routed_experts)
Expand All @@ -586,13 +587,14 @@ async def _handle_non_stream_response(
routed_experts = ray.cloudpickle.loads(data)
cur_routed_experts = await routed_experts # n,layer,expert
ray.internal.free(routed_experts, local_only=False)
del data
else:
routed_experts = torch.tensor(routed_experts) # n,layer,expert
cur_routed_experts = routed_experts

history_routed_experts = await input_extra_info["routed_experts"] # n, layer, expert
ray.internal.free(input_extra_info["routed_experts"], local_only=False)
del input_extra_info["routed_experts"]
del input_extra_info

assert (history_routed_experts.shape[0] - 1) > 0 and history_routed_experts.shape[
0
Expand All @@ -613,6 +615,8 @@ async def _handle_non_stream_response(
f"Experts(exist={history_routed_experts.shape}, init_cur={init_cur_roued_experts}, cur={cur_routed_experts.shape}, concat={concat_routed_experts.shape})"
)
extra_info["routed_experts"] = ray.put(concat_routed_experts)
del history_routed_experts
del cur_routed_experts
else:
assert finish_reason == "abort", (
f"routed_experts is None, but finish_reason is {finish_reason}, expected abort. response: {response}"
Expand Down
11 changes: 11 additions & 0 deletions xtuner/v1/ray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, Callable, Coroutine, List, Optional, cast

import ray
from ray import ObjectRef


if TYPE_CHECKING:
Expand Down Expand Up @@ -208,3 +209,13 @@ def create_task(
for callback in done_callbacks:
task.add_done_callback(callback)
return task


def free_object_refs(refs: List[ObjectRef]) -> None:
valid_refs = [ref for ref in refs if isinstance(ref, ObjectRef)]
if not valid_refs:
return
try:
ray._private.internal_api.free(valid_refs, local_only=False)
except Exception:
ray.internal.free(valid_refs, local_only=False)
5 changes: 4 additions & 1 deletion xtuner/v1/rl/base/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,10 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx:
rollout_idx=rollout_idx,
)
)
log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT)
try:
log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT)
finally:
del packed_data_batches
return log_infos

@ray_method
Expand Down
9 changes: 7 additions & 2 deletions xtuner/v1/rl/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo
from xtuner.v1.ray.base import SingleAcceleratorWorker
from xtuner.v1.ray.config import RolloutConfig
from xtuner.v1.ray.utils import free_object_refs
from xtuner.v1.rl.base.loss import BaseRLLossContext
from xtuner.v1.rl.utils import gather_logprobs
from xtuner.v1.train.trainer import LoadCheckpointConfig
Expand Down Expand Up @@ -483,8 +484,12 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo
assert isinstance(pixel_values, list), (
f"pixel_values should be list of tensor, got {type(pixel_values)}"
)
pixel_values = [ray.get(pixel_obf) for pixel_obf in pixel_values]
pixel_values = torch.cat(pixel_values, dim=0)
pixel_value_refs = list(pixel_values)
try:
pixel_values = torch.cat(ray.get(pixel_value_refs), dim=0)
finally:
free_object_refs(pixel_value_refs)

seq_ctx.pixel_values = pixel_values

rollout_routed_experts = seq_ctx.rollout_routed_experts
Expand Down
9 changes: 6 additions & 3 deletions xtuner/v1/train/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_train_seq_ctx(
):
seq_ctx = SequenceContext.from_input_ids((input_ids,), device="cpu")
if multimodal_train_info and len(multimodal_train_info) > 0:
position_ids = multimodal_train_info.get("position_ids") # (1,n) or (3,1,n)
position_ids = multimodal_train_info.pop("position_ids") # (1,n) or (3,1,n)
if position_ids is not None and len(position_ids.shape) == 3:
# qwen3vl 需要特殊处理,其余的不需要额外处理
max_value = position_ids.max(dim=-1).values # (3,1)
Expand All @@ -128,8 +128,9 @@ def get_train_seq_ctx(
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
seq_ctx.position_ids = position_ids # type: ignore[assignment]
assert position_ids.size(-1) == input_ids.size(-1)
seq_ctx.pixel_values = multimodal_train_info.get("pixel_values")
seq_ctx.image_grid_thw = multimodal_train_info.get("image_grid_thw")
seq_ctx.pixel_values = multimodal_train_info.pop("pixel_values")
seq_ctx.image_grid_thw = multimodal_train_info.pop("image_grid_thw")
del multimodal_train_info
return seq_ctx


Expand Down Expand Up @@ -623,6 +624,8 @@ def fit(self):
# 1. Rollout to generate experience
rollout_info = self._rollout_step(rollout_idx, step_timer_dict)

train_log_info = {}
eval_log_info = {}
if not self._debug_rollout:
# 2. Train on the generated experience
train_log_info = self._train_step(
Expand Down
Loading