Skip to content

Commit 3b78823

Browse files
authored
Fix position_ids error with VL Model RL (#1664)
fix pixelvalue
1 parent e8a2042 commit 3b78823

3 files changed

Lines changed: 14 additions & 10 deletions

File tree

xtuner/v1/rl/base/controller.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from xtuner.v1.data_proto.sequence_context import SequenceContext
1010
from xtuner.v1.model.compose.base import BaseComposeConfig
11+
from xtuner.v1.ray.utils import free_object_refs
1112
from xtuner.v1.train.trainer import LoadCheckpointConfig
1213
from xtuner.v1.utils import ray_method
1314

@@ -263,6 +264,13 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx:
263264
try:
264265
log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT)
265266
finally:
267+
# free pixel values ref
268+
free_pixel_value_refs: list[ray.ObjectRef] = []
269+
for data in packed_data_batches:
270+
if data["seq_ctx"].pixel_values is not None:
271+
free_pixel_value_refs.extend(data["seq_ctx"].pixel_values)
272+
if len(free_pixel_value_refs) > 0:
273+
free_object_refs(free_pixel_value_refs)
266274
del packed_data_batches
267275
return log_infos
268276

xtuner/v1/rl/base/worker.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo
3636
from xtuner.v1.ray.base import SingleAcceleratorWorker
3737
from xtuner.v1.ray.config import RolloutConfig
38-
from xtuner.v1.ray.utils import free_object_refs
3938
from xtuner.v1.rl.base.loss import BaseRLLossContext
4039
from xtuner.v1.train.trainer import LoadCheckpointConfig
4140
from xtuner.v1.utils import (
@@ -485,11 +484,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo
485484
f"pixel_values should be list of tensor, got {type(pixel_values)}"
486485
)
487486
pixel_value_refs = list(pixel_values)
488-
try:
489-
pixel_values = torch.cat(ray.get(pixel_value_refs), dim=0)
490-
finally:
491-
free_object_refs(pixel_value_refs)
492-
487+
pixel_values = torch.cat(ray.get(pixel_value_refs), dim=0)
493488
seq_ctx.pixel_values = pixel_values
494489

495490
rollout_routed_experts = seq_ctx.rollout_routed_experts

xtuner/v1/train/rl_trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def get_train_seq_ctx(
120120
):
121121
seq_ctx = SequenceContext.from_input_ids((input_ids,), device="cpu")
122122
if multimodal_train_info and len(multimodal_train_info) > 0:
123-
position_ids = multimodal_train_info.pop("position_ids") # (1,n) or (3,1,n)
123+
position_ids = multimodal_train_info.get("position_ids") # (1,n) or (3,1,n)
124124
if position_ids is not None and len(position_ids.shape) == 3:
125125
# qwen3vl 需要特殊处理,其余的不需要额外处理
126126
max_value = position_ids.max(dim=-1).values # (3,1)
@@ -130,9 +130,8 @@ def get_train_seq_ctx(
130130
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
131131
seq_ctx.position_ids = position_ids # type: ignore[assignment]
132132
assert position_ids.size(-1) == input_ids.size(-1)
133-
seq_ctx.pixel_values = multimodal_train_info.pop("pixel_values")
134-
seq_ctx.image_grid_thw = multimodal_train_info.pop("image_grid_thw")
135-
del multimodal_train_info
133+
seq_ctx.pixel_values = multimodal_train_info.get("pixel_values")
134+
seq_ctx.image_grid_thw = multimodal_train_info.get("image_grid_thw")
136135
return seq_ctx
137136

138137

@@ -803,6 +802,8 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf
803802
seq_ctx.rollout_routed_experts = routed_experts # n,layer,expert
804803

805804
data_batches.append(data_dict)
805+
if multimodal_train_info is not None:
806+
del multimodal_train_info
806807
random.shuffle(data_batches)
807808

808809
rewards_t = torch.tensor(rewards_list).float() if rewards_list else torch.tensor([0.0]).float()

0 commit comments

Comments
 (0)