Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions xtuner/v1/rl/base/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from xtuner.v1.data_proto.sequence_context import SequenceContext
from xtuner.v1.model.compose.base import BaseComposeConfig
from xtuner.v1.ray.utils import free_object_refs
from xtuner.v1.train.trainer import LoadCheckpointConfig
from xtuner.v1.utils import ray_method

Expand Down Expand Up @@ -263,6 +264,13 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx:
try:
log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT)
finally:
# free pixel values ref
free_pixel_value_refs: list[ray.ObjectRef] = []
for data in packed_data_batches:
if data["seq_ctx"].pixel_values is not None:
free_pixel_value_refs.extend(data["seq_ctx"].pixel_values)
if len(free_pixel_value_refs) > 0:
free_object_refs(free_pixel_value_refs)
del packed_data_batches
return log_infos

Expand Down
7 changes: 1 addition & 6 deletions xtuner/v1/rl/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo
from xtuner.v1.ray.base import SingleAcceleratorWorker
from xtuner.v1.ray.config import RolloutConfig
from xtuner.v1.ray.utils import free_object_refs
from xtuner.v1.rl.base.loss import BaseRLLossContext
from xtuner.v1.train.trainer import LoadCheckpointConfig
from xtuner.v1.utils import (
Expand Down Expand Up @@ -485,11 +484,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo
f"pixel_values should be list of tensor, got {type(pixel_values)}"
)
pixel_value_refs = list(pixel_values)
try:
pixel_values = torch.cat(ray.get(pixel_value_refs), dim=0)
finally:
free_object_refs(pixel_value_refs)

pixel_values = torch.cat(ray.get(pixel_value_refs), dim=0)
seq_ctx.pixel_values = pixel_values

rollout_routed_experts = seq_ctx.rollout_routed_experts
Expand Down
9 changes: 5 additions & 4 deletions xtuner/v1/train/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def get_train_seq_ctx(
):
seq_ctx = SequenceContext.from_input_ids((input_ids,), device="cpu")
if multimodal_train_info and len(multimodal_train_info) > 0:
position_ids = multimodal_train_info.pop("position_ids") # (1,n) or (3,1,n)
position_ids = multimodal_train_info.get("position_ids") # (1,n) or (3,1,n)
if position_ids is not None and len(position_ids.shape) == 3:
# qwen3vl 需要特殊处理,其余的不需要额外处理
max_value = position_ids.max(dim=-1).values # (3,1)
Expand All @@ -130,9 +130,8 @@ def get_train_seq_ctx(
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
seq_ctx.position_ids = position_ids # type: ignore[assignment]
assert position_ids.size(-1) == input_ids.size(-1)
seq_ctx.pixel_values = multimodal_train_info.pop("pixel_values")
seq_ctx.image_grid_thw = multimodal_train_info.pop("image_grid_thw")
del multimodal_train_info
seq_ctx.pixel_values = multimodal_train_info.get("pixel_values")
seq_ctx.image_grid_thw = multimodal_train_info.get("image_grid_thw")
return seq_ctx


Expand Down Expand Up @@ -803,6 +802,8 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf
seq_ctx.rollout_routed_experts = routed_experts # n,layer,expert

data_batches.append(data_dict)
if multimodal_train_info is not None:
del multimodal_train_info
random.shuffle(data_batches)

rewards_t = torch.tensor(rewards_list).float() if rewards_list else torch.tensor([0.0]).float()
Expand Down
Loading