Skip to content

Commit ca0be5a

Browse files
committed
Fix mRoPE position ID crash when Qwen2-VL prompts are truncated
When training Qwen2.5-VL with agent-lightning + verl, prompt truncation changes the token count but image_grid_thw is computed from the original (untruncated) image_urls. This causes get_rope_index to fail with a shape mismatch because it finds fewer image tokens in the truncated input_ids than entries in image_grid_thw. After prompt truncation, count remaining image regions in the truncated token sequence and slice image_urls to match before computing image_grid_thw, ensuring consistency between the token content and the mRoPE spatial metadata. Fixes #441
1 parent 82d8535 commit ca0be5a

File tree

1 file changed

+48
-1
lines changed

1 file changed

+48
-1
lines changed

agentlightning/verl/daemon.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,45 @@ def _resolve_image_path(self, path: str) -> str:
310310
raise ValueError(f"Relative path '{path}' requires 'image_base_dir' to be set.")
311311
return os.path.join(self.image_base_dir, path)
312312

313+
def _count_images_in_tokens(self, token_ids: List[int]) -> int:
314+
"""Count the number of complete image regions in a token ID sequence.
315+
316+
Image regions are identified by finding ``vision_start_token_id``
317+
followed by ``image_token_id``, matching the detection logic used by
318+
``get_rope_index`` in the Qwen2-VL / Qwen2.5-VL model implementation.
319+
This is needed to reconcile ``image_grid_thw`` with truncated prompts
320+
so that mRoPE position IDs are computed correctly.
321+
322+
Args:
323+
token_ids: List of token IDs (possibly truncated).
324+
325+
Returns:
326+
Number of image regions found in the token sequence, or ``-1`` if
327+
the required special-token IDs could not be resolved (in which case
328+
the caller should fall back to the original image count).
329+
"""
330+
# Resolve image_token_id from the processor (set during __init__)
331+
image_token_id = getattr(self.processor, "image_token_id", None)
332+
if image_token_id is None and hasattr(self.tokenizer, "convert_tokens_to_ids"):
333+
image_token_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>")
334+
335+
# Resolve vision_start_token_id -- not stored on the processor, so we
336+
# try the tokenizer first and fall back to the well-known default.
337+
vision_start_token_id = None
338+
if hasattr(self.tokenizer, "convert_tokens_to_ids"):
339+
vision_start_token_id = self.tokenizer.convert_tokens_to_ids("<|vision_start|>")
340+
if vision_start_token_id is None:
341+
vision_start_token_id = 151652 # Qwen2-VL / Qwen2.5-VL default
342+
343+
if image_token_id is None:
344+
return -1
345+
346+
count = 0
347+
for i in range(len(token_ids) - 1):
348+
if token_ids[i] == vision_start_token_id and token_ids[i + 1] == image_token_id:
349+
count += 1
350+
return count
351+
313352
def _get_image_grid_thw(self, image_urls: List[str]) -> Optional[torch.Tensor]:
314353
"""Compute image_grid_thw from image URLs for M-RoPE computation.
315354
@@ -907,9 +946,17 @@ def get_train_data_batch(
907946
rollout_id_list.append(rollout_id)
908947
turn_index_list.append(turn_index)
909948

910-
# Compute image_grid_thw for this triplet using image_urls from prompt
949+
# Compute image_grid_thw for this triplet using image_urls from prompt.
950+
# After prompt truncation, some image tokens may have been removed,
951+
# so we must reconcile image_urls with the actual images remaining
952+
# in the (possibly truncated) prompt to avoid shape mismatches in
953+
# get_rope_index when computing mRoPE position IDs.
911954
if self._use_mrope:
912955
image_urls = trace.get("image_urls", [])
956+
if image_urls:
957+
n_images_in_tokens = self._count_images_in_tokens(prompt_ids)
958+
if n_images_in_tokens >= 0 and n_images_in_tokens < len(image_urls):
959+
image_urls = image_urls[:n_images_in_tokens]
913960
image_grid_thw_list.append(self._get_image_grid_thw(image_urls))
914961

915962
elif self.trace_aggregator.get("level", "transition") == "trajectory":

0 commit comments

Comments
 (0)