Skip to content

Commit 768bb46

Browse files
committed
fix(vlm): forward get_rope_index to neat packing for mRoPE models
The VLM recipe never passed the model's get_rope_index callable to neat_pack_dataset_vlm. With it absent, PackedDatasetWrapper sets has_mrope=False and emits 1D position_ids per pack. The collater then forwards 2D [B, L] position_ids to the model, which short-circuits get_rope_index inside model.forward and the language model expands the same 1D positions across all 3 mRoPE channels. Net effect: packed Qwen2.5-VL / Qwen3-VL / Qwen3-VL-MoE / Qwen3-Omni training silently degraded mRoPE to plain 1D rotary, losing image spatial/temporal positional information. Non-packed and non-mRoPE VLMs were unaffected. Plumbing only: extract get_rope_index via getattr(model_parts[0], ...) in the recipe and forward it through build_dataloader to neat_pack_dataset_vlm. Models without the method (Mistral3, LLaVA-OV, KimiVL, Gemma4-VLM) keep the prior behavior since getattr returns None. Adds two unit tests guarding the wiring against regression. Signed-off-by: khazic <khazzz1c@gmail.com>
1 parent 2afd94e commit 768bb46

2 files changed

Lines changed: 144 additions & 0 deletions

File tree

nemo_automodel/recipes/vlm/finetune.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ def build_dataloader(
348348
local_batch_size,
349349
cfg_model=None,
350350
cfg_ps=None,
351+
get_rope_index=None,
351352
) -> tuple[DataLoader, ProcessorMixin]:
352353
"""Build a DataLoader for the VLM dataset.
353354
@@ -362,6 +363,11 @@ def build_dataloader(
362363
cfg_model: Model configuration (used to detect attention backend).
363364
cfg_ps: Packed sequence configuration (top-level ``packed_sequence:`` section).
364365
When provided, takes precedence over ``dataset.packing``.
366+
get_rope_index: Optional ``model.get_rope_index`` callable. When provided,
367+
VLM neat packing computes mRoPE 3D position IDs per sample so packed
368+
mRoPE-aware models (Qwen2.5-VL, Qwen3-VL, ...) preserve multimodal
369+
position semantics across pack boundaries instead of falling back to
370+
plain 1D positions.
365371
366372
Returns:
367373
The instantiated DataLoader and processor.
@@ -479,6 +485,7 @@ def build_dataloader(
479485
packing_ratio=packing_cfg.get("packing_ratio", 1.0),
480486
processor=processor,
481487
balance_media_tokens=packing_cfg.get("balance_media_tokens", True),
488+
get_rope_index=get_rope_index,
482489
)
483490
_pad_id = getattr(processor.tokenizer, "pad_token_id", 0) or 0
484491
_collate_max_length = packing_cfg.get("collate_max_length", None)
@@ -832,6 +839,11 @@ def setup(self):
832839
self.model_parts = [model]
833840
self.pp = None
834841

842+
# Extract mRoPE position-id builder from the model so VLM neat packing can
843+
# produce 3D position_ids per sample. Without this, packed Qwen2.5-VL /
844+
# Qwen3-VL training silently degrades mRoPE to plain 1D positions.
845+
get_rope_index = getattr(self.model_parts[0], "get_rope_index", None)
846+
835847
self.dataloader, self.processor = build_dataloader(
836848
self.cfg.dataset,
837849
self.cfg.dataloader,
@@ -842,6 +854,7 @@ def setup(self):
842854
local_batch_size=self.cfg.get("step_scheduler.local_batch_size", 1),
843855
cfg_model=self.cfg.model,
844856
cfg_ps=self.cfg.get("packed_sequence", None),
857+
get_rope_index=get_rope_index,
845858
)
846859

847860
# Build validation dataloader if the config provides it
@@ -855,6 +868,7 @@ def setup(self):
855868
device_mesh=self.device_mesh,
856869
seed=self.cfg.get("seed", 42),
857870
local_batch_size=self.cfg.get("step_scheduler.local_batch_size", 1),
871+
get_rope_index=get_rope_index,
858872
)
859873

860874
self.best_metric_key = self.cfg.get("checkpoint.best_metric_key", "default")

tests/unit_tests/recipes/test_finetune_vlm_helpers.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2432,3 +2432,133 @@ def test_fallback_mismatched_images(self):
24322432
assert len(pv_chunks) == 2
24332433
assert pv_chunks[0].shape[0] == 12 # all in first
24342434
assert pv_chunks[1].shape[0] == 0 # empty
2435+
2436+
2437+
# -----------------------------------------------------------------------------
2438+
# get_rope_index forwarding tests for build_dataloader
2439+
#
2440+
# Guard against a regression where the VLM recipe forgot to pass
2441+
# get_rope_index to neat_pack_dataset_vlm, silently degrading mRoPE to
2442+
# plain 1D positions for packed Qwen2.5-VL / Qwen3-VL training.
2443+
# -----------------------------------------------------------------------------
2444+
2445+
2446+
def _make_packing_cfg(pack_size=128):
2447+
cfg = MagicMock()
2448+
cfg.pack_size = pack_size
2449+
cfg.pretokenize = True
2450+
cfg.max_length = pack_size
2451+
cfg.get.side_effect = lambda key, default=None: {
2452+
"pack_size": pack_size,
2453+
"drop_long_samples": True,
2454+
"max_packs": None,
2455+
"packing_ratio": 1.0,
2456+
"balance_media_tokens": True,
2457+
"collate_max_length": None,
2458+
"post_tokenize_hook_fn": None,
2459+
}.get(key, default)
2460+
return cfg
2461+
2462+
2463+
def _make_dataset_cfg():
2464+
cfg = MagicMock(spec=["get", "instantiate", "path_or_dataset"])
2465+
cfg.get.side_effect = lambda key, default=None: {
2466+
"path_or_dataset": None,
2467+
"truncate": True,
2468+
}.get(key, default)
2469+
cfg.path_or_dataset = None
2470+
cfg.instantiate.return_value = []
2471+
return cfg
2472+
2473+
2474+
def _patches_for_packing(neat_pack_side_effect):
2475+
processor = MagicMock()
2476+
processor.tokenizer.pad_token_id = 0
2477+
processor.chat_template = "{{ x }}"
2478+
return processor, [
2479+
patch("transformers.AutoProcessor.from_pretrained", return_value=processor),
2480+
patch("torch.utils.data.distributed.DistributedSampler"),
2481+
patch(
2482+
"nemo_automodel.components.datasets.vlm.datasets.PreTokenizedDatasetWrapper",
2483+
return_value=MagicMock(),
2484+
),
2485+
patch(
2486+
"nemo_automodel.components.datasets.vlm.neat_packing_vlm.neat_pack_dataset_vlm",
2487+
side_effect=neat_pack_side_effect,
2488+
),
2489+
patch("nemo_automodel.components.models.common.packing.configure_packing"),
2490+
patch(
2491+
"nemo_automodel.components.models.common.packing.get_attn_implementation",
2492+
return_value="sdpa",
2493+
),
2494+
]
2495+
2496+
2497+
def test_build_dataloader_forwards_get_rope_index_to_packing():
2498+
"""get_rope_index passed to build_dataloader must reach neat_pack_dataset_vlm."""
2499+
from contextlib import ExitStack
2500+
2501+
from nemo_automodel.recipes.vlm.finetune import build_dataloader
2502+
2503+
sentinel = MagicMock(name="get_rope_index")
2504+
captured = {}
2505+
2506+
def fake_neat_pack(*args, **kwargs):
2507+
captured.update(kwargs)
2508+
return MagicMock()
2509+
2510+
_, ctx_managers = _patches_for_packing(fake_neat_pack)
2511+
2512+
with ExitStack() as stack:
2513+
for cm in ctx_managers:
2514+
stack.enter_context(cm)
2515+
build_dataloader(
2516+
_make_dataset_cfg(),
2517+
MagicMock(get=MagicMock(return_value=None), instantiate=MagicMock(return_value=MagicMock())),
2518+
"test/model",
2519+
None,
2520+
None,
2521+
42,
2522+
1,
2523+
cfg_ps=_make_packing_cfg(pack_size=64),
2524+
get_rope_index=sentinel,
2525+
)
2526+
2527+
assert captured.get("get_rope_index") is sentinel, (
2528+
"build_dataloader must forward get_rope_index to neat_pack_dataset_vlm; "
2529+
f"got kwargs={list(captured.keys())}"
2530+
)
2531+
2532+
2533+
def test_build_dataloader_default_get_rope_index_is_none():
2534+
"""When the model does not expose get_rope_index, packing must receive None."""
2535+
from contextlib import ExitStack
2536+
2537+
from nemo_automodel.recipes.vlm.finetune import build_dataloader
2538+
2539+
captured = {}
2540+
2541+
def fake_neat_pack(*args, **kwargs):
2542+
captured.update(kwargs)
2543+
return MagicMock()
2544+
2545+
_, ctx_managers = _patches_for_packing(fake_neat_pack)
2546+
2547+
with ExitStack() as stack:
2548+
for cm in ctx_managers:
2549+
stack.enter_context(cm)
2550+
build_dataloader(
2551+
_make_dataset_cfg(),
2552+
MagicMock(get=MagicMock(return_value=None), instantiate=MagicMock(return_value=MagicMock())),
2553+
"test/model",
2554+
None,
2555+
None,
2556+
42,
2557+
1,
2558+
cfg_ps=_make_packing_cfg(pack_size=64),
2559+
)
2560+
2561+
assert "get_rope_index" in captured, (
2562+
"neat_pack_dataset_vlm must receive get_rope_index kwarg even when None"
2563+
)
2564+
assert captured["get_rope_index"] is None

0 commit comments

Comments
 (0)