Skip to content

Commit 2bd7aa8

Browse files
hhaAndroidnil0x9
authored andcommitted
Fix position_ids error with VL Model RL (#1664)
fix pixelvalue
1 parent e8a2042 commit 2bd7aa8

File tree

4 files changed

+27
-10
lines changed

4 files changed

+27
-10
lines changed

xtuner/v1/data_proto/templates/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,19 @@
6262
image_context_token="<|image_pad|>",
6363
video_context_token="<|video_pad|>",
6464
),
65+
"qwen3.5-vl": HybridChatTemplate(
66+
template_name="qwen3.5-vl",
67+
system="<|im_start|>system\n{system}<|im_end|>\n",
68+
tool_prompt="# Tools\n\nYou have access to the following functions:\n\n<tools>{tool_text}\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>",
69+
tool_extractor="<|im_start|>user\n<tool_response>\n{tool_extractor}\n</tool_response><|im_end|>\n<|im_start|>assistant\n",
70+
user="<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
71+
stop_words=["<|im_end|>", "<|endoftext|>"],
72+
assistant="{assistant}<|im_end|>",
73+
image_start_token="<|vision_start|>",
74+
image_end_token="<|vision_end|>",
75+
image_context_token="<|image_pad|>",
76+
video_context_token="<|video_pad|>",
77+
),
6578
"llama3": HybridChatTemplate(
6679
system="<|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|>",
6780
user=(

xtuner/v1/rl/base/controller.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from xtuner.v1.data_proto.sequence_context import SequenceContext
1010
from xtuner.v1.model.compose.base import BaseComposeConfig
11+
from xtuner.v1.ray.utils import free_object_refs
1112
from xtuner.v1.train.trainer import LoadCheckpointConfig
1213
from xtuner.v1.utils import ray_method
1314

@@ -263,6 +264,13 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx:
263264
try:
264265
log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT)
265266
finally:
267+
# free pixel values ref
268+
free_pixel_value_refs: list[ray.ObjectRef] = []
269+
for data in packed_data_batches:
270+
if data["seq_ctx"].pixel_values is not None:
271+
free_pixel_value_refs.extend(data["seq_ctx"].pixel_values)
272+
if len(free_pixel_value_refs) > 0:
273+
free_object_refs(free_pixel_value_refs)
266274
del packed_data_batches
267275
return log_infos
268276

xtuner/v1/rl/base/worker.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo
3636
from xtuner.v1.ray.base import SingleAcceleratorWorker
3737
from xtuner.v1.ray.config import RolloutConfig
38-
from xtuner.v1.ray.utils import free_object_refs
3938
from xtuner.v1.rl.base.loss import BaseRLLossContext
4039
from xtuner.v1.train.trainer import LoadCheckpointConfig
4140
from xtuner.v1.utils import (
@@ -485,11 +484,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo
485484
f"pixel_values should be list of tensor, got {type(pixel_values)}"
486485
)
487486
pixel_value_refs = list(pixel_values)
488-
try:
489-
pixel_values = torch.cat(ray.get(pixel_value_refs), dim=0)
490-
finally:
491-
free_object_refs(pixel_value_refs)
492-
487+
pixel_values = torch.cat(ray.get(pixel_value_refs), dim=0)
493488
seq_ctx.pixel_values = pixel_values
494489

495490
rollout_routed_experts = seq_ctx.rollout_routed_experts

xtuner/v1/train/rl_trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def get_train_seq_ctx(
120120
):
121121
seq_ctx = SequenceContext.from_input_ids((input_ids,), device="cpu")
122122
if multimodal_train_info and len(multimodal_train_info) > 0:
123-
position_ids = multimodal_train_info.pop("position_ids") # (1,n) or (3,1,n)
123+
position_ids = multimodal_train_info.get("position_ids") # (1,n) or (3,1,n)
124124
if position_ids is not None and len(position_ids.shape) == 3:
125125
# qwen3vl 需要特殊处理,其余的不需要额外处理
126126
max_value = position_ids.max(dim=-1).values # (3,1)
@@ -130,9 +130,8 @@ def get_train_seq_ctx(
130130
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
131131
seq_ctx.position_ids = position_ids # type: ignore[assignment]
132132
assert position_ids.size(-1) == input_ids.size(-1)
133-
seq_ctx.pixel_values = multimodal_train_info.pop("pixel_values")
134-
seq_ctx.image_grid_thw = multimodal_train_info.pop("image_grid_thw")
135-
del multimodal_train_info
133+
seq_ctx.pixel_values = multimodal_train_info.get("pixel_values")
134+
seq_ctx.image_grid_thw = multimodal_train_info.get("image_grid_thw")
136135
return seq_ctx
137136

138137

@@ -803,6 +802,8 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf
803802
seq_ctx.rollout_routed_experts = routed_experts # n,layer,expert
804803

805804
data_batches.append(data_dict)
805+
if multimodal_train_info is not None:
806+
del multimodal_train_info
806807
random.shuffle(data_batches)
807808

808809
rewards_t = torch.tensor(rewards_list).float() if rewards_list else torch.tensor([0.0]).float()

0 commit comments

Comments
 (0)