diff --git a/examples/multimodal_dev/README.md b/examples/multimodal_dev/README.md
new file mode 100644
index 00000000000..bdc34414da9
--- /dev/null
+++ b/examples/multimodal_dev/README.md
@@ -0,0 +1,162 @@
+# multimodal_dev — Standalone Multimodal Training
+
+Standalone, model-agnostic training entry point for multimodal
+vision-language models built on Megatron-Core (FSDP + EP).
+
+## Directory Structure
+
+```
+multimodal_dev/
+├── pretrain_multimodal.py # Training entry point (model-agnostic)
+├── forward_step.py # Forward step, TP broadcast, loss computation
+├── arguments.py # Multimodal CLI arguments
+├── data/
+│ └── mock.py # Mock dataset for end-to-end testing
+├── models/
+│ ├── __init__.py # MODEL_REGISTRY — central model registry
+│ ├── base.py # MultimodalModel base class (vision encoder + GPTModel)
+│ └── qwen35_vl/ # Qwen3.5-VL architecture
+│ ├── factory.py # Factory functions for pretrain entry point
+│ ├── model.py # Qwen35VLModel (MRoPE, vision encoder wiring)
+│ ├── configuration.py # TransformerConfig builders and constants
+│ ├── specs.py # Layer spec builders (hybrid attention, ViT)
+│ ├── mrope.py # 3D MRoPE position ID computation
+│ └── vision_encoder.py# ViT encoder (patch embed, merger, RoPE)
+└── scripts/ # Launch scripts (torchrun, Slurm)
+```
+
+## Quick Start
+
+```bash
+torchrun --nproc_per_node=8 multimodal_dev/pretrain_multimodal.py \
+ --model-arch qwen35_vl \
+ --dataset-provider mock \
+ ... # other Megatron args (--num-layers, --hidden-size, etc.)
+```
+
+## Architecture
+
+`pretrain_multimodal.py` is **model-agnostic**. All model-specific logic
+is delegated to factory functions registered in `MODEL_REGISTRY`
+(`models/__init__.py`). The entry point handles only generic concerns:
+
+- Building `language_config` from Megatron CLI args
+- Constructing `vision_config` via the registry
+- Applying vision recompute and dtype propagation
+- Routing to model and dataset factories
+
+The `forward_step` is also model-agnostic — it uses the model's
+`compute_position_ids()` method polymorphically and passes a standard
+batch dict.
+
+## Adding a New Model Architecture
+
+Adding a new model (e.g. `llava_next`) requires **no changes** to
+`pretrain_multimodal.py` or `forward_step.py`. Follow these steps:
+
+### Step 1 — Create the model package
+
+```
+multimodal_dev/models/llava_next/
+├── __init__.py
+├── factory.py # Required: factory functions
+├── configuration.py # Vision/language TransformerConfig builders
+├── model.py # Model class (subclass MultimodalModel)
+├── specs.py # Layer spec builders
+└── vision_encoder.py # Vision encoder (if custom)
+```
+
+### Step 2 — Implement factory functions
+
+Create `factory.py` with up to three functions:
+
+```python
+# models/llava_next/factory.py
+
+def post_language_config(language_config, args):
+ """(Optional) Mutate language_config with model-specific fields."""
+ # e.g. language_config.some_field = value
+ pass
+
+def set_vision_flops_metadata(args, language_config, vision_config):
+ """(Optional) Set vision FLOPs metadata on args."""
+ args.count_vision_model_flops = True
+ args.vision_flops_variant = "llava_next"
+ # ... set dimension fields for FLOPs calculation
+
+def build_model(args, language_config, vision_config, **kwargs):
+ """(Required) Build and return the complete model instance."""
+ from .model import LlavaNextModel
+ from .specs import get_llava_next_language_spec
+
+ language_spec = get_llava_next_language_spec(
+ config=language_config,
+ vp_stage=kwargs.get("vp_stage", None),
+ pp_rank=None,
+ )
+ return LlavaNextModel(
+ language_config=language_config,
+ language_spec=language_spec,
+ vision_config=vision_config,
+ # ... model-specific args
+ )
+```
+
+### Step 3 — Register in `MODEL_REGISTRY`
+
+Add an entry in `models/__init__.py`:
+
+```python
+from multimodal_dev.models.llava_next.configuration import (
+ get_llava_next_vision_config,
+)
+from multimodal_dev.models.llava_next.factory import (
+ build_model as _build_llava_next_model,
+ post_language_config as _llava_next_post_language_config,
+ set_vision_flops_metadata as _llava_next_vision_flops,
+)
+
+MODEL_REGISTRY["llava_next"] = {
+ "model_factory_fn": _build_llava_next_model, # required
+ "vision_config_fn": get_llava_next_vision_config, # required
+ "post_language_config_fn": _llava_next_post_language_config, # optional
+ "vision_flops_fn": _llava_next_vision_flops, # optional
+ "dataset_providers": { # optional
+ "mock": "multimodal_dev.data.llava_mock.train_valid_test_datasets_provider",
+ },
+}
+```
+
+### Step 4 — (Optional) Add a dataset provider
+
+Create a dataset module under `data/` if the model needs custom data
+preprocessing. The provider function signature is:
+
+```python
+def train_valid_test_datasets_provider(train_val_test_num_samples):
+ """Return (train_dataset, val_dataset, test_dataset)."""
+ ...
+```
+
+Register it in the `dataset_providers` dict of the registry entry.
+Providers can be either direct callables or dotted import path strings
+(resolved lazily at runtime).
+
+### Step 5 — Launch
+
+```bash
+torchrun --nproc_per_node=8 multimodal_dev/pretrain_multimodal.py \
+ --model-arch llava_next \
+ --dataset-provider mock \
+ ...
+```
+
+## Registry Entry Reference
+
+| Field | Required | Signature |
+|-------|----------|-----------|
+| `model_factory_fn` | Yes | `(args, language_config, vision_config, **kwargs) -> MegatronModule` |
+| `vision_config_fn` | Yes | `(num_layers_override=None) -> TransformerConfig` |
+| `post_language_config_fn` | No | `(language_config, args) -> None` |
+| `vision_flops_fn` | No | `(args, language_config, vision_config) -> None` |
+| `dataset_providers` | No | `Dict[str, str \| callable]` |
diff --git a/examples/multimodal_dev/__init__.py b/examples/multimodal_dev/__init__.py
new file mode 100644
index 00000000000..e76ed74857b
--- /dev/null
+++ b/examples/multimodal_dev/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
diff --git a/examples/multimodal_dev/arguments.py b/examples/multimodal_dev/arguments.py
new file mode 100644
index 00000000000..16d39c82857
--- /dev/null
+++ b/examples/multimodal_dev/arguments.py
@@ -0,0 +1,100 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""Extra CLI arguments for multimodal_dev standalone training."""
+
+
+def add_multimodal_args(parser):
+ """Add multimodal-specific arguments to the Megatron argument parser."""
+ group = parser.add_argument_group(
+ "Multimodal", "Multimodal model arguments",
+ )
+
+ group.add_argument(
+ "--model-arch",
+ type=str,
+ default="qwen35_vl",
+ help="Model architecture. Available: qwen35_vl",
+ )
+ group.add_argument(
+ "--model-variant",
+ type=str,
+ default="proxy",
+ help="Model variant (size). E.g. proxy, 9b, 397b_a17b",
+ )
+ group.add_argument(
+ "--dataset-provider",
+ type=str,
+ default="mock",
+ help="Dataset provider: mock",
+ )
+ group.add_argument(
+ "--image-token-id",
+ type=int,
+ default=248056,
+ help="Token ID for image placeholder tokens",
+ )
+ group.add_argument(
+ "--image-size",
+ type=int,
+ default=224,
+ help="Image size (height and width) for mock data",
+ )
+ group.add_argument(
+ "--total-seq-length",
+ type=int,
+ default=1024,
+ help="Total sequence length for mock data",
+ )
+ group.add_argument(
+ "--image-seq-length",
+ type=int,
+ default=256,
+ help="Number of image tokens in mock data",
+ )
+ group.add_argument(
+ "--vision-num-layers",
+ type=int,
+ default=None,
+ help=(
+ "Override for vision backbone depth. "
+ "Useful for proxy perf runs."
+ ),
+ )
+ group.add_argument(
+ "--hf-processor-path",
+ type=str,
+ default=None,
+ help=(
+ "HuggingFace processor path for real VLM datasets "
+ "(e.g. Qwen/Qwen2.5-VL-7B-Instruct)"
+ ),
+ )
+ group.add_argument(
+ "--recompute-vision",
+ action="store_true",
+ default=False,
+ help=(
+ "Enable full activation recomputation for vision encoder layers. "
+ "Uses uniform method and recomputes every layer. "
+ "Independent of the decoder --recompute-* flags."
+ ),
+ )
+ group.add_argument(
+ "--use-packed-sequence",
+ action="store_true",
+ default=False,
+ help=(
+ "Pack variable-length sequences into THD format to eliminate "
+ "padding waste."
+ ),
+ )
+ group.add_argument(
+ "--use-vanilla-collate-fn",
+ action="store_true",
+ default=False,
+ help=(
+ "Use vanilla collate function to collate the data."
+ ),
+ )
+
+ return parser
diff --git a/examples/multimodal_dev/data/__init__.py b/examples/multimodal_dev/data/__init__.py
new file mode 100644
index 00000000000..e76ed74857b
--- /dev/null
+++ b/examples/multimodal_dev/data/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
diff --git a/examples/multimodal_dev/data/mock.py b/examples/multimodal_dev/data/mock.py
new file mode 100644
index 00000000000..2daac8e4eba
--- /dev/null
+++ b/examples/multimodal_dev/data/mock.py
@@ -0,0 +1,192 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""Mock dataset for multimodal_dev end-to-end testing.
+
+Generates synthetic image + text data. Each sample has random text
+tokens with image-token placeholders, random pixel values sized for the
+vision encoder, 3D MRoPE position IDs, and shifted labels.
+"""
+
+import torch
+from torch.utils.data import Dataset
+
+from examples.multimodal_dev.models.qwen35_vl.configuration import (
+ QWEN35_VL_IMAGE_TOKEN_ID,
+ QWEN35_VL_VIDEO_TOKEN_ID,
+ QWEN35_VL_VISION_START_TOKEN_ID,
+)
+from examples.multimodal_dev.models.qwen35_vl.mrope import get_rope_index
+
+
+class MockQwen35VLDataset(Dataset):
+ """Synthetic Qwen3.5-VL training samples.
+
+ Args:
+ num_samples: Number of samples.
+ seq_length: Total sequence length (text + image tokens).
+ image_seq_length: Number of image tokens per sample.
+ vocab_size: Vocabulary size for random text tokens.
+ image_token_id: Token ID for image placeholders.
+ video_token_id: Token ID for video placeholders.
+ vision_start_token_id: Token ID marking start of a vision region.
+ image_size: Image height and width in pixels.
+ patch_size: Spatial patch size.
+ temporal_patch_size: Temporal patch size.
+ spatial_merge_size: Spatial merge factor.
+ """
+
+ def __init__(
+ self,
+ num_samples: int = 1000,
+ seq_length: int = 1024,
+ image_seq_length: int = 256,
+ vocab_size: int = 248320,
+ image_token_id: int = QWEN35_VL_IMAGE_TOKEN_ID,
+ video_token_id: int = QWEN35_VL_VIDEO_TOKEN_ID,
+ vision_start_token_id: int = QWEN35_VL_VISION_START_TOKEN_ID,
+ image_size: int = 224,
+ patch_size: int = 16,
+ temporal_patch_size: int = 2,
+ spatial_merge_size: int = 2,
+ ):
+ self.num_samples = num_samples
+ self.seq_length = seq_length
+ self.vocab_size = vocab_size
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ self.vision_start_token_id = vision_start_token_id
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.spatial_merge_size = spatial_merge_size
+
+ h_patches = image_size // patch_size
+ w_patches = image_size // patch_size
+ t_patches = temporal_patch_size
+ self.grid_thw = torch.tensor([[t_patches, h_patches, w_patches]])
+
+ self.num_merged_tokens = (
+ t_patches
+ * (h_patches // spatial_merge_size)
+ * (w_patches // spatial_merge_size)
+ )
+ self.image_seq_length = min(
+ image_seq_length, self.num_merged_tokens,
+ )
+ self.total_patches = t_patches * h_patches * w_patches
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, idx):
+ # Reserve 1 slot for the vision_start sentinel before image tokens.
+ text_length = self.seq_length - self.image_seq_length - 1
+ text_tokens = torch.randint(
+ 1, self.vocab_size, (text_length,), dtype=torch.long,
+ )
+ special_ids = {
+ self.image_token_id,
+ self.video_token_id,
+ self.vision_start_token_id,
+ }
+ for sid in special_ids:
+ text_tokens[text_tokens == sid] = 1
+
+ prefix_len = text_length // 2
+ suffix_len = text_length - prefix_len
+ input_ids = torch.cat([
+ text_tokens[:prefix_len],
+ torch.tensor(
+ [self.vision_start_token_id], dtype=torch.long,
+ ),
+ torch.full(
+ (self.image_seq_length,),
+ self.image_token_id,
+ dtype=torch.long,
+ ),
+ text_tokens[prefix_len: prefix_len + suffix_len],
+ ])
+
+ labels = input_ids.clone()
+ labels[:-1] = input_ids[1:]
+ labels[-1] = 0
+
+ loss_mask = (input_ids != self.image_token_id).float()
+ loss_mask[-1] = 0
+
+ pixel_dim = (
+ 3
+ * self.temporal_patch_size
+ * self.patch_size
+ * self.patch_size
+ )
+ pixel_values = torch.randn(self.total_patches, pixel_dim)
+
+ image_grid_thw = self.grid_thw.clone()
+
+ position_ids, _ = get_rope_index(
+ spatial_merge_size=self.spatial_merge_size,
+ image_token_id=self.image_token_id,
+ video_token_id=self.video_token_id,
+ vision_start_token_id=self.vision_start_token_id,
+ input_ids=input_ids.unsqueeze(0),
+ image_grid_thw=image_grid_thw,
+ )
+ position_ids = position_ids.squeeze(1)
+
+ return {
+ "input_ids": input_ids,
+ "labels": labels,
+ "loss_mask": loss_mask,
+ "cu_seqlens": torch.tensor([0, self.seq_length], dtype=torch.int32),
+ "cu_seqlens_padded": torch.tensor(
+ [0, self.seq_length], dtype=torch.int32,
+ ),
+ "max_seqlen": torch.tensor(self.seq_length, dtype=torch.int32),
+ "position_ids": position_ids,
+ "pixel_values": pixel_values,
+ "image_grid_thw": image_grid_thw,
+ }
+
+
+def mock_collate_fn(batch):
+ """Collate: handles position_ids ``[3, S]`` stacking."""
+ result = {}
+ keys = batch[0].keys()
+ for key in keys:
+ tensors = [sample[key] for sample in batch]
+ if key == "position_ids":
+ result[key] = torch.stack(tensors, dim=1)
+ elif key == "image_grid_thw":
+ result[key] = torch.cat(tensors, dim=0)
+ elif key == "pixel_values":
+ result[key] = torch.cat(tensors, dim=0)
+ else:
+ result[key] = torch.stack(tensors, dim=0)
+ return result
+
+
+def train_valid_test_datasets_provider(train_val_test_num_samples):
+ """Provide mock train / val / test datasets."""
+ from megatron.training import get_args
+
+ args = get_args()
+ kwargs = dict(
+ seq_length=getattr(args, "total_seq_length", 1024),
+ image_seq_length=getattr(args, "image_seq_length", 256),
+ vocab_size=getattr(args, "padded_vocab_size", 248320),
+ image_token_id=getattr(args, "image_token_id", 248056),
+ image_size=getattr(args, "image_size", 224),
+ )
+
+ train_ds = MockQwen35VLDataset(
+ num_samples=train_val_test_num_samples[0], **kwargs,
+ )
+ val_ds = MockQwen35VLDataset(
+ num_samples=train_val_test_num_samples[1], **kwargs,
+ )
+ test_ds = MockQwen35VLDataset(
+ num_samples=train_val_test_num_samples[2], **kwargs,
+ )
+
+ return train_ds, val_ds, test_ds
diff --git a/examples/multimodal_dev/data/vlm_dataset.py b/examples/multimodal_dev/data/vlm_dataset.py
new file mode 100644
index 00000000000..47a323d61ad
--- /dev/null
+++ b/examples/multimodal_dev/data/vlm_dataset.py
@@ -0,0 +1,249 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""Simple VLM dataset for multimodal_dev training.
+
+Single-turn image-text dataset using a HuggingFace ``AutoProcessor`` for
+tokenization and image preprocessing. Currently supports CORD-V2 (receipt
+OCR). No multi-turn support — each sample is one image + question →
+answer pair.
+
+Because Megatron's DataLoader uses ``default_collate`` (no custom collate
+function), all images are resized to a fixed resolution so that
+``pixel_values`` has a consistent shape across samples.
+
+Usage::
+
+ torchrun ... pretrain_multimodal.py \\
+ --model-arch qwen35_vl --dataset-provider cord_v2 \\
+ --hf-processor-path Qwen/Qwen2.5-VL-7B-Instruct \\
+ --image-size 448 --seq-length 2048
+"""
+
+import json
+import logging
+import random
+from typing import Dict, List, Optional
+
+import torch
+from torch.utils.data import Dataset
+
+logger = logging.getLogger(__name__)
+
+
+# ---------------------------------------------------------------------------
+# CORD-V2 helpers
+# ---------------------------------------------------------------------------
+
+def _json2token(obj, sort_json_key=True):
+ """Convert a JSON object to a token-sequence string (Donut format)."""
+ if isinstance(obj, dict):
+ if len(obj) == 1 and "text_sequence" in obj:
+ return obj["text_sequence"]
+ output = ""
+ keys = sorted(obj.keys(), reverse=True) if sort_json_key else obj.keys()
+ for k in keys:
+ output += f"" + _json2token(obj[k], sort_json_key) + f""
+ return output
+ if isinstance(obj, list):
+ return "".join(_json2token(item, sort_json_key) for item in obj)
+ return str(obj)
+
+
+def load_cord_v2(split="train"):
+ """Load CORD-V2 and return a list of ``{image, question, answer}`` dicts."""
+ from datasets import load_dataset
+
+ ds = load_dataset("naver-clova-ix/cord-v2", split=split)
+ rng = random.Random(42)
+ examples = []
+ for ex in ds:
+ gt = json.loads(ex["ground_truth"])
+ gt_jsons = gt.get("gt_parses") or [gt["gt_parse"]]
+ text = rng.choice(
+ [_json2token(g, sort_json_key=True) for g in gt_jsons]
+ )
+ examples.append(
+ {"image": ex["image"], "question": "Describe this image.", "answer": text}
+ )
+ return examples
+
+
+# ---------------------------------------------------------------------------
+# Dataset
+# ---------------------------------------------------------------------------
+
+class CordV2VLMDataset(Dataset):
+ """Single-turn VLM dataset backed by CORD-V2.
+
+ Each sample is tokenized by the HF ``AutoProcessor`` and resized to a
+ fixed resolution so that ``pixel_values`` has consistent shape across
+ samples (required by ``default_collate``).
+
+ Args:
+ examples: Output of :func:`load_cord_v2`.
+ processor: ``AutoProcessor`` instance.
+ seq_length: Pad / truncate ``input_ids`` to this length.
+ image_size: Resize all images to ``(image_size, image_size)``.
+ image_token_id: Token ID for image placeholders.
+ target_length: Virtual dataset length (repeats examples if needed).
+
+ NOTE:
+ For qwen3.5 vl processor, an example is below.
+ - For images, the processor duplicates the frame so T=2, which makes the 3D conv behave exactly like a 2D conv on one image, to support both image and video inputs.
+ - image_array, 448 * 448 * 3
+ - pixel_values.shape (28 * 28, 1536 = 16 * 16 * 3 * 2)
+ - image_grid_thw, 1, 28, 28 (H/16, W/16)
+ """
+
+ def __init__(
+ self,
+ examples: List[Dict],
+ processor,
+ seq_length: int = 2048,
+ image_size: int = 448,
+ image_token_id: Optional[int] = None,
+ target_length: Optional[int] = None,
+ ):
+ self.examples = examples
+ self.processor = processor
+ self.seq_length = seq_length
+ self.image_size = (image_size, image_size)
+ self._length = target_length if target_length else len(examples)
+ tok = processor.tokenizer
+ self.pad_token_id = tok.pad_token_id if tok.pad_token_id is not None else 0
+
+ # Resolve image token ID
+ if image_token_id is not None:
+ self.image_token_id = image_token_id
+ else:
+ vocab = tok.get_vocab()
+ for candidate in ("<|image_pad|>", "<|placeholder|>"):
+ if candidate in vocab:
+ self.image_token_id = vocab[candidate]
+ break
+ else:
+ self.image_token_id = None
+
+ def __len__(self) -> int:
+ return self._length
+
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
+ example = self.examples[idx % len(self.examples)]
+
+ # Fixed-resolution image
+ image = example["image"].convert("RGB").resize(self.image_size)
+
+ # Build single-turn conversation
+ conversation = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image"},
+ {"type": "text", "text": example["question"]},
+ ],
+ },
+ {
+ "role": "assistant",
+ "content": example["answer"],
+ },
+ ]
+
+ # Tokenize + extract pixel values via HF processor
+ text = self.processor.apply_chat_template(
+ conversation, tokenize=False, add_generation_prompt=False,
+ )
+ batch = self.processor(
+ text=[text], images=[image], return_tensors="pt",
+ )
+
+ input_ids = batch["input_ids"].squeeze(0)
+ pixel_values = batch["pixel_values"] # [total_patches, pixel_dim]
+ image_grid_thw = batch["image_grid_thw"] # [1, 3]
+
+ # Labels: shifted next-token prediction targets
+ labels = input_ids.clone()
+ labels[:-1] = input_ids[1:]
+ labels[-1] = -100
+
+ # Loss mask: 1 for trainable positions, 0 for padding / image tokens
+ loss_mask = torch.ones_like(input_ids, dtype=torch.float32)
+ loss_mask[input_ids == self.pad_token_id] = 0.0
+ if self.image_token_id is not None:
+ loss_mask[input_ids == self.image_token_id] = 0.0
+ loss_mask[-1] = 0.0
+ labels[loss_mask == 0] = -100
+
+ return {
+ "input_ids": input_ids,
+ "labels": labels,
+ "loss_mask": loss_mask,
+ "pixel_values": pixel_values,
+ "image_grid_thw": image_grid_thw,
+ }
+
+
+# ---------------------------------------------------------------------------
+# Megatron dataset provider interface
+# ---------------------------------------------------------------------------
+
+def train_valid_test_datasets_provider(train_val_test_num_samples):
+ """Provide CORD-V2 train / val / test datasets.
+
+ Requires ``--hf-processor-path`` to point to a HuggingFace VL model
+ (e.g. ``Qwen/Qwen2.5-VL-7B-Instruct``) whose processor handles
+ tokenization and image preprocessing.
+ """
+ from transformers import AutoProcessor
+
+ from megatron.training import get_args
+
+ args = get_args()
+
+ processor_path = getattr(args, "hf_processor_path", None)
+ if processor_path is None:
+ raise ValueError(
+ "cord_v2 dataset requires --hf-processor-path "
+ "(e.g. Qwen/Qwen2.5-VL-7B-Instruct)"
+ )
+ processor = AutoProcessor.from_pretrained(
+ processor_path, trust_remote_code=True,
+ )
+
+ seq_length = (
+ getattr(args, "total_seq_length", None)
+ or getattr(args, "seq_length", 2048)
+ )
+ image_size = getattr(args, "image_size", 448)
+ image_token_id = getattr(args, "image_token_id", None)
+
+ # Load real data
+ train_examples = load_cord_v2(split="train")
+ val_examples = load_cord_v2(split="validation")
+ test_examples = load_cord_v2(split="test")
+
+ def _make(examples, num_samples):
+ return CordV2VLMDataset(
+ examples=examples,
+ processor=processor,
+ seq_length=seq_length,
+ image_size=image_size,
+ image_token_id=image_token_id,
+ target_length=num_samples,
+ )
+
+ train_ds = _make(train_examples, train_val_test_num_samples[0])
+ val_ds = _make(val_examples, max(train_val_test_num_samples[1], 1))
+ test_ds = _make(test_examples, max(train_val_test_num_samples[2], 1))
+
+ return train_ds, val_ds, test_ds
+
+
+if __name__ == "__main__":
+ from transformers import AutoProcessor
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen3.5-35B-A3B", trust_remote_code=True)
+ examples = load_cord_v2(split="train")
+ dataset = CordV2VLMDataset(
+ examples=examples,
+ processor=processor,
+ )
+ print(dataset[0])
diff --git a/examples/multimodal_dev/forward_step.py b/examples/multimodal_dev/forward_step.py
new file mode 100644
index 00000000000..4ee7d688a2b
--- /dev/null
+++ b/examples/multimodal_dev/forward_step.py
@@ -0,0 +1,446 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""Forward step, TP broadcast, and loss for multimodal_dev training."""
+
+import math
+from functools import partial
+from itertools import accumulate
+from typing import Any, Dict, Iterator
+
+import torch
+import torch.nn.functional as F
+
+from megatron.core import mpu
+from megatron.core.packed_seq_params import PackedSeqParams
+from megatron.core.parallel_state import (
+ get_context_parallel_world_size,
+ get_tensor_model_parallel_group,
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_src_rank,
+)
+from megatron.training import get_args
+
+# -------------------------------------------------------------------
+# dtype <-> int mapping for cross-rank broadcast
+# -------------------------------------------------------------------
+
+_DTYPE_MAP = {
+ torch.float32: 0,
+ torch.float16: 1,
+ torch.bfloat16: 2,
+ torch.int64: 3,
+ torch.int32: 4,
+ torch.bool: 5,
+}
+_ID_MAP = {v: k for k, v in _DTYPE_MAP.items()}
+
+
+def _dtype_to_id(dtype):
+ return _DTYPE_MAP.get(dtype, 0)
+
+
+def _id_to_dtype(id_val):
+ return _ID_MAP.get(id_val, torch.float32)
+
+
+# -------------------------------------------------------------------
+# Tensor broadcast helper
+# -------------------------------------------------------------------
+
+def _broadcast_tensor(tensor, src, group, device):
+ """Broadcast a single tensor from *src* to all ranks in *group*."""
+ ndim = torch.tensor(
+ [len(tensor.shape) if tensor is not None else 0],
+ dtype=torch.long,
+ device=device,
+ )
+ torch.distributed.broadcast(ndim, src, group=group)
+
+ if ndim.item() == 0:
+ return None
+
+ if tensor is not None:
+ shape_tensor = torch.tensor(
+ list(tensor.shape), dtype=torch.long, device=device,
+ )
+ dtype_id = torch.tensor(
+ [_dtype_to_id(tensor.dtype)],
+ dtype=torch.long,
+ device=device,
+ )
+ else:
+ shape_tensor = torch.zeros(
+ ndim.item(), dtype=torch.long, device=device,
+ )
+ dtype_id = torch.zeros(1, dtype=torch.long, device=device)
+
+ torch.distributed.broadcast(shape_tensor, src, group=group)
+ torch.distributed.broadcast(dtype_id, src, group=group)
+
+ dtype = _id_to_dtype(dtype_id.item())
+ shape = tuple(shape_tensor.tolist())
+
+ if tensor is None:
+ tensor = torch.empty(shape, dtype=dtype, device=device)
+ torch.distributed.broadcast(tensor, src, group=group)
+ return tensor
+
+
+# -------------------------------------------------------------------
+# Batch broadcast across TP ranks
+# -------------------------------------------------------------------
+
+def broadcast_data_batch(data, device="cuda"):
+ """Broadcast a data-batch dict from TP rank 0 to all TP ranks."""
+ src = get_tensor_model_parallel_src_rank()
+ group = get_tensor_model_parallel_group()
+
+ if data is None:
+ data = {}
+
+ if get_tensor_model_parallel_rank() == 0:
+ keys = list(data.keys())
+ key_str = ",".join(keys)
+ key_bytes = key_str.encode("utf-8")
+ key_len = torch.tensor(
+ [len(key_bytes)], dtype=torch.long, device=device,
+ )
+ else:
+ key_len = torch.zeros(1, dtype=torch.long, device=device)
+ keys = []
+
+ torch.distributed.broadcast(key_len, src, group=group)
+
+ if get_tensor_model_parallel_rank() == 0:
+ key_tensor = torch.tensor(
+ list(key_bytes), dtype=torch.uint8, device=device,
+ )
+ else:
+ key_tensor = torch.zeros(
+ key_len.item(), dtype=torch.uint8, device=device,
+ )
+
+ torch.distributed.broadcast(key_tensor, src, group=group)
+
+ if get_tensor_model_parallel_rank() != 0:
+ key_str = bytes(key_tensor.cpu().tolist()).decode("utf-8")
+ keys = key_str.split(",") if key_str else []
+
+ result = {}
+ for key in keys:
+ tensor = data.get(key, None) if data else None
+ if tensor is not None and isinstance(tensor, torch.Tensor):
+ tensor = tensor.to(device)
+ result[key] = _broadcast_tensor(
+ tensor if isinstance(tensor, torch.Tensor) else None,
+ src, group, device,
+ )
+
+ return result
+
+
+# -------------------------------------------------------------------
+# THD (packed sequence) helpers
+# -------------------------------------------------------------------
+
+def _build_packed_seq_params(
+ seq_lengths: torch.Tensor, device: torch.device,
+) -> PackedSeqParams:
+ """Build ``PackedSeqParams`` from per-sample valid sequence lengths.
+
+ Args:
+ seq_lengths: ``[B]`` valid token counts per sample.
+ device: Target device for cu_seqlens tensors.
+
+ Returns:
+ A ``PackedSeqParams`` instance with ``qkv_format='thd'``.
+ """
+ if not isinstance(seq_lengths, torch.Tensor):
+ seq_lengths = torch.tensor(seq_lengths)
+ lengths_t = seq_lengths.to(device=device, dtype=torch.int32)
+ cu_seqlens = torch.zeros(
+ lengths_t.numel() + 1, dtype=torch.int32, device=device,
+ )
+ torch.cumsum(lengths_t, dim=0, out=cu_seqlens[1:])
+ max_seqlen = int(lengths_t.max().item())
+ return _build_packed_seq_params_from_cu_seqlens(
+ cu_seqlens=cu_seqlens, max_seqlen=max_seqlen,
+ )
+
+
+def _build_packed_seq_params_from_cu_seqlens(
+ cu_seqlens: torch.Tensor, max_seqlen: int,
+) -> PackedSeqParams:
+ """Build ``PackedSeqParams`` from packed cumulative sequence lengths.
+
+ ``cu_seqlens`` must already be on the target compute device.
+ """
+ cs = cu_seqlens.to(dtype=torch.int32)
+ total_tokens = int(cs[-1].item())
+ return PackedSeqParams(
+ cu_seqlens_q=cs,
+ cu_seqlens_kv=cs,
+ cu_seqlens_q_padded=cs,
+ cu_seqlens_kv_padded=cs,
+ max_seqlen_q=max_seqlen,
+ max_seqlen_kv=max_seqlen,
+ qkv_format='thd',
+ total_tokens=total_tokens,
+ )
+
+
+
+
+def pack_or_pad_batch(batch: list[Dict[str, Any]], use_packed_sequence: bool=False, seq_length: int=None, device = "cuda") -> list[Dict[str, Any]]:
+ """Pack or pad a ``[B, S]`` batch into ``[1, T]`` THD format."""
+ tp_size = mpu.get_tensor_model_parallel_world_size()
+ cp_size = mpu.get_context_parallel_world_size()
+
+ # SP is an explicit runtime option; TP>1 does not imply SP is enabled.
+ try:
+ has_sp = bool(get_args().sequence_parallel)
+ except Exception:
+ has_sp = False
+
+ if cp_size > 1:
+ divisible_by = (tp_size * cp_size * 2) if has_sp else (cp_size * 2)
+ else:
+ divisible_by = tp_size if has_sp else 1
+ # NOTE: don't consider fp8 padding now
+
+ if use_packed_sequence:
+ input_ids_list, labels_list, loss_mask_list, pixel_values_list, image_grid_thw_list = [], [], [], [], []
+ seqlens_list, seqlens_padded_list = [], []
+
+ # NOTE: for attention_mask, we don't use attention mask
+ # for position_ids, let model handle it itself
+ # we don't cut input id, althrough it exceeds seq_length
+
+ packed_batch = dict()
+
+ for sample in batch:
+ seqlen = sample["input_ids"].shape[0]
+ assert sample["labels"].shape == sample["input_ids"].shape == sample["loss_mask"].shape, "labels, input_ids, and loss_mask must have the same shape"
+ target_len = math.ceil(seqlen / divisible_by) * divisible_by
+ input_ids = F.pad(sample["input_ids"], (0, target_len - seqlen), value=0)
+ labels = F.pad(sample["labels"], (0, target_len - seqlen), value=-100)
+ loss_mask = F.pad(sample["loss_mask"], (0, target_len - seqlen), value=0)
+
+ input_ids_list.append(input_ids)
+ labels_list.append(labels)
+ loss_mask_list.append(loss_mask)
+ seqlens_list.append(seqlen)
+ seqlens_padded_list.append(target_len)
+ pixel_values_list.append(sample["pixel_values"])
+ image_grid_thw_list.append(sample["image_grid_thw"])
+
+ cu_seqlens = list(accumulate(seqlens_list, initial=0))
+ cu_seqlens_padded = list(accumulate(seqlens_padded_list, initial=0))
+
+ # padding_mask: True at collate-padded positions within each packed
+ # sample. Real tokens occupy [cu_seqlens_padded[i], +seqlens_list[i]);
+ # the tail up to cu_seqlens_padded[i+1] is padding. Consumed by MoE
+ # routing in megatron.core to exclude padded tokens from aux loss,
+ # z-loss, and expert-bias accumulation.
+ total_tokens_padded = cu_seqlens_padded[-1]
+ padding_mask_thd = torch.zeros(total_tokens_padded, dtype=torch.bool)
+ for i, real_seqlen in enumerate(seqlens_list):
+ pad_start = cu_seqlens_padded[i] + real_seqlen
+ pad_end = cu_seqlens_padded[i + 1]
+ if pad_end > pad_start:
+ padding_mask_thd[pad_start:pad_end] = True
+
+ packed_batch["input_ids"] = torch.concat(input_ids_list, dim=0).unsqueeze(0)
+ packed_batch["labels"] = torch.concat(labels_list, dim=0).unsqueeze(0)
+ packed_batch["loss_mask"] = torch.concat(loss_mask_list, dim=0).unsqueeze(0)
+ packed_batch["padding_mask"] = padding_mask_thd.unsqueeze(0)
+
+ # TODO, maybe pixel_values's seqlens needs to be recorded.
+ packed_batch["pixel_values"] = torch.concat(pixel_values_list)
+ packed_batch["image_grid_thw"] = torch.concat(image_grid_thw_list)
+
+ # broadcast to all tp ranks
+ packed_batch = broadcast_data_batch(packed_batch, device=device)
+
+ packed_batch["packed_seq_params"] = PackedSeqParams(
+ qkv_format='thd',
+ cu_seqlens_q=torch.tensor(cu_seqlens, dtype=torch.int32, device=device),
+ cu_seqlens_kv=torch.tensor(cu_seqlens, dtype=torch.int32, device=device),
+ cu_seqlens_q_padded=torch.tensor(cu_seqlens_padded, dtype=torch.int32, device=device),
+ cu_seqlens_kv_padded=torch.tensor(cu_seqlens_padded, dtype=torch.int32, device=device),
+ max_seqlen_q=max(seqlens_padded_list),
+ max_seqlen_kv=max(seqlens_padded_list),
+ total_tokens=cu_seqlens_padded[-1],
+ )
+ return packed_batch
+ else:
+ assert seq_length is not None, "seq_length must be provided when use_packed_sequence is False"
+ max_seqlens = max([x["input_ids"].shape[0] for x in batch])
+ target_seqlens = min(max_seqlens, seq_length)
+ # Round target seqlen up to the parallelism alignment factor so the
+ # batched tensor is divisible for CP (+SP) splitting downstream.
+ if divisible_by > 1:
+ target_seqlens = math.ceil(target_seqlens / divisible_by) * divisible_by
+ padded_batch = dict()
+ # Capture real lengths before in-place padding so we can build a
+ # padding_mask for MoE routing (True at collate-padded positions).
+ real_seqlens = [s["input_ids"].shape[0] for s in batch]
+
+ for sample in batch:
+ sample["input_ids"] = F.pad(sample["input_ids"], (0, target_seqlens - sample["input_ids"].shape[0]), value=0)
+ sample["labels"] = F.pad(sample["labels"], (0, target_seqlens - sample["labels"].shape[0]), value=-100)
+ sample["loss_mask"] = F.pad(sample["loss_mask"], (0, target_seqlens - sample["loss_mask"].shape[0]), value=0)
+
+ padded_batch["input_ids"] = torch.concat([x["input_ids"].unsqueeze(0) for x in batch], dim=0)
+ padded_batch["labels"] = torch.concat([x["labels"].unsqueeze(0) for x in batch], dim=0)
+ padded_batch["loss_mask"] = torch.concat([x["loss_mask"].unsqueeze(0) for x in batch], dim=0)
+ positions = torch.arange(target_seqlens).unsqueeze(0)
+ padded_batch["padding_mask"] = positions >= torch.tensor(real_seqlens).unsqueeze(1)
+ padded_batch["pixel_values"] = torch.concat([x["pixel_values"] for x in batch])
+ padded_batch["image_grid_thw"] = torch.concat([x["image_grid_thw"] for x in batch])
+ # broadcast to all tp ranks
+ padded_batch = broadcast_data_batch(padded_batch, device=device)
+ return padded_batch
+
+
+# -------------------------------------------------------------------
+# get_batch
+# -------------------------------------------------------------------
+
+def get_batch(data_iterator: Iterator[Dict[str, Any]]):
+ """Get a batch from *data_iterator* and broadcast across TP ranks."""
+ device = "cuda"
+ args = get_args()
+
+ if get_tensor_model_parallel_rank() == 0:
+ try:
+ data = next(data_iterator)
+ has_data = torch.tensor(
+ [1], dtype=torch.uint8, device=device,
+ )
+ except StopIteration:
+ has_data = torch.tensor(
+ [0], dtype=torch.uint8, device=device,
+ )
+ data = None
+ else:
+ has_data = torch.empty(1, dtype=torch.uint8, device=device)
+ data = None
+
+ src = get_tensor_model_parallel_src_rank()
+ group = get_tensor_model_parallel_group()
+ torch.distributed.broadcast(has_data, src, group=group)
+
+ if has_data.item() == 0:
+ return None
+
+ # Because broadcast will not broadcast packed_seq_params, we move it into pack_or_pad_batch
+ batch = pack_or_pad_batch(data, args.use_packed_sequence, args.seq_length, device=device)
+
+ # Fix shapes produced by default_collate.
+ if "position_ids" in batch and batch["position_ids"] is not None:
+ p = batch["position_ids"]
+ if p.dim() == 3 and p.shape[1] == 3:
+ batch["position_ids"] = p.permute(1, 0, 2).contiguous()
+
+ if "pixel_values" in batch and batch["pixel_values"] is not None:
+ pv = batch["pixel_values"]
+ if pv.dim() == 3:
+ B, P, D = pv.shape
+ batch["pixel_values"] = pv.reshape(B * P, D)
+
+ if (
+ "image_grid_thw" in batch
+ and batch["image_grid_thw"] is not None
+ ):
+ g = batch["image_grid_thw"]
+ if g.dim() == 3:
+ batch["image_grid_thw"] = g.squeeze(1)
+
+ return batch
+
+
+# -------------------------------------------------------------------
+# Loss
+# -------------------------------------------------------------------
+
+def loss_func(loss_mask, output_tensor):
+ """Compute masked language model loss."""
+ losses = output_tensor.float()
+ loss_mask = loss_mask.contiguous().view(-1).float()
+
+ total_tokens = loss_mask.sum().clone().detach().to(torch.int)
+ total_loss = torch.sum(losses.view(-1) * loss_mask)
+ reporting_loss = torch.cat(
+ [total_loss.clone().detach().view(1), total_tokens.view(1)],
+ )
+
+ return (total_loss, total_tokens, {"lm loss": reporting_loss})
+
+
+# -------------------------------------------------------------------
+# Forward step
+# -------------------------------------------------------------------
+
+def forward_step(data_iterator, model):
+ """Forward step for multimodal_dev training."""
+ batch = get_batch(data_iterator)
+
+ if batch is None:
+ return None, None
+
+ pixel_values = batch.get("pixel_values", None)
+ if (
+ pixel_values is not None
+ and pixel_values.is_floating_point()
+ and pixel_values.dtype == torch.float32
+ ):
+ pixel_values = pixel_values.bfloat16()
+
+ # We don't provide position_ids, now. Let model handle it itself.
+ output_tensor = model(
+ input_ids=batch["input_ids"],
+ position_ids=batch.get("position_ids"),
+ attention_mask=batch.get("attention_mask", None),
+ labels=batch.get("labels", None),
+ loss_mask=batch.get("loss_mask", None),
+ padding_mask=batch.get("padding_mask", None),
+ pixel_values=pixel_values,
+ image_grid_thw=batch.get("image_grid_thw", None),
+ packed_seq_params=batch.get("packed_seq_params", None),
+ )
+
+ loss_mask = batch.get("loss_mask", None)
+ if loss_mask is None:
+ loss_mask = torch.ones_like(
+ batch["input_ids"], dtype=torch.float,
+ )
+
+ # CP-split loss_mask to match the model output (which is CP-split
+ # inside MultimodalModel.forward / Qwen35VLModel.forward).
+ # THD: use the same TE-based per-sample partition index as the model.
+ # BSHD: use the matching zigzag split.
+ cp_size = get_context_parallel_world_size()
+ if cp_size > 1:
+ from megatron.core.parallel_state import get_context_parallel_rank
+
+ from examples.multimodal_dev.models.base import (
+ _cp_split_tensor,
+ _thd_cp_partition_index,
+ )
+
+ cp_rank = get_context_parallel_rank()
+ psp = batch.get("packed_seq_params", None)
+ if psp is not None:
+ idx = _thd_cp_partition_index(
+ psp.cu_seqlens_q_padded,
+ loss_mask.shape[1], cp_size, cp_rank,
+ )
+ loss_mask = loss_mask.index_select(1, idx)
+ else:
+ loss_mask = _cp_split_tensor(
+ loss_mask, seq_dim=1,
+ cp_size=cp_size, cp_rank=cp_rank,
+ )
+
+ return output_tensor, partial(loss_func, loss_mask)
diff --git a/examples/multimodal_dev/models/__init__.py b/examples/multimodal_dev/models/__init__.py
new file mode 100644
index 00000000000..33788bb3619
--- /dev/null
+++ b/examples/multimodal_dev/models/__init__.py
@@ -0,0 +1,62 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""Model registry for multimodal_dev training.
+
+Maps ``--model-arch`` to a set of factory functions that fully encapsulate
+model-specific logic. The training entry point (``pretrain_multimodal.py``)
+remains model-agnostic — adding a new architecture only requires a new
+registry entry (and its backing module) without touching the entry point.
+
+Registry entry fields
+---------------------
+``model_factory_fn`` *(required)*
+ ``(args, language_config, vision_config, **kwargs) -> MegatronModule``
+ Builds and returns the complete model instance.
+
+``vision_config_fn`` *(required)*
+ ``(num_layers_override=None, variant=None) -> TransformerConfig``
+ Returns the vision encoder TransformerConfig.
+
+``post_language_config_fn`` *(optional)*
+ ``(language_config, args) -> None``
+ Mutates the language TransformerConfig in-place with model-specific
+ fields (e.g. ``mrope_section``).
+
+``vision_flops_fn`` *(optional)*
+ ``(args, language_config, vision_config) -> None``
+ Sets vision FLOPs metadata on ``args`` for training throughput logging.
+
+``dataset_providers`` *(optional)*
+ ``Dict[str, str | callable]``
+ Maps ``--dataset-provider`` names to callables (or dotted import paths
+ resolved lazily) with signature
+ ``(train_val_test_num_samples) -> (train_ds, val_ds, test_ds)``.
+"""
+
+from examples.multimodal_dev.models.qwen35_vl.configuration import get_qwen35_vl_vision_config
+from examples.multimodal_dev.models.qwen35_vl.factory import build_model as _build_qwen35_vl_model
+from examples.multimodal_dev.models.qwen35_vl.factory import (
+ post_language_config as _qwen35_vl_post_language_config,
+)
+from examples.multimodal_dev.models.qwen35_vl.factory import (
+ set_vision_flops_metadata as _qwen35_vl_vision_flops,
+)
+
+MODEL_REGISTRY = {
+ "qwen35_vl": {
+ "model_factory_fn": _build_qwen35_vl_model,
+ "vision_config_fn": get_qwen35_vl_vision_config,
+ "post_language_config_fn": _qwen35_vl_post_language_config,
+ "vision_flops_fn": _qwen35_vl_vision_flops,
+ "dataset_providers": {
+ "mock": (
+ "examples.multimodal_dev.data.mock"
+ ".train_valid_test_datasets_provider"
+ ),
+ "cord_v2": (
+ "examples.multimodal_dev.data.vlm_dataset"
+ ".train_valid_test_datasets_provider"
+ ),
+ },
+ },
+}
diff --git a/examples/multimodal_dev/models/base.py b/examples/multimodal_dev/models/base.py
new file mode 100644
index 00000000000..4f1ed97b284
--- /dev/null
+++ b/examples/multimodal_dev/models/base.py
@@ -0,0 +1,404 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""Base multimodal model for FSDP + EP training.
+
+Composes a vision encoder and a ``GPTModel`` language decoder. Designed
+for FSDP + EP: always builds the **full** model on every rank (no PP
+flags). PP support is only available through the MIMO ``MimoModel``
+assembly path.
+
+Subclasses override ``compute_position_ids()`` for model-specific
+position encoding (e.g. MRoPE for Qwen3.5-VL).
+"""
+
+import contextlib
+from typing import Optional
+
+import torch
+from torch import Tensor
+
+from megatron.core import parallel_state, tensor_parallel
+from megatron.core.models.gpt import GPTModel
+from megatron.core.transformer.module import MegatronModule
+from megatron.core.transformer.spec_utils import ModuleSpec
+from megatron.core.transformer.transformer_config import TransformerConfig
+
+
+def _cp_split_tensor(tensor, seq_dim, cp_size, cp_rank):
+ """Zigzag-split *tensor* along *seq_dim* for context parallelism (BSHD).
+
+ Splits the sequence into ``2 * cp_size`` equal chunks, then selects
+ chunks ``[cp_rank, 2*cp_size - cp_rank - 1]`` and concatenates them.
+ This mirrors ``megatron.core.utils.get_batch_on_this_cp_rank``.
+ """
+ S = tensor.shape[seq_dim]
+ assert S % (2 * cp_size) == 0, (
+ f"seq_len {S} not divisible by 2*cp_size={2 * cp_size}"
+ )
+ tensor = tensor.view(
+ *tensor.shape[:seq_dim],
+ 2 * cp_size,
+ S // (2 * cp_size),
+ *tensor.shape[seq_dim + 1 :],
+ )
+ index = torch.zeros(2, dtype=torch.int64, device=tensor.device)
+ index[0] = cp_rank
+ index[1] = 2 * cp_size - cp_rank - 1
+ tensor = tensor.index_select(seq_dim, index)
+ tensor = tensor.view(
+ *tensor.shape[:seq_dim],
+ -1,
+ *tensor.shape[seq_dim + 2 :],
+ )
+ return tensor
+
+
+class _NoCPGroup:
+ """Dummy size-1 process group used to bypass MRoPE's BSHD-style
+ zigzag of pre-computed THD freqs (Megatron-Core gap:
+ ``MultimodalRotaryEmbedding.forward`` lacks the ``not packed_seq``
+ skip that plain ``RotaryEmbedding`` has).
+ """
+
+ def size(self):
+ return 1
+
+ def rank(self):
+ return 0
+
+
+_NO_CP_GROUP = _NoCPGroup()
+
+# Note: reported ``mtp_1 loss`` drifts ~1.3% from the CP=1 baseline under
+# THD+CP. Megatron-Core's logging averages per-rank pre-divided ratios
+# with op=AVG, and per-rank num_tokens are unequal after MTP rolling.
+# Gradients are correct; only the *logged* value drifts.
+
+
+def _thd_cp_partition_index(cu_seqlens_padded, total_tokens, cp_size, cp_rank):
+ """Per-rank token index for THD + CP via TE's
+ ``thd_get_partitioned_indices``. Cast to int64 so the result can be
+ used directly with ``index_select`` regardless of TE's return dtype.
+ """
+ from transformer_engine.pytorch import cpp_extensions as tex
+
+ idx = tex.thd_get_partitioned_indices(
+ cu_seqlens_padded, total_tokens, cp_size, cp_rank,
+ )
+ return idx.long()
+
+
+class MultimodalModel(MegatronModule):
+ """Base class for multimodal vision-language models.
+
+ Composes a pre-constructed vision encoder and a ``GPTModel`` language
+ decoder. Designed for FSDP + EP; always builds the full model on
+ every rank.
+
+ Args:
+ language_config: ``TransformerConfig`` for the language decoder.
+ language_spec: ``ModuleSpec`` for decoder transformer layers.
+ vision_encoder: Pre-constructed vision encoder module.
+ vocab_size: Language model vocabulary size.
+ max_sequence_length: Maximum sequence length.
+ image_token_id: Token ID for image placeholder tokens.
+ position_embedding_type: Position embedding type for the decoder.
+ rotary_percent: Fraction of hidden dim for RoPE.
+ rotary_base: Base frequency for RoPE.
+ mrope_section: MRoPE channel sections.
+ mtp_block_spec: Optional MTP block spec.
+ parallel_output: Keep outputs split across TP ranks.
+ share_embeddings_and_output_weights: Tie input/output embeddings.
+ """
+
+ def __init__(
+ self,
+ language_config: TransformerConfig,
+ language_spec: ModuleSpec,
+ vision_encoder: MegatronModule,
+ vocab_size: int,
+ max_sequence_length: int,
+ image_token_id: int,
+ position_embedding_type: str = "rope",
+ rotary_percent: float = 1.0,
+ rotary_base: int = 10000,
+ mrope_section: list = None,
+ mtp_block_spec: ModuleSpec = None,
+ parallel_output: bool = True,
+ share_embeddings_and_output_weights: bool = False,
+ ):
+ super().__init__(config=language_config)
+
+ self.image_token_id = image_token_id
+
+ self.vision_model = vision_encoder
+ self.language_model = GPTModel(
+ config=language_config,
+ transformer_layer_spec=language_spec,
+ vocab_size=vocab_size,
+ max_sequence_length=max_sequence_length,
+ pre_process=True,
+ post_process=True,
+ parallel_output=parallel_output,
+ share_embeddings_and_output_weights=(
+ share_embeddings_and_output_weights
+ ),
+ position_embedding_type=position_embedding_type,
+ rotary_percent=rotary_percent,
+ rotary_base=rotary_base,
+ mtp_block_spec=mtp_block_spec,
+ )
+
+ def set_input_tensor(self, input_tensor):
+ """Route input tensors (simplified, no PP routing)."""
+ if not isinstance(input_tensor, list):
+ input_tensor = [input_tensor]
+ assert len(input_tensor) == 1
+ self.language_model.set_input_tensor(input_tensor[0])
+
+ def _scatter_vision_embeddings(
+ self,
+ input_ids: Tensor,
+ text_embeddings: Tensor,
+ vision_embeddings: Tensor,
+ ) -> Tensor:
+ """Replace image-token positions with vision embeddings.
+
+ Handles sequence parallelism (gather → scatter → re-scatter).
+
+ Args:
+ input_ids: ``[B, S]`` token IDs.
+ text_embeddings: ``[S, B, D]`` (or ``[S/TP, B, D]`` with SP).
+ vision_embeddings: ``[num_visual_tokens, D]``.
+
+ Returns:
+ Combined embeddings, same shape as *text_embeddings*.
+ """
+ sp = (
+ self.config.sequence_parallel
+ and parallel_state.get_tensor_model_parallel_world_size()
+ > 1
+ )
+
+ if sp:
+ text_embeddings = (
+ tensor_parallel.gather_from_sequence_parallel_region(
+ text_embeddings, tensor_parallel_output_grad=False,
+ )
+ )
+
+ combined = text_embeddings.transpose(0, 1).contiguous()
+ image_mask = input_ids == self.image_token_id
+ mask_expanded = image_mask.unsqueeze(-1).expand_as(combined)
+ combined = combined.masked_scatter(
+ mask_expanded, vision_embeddings,
+ )
+ combined = combined.transpose(0, 1).contiguous()
+
+ if sp:
+ combined = (
+ tensor_parallel.scatter_to_sequence_parallel_region(
+ combined,
+ )
+ )
+
+ return combined
+
+ def compute_position_ids(
+ self,
+ input_ids: Tensor,
+ image_grid_thw: Optional[Tensor] = None,
+ packed_seq_params=None,
+ ) -> Tensor:
+ """Compute position IDs. Override for MRoPE etc.
+
+ Default: simple sequential positions. ``packed_seq_params`` is
+ accepted for subclass compatibility (e.g. MRoPE in THD mode).
+ """
+ B, S = input_ids.shape
+ return (
+ torch.arange(S, device=input_ids.device)
+ .unsqueeze(0)
+ .expand(B, -1)
+ )
+
+ def _cp_split_for_forward(
+ self,
+ *,
+ decoder_input,
+ input_ids,
+ labels,
+ loss_mask,
+ attention_mask,
+ position_ids,
+ packed_seq_params,
+ padding_mask=None,
+ ):
+ """Apply CP split to model-forward inputs.
+
+ BSHD path zigzag-splits each tensor along its seq dim. THD path
+ partitions per-sample via ``tex.thd_get_partitioned_indices`` so
+ chunks line up with ``cu_seqlens_q_padded`` boundaries.
+ ``position_ids`` and ``attention_mask`` are NOT split in THD —
+ MRoPE returns full freqs and TE attention's
+ ``_apply_rotary_pos_emb_thd`` does the per-sample CP zigzag
+ itself via ``_get_thd_freqs_on_this_cp_rank``.
+ """
+ cp_size = parallel_state.get_context_parallel_world_size()
+ if cp_size <= 1:
+ return (
+ decoder_input, input_ids, labels, loss_mask,
+ attention_mask, position_ids, padding_mask,
+ )
+ cp_rank = parallel_state.get_context_parallel_rank()
+
+ if packed_seq_params is not None:
+ total_tokens = (
+ decoder_input.shape[0]
+ if decoder_input is not None
+ else input_ids.shape[1]
+ )
+ idx = _thd_cp_partition_index(
+ packed_seq_params.cu_seqlens_q_padded,
+ total_tokens, cp_size, cp_rank,
+ )
+ if decoder_input is not None:
+ decoder_input = decoder_input.index_select(0, idx)
+ if input_ids is not None:
+ input_ids = input_ids.index_select(1, idx)
+ if labels is not None:
+ labels = labels.index_select(1, idx)
+ if loss_mask is not None:
+ loss_mask = loss_mask.index_select(1, idx)
+ if padding_mask is not None:
+ padding_mask = padding_mask.index_select(1, idx)
+ else:
+ def _split(t, seq_dim):
+ return None if t is None else _cp_split_tensor(
+ t, seq_dim=seq_dim, cp_size=cp_size, cp_rank=cp_rank,
+ )
+ decoder_input = _split(decoder_input, 0)
+ input_ids = _split(input_ids, 1)
+ labels = _split(labels, 1)
+ loss_mask = _split(loss_mask, 1)
+ attention_mask = _split(attention_mask, 1)
+ padding_mask = _split(padding_mask, 1)
+
+ return (
+ decoder_input, input_ids, labels, loss_mask,
+ attention_mask, position_ids, padding_mask,
+ )
+
+ @contextlib.contextmanager
+ def _thd_mrope_no_cp_override(self, packed_seq_params):
+ """Force ``rotary_pos_emb.cp_group`` to size 1 for the wrapped
+ forward call so MRoPE returns full-length freqs in THD mode.
+ Attention then applies per-sample CP zigzag itself via
+ ``_apply_rotary_pos_emb_thd``. Done by direct mutation rather
+ than via ``packed_seq_params.cp_group`` so MTP's CP-aware roll
+ (which reads that field) still sees the real CP group.
+ """
+ mrope = (
+ getattr(self.language_model, "rotary_pos_emb", None)
+ if packed_seq_params is not None
+ and parallel_state.get_context_parallel_world_size() > 1
+ else None
+ )
+ saved = getattr(mrope, "cp_group", None) if mrope is not None else None
+ if mrope is not None:
+ mrope.cp_group = _NO_CP_GROUP
+ try:
+ yield
+ finally:
+ if mrope is not None:
+ mrope.cp_group = saved
+
+ def forward(
+ self,
+ input_ids: Tensor,
+ position_ids: Tensor,
+ attention_mask: Tensor = None,
+ labels: Tensor = None,
+ loss_mask: Tensor = None,
+ padding_mask: Tensor = None,
+ pixel_values: Tensor = None,
+ image_grid_thw: Tensor = None,
+ decoder_input: Tensor = None,
+ packed_seq_params=None,
+ **kwargs,
+ ):
+ """Forward pass.
+
+ Args:
+ input_ids: ``[B, S]`` token IDs (or ``[1, T]`` in THD mode).
+ position_ids: ``[3, B, S]`` for MRoPE or ``[B, S]``
+ (``[3, 1, T]`` / ``[1, T]`` in THD mode).
+ attention_mask: ``[B, S]`` attention mask (None in THD).
+ labels: ``[B, S]`` target token IDs (``[1, T]`` in THD).
+ loss_mask: ``[B, S]`` mask for loss (``[1, T]`` in THD).
+ padding_mask: ``[B, S]`` bool mask, True at collate-padded
+ positions (``[1, T]`` in THD). Forwarded to the language
+ decoder so MoE routing excludes padded tokens from aux
+ loss / z-loss / expert-bias accumulation. Distinct from
+ ``loss_mask``: only true padding, not SFT prompt tokens.
+ pixel_values: Preprocessed image pixels.
+ image_grid_thw: ``[num_images, 3]`` grid dimensions.
+ decoder_input: Pre-computed decoder input (skip embed).
+ packed_seq_params: ``PackedSeqParams`` for THD attention.
+
+ Returns:
+ Loss tensor (post_process=True) or hidden states.
+ """
+ if position_ids is None:
+ position_ids = self.compute_position_ids(
+ input_ids=input_ids,
+ image_grid_thw=image_grid_thw,
+ packed_seq_params=packed_seq_params,
+ )
+
+ vision_embeddings = None
+ if (
+ self.vision_model is not None
+ and pixel_values is not None
+ ):
+ vision_embeddings = self.vision_model(
+ pixel_values, image_grid_thw,
+ )
+
+ if decoder_input is None and self.language_model is not None:
+ text_embeddings = self.language_model.embedding(
+ input_ids=input_ids, position_ids=None,
+ )
+
+ if vision_embeddings is not None:
+ decoder_input = self._scatter_vision_embeddings(
+ input_ids, text_embeddings, vision_embeddings,
+ )
+ else:
+ decoder_input = text_embeddings
+
+ (
+ decoder_input, input_ids, labels, loss_mask,
+ attention_mask, position_ids, padding_mask,
+ ) = self._cp_split_for_forward(
+ decoder_input=decoder_input,
+ input_ids=input_ids,
+ labels=labels,
+ loss_mask=loss_mask,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ packed_seq_params=packed_seq_params,
+ padding_mask=padding_mask,
+ )
+
+ with self._thd_mrope_no_cp_override(packed_seq_params):
+ return self.language_model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ decoder_input=decoder_input,
+ labels=labels,
+ loss_mask=loss_mask,
+ padding_mask=padding_mask,
+ packed_seq_params=packed_seq_params,
+ )
diff --git a/examples/multimodal_dev/models/qwen35_vl/__init__.py b/examples/multimodal_dev/models/qwen35_vl/__init__.py
new file mode 100644
index 00000000000..44e82fa14bb
--- /dev/null
+++ b/examples/multimodal_dev/models/qwen35_vl/__init__.py
@@ -0,0 +1,70 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""Qwen3.5-VL model components — the single source of truth.
+
+Both the standalone ``multimodal_dev`` training path and the MIMO path
+import from here.
+"""
+
+from examples.multimodal_dev.models.qwen35_vl.configuration import (
+ MROPE_SECTION,
+ QWEN35_VL_IMAGE_TOKEN_ID,
+ QWEN35_VL_VIDEO_TOKEN_ID,
+ QWEN35_VL_VISION_END_TOKEN_ID,
+ QWEN35_VL_VISION_START_TOKEN_ID,
+ QWEN35_VL_VOCAB_SIZE,
+ ROTARY_BASE,
+ ROTARY_PERCENT,
+ VISION_KWARGS,
+ get_qwen35_vl_language_config,
+ get_qwen35_vl_vision_config,
+)
+from examples.multimodal_dev.models.qwen35_vl.factory import (
+ build_model,
+ post_language_config,
+ set_vision_flops_metadata,
+)
+from examples.multimodal_dev.models.qwen35_vl.model import Qwen35VLModel
+from examples.multimodal_dev.models.qwen35_vl.mrope import get_rope_index
+from examples.multimodal_dev.models.qwen35_vl.specs import (
+ get_qwen35_vl_language_spec,
+ get_qwen35_vl_vision_spec,
+)
+from examples.multimodal_dev.models.qwen35_vl.vision_encoder import (
+ Qwen35VLPatchEmbed,
+ Qwen35VLPatchMerger,
+ Qwen35VLVisionEncoder,
+ Qwen35VLVisionRotaryEmbedding,
+)
+
+__all__ = [
+ # Model class
+ "Qwen35VLModel",
+ # Factory functions
+ "build_model",
+ "post_language_config",
+ "set_vision_flops_metadata",
+ # Vision encoder
+ "Qwen35VLVisionEncoder",
+ "Qwen35VLPatchEmbed",
+ "Qwen35VLPatchMerger",
+ "Qwen35VLVisionRotaryEmbedding",
+ # Config helpers
+ "get_qwen35_vl_vision_config",
+ "get_qwen35_vl_language_config",
+ # Spec helpers
+ "get_qwen35_vl_language_spec",
+ "get_qwen35_vl_vision_spec",
+ # MRoPE
+ "get_rope_index",
+ # Constants
+ "QWEN35_VL_IMAGE_TOKEN_ID",
+ "QWEN35_VL_VIDEO_TOKEN_ID",
+ "QWEN35_VL_VISION_START_TOKEN_ID",
+ "QWEN35_VL_VISION_END_TOKEN_ID",
+ "QWEN35_VL_VOCAB_SIZE",
+ "ROTARY_BASE",
+ "ROTARY_PERCENT",
+ "MROPE_SECTION",
+ "VISION_KWARGS",
+]
diff --git a/examples/multimodal_dev/models/qwen35_vl/configuration.py b/examples/multimodal_dev/models/qwen35_vl/configuration.py
new file mode 100644
index 00000000000..41f6dda2524
--- /dev/null
+++ b/examples/multimodal_dev/models/qwen35_vl/configuration.py
@@ -0,0 +1,355 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""Configuration helpers for Qwen3.5-VL vision-language model.
+
+Provides TransformerConfig builders for the vision encoder and all language
+decoder variants. Both the standalone ``multimodal_dev`` training path and the
+MIMO path import from here — this is the single source of truth.
+
+Supported language variants (HuggingFace Qwen3.5 series):
+ ``0.8b`` Dense 0.8B
+ ``2b`` Dense 2B
+ ``4b`` Dense 4B
+ ``9b`` Dense 9B
+ ``27b`` Dense 27B
+ ``35b_a3b`` MoE 35B-A3B (256 experts, top-8)
+ ``122b_a10b`` MoE 122B-A10B (256 experts, top-8)
+ ``397b_a17b`` MoE 397B-A17B (512 experts, top-10)
+ ``35b_a3b_light`` Reduced 35B-A3B for testing
+ ``proxy`` Reduced proxy based on 397B for single-node testing
+"""
+
+from typing import Optional
+
+import torch
+
+from megatron.core.transformer.transformer_config import TransformerConfig
+
+# ---------------------------------------------------------------------------
+# Public constants
+# ---------------------------------------------------------------------------
+
+QWEN35_VL_IMAGE_TOKEN_ID: int = 248056
+QWEN35_VL_VIDEO_TOKEN_ID: int = 248057
+QWEN35_VL_VISION_START_TOKEN_ID: int = 248053
+QWEN35_VL_VISION_END_TOKEN_ID: int = 248054
+QWEN35_VL_VOCAB_SIZE: int = 248320
+
+ROTARY_BASE: int = 10_000_000
+ROTARY_PERCENT: float = 0.25
+MROPE_SECTION: list = [11, 11, 10]
+
+# ---------------------------------------------------------------------------
+# Vision config
+# ---------------------------------------------------------------------------
+
+VISION_KWARGS = {
+ "in_channels": 3,
+ "patch_size": 16,
+ "temporal_patch_size": 2,
+ "spatial_merge_size": 2,
+ "out_hidden_size": 3584,
+ "max_num_positions": 2304,
+}
+
+# Three distinct vision encoder architectures in the Qwen3.5 family.
+_VISION_SMALL = {
+ "num_layers": 12, "hidden_size": 768, "num_attention_heads": 12,
+ "kv_channels": 64, "ffn_hidden_size": 3072,
+}
+_VISION_MEDIUM = {
+ "num_layers": 24, "hidden_size": 1024, "num_attention_heads": 16,
+ "kv_channels": 64, "ffn_hidden_size": 4096,
+}
+_VISION_LARGE = {
+ "num_layers": 27, "hidden_size": 1152, "num_attention_heads": 16,
+ "kv_channels": 72, "ffn_hidden_size": 4304,
+}
+
+# Per-variant vision config. ``out_hidden_size`` equals the language model's
+# hidden_size and controls the merger projection output dimension.
+_VISION_VARIANT_CONFIGS = {
+ "0.8b": {**_VISION_SMALL, "out_hidden_size": 1024},
+ "2b": {**_VISION_MEDIUM, "out_hidden_size": 2048},
+ "4b": {**_VISION_MEDIUM, "out_hidden_size": 2560},
+ "9b": {**_VISION_LARGE, "out_hidden_size": 4096},
+ "27b": {**_VISION_LARGE, "out_hidden_size": 5120},
+ "35b_a3b": {**_VISION_LARGE, "out_hidden_size": 2048},
+ "122b_a10b": {**_VISION_LARGE, "out_hidden_size": 3072},
+ "397b_a17b": {**_VISION_LARGE, "out_hidden_size": 4096},
+}
+
+# Fallback for proxy/unknown variants (large ViT, generic out_hidden_size).
+_VISION_DEFAULT = {**_VISION_LARGE, "out_hidden_size": 3584}
+
+
+def get_qwen35_vl_vision_config(
+ num_layers_override: Optional[int] = None,
+ variant: Optional[str] = None,
+) -> TransformerConfig:
+ """TransformerConfig for the Qwen3.5-VL vision encoder.
+
+ Three ViT architectures are used across the family:
+ - Small (0.8b): depth 12, 768-dim, 12 heads
+ - Medium (2b, 4b): depth 24, 1024-dim, 16 heads
+ - Large (9b, 27b, MoE variants): depth 27, 1152-dim, 16 heads
+
+ Args:
+ num_layers_override: Override vision backbone depth for proxy runs.
+ variant: Language model variant name. When set, selects the
+ matching vision config from ``_VISION_VARIANT_CONFIGS`` if one
+ exists; otherwise the default large-ViT config is used.
+ """
+ vcfg = _VISION_VARIANT_CONFIGS.get(variant, _VISION_DEFAULT)
+ num_layers = vcfg["num_layers"]
+ if num_layers_override is not None:
+ num_layers = num_layers_override
+
+ return TransformerConfig(
+ num_layers=num_layers,
+ hidden_size=vcfg["hidden_size"],
+ num_attention_heads=vcfg["num_attention_heads"],
+ kv_channels=vcfg["kv_channels"],
+ ffn_hidden_size=vcfg["ffn_hidden_size"],
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ layernorm_epsilon=1e-6,
+ normalization="LayerNorm",
+ gated_linear_unit=False,
+ activation_func=lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
+ bias_activation_fusion=False,
+ apply_query_key_layer_scaling=False,
+ apply_rope_fusion=False,
+ bf16=False,
+ )
+
+
+# ---------------------------------------------------------------------------
+# Language config variants
+# ---------------------------------------------------------------------------
+
+_VARIANT_CONFIGS = {
+ "0.8b": {
+ "num_layers": 24,
+ "hidden_size": 1024,
+ "ffn_hidden_size": 3584,
+ "num_attention_heads": 8,
+ "num_query_groups": 2,
+ "kv_channels": 256,
+ "linear_num_value_heads": 16,
+ "num_moe_experts": None,
+ "moe_router_topk": None,
+ "moe_ffn_hidden_size": None,
+ "moe_shared_expert_intermediate_size": None,
+ },
+ "2b": {
+ "num_layers": 24,
+ "hidden_size": 2048,
+ "ffn_hidden_size": 6144,
+ "num_attention_heads": 8,
+ "num_query_groups": 2,
+ "kv_channels": 256,
+ "linear_num_value_heads": 16,
+ "num_moe_experts": None,
+ "moe_router_topk": None,
+ "moe_ffn_hidden_size": None,
+ "moe_shared_expert_intermediate_size": None,
+ },
+ "4b": {
+ "num_layers": 32,
+ "hidden_size": 2560,
+ "ffn_hidden_size": 9216,
+ "num_attention_heads": 16,
+ "num_query_groups": 4,
+ "kv_channels": 256,
+ "linear_num_value_heads": 32,
+ "num_moe_experts": None,
+ "moe_router_topk": None,
+ "moe_ffn_hidden_size": None,
+ "moe_shared_expert_intermediate_size": None,
+ },
+ "9b": {
+ "num_layers": 32,
+ "hidden_size": 4096,
+ "ffn_hidden_size": 12288,
+ "num_attention_heads": 16,
+ "num_query_groups": 4,
+ "kv_channels": 256,
+ "linear_num_value_heads": 32,
+ "num_moe_experts": None,
+ "moe_router_topk": None,
+ "moe_ffn_hidden_size": None,
+ "moe_shared_expert_intermediate_size": None,
+ },
+ "27b": {
+ "num_layers": 64,
+ "hidden_size": 5120,
+ "ffn_hidden_size": 17408,
+ "num_attention_heads": 24,
+ "num_query_groups": 4,
+ "kv_channels": 256,
+ "linear_num_value_heads": 48,
+ "num_moe_experts": None,
+ "moe_router_topk": None,
+ "moe_ffn_hidden_size": None,
+ "moe_shared_expert_intermediate_size": None,
+ },
+ "35b_a3b": {
+ "num_layers": 40,
+ "hidden_size": 2048,
+ "ffn_hidden_size": 4096,
+ "num_attention_heads": 16,
+ "num_query_groups": 2,
+ "kv_channels": 256,
+ "linear_num_value_heads": 32,
+ "num_moe_experts": 256,
+ "moe_router_topk": 8,
+ "moe_ffn_hidden_size": 512,
+ "moe_shared_expert_intermediate_size": 512,
+ },
+ "35b_a3b_light": {
+ "num_layers": 20,
+ "hidden_size": 2048,
+ "ffn_hidden_size": 4096,
+ "num_attention_heads": 16,
+ "num_query_groups": 2,
+ "kv_channels": 256,
+ "linear_num_value_heads": 32,
+ "num_moe_experts": 256,
+ "moe_router_topk": 8,
+ "moe_ffn_hidden_size": 512,
+ "moe_shared_expert_intermediate_size": 512,
+ },
+ "122b_a10b": {
+ "num_layers": 48,
+ "hidden_size": 3072,
+ "ffn_hidden_size": 8192,
+ "num_attention_heads": 32,
+ "num_query_groups": 2,
+ "kv_channels": 256,
+ "linear_num_value_heads": 64,
+ "num_moe_experts": 256,
+ "moe_router_topk": 8,
+ "moe_ffn_hidden_size": 1024,
+ "moe_shared_expert_intermediate_size": 1024,
+ },
+ "397b_a17b": {
+ "num_layers": 60,
+ "hidden_size": 4096,
+ "ffn_hidden_size": 10240,
+ "num_attention_heads": 32,
+ "num_query_groups": 2,
+ "kv_channels": 256,
+ "linear_num_value_heads": 64,
+ "num_moe_experts": 512,
+ "moe_router_topk": 10,
+ "moe_ffn_hidden_size": 1024,
+ "moe_shared_expert_intermediate_size": 1024,
+ },
+ "proxy": {
+ "num_layers": 4,
+ "hidden_size": 4096,
+ "ffn_hidden_size": 10240,
+ "num_attention_heads": 32,
+ "num_query_groups": 2,
+ "kv_channels": 256,
+ "linear_num_value_heads": 64,
+ "num_moe_experts": 16,
+ "moe_router_topk": 2,
+ "moe_ffn_hidden_size": 1024,
+ "moe_shared_expert_intermediate_size": 1024,
+ },
+}
+
+
+def get_qwen35_vl_language_config(
+ variant: str = "proxy",
+ **overrides,
+) -> TransformerConfig:
+ """TransformerConfig for the Qwen3.5-VL language decoder.
+
+ The ``397b_a17b`` variant reproduces the MIMO
+ ``get_qwen35_language_model_config()`` output exactly.
+
+ Args:
+ variant: One of ``0.8b``, ``2b``, ``4b``, ``9b``, ``27b``,
+ ``35b_a3b``, ``122b_a10b``, ``397b_a17b``,
+ ``35b_a3b_light``, ``proxy``.
+ **overrides: Override any TransformerConfig field.
+
+ Returns:
+ Fully-populated TransformerConfig.
+ """
+ if variant not in _VARIANT_CONFIGS:
+ raise ValueError(
+ f"Unknown variant '{variant}'. "
+ f"Choose from {list(_VARIANT_CONFIGS.keys())}"
+ )
+
+ v = _VARIANT_CONFIGS[variant]
+
+ kwargs = dict(
+ # Architecture
+ num_layers=v["num_layers"],
+ hidden_size=v["hidden_size"],
+ ffn_hidden_size=v["ffn_hidden_size"],
+ num_attention_heads=v["num_attention_heads"],
+ num_query_groups=v["num_query_groups"],
+ kv_channels=v["kv_channels"],
+ # Normalization & activation
+ normalization="RMSNorm",
+ layernorm_epsilon=1e-6,
+ layernorm_zero_centered_gamma=True,
+ apply_residual_connection_post_layernorm=False,
+ gated_linear_unit=True,
+ activation_func=torch.nn.functional.silu,
+ # MRoPE section (interleaved T/H/W layout, Qwen3.5-VL style)
+ mrope_section=list(MROPE_SECTION),
+ mrope_interleaved=True,
+ rotary_interleaved=False,
+ # Attention
+ qk_layernorm=True,
+ attention_output_gate=True,
+ attention_dropout=0.0,
+ hidden_dropout=0.0,
+ add_bias_linear=False,
+ # Hybrid attention (GatedDeltaNet)
+ experimental_attention_variant="gated_delta_net",
+ linear_attention_freq=4,
+ linear_conv_kernel_dim=4,
+ linear_key_head_dim=128,
+ linear_value_head_dim=128,
+ linear_num_key_heads=16,
+ linear_num_value_heads=v["linear_num_value_heads"],
+ # Kernel / TE fusions
+ bias_activation_fusion=True,
+ masked_softmax_fusion=True,
+ persist_layer_norm=True,
+ bias_dropout_fusion=True,
+ apply_rope_fusion=False,
+ # Precision
+ bf16=True,
+ )
+
+ # MoE config (only for MoE variants)
+ if v["num_moe_experts"] is not None:
+ kwargs.update(
+ num_moe_experts=v["num_moe_experts"],
+ moe_router_topk=v["moe_router_topk"],
+ moe_ffn_hidden_size=v["moe_ffn_hidden_size"],
+ moe_shared_expert_intermediate_size=v[
+ "moe_shared_expert_intermediate_size"
+ ],
+ moe_shared_expert_gate=True,
+ moe_layer_freq=1,
+ moe_router_pre_softmax=False,
+ moe_router_load_balancing_type="global_aux_loss",
+ moe_permute_fusion=True,
+ moe_aux_loss_coeff=1e-3,
+ moe_grouped_gemm=True,
+ moe_token_dispatcher_type="alltoall",
+ moe_router_dtype="fp32",
+ )
+
+ kwargs.update(overrides)
+ return TransformerConfig(**kwargs)
diff --git a/examples/multimodal_dev/models/qwen35_vl/factory.py b/examples/multimodal_dev/models/qwen35_vl/factory.py
new file mode 100644
index 00000000000..1b9076845e7
--- /dev/null
+++ b/examples/multimodal_dev/models/qwen35_vl/factory.py
@@ -0,0 +1,101 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""Factory functions for Qwen3.5-VL model construction.
+
+Encapsulates all Qwen3.5-VL-specific logic needed by ``pretrain_multimodal.py``
+so that the training entry point remains model-agnostic.
+"""
+
+from examples.multimodal_dev.models.qwen35_vl.configuration import (
+ MROPE_SECTION,
+ VISION_KWARGS,
+)
+
+
+def post_language_config(language_config, args):
+ """Apply Qwen3.5-VL-specific settings to the language TransformerConfig.
+
+ Called after ``core_transformer_config_from_args`` to inject model-specific
+ fields that cannot be expressed via CLI args alone.
+ """
+ language_config.mrope_section = list(MROPE_SECTION)
+ language_config.mrope_interleaved = True
+
+
+def set_vision_flops_metadata(args, language_config, vision_config):
+ """Expose Qwen3.5-VL vision-model dimensions for FLOPs estimation."""
+ args.count_vision_model_flops = True
+ args.vision_flops_variant = "qwen35_vl_v2"
+ args.vision_num_layers = vision_config.num_layers
+ args.vision_hidden_size = vision_config.hidden_size
+ args.vision_ffn_hidden_size = vision_config.ffn_hidden_size
+ args.vision_num_attention_heads = vision_config.num_attention_heads
+ args.vision_kv_channels = vision_config.kv_channels
+ args.vision_in_channels = VISION_KWARGS["in_channels"]
+ args.vision_patch_size = VISION_KWARGS["patch_size"]
+ args.vision_temporal_patch_size = VISION_KWARGS["temporal_patch_size"]
+ args.vision_spatial_merge_size = VISION_KWARGS["spatial_merge_size"]
+ args.vision_out_hidden_size = language_config.hidden_size
+
+
+def build_model(args, language_config, vision_config, **kwargs):
+ """Build a complete Qwen3.5-VL model instance.
+
+ Handles language spec construction, optional MTP block spec, and
+ model instantiation with Qwen3.5-VL-specific parameters.
+
+ Args:
+ args: Megatron parsed arguments.
+ language_config: ``TransformerConfig`` for the language decoder
+ (already post-processed by :func:`post_language_config`).
+ vision_config: ``TransformerConfig`` for the vision encoder.
+ **kwargs: Extra keyword arguments (e.g. ``vp_stage``).
+
+ Returns:
+ A :class:`Qwen35VLModel` instance.
+ """
+ from megatron.core.models.gpt.gpt_layer_specs import (
+ get_gpt_mtp_block_spec,
+ )
+
+ from examples.multimodal_dev.models.qwen35_vl.model import Qwen35VLModel
+ from examples.multimodal_dev.models.qwen35_vl.specs import (
+ get_qwen35_vl_language_spec,
+ )
+
+ language_spec = get_qwen35_vl_language_spec(
+ config=language_config,
+ vp_stage=kwargs.get("vp_stage", None),
+ pp_rank=None,
+ )
+
+ mtp_block_spec = None
+ if getattr(args, "mtp_num_layers", None):
+ mtp_block_spec = get_gpt_mtp_block_spec(
+ config=language_config,
+ spec=language_spec,
+ use_transformer_engine=(
+ args.transformer_impl == "transformer_engine"
+ ),
+ vp_stage=kwargs.get("vp_stage", None),
+ pp_rank=None,
+ )
+
+ # When --untie-embeddings-and-output-weights is NOT passed, Megatron
+ # defaults to tied embeddings (share_embeddings_and_output_weights=True).
+ # The 0.8B variant uses tied embeddings, while larger variants untie them.
+ share_embeddings = not getattr(
+ args, "untie_embeddings_and_output_weights", False
+ )
+
+ return Qwen35VLModel(
+ language_config=language_config,
+ language_spec=language_spec,
+ vision_config=vision_config,
+ vocab_size=args.padded_vocab_size,
+ max_sequence_length=args.max_position_embeddings,
+ image_token_id=getattr(args, "image_token_id", 248056),
+ mtp_block_spec=mtp_block_spec,
+ parallel_output=True,
+ share_embeddings_and_output_weights=share_embeddings,
+ )
diff --git a/examples/multimodal_dev/models/qwen35_vl/model.py b/examples/multimodal_dev/models/qwen35_vl/model.py
new file mode 100644
index 00000000000..2e2d1455af8
--- /dev/null
+++ b/examples/multimodal_dev/models/qwen35_vl/model.py
@@ -0,0 +1,129 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""Qwen3.5-VL multimodal model for standalone FSDP + EP training.
+
+Composes a Megatron-native Qwen3.5 vision encoder with a ``GPTModel``
+language decoder using MRoPE and hybrid GatedDeltaNet / full-attention
+layers.
+"""
+
+from typing import Optional
+
+from torch import Tensor
+
+from examples.multimodal_dev.models.base import MultimodalModel
+from examples.multimodal_dev.models.qwen35_vl.configuration import (
+ QWEN35_VL_IMAGE_TOKEN_ID,
+ QWEN35_VL_VIDEO_TOKEN_ID,
+ QWEN35_VL_VISION_START_TOKEN_ID,
+ QWEN35_VL_VOCAB_SIZE,
+ ROTARY_BASE,
+ ROTARY_PERCENT,
+ VISION_KWARGS,
+)
+from examples.multimodal_dev.models.qwen35_vl.mrope import get_rope_index
+from examples.multimodal_dev.models.qwen35_vl.specs import get_qwen35_vl_vision_spec
+from examples.multimodal_dev.models.qwen35_vl.vision_encoder import Qwen35VLVisionEncoder
+from megatron.core.transformer.spec_utils import ModuleSpec
+from megatron.core.transformer.transformer_config import TransformerConfig
+
+
+class Qwen35VLModel(MultimodalModel):
+ """Qwen3.5-VL multimodal model.
+
+ Args:
+ language_config: ``TransformerConfig`` for the language decoder.
+ language_spec: ``ModuleSpec`` for language decoder layers.
+ vision_config: ``TransformerConfig`` for the vision encoder.
+ vision_spec: ``ModuleSpec`` for vision encoder layers.
+ vocab_size: Vocabulary size.
+ max_sequence_length: Maximum sequence length.
+ image_token_id: Token ID for image placeholders.
+ spatial_merge_size: Vision encoder spatial merge factor.
+ mtp_block_spec: Optional MTP block spec.
+ parallel_output: Keep outputs split across TP.
+ share_embeddings_and_output_weights: Tie embeddings.
+ """
+
+ def __init__(
+ self,
+ language_config: TransformerConfig,
+ language_spec: ModuleSpec,
+ vision_config: TransformerConfig,
+ vision_spec: ModuleSpec = None,
+ vocab_size: int = QWEN35_VL_VOCAB_SIZE,
+ max_sequence_length: int = 262144,
+ image_token_id: int = QWEN35_VL_IMAGE_TOKEN_ID,
+ video_token_id: int = QWEN35_VL_VIDEO_TOKEN_ID,
+ vision_start_token_id: int = QWEN35_VL_VISION_START_TOKEN_ID,
+ spatial_merge_size: int = 2,
+ mtp_block_spec: ModuleSpec = None,
+ parallel_output: bool = True,
+ share_embeddings_and_output_weights: bool = False,
+ ):
+ if vision_spec is None:
+ vision_spec = get_qwen35_vl_vision_spec()
+
+ self.video_token_id = video_token_id
+ self.vision_start_token_id = vision_start_token_id
+ self.spatial_merge_size = spatial_merge_size
+
+ vkw = dict(VISION_KWARGS)
+ vkw["spatial_merge_size"] = spatial_merge_size
+ vkw["out_hidden_size"] = language_config.hidden_size
+
+ vision_encoder = Qwen35VLVisionEncoder(
+ config=vision_config,
+ transformer_layer_spec=vision_spec,
+ in_channels=vkw["in_channels"],
+ patch_size=vkw["patch_size"],
+ temporal_patch_size=vkw["temporal_patch_size"],
+ spatial_merge_size=vkw["spatial_merge_size"],
+ out_hidden_size=vkw["out_hidden_size"],
+ max_num_positions=vkw["max_num_positions"],
+ )
+
+ super().__init__(
+ language_config=language_config,
+ language_spec=language_spec,
+ vision_encoder=vision_encoder,
+ vocab_size=vocab_size,
+ max_sequence_length=max_sequence_length,
+ image_token_id=image_token_id,
+ position_embedding_type="mrope",
+ rotary_percent=ROTARY_PERCENT,
+ rotary_base=ROTARY_BASE,
+ mrope_section=language_config.mrope_section,
+ mtp_block_spec=mtp_block_spec,
+ parallel_output=parallel_output,
+ share_embeddings_and_output_weights=(
+ share_embeddings_and_output_weights
+ ),
+ )
+
+ def compute_position_ids(
+ self,
+ input_ids: Tensor,
+ image_grid_thw: Optional[Tensor] = None,
+ packed_seq_params=None,
+ ) -> Tensor:
+ """Compute 3D MRoPE position IDs for Qwen3.5-VL.
+
+ In THD mode ``input_ids`` is ``[1, T]`` and ``packed_seq_params``
+ supplies per-segment boundaries; positions restart at 0 per
+ segment. In BSHD mode ``input_ids`` is ``[B, S]`` and
+ ``packed_seq_params`` should be ``None``.
+
+ Returns:
+ ``[3, B, S]`` position IDs for MRoPE (``[3, 1, T]`` in THD).
+ """
+ position_ids, _ = get_rope_index(
+ spatial_merge_size=self.spatial_merge_size,
+ image_token_id=self.image_token_id,
+ video_token_id=self.video_token_id,
+ vision_start_token_id=self.vision_start_token_id,
+ input_ids=input_ids,
+ image_grid_thw=image_grid_thw,
+ packed_seq_params=packed_seq_params,
+ )
+ return position_ids
diff --git a/examples/multimodal_dev/models/qwen35_vl/mrope.py b/examples/multimodal_dev/models/qwen35_vl/mrope.py
new file mode 100644
index 00000000000..44806a9bdf0
--- /dev/null
+++ b/examples/multimodal_dev/models/qwen35_vl/mrope.py
@@ -0,0 +1,378 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""MRoPE (Multimodal Rotary Position Embedding) position ID computation.
+
+Computes 3D position IDs for Qwen3.5-VL: for text tokens all three
+dimensions share sequential positions; for image/video tokens the three
+dimensions encode (temporal, height, width) in the merged spatial grid.
+
+Supports two input layouts:
+
+* **BSHD** — ``input_ids`` is ``[B, S]``; each row is an independent
+ sample (possibly padded) and ``attention_mask`` marks valid tokens.
+* **THD** — ``input_ids`` is ``[1, T]``, a concatenation of ``N``
+ sub-sequences. ``packed_seq_params.cu_seqlens_q_padded`` gives the
+ physical segment boundaries in the packed tensor and
+ ``cu_seqlens_q`` gives the valid (unpadded) token count inside each
+ segment. Position IDs restart at 0 at every segment boundary; image
+ / video grid rows are consumed in packed order across segments.
+
+Ported from Megatron-Bridge ``get_rope_index`` (which itself is adapted
+from HF ``Qwen3VLForConditionalGeneration.get_rope_index``). The inner
+loop iterates over vision occurrences, not individual tokens.
+"""
+
+from typing import Optional
+
+import torch
+from torch import Tensor
+
+from megatron.core.packed_seq_params import PackedSeqParams
+
+
+def _build_sample_mrope_positions(
+ sample_input_ids: Tensor,
+ image_grid_thw: Optional[Tensor],
+ video_grid_thw: Optional[Tensor],
+ image_index: int,
+ video_index: int,
+ spatial_merge_size: int,
+ image_token_id: int,
+ video_token_id: int,
+ vision_start_token_id: int,
+) -> tuple[Tensor, int, int]:
+ """Compute MRoPE position IDs for a single sub-sequence.
+
+ Walks vision occurrences in ``sample_input_ids`` and produces a
+ ``[3, L]`` position tensor whose values start at 0. Advances
+ ``image_index`` / ``video_index`` through ``image_grid_thw`` /
+ ``video_grid_thw`` so callers can keep a running cursor across
+ multiple sub-sequences.
+ """
+ vision_start_indices = torch.argwhere(
+ sample_input_ids == vision_start_token_id,
+ ).squeeze(1)
+ vision_tokens = sample_input_ids[vision_start_indices + 1]
+ image_nums = int((vision_tokens == image_token_id).sum())
+ video_nums = int((vision_tokens == video_token_id).sum())
+ input_tokens = sample_input_ids.tolist()
+ llm_pos_ids_list: list = []
+ st = 0
+ remain_images, remain_videos = image_nums, video_nums
+
+ for _ in range(image_nums + video_nums):
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+
+ if ed_image < ed_video:
+ t, h, w = (
+ image_grid_thw[image_index][0],
+ image_grid_thw[image_index][1],
+ image_grid_thw[image_index][2],
+ )
+ image_index += 1
+ remain_images -= 1
+ ed = ed_image
+ else:
+ t, h, w = (
+ video_grid_thw[video_index][0],
+ video_grid_thw[video_index][1],
+ video_grid_thw[video_index][2],
+ )
+ video_index += 1
+ remain_videos -= 1
+ ed = ed_video
+
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t.item(),
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+ text_len = ed - st
+
+ st_idx = (
+ llm_pos_ids_list[-1].max() + 1
+ if llm_pos_ids_list
+ else 0
+ )
+ llm_pos_ids_list.append(
+ torch.arange(text_len).view(1, -1).expand(3, -1)
+ + st_idx
+ )
+
+ t_index = (
+ torch.arange(llm_grid_t)
+ .view(-1, 1)
+ .expand(-1, llm_grid_h * llm_grid_w)
+ .flatten()
+ )
+ h_index = (
+ torch.arange(llm_grid_h)
+ .view(1, -1, 1)
+ .expand(llm_grid_t, -1, llm_grid_w)
+ .flatten()
+ )
+ w_index = (
+ torch.arange(llm_grid_w)
+ .view(1, 1, -1)
+ .expand(llm_grid_t, llm_grid_h, -1)
+ .flatten()
+ )
+ llm_pos_ids_list.append(
+ torch.stack([t_index, h_index, w_index])
+ + text_len
+ + st_idx
+ )
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+
+ if st < len(input_tokens):
+ st_idx = (
+ llm_pos_ids_list[-1].max() + 1
+ if llm_pos_ids_list
+ else 0
+ )
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(
+ torch.arange(text_len).view(1, -1).expand(3, -1)
+ + st_idx
+ )
+
+ if llm_pos_ids_list:
+ positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ else:
+ positions = torch.zeros(
+ 3, 0,
+ dtype=sample_input_ids.dtype,
+ device=sample_input_ids.device,
+ )
+ return positions, image_index, video_index
+
+
+def get_rope_index(
+ spatial_merge_size: int,
+ image_token_id: int,
+ video_token_id: int,
+ vision_start_token_id: int,
+ input_ids: Optional[Tensor] = None,
+ image_grid_thw: Optional[Tensor] = None,
+ video_grid_thw: Optional[Tensor] = None,
+ attention_mask: Optional[Tensor] = None,
+ packed_seq_params: Optional[PackedSeqParams] = None,
+) -> tuple[Tensor, Tensor]:
+ """Compute 3D MRoPE position IDs for Qwen3-VL / Qwen3.5-VL.
+
+ Qwen3-VL uses timestamps rather than absolute time position IDs.
+
+ For text tokens all three dimensions share sequential positions.
+ For vision tokens the three dimensions encode (temporal, height,
+ width) in the merged spatial grid.
+
+ Args:
+ spatial_merge_size: Merge factor for spatial dimensions.
+ image_token_id: Token ID for image placeholders.
+ video_token_id: Token ID for video placeholders.
+ vision_start_token_id: Token ID marking start of a vision region.
+ input_ids: ``[B, S]`` in BSHD or ``[1, T]`` in THD.
+ image_grid_thw: ``[num_images, 3]`` per-image
+ ``(temporal, height, width)`` in patch-grid units. Rows are
+ consumed in the order their image tokens appear in
+ ``input_ids`` (packed order across segments in THD).
+ video_grid_thw: ``[num_videos, 3]`` per-video grid dimensions.
+ attention_mask: ``[B, S]`` mask (1 = keep, 0 = pad). BSHD only.
+ packed_seq_params: When provided, selects the THD branch and
+ supplies segment boundaries via ``cu_seqlens_q`` (valid
+ lengths) and ``cu_seqlens_q_padded`` (packed layout).
+
+ Returns:
+ ``(position_ids, mrope_position_deltas)`` where *position_ids*
+ has shape ``[3, B, S]`` (``[3, 1, T]`` in THD).
+ """
+ if video_grid_thw is not None:
+ video_grid_thw = torch.repeat_interleave(
+ video_grid_thw, video_grid_thw[:, 0], dim=0,
+ )
+ video_grid_thw[:, 0] = 1
+
+ # -----------------------------------------------------------------
+ # THD (packed) branch
+ # -----------------------------------------------------------------
+ if packed_seq_params is not None and input_ids is not None:
+ cu_seqlens = packed_seq_params.cu_seqlens_q
+ cu_seqlens_padded = getattr(
+ packed_seq_params, "cu_seqlens_q_padded", None,
+ )
+ if cu_seqlens_padded is None:
+ cu_seqlens_padded = cu_seqlens
+
+ assert (
+ input_ids.dim() == 2 and input_ids.shape[0] == 1
+ ), "THD get_rope_index expects input_ids shape [1, T]"
+
+ total_tokens = input_ids.shape[1]
+ device = input_ids.device
+
+ # Padding slots default to 1 (matches BSHD convention where
+ # masked positions get filled with 1).
+ position_ids = torch.ones(
+ 3, 1, total_tokens,
+ dtype=input_ids.dtype, device=device,
+ )
+ deltas: list = []
+ image_index = 0
+ video_index = 0
+ num_segs = cu_seqlens.numel() - 1
+
+ for k in range(num_segs):
+ seg_start = int(cu_seqlens_padded[k].item())
+ valid_len = int(
+ cu_seqlens[k + 1].item() - cu_seqlens[k].item()
+ )
+ valid_end = seg_start + valid_len
+
+ if valid_len == 0:
+ deltas.append(0)
+ continue
+
+ sample_input_ids = input_ids[0, seg_start:valid_end]
+
+ if (
+ image_grid_thw is not None
+ or video_grid_thw is not None
+ ):
+ (
+ positions,
+ image_index,
+ video_index,
+ ) = _build_sample_mrope_positions(
+ sample_input_ids=sample_input_ids,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ image_index=image_index,
+ video_index=video_index,
+ spatial_merge_size=spatial_merge_size,
+ image_token_id=image_token_id,
+ video_token_id=video_token_id,
+ vision_start_token_id=vision_start_token_id,
+ )
+ else:
+ positions = (
+ torch.arange(valid_len, device=device)
+ .view(1, -1)
+ .expand(3, -1)
+ )
+
+ position_ids[:, 0, seg_start:valid_end] = positions.to(
+ device=device, dtype=position_ids.dtype,
+ )
+
+ if positions.numel() > 0:
+ deltas.append(
+ int(positions.max().item()) + 1 - valid_len
+ )
+ else:
+ deltas.append(0)
+
+ mrope_position_deltas = torch.tensor(
+ deltas, device=device,
+ ).unsqueeze(1)
+ return position_ids, mrope_position_deltas
+
+ # -----------------------------------------------------------------
+ # BSHD branch with vision
+ # -----------------------------------------------------------------
+ if input_ids is not None and (
+ image_grid_thw is not None or video_grid_thw is not None
+ ):
+ total_input_ids = input_ids
+ if attention_mask is None:
+ attention_mask = torch.ones_like(total_input_ids)
+ elif attention_mask.dim() > 2:
+ attention_mask = attention_mask.any(dim=-1)
+ if attention_mask.dim() == 3:
+ attention_mask = attention_mask.squeeze(1)
+ attention_mask = attention_mask.to(dtype=total_input_ids.dtype)
+
+ position_ids = torch.ones(
+ 3,
+ input_ids.shape[0],
+ input_ids.shape[1],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ mrope_position_deltas = []
+ image_index, video_index = 0, 0
+ attention_mask = attention_mask.to(total_input_ids.device)
+
+ for i, sample_input_ids in enumerate(total_input_ids):
+ sample_input_ids = sample_input_ids[attention_mask[i] == 1]
+ (
+ llm_positions,
+ image_index,
+ video_index,
+ ) = _build_sample_mrope_positions(
+ sample_input_ids=sample_input_ids,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ image_index=image_index,
+ video_index=video_index,
+ spatial_merge_size=spatial_merge_size,
+ image_token_id=image_token_id,
+ video_token_id=video_token_id,
+ vision_start_token_id=vision_start_token_id,
+ )
+ position_ids[
+ ..., i, attention_mask[i] == 1
+ ] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(
+ llm_positions.max() + 1 - len(total_input_ids[i]),
+ )
+
+ mrope_position_deltas = torch.tensor(
+ mrope_position_deltas, device=total_input_ids.device,
+ ).unsqueeze(1)
+ return position_ids, mrope_position_deltas
+
+ # -----------------------------------------------------------------
+ # Text-only fallback
+ # -----------------------------------------------------------------
+ if attention_mask is not None:
+ if attention_mask.dim() > 2:
+ attention_mask = attention_mask.any(dim=-1)
+ if attention_mask.dim() == 3:
+ attention_mask = attention_mask.squeeze(1)
+ attention_mask = attention_mask.to(dtype=torch.long)
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = (
+ position_ids.unsqueeze(0)
+ .expand(3, -1, -1)
+ .to(attention_mask.device)
+ )
+ max_position_ids = (
+ position_ids.max(0, keepdim=False)[0]
+ .max(-1, keepdim=True)[0]
+ )
+ mrope_position_deltas = (
+ max_position_ids + 1 - attention_mask.shape[-1]
+ )
+ else:
+ position_ids = (
+ torch.arange(
+ input_ids.shape[1], device=input_ids.device,
+ )
+ .view(1, 1, -1)
+ .expand(3, input_ids.shape[0], -1)
+ )
+ mrope_position_deltas = torch.zeros(
+ [input_ids.shape[0], 1],
+ device=input_ids.device,
+ dtype=input_ids.dtype,
+ )
+
+ return position_ids, mrope_position_deltas
diff --git a/examples/multimodal_dev/models/qwen35_vl/specs.py b/examples/multimodal_dev/models/qwen35_vl/specs.py
new file mode 100644
index 00000000000..64b864e223e
--- /dev/null
+++ b/examples/multimodal_dev/models/qwen35_vl/specs.py
@@ -0,0 +1,131 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""Layer spec helpers for Qwen3.5-VL vision encoder and language decoder.
+
+Provides ModuleSpec builders that define the transformer layer composition.
+Both the standalone and MIMO training paths import from here.
+"""
+
+from typing import Optional
+
+from examples.multimodal_dev.models.base import _NO_CP_GROUP
+from megatron.core.models.gpt.experimental_attention_variant_module_specs import (
+ get_transformer_block_with_experimental_attention_variant_spec,
+)
+from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec
+from megatron.core.transformer.attention import SelfAttention
+from megatron.core.transformer.spec_utils import ModuleSpec
+from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
+from megatron.core.transformer.transformer_config import TransformerConfig
+
+
+def _apply_rope_fp32(t, freqs, config, cu_seqlens=None, mscale=1.0, cp_group=None):
+ """Apply rotary positional embedding in fp32, then cast back to original dtype.
+
+ Mirrors ``Qwen3VLSelfAttention.apply_rotary_pos_emb_absolute`` in Megatron-Bridge
+ with ``apply_rotary_pos_emb_in_fp32=True``.
+ """
+ from megatron.core import parallel_state
+ from megatron.core.models.common.embeddings.rope_utils import (
+ _apply_rotary_pos_emb_bshd,
+ _apply_rotary_pos_emb_thd,
+ )
+
+ orig_dtype = t.dtype
+ t_fp32 = t.float()
+
+ if cu_seqlens is None:
+ out = _apply_rotary_pos_emb_bshd(
+ t_fp32,
+ freqs,
+ rotary_interleaved=config.rotary_interleaved,
+ multi_latent_attention=getattr(config, 'multi_latent_attention', False),
+ mscale=mscale,
+ )
+ else:
+ if cp_group is None:
+ cp_group = parallel_state.get_context_parallel_group()
+ out = _apply_rotary_pos_emb_thd(
+ t_fp32,
+ cu_seqlens,
+ freqs,
+ rotary_interleaved=config.rotary_interleaved,
+ multi_latent_attention=getattr(config, 'multi_latent_attention', False),
+ mscale=mscale,
+ cp_group=cp_group,
+ )
+ return out.to(orig_dtype)
+
+
+def _apply_rope_fp32_no_cp(t, freqs, config, cu_seqlens=None, mscale=1.0, cp_group=None):
+ """Same as ``_apply_rope_fp32`` but forces CP-size=1.
+
+ The vision encoder uses THD packed sequences for variable-resolution
+ images. When the language model uses CP>1, the global CP group would
+ incorrectly split the vision seqlens. This wrapper substitutes a
+ trivial group so the vision RoPE sees the full packed sequence.
+ """
+ return _apply_rope_fp32(
+ t, freqs, config, cu_seqlens, mscale, cp_group=_NO_CP_GROUP,
+ )
+
+
+class Qwen35VLVisionSelfAttention(SelfAttention):
+ """ViT self-attention with RoPE applied in fp32.
+
+ Matches Bridge's ``Qwen3VLSelfAttention`` behaviour when
+ ``apply_rotary_pos_emb_in_fp32=True``: query and key are cast to float32
+ before the rotary multiply and cast back to bf16 afterwards. The
+ monkey-patch approach avoids duplicating the 300-line ``SelfAttention.forward``
+ while keeping the change local to this class.
+ """
+
+ def forward(self, *args, **kwargs):
+ import megatron.core.transformer.attention as _attn_mod
+
+ _orig = _attn_mod.apply_rotary_pos_emb
+ _attn_mod.apply_rotary_pos_emb = _apply_rope_fp32_no_cp
+ try:
+ return super().forward(*args, **kwargs)
+ finally:
+ _attn_mod.apply_rotary_pos_emb = _orig
+
+
+def get_qwen35_vl_language_spec(
+ config: TransformerConfig,
+ vp_stage: Optional[int] = None,
+ pp_rank: Optional[int] = None,
+) -> TransformerBlockSubmodules:
+ """Transformer block spec for the Qwen3.5-VL language decoder.
+
+ Uses the experimental attention variant infrastructure to build hybrid
+ GatedDeltaNet + full-attention layers with optional MoE interleaving.
+
+ Args:
+ config: Language decoder TransformerConfig.
+ vp_stage: Virtual pipeline stage.
+ pp_rank: Pipeline parallel rank.
+
+ Returns:
+ TransformerBlockSubmodules with per-layer specs.
+ """
+ return get_transformer_block_with_experimental_attention_variant_spec(
+ config=config,
+ vp_stage=vp_stage,
+ pp_rank=pp_rank,
+ )
+
+
+def get_qwen35_vl_vision_spec() -> ModuleSpec:
+ """ModuleSpec for vision encoder transformer layers.
+
+ Uses ``TEDotProductAttention`` which supports packed-sequence (THD)
+ attention via ``PackedSeqParams`` for variable-length images.
+
+ ``Qwen35VLVisionSelfAttention`` replaces the default ``SelfAttention`` so
+ that RoPE is applied in fp32, matching Bridge's
+ ``apply_rotary_pos_emb_in_fp32=True`` behaviour.
+ """
+ spec = get_vit_layer_with_transformer_engine_spec()
+ spec.submodules.self_attention.module = Qwen35VLVisionSelfAttention
+ return spec
diff --git a/examples/multimodal_dev/models/qwen35_vl/vision_encoder.py b/examples/multimodal_dev/models/qwen35_vl/vision_encoder.py
new file mode 100644
index 00000000000..8e16417937d
--- /dev/null
+++ b/examples/multimodal_dev/models/qwen35_vl/vision_encoder.py
@@ -0,0 +1,593 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""Megatron-native Qwen3.5-VL vision encoder.
+
+Architecture (matches HF ``Qwen3VLVisionModel`` exactly):
+
+ PatchEmbed (Conv3d)
+ → learned position embedding (bilinear interpolation)
+ → 2D Vision RoPE
+ → TransformerBlock × N (with PackedSeqParams / THD attention)
+ → PatchMerger (per-token LN → spatial merge → MLP)
+
+Key design choices:
+ * ``Conv3d`` patch embedding is replicated across TP ranks (no MCore
+ equivalent for 3D convolutions).
+ * ``PatchMerger`` MLP uses ``ColumnParallelLinear`` / ``RowParallelLinear``
+ for TP sharding.
+ * Inherits from ``VisionModule``.
+ * Expects pixel values in block-merge order (as produced by the HF
+ processor) so the merger's simple reshape is correct.
+"""
+
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+
+from megatron.core.models.common.vision_module.vision_module import (
+ VisionModule,
+)
+from megatron.core.packed_seq_params import PackedSeqParams
+from megatron.core.tensor_parallel.layers import (
+ ColumnParallelLinear,
+ RowParallelLinear,
+)
+from megatron.core.extensions.transformer_engine import TENorm
+from megatron.core.transformer.module import MegatronModule
+from megatron.core.transformer.spec_utils import ModuleSpec, build_module
+from megatron.core.transformer.transformer_block import TransformerBlock
+from megatron.core.transformer.transformer_config import TransformerConfig
+
+# -------------------------------------------------------------------
+# PatchEmbed — Conv3d (replicated, no TP sharding)
+# -------------------------------------------------------------------
+
+class Qwen35VLPatchEmbed(MegatronModule):
+ """3D convolution patch embedding matching HF ``Qwen3VLVisionPatchEmbed``.
+
+ Uses ``nn.Conv3d`` with kernel = stride = ``[temporal_patch_size,
+ patch_size, patch_size]`` and ``bias=True``. The module is replicated
+ across TP ranks (no MCore equivalent for 3D conv).
+
+ Args:
+ config: TransformerConfig (used by MegatronModule base).
+ in_channels: Number of input channels (3 for RGB).
+ hidden_size: Output embedding dimension.
+ patch_size: Spatial patch size.
+ temporal_patch_size: Temporal patch size.
+ """
+
+ def __init__(
+ self,
+ config: TransformerConfig,
+ in_channels: int = 3,
+ hidden_size: int = 1152,
+ patch_size: int = 16,
+ temporal_patch_size: int = 2,
+ ):
+ super().__init__(config=config)
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.in_channels = in_channels
+ self.hidden_size = hidden_size
+
+ kernel = [temporal_patch_size, patch_size, patch_size]
+ self.proj = torch.nn.Conv3d(
+ in_channels,
+ hidden_size,
+ kernel_size=kernel,
+ stride=kernel,
+ bias=True,
+ )
+
+ def forward(self, pixel_values: Tensor) -> Tensor:
+ """Forward pass.
+
+ Args:
+ pixel_values: ``[total_patches, C * T * pH * pW]``
+ pre-extracted flat patches.
+
+ Returns:
+ Patch embeddings ``[total_patches, hidden_size]``.
+ """
+ target_dtype = self.proj.weight.dtype
+ pixel_values = pixel_values.view(
+ -1,
+ self.in_channels,
+ self.temporal_patch_size,
+ self.patch_size,
+ self.patch_size,
+ )
+ return self.proj(pixel_values.to(dtype=target_dtype)).view(
+ -1, self.hidden_size
+ )
+
+
+# -------------------------------------------------------------------
+# VisionRotaryEmbedding — 1D frequency table
+# -------------------------------------------------------------------
+
+class Qwen35VLVisionRotaryEmbedding(MegatronModule):
+ """1D rotary position frequency table for the vision transformer.
+
+ Generates RoPE frequencies for integer positions ``0 .. seqlen-1``.
+ The encoder maps 2D (row, col) positions to embeddings via table
+ lookup. Matches HF ``Qwen3VLVisionRotaryEmbedding``.
+
+ Args:
+ dim: Frequency dimension (``head_dim // 2``).
+ theta: RoPE base frequency.
+ config: Optional TransformerConfig for MegatronModule base.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ theta: float = 10000.0,
+ config: Optional[TransformerConfig] = None,
+ ):
+ super().__init__(config=config)
+ self.dim = dim
+ self.theta = theta
+ inv_freq = 1.0 / (
+ theta
+ ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def _get_inv_freq(self, device: torch.device) -> Tensor:
+ """Return ``inv_freq`` in float32 on *device*.
+
+ Always recomputes in float32 regardless of the buffer's stored dtype.
+ This matches Bridge's lazy-init behaviour where ``inv_freq`` is
+ constructed fresh (in float32) on the first forward call, after any
+ ``model.bfloat16()`` cast has already occurred.
+ """
+ return 1.0 / (
+ self.theta
+ ** (
+ torch.arange(
+ 0, self.dim, 2,
+ dtype=torch.float32, device=device,
+ )
+ / self.dim
+ )
+ )
+
+ def forward(
+ self,
+ seqlen: int,
+ device: Optional[torch.device] = None,
+ ) -> Tensor:
+ """Frequency lookup table for positions ``0 .. seqlen-1``.
+
+ Args:
+ seqlen: Number of positions.
+ device: Runtime device (required for meta-init safety).
+
+ Returns:
+ ``[seqlen, dim // 2]`` frequencies.
+ """
+ if device is None:
+ if self.inv_freq.device.type != "meta":
+ device = self.inv_freq.device
+ else:
+ device = torch.device(
+ "cuda", torch.cuda.current_device()
+ )
+ inv_freq = self._get_inv_freq(device)
+ seq = torch.arange(seqlen, device=device, dtype=inv_freq.dtype)
+ return torch.outer(seq, inv_freq)
+
+
+# -------------------------------------------------------------------
+# PatchMerger — per-token LN, spatial merge, TP-sharded MLP
+# -------------------------------------------------------------------
+
+class Qwen35VLPatchMerger(MegatronModule):
+ """Spatial patch merger matching HF ``Qwen3VLVisionPatchMerger``.
+
+ Per-token ``LayerNorm`` on ``hidden_size`` → reshape to merge
+ ``spatial_merge_size ** 2`` adjacent patches → two-layer MLP
+ (``ColumnParallelLinear`` → GELU → ``RowParallelLinear``).
+
+ MLP dimensions: ``merge_dim → merge_dim → out_hidden_size``
+ where ``merge_dim = hidden_size * spatial_merge_size ** 2``.
+
+ Args:
+ config: TransformerConfig (provides TP settings, init_method).
+ hidden_size: Per-token hidden size from the ViT.
+ out_hidden_size: Output dimension (language model hidden_size).
+ spatial_merge_size: Merge factor per spatial dimension.
+ """
+
+ def __init__(
+ self,
+ config: TransformerConfig,
+ hidden_size: int = 1152,
+ out_hidden_size: int = 3584,
+ spatial_merge_size: int = 2,
+ ):
+ super().__init__(config=config)
+ self.spatial_merge_size = spatial_merge_size
+ self.merge_dim = hidden_size * (spatial_merge_size ** 2)
+ merge_dim = self.merge_dim
+
+ self.patch_norm = TENorm(config=config, hidden_size=hidden_size, eps=1e-6)
+ self.linear_fc1 = build_module(
+ ColumnParallelLinear,
+ merge_dim,
+ merge_dim,
+ config=config,
+ init_method=config.init_method,
+ bias=True,
+ gather_output=False,
+ )
+ self.linear_fc2 = build_module(
+ RowParallelLinear,
+ merge_dim,
+ out_hidden_size,
+ config=config,
+ init_method=config.output_layer_init_method,
+ bias=True,
+ input_is_parallel=True,
+ skip_bias_add=False,
+ )
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ """Merge patches spatially.
+
+ Args:
+ hidden_states: ``[total_patches, hidden_size]`` in block-merge
+ order from the ViT transformer blocks.
+
+ Returns:
+ ``[total_merged_patches, out_hidden_size]``.
+ """
+ hidden_states = self.patch_norm(hidden_states)
+ merged = hidden_states.view(-1, self.merge_dim)
+ merged, _ = self.linear_fc1(merged)
+ # NOTE: Official HuggingFace uses default approximate='none' in Qwen3VLVisionPatchMerger.
+ merged = torch.nn.functional.gelu(merged, approximate="tanh")
+ merged, _ = self.linear_fc2(merged)
+ return merged
+
+
+# -------------------------------------------------------------------
+# Qwen35VLVisionEncoder — top-level encoder module
+# -------------------------------------------------------------------
+
+class Qwen35VLVisionEncoder(VisionModule):
+ """Megatron-native Qwen3.5-VL vision encoder.
+
+ Processes image / video inputs through:
+
+ 1. ``Qwen35VLPatchEmbed`` (Conv3d)
+ 2. Learned ``nn.Embedding`` position table with bilinear interpolation
+ 3. 2D Vision RoPE from ``(row, col)`` patch positions
+ 4. ``TransformerBlock`` × N with ``PackedSeqParams`` (THD attention)
+ 5. ``Qwen35VLPatchMerger``
+
+ Output dimension matches the language model ``hidden_size``.
+
+ Args:
+ config: Vision ``TransformerConfig``.
+ transformer_layer_spec: ``ModuleSpec`` for ViT layers.
+ in_channels: Image channels (3 for RGB).
+ patch_size: Spatial patch size.
+ temporal_patch_size: Temporal patch size.
+ spatial_merge_size: Spatial merge factor.
+ out_hidden_size: Output dim (language decoder hidden_size).
+ max_num_positions: Size of the learned position table.
+ """
+
+ def __init__(
+ self,
+ config: TransformerConfig,
+ transformer_layer_spec: ModuleSpec = None,
+ in_channels: int = 3,
+ patch_size: int = 16,
+ temporal_patch_size: int = 2,
+ spatial_merge_size: int = 2,
+ out_hidden_size: int = 3584,
+ max_num_positions: int = 2304,
+ ):
+ super().__init__(config=config)
+
+ self.hidden_size = config.hidden_size
+ self.spatial_merge_size = spatial_merge_size
+
+ # --- Patch embedding (Conv3d) ---
+ self.patch_embed = Qwen35VLPatchEmbed(
+ config=config,
+ in_channels=in_channels,
+ hidden_size=config.hidden_size,
+ patch_size=patch_size,
+ temporal_patch_size=temporal_patch_size,
+ )
+
+ # --- Learned position embedding with bilinear interpolation ---
+ self.pos_embed = torch.nn.Embedding(
+ max_num_positions, config.hidden_size,
+ )
+ self.num_grid_per_side = int(max_num_positions ** 0.5)
+
+ # --- Vision rotary embeddings ---
+ head_dim = config.hidden_size // config.num_attention_heads
+ self.rot_pos_emb = Qwen35VLVisionRotaryEmbedding(
+ head_dim // 2, config=config,
+ )
+
+ # --- Transformer blocks ---
+ if transformer_layer_spec is None:
+ from examples.multimodal_dev.models.qwen35_vl.specs import (
+ get_qwen35_vl_vision_spec,
+ )
+ transformer_layer_spec = get_qwen35_vl_vision_spec()
+
+ self.decoder = TransformerBlock(
+ config=config,
+ spec=transformer_layer_spec,
+ pre_process=True,
+ post_process=True,
+ post_layer_norm=False,
+ )
+
+ # --- Patch merger ---
+ self.merger = Qwen35VLPatchMerger(
+ config=config,
+ hidden_size=config.hidden_size,
+ out_hidden_size=out_hidden_size,
+ spatial_merge_size=spatial_merge_size,
+ )
+
+ # ---------------------------------------------------------------
+ # Learned position embedding with bilinear interpolation
+ # ---------------------------------------------------------------
+
+ def _fast_pos_embed_interpolate(
+ self, grid_thw: Tensor,
+ ) -> Tensor:
+ """Bilinear interpolation of the learned 2D position table.
+
+ Matches HF ``Qwen3VLVisionModel.fast_pos_embed_interpolate``.
+
+ Args:
+ grid_thw: ``[num_images, 3]`` (T, H, W) in patch-grid units.
+
+ Returns:
+ ``[total_patches, hidden_size]`` position embeddings in
+ block-merge order.
+ """
+ grid_thw_list = grid_thw.tolist()
+ grid_ts = [int(row[0]) for row in grid_thw_list]
+ grid_hs = [int(row[1]) for row in grid_thw_list]
+ grid_ws = [int(row[2]) for row in grid_thw_list]
+ device = self.pos_embed.weight.device
+ n = self.num_grid_per_side
+
+ idx_list: List[List[int]] = [[] for _ in range(4)]
+ weight_list: List[List[float]] = [[] for _ in range(4)]
+
+ for t, h, w in grid_thw_list:
+ t, h, w = int(t), int(h), int(w)
+ h_idxs = torch.linspace(0, n - 1, h)
+ w_idxs = torch.linspace(0, n - 1, w)
+
+ h_floor = h_idxs.int()
+ w_floor = w_idxs.int()
+ h_ceil = (h_floor + 1).clip(max=n - 1)
+ w_ceil = (w_floor + 1).clip(max=n - 1)
+
+ dh = h_idxs - h_floor.float()
+ dw = w_idxs - w_floor.float()
+
+ base_h = h_floor * n
+ base_h_ceil = h_ceil * n
+
+ indices = [
+ (base_h[None].T + w_floor[None]).flatten(),
+ (base_h[None].T + w_ceil[None]).flatten(),
+ (base_h_ceil[None].T + w_floor[None]).flatten(),
+ (base_h_ceil[None].T + w_ceil[None]).flatten(),
+ ]
+ weights = [
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
+ ((1 - dh)[None].T * dw[None]).flatten(),
+ (dh[None].T * (1 - dw)[None]).flatten(),
+ (dh[None].T * dw[None]).flatten(),
+ ]
+
+ for i in range(4):
+ idx_list[i].extend(indices[i].tolist())
+ weight_list[i].extend(weights[i].tolist())
+
+ idx_tensor = torch.tensor(
+ idx_list, dtype=torch.long, device=device,
+ )
+ weight_tensor = torch.tensor(
+ weight_list,
+ dtype=self.pos_embed.weight.dtype,
+ device=device,
+ )
+ pos_embeds = (
+ self.pos_embed(idx_tensor).to(device)
+ * weight_tensor[:, :, None]
+ )
+ patch_pos_embeds = (
+ pos_embeds[0] + pos_embeds[1]
+ + pos_embeds[2] + pos_embeds[3]
+ )
+
+ patch_pos_embeds = patch_pos_embeds.split(
+ [h * w for h, w in zip(grid_hs, grid_ws)]
+ )
+
+ merge = self.spatial_merge_size
+ result = []
+ for pe, t, h, w in zip(
+ patch_pos_embeds, grid_ts, grid_hs, grid_ws,
+ ):
+ pe = pe.repeat(t, 1)
+ pe = (
+ pe.view(
+ t, h // merge, merge, w // merge, merge, -1,
+ )
+ .permute(0, 1, 3, 2, 4, 5)
+ .flatten(0, 4)
+ )
+ result.append(pe)
+
+ return torch.cat(result)
+
+ # ---------------------------------------------------------------
+ # 2D Vision RoPE
+ # ---------------------------------------------------------------
+
+ def _compute_rotary_pos_emb(self, grid_thw: Tensor) -> Tensor:
+ """Compute 2D Vision RoPE for all patches in block-merge order.
+
+ Matches HF ``Qwen3VLVisionModel.rot_pos_emb``.
+
+ Args:
+ grid_thw: ``[num_images, 3]`` (T, H, W) per image.
+
+ Returns:
+ ``[total_patches, head_dim // 2]`` raw RoPE frequencies.
+ """
+ merge = self.spatial_merge_size
+ grid_thw_list = grid_thw.tolist()
+
+ max_hw = max(max(int(h), int(w)) for _, h, w in grid_thw_list)
+ freq_table = self.rot_pos_emb(
+ max_hw, device=grid_thw.device,
+ )
+ device = freq_table.device
+
+ total_tokens = sum(
+ int(t) * int(h) * int(w) for t, h, w in grid_thw_list
+ )
+ pos_ids = torch.empty(
+ (total_tokens, 2), dtype=torch.long, device=device,
+ )
+
+ offset = 0
+ for num_frames, height, width in grid_thw_list:
+ num_frames = int(num_frames)
+ height = int(height)
+ width = int(width)
+ merged_h = height // merge
+ merged_w = width // merge
+
+ block_rows = torch.arange(merged_h, device=device)
+ block_cols = torch.arange(merged_w, device=device)
+ intra_row = torch.arange(merge, device=device)
+ intra_col = torch.arange(merge, device=device)
+
+ row_idx = (
+ block_rows[:, None, None, None] * merge
+ + intra_row[None, None, :, None]
+ )
+ col_idx = (
+ block_cols[None, :, None, None] * merge
+ + intra_col[None, None, None, :]
+ )
+
+ row_idx = row_idx.expand(
+ merged_h, merged_w, merge, merge,
+ ).reshape(-1)
+ col_idx = col_idx.expand(
+ merged_h, merged_w, merge, merge,
+ ).reshape(-1)
+
+ coords = torch.stack((row_idx, col_idx), dim=-1)
+ if num_frames > 1:
+ coords = coords.repeat(num_frames, 1)
+
+ n_tokens = coords.shape[0]
+ pos_ids[offset: offset + n_tokens] = coords
+ offset += n_tokens
+
+ embeddings = freq_table[pos_ids]
+ embeddings = embeddings.flatten(1)
+ return embeddings
+
+ # ---------------------------------------------------------------
+ # PackedSeqParams for variable-length attention
+ # ---------------------------------------------------------------
+
+ @staticmethod
+ def _build_packed_seq_params(grid_thw: Tensor) -> PackedSeqParams:
+ """Build ``PackedSeqParams`` from grid dimensions.
+
+ Each temporal frame of each image forms a separate sub-sequence
+ in the packed THD layout, matching HF's ``cu_seqlens`` computation.
+
+ Args:
+ grid_thw: ``[num_images, 3]``.
+
+ Returns:
+ ``PackedSeqParams`` for ``TransformerBlock``.
+ """
+ cu_seqlens = torch.repeat_interleave(
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0],
+ ).cumsum(dim=0, dtype=torch.int32)
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+ max_seqlen = int(
+ (grid_thw[:, 1] * grid_thw[:, 2]).max().item()
+ )
+
+ return PackedSeqParams(
+ qkv_format="thd",
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_kv=cu_seqlens,
+ max_seqlen_q=max_seqlen,
+ max_seqlen_kv=max_seqlen,
+ )
+
+ # ---------------------------------------------------------------
+ # Forward
+ # ---------------------------------------------------------------
+
+ def forward(
+ self,
+ pixel_values: Tensor,
+ grid_thw: Tensor,
+ ) -> Tensor:
+ """Encode images / video frames.
+
+ Args:
+ pixel_values: ``[total_patches, C * T * pH * pW]``
+ pre-extracted flat patches in block-merge order.
+ grid_thw: ``[num_images, 3]`` (T, H, W) in patch-grid units.
+
+ Returns:
+ ``[total_merged_patches, out_hidden_size]`` visual embeddings.
+ """
+ # 1. Patch embedding (Conv3d)
+ hidden_states = self.patch_embed(pixel_values)
+
+ # 2. Learned position embedding (bilinear interpolation)
+ pos_embeds = self._fast_pos_embed_interpolate(grid_thw)
+ hidden_states = hidden_states + pos_embeds
+
+ # 3. 2D Vision RoPE
+ rot_freqs = self._compute_rotary_pos_emb(grid_thw)
+ emb = torch.cat((rot_freqs, rot_freqs), dim=-1)
+ rot_freqs_expanded = emb.unsqueeze(1).unsqueeze(1)
+
+ # 4. Transformer blocks with PackedSeqParams
+ packed_seq_params = self._build_packed_seq_params(grid_thw)
+ hidden_states = hidden_states.unsqueeze(1)
+ hidden_states = self.decoder(
+ hidden_states=hidden_states,
+ attention_mask=None,
+ rotary_pos_emb=rot_freqs_expanded,
+ packed_seq_params=packed_seq_params,
+ )
+ hidden_states = hidden_states.squeeze(1)
+
+ # 5. Patch merger
+ return self.merger(hidden_states)
diff --git a/examples/multimodal_dev/pretrain_multimodal.py b/examples/multimodal_dev/pretrain_multimodal.py
new file mode 100644
index 00000000000..7a54bb93271
--- /dev/null
+++ b/examples/multimodal_dev/pretrain_multimodal.py
@@ -0,0 +1,158 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""Standalone entry point for multimodal_dev model training (FSDP + EP).
+
+This entry point is **model-agnostic**. All model-specific logic (layer
+specs, model construction, FLOPs metadata, dataset generation) is
+delegated to factory functions registered in
+:data:`multimodal_dev.models.MODEL_REGISTRY`.
+
+Adding a new architecture only requires:
+
+1. Creating a new model package under ``multimodal_dev/models//``
+ with the appropriate factory functions.
+2. Registering an entry in ``MODEL_REGISTRY``.
+
+No changes to this file are necessary.
+
+Usage::
+
+ torchrun --nproc_per_node=8 multimodal_dev/pretrain_multimodal.py \\
+ --model-arch qwen35_vl \\
+ --dataset-provider mock \\
+ ... (other megatron args)
+"""
+
+import importlib
+import os
+import sys
+
+sys.path.insert(
+ 0,
+ os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")),
+)
+
+from examples.multimodal_dev.arguments import add_multimodal_args
+from examples.multimodal_dev.forward_step import forward_step
+from megatron.core.enums import ModelType
+from megatron.training import get_args, pretrain
+from megatron.training.argument_utils import pretrain_cfg_container_from_args
+from megatron.training.arguments import core_transformer_config_from_args, parse_and_validate_args
+
+
+def model_provider(
+ pre_process: bool = True,
+ post_process: bool = True,
+ **kwargs,
+):
+ """Build a multimodal model from ``--model-arch``.
+
+ The language ``TransformerConfig`` is built from CLI args so that
+ parallelism settings, precision, and fusion flags are inherited.
+ Model-specific post-processing and construction are delegated to the
+ registry factory functions.
+ """
+ args = get_args()
+ model_arch = getattr(args, "model_arch", "qwen35_vl")
+
+ from examples.multimodal_dev.models import MODEL_REGISTRY
+
+ if model_arch not in MODEL_REGISTRY:
+ raise ValueError(
+ f"Unknown model arch '{model_arch}'. "
+ f"Available: {list(MODEL_REGISTRY.keys())}"
+ )
+
+ registry = MODEL_REGISTRY[model_arch]
+
+ # --- language config (generic + model-specific post-processing) ---
+ language_config = core_transformer_config_from_args(args)
+ post_language_config_fn = registry.get("post_language_config_fn")
+ if post_language_config_fn is not None:
+ post_language_config_fn(language_config, args)
+
+ # --- vision config ---
+ vision_config = registry["vision_config_fn"](
+ num_layers_override=getattr(args, "vision_num_layers", None),
+ variant=getattr(args, "model_variant", None),
+ )
+ vision_config.bf16 = language_config.bf16
+ vision_config.fp16 = language_config.fp16
+
+ if getattr(args, "recompute_vision", False):
+ vision_config.recompute_granularity = "full"
+ vision_config.recompute_method = "uniform"
+ vision_config.recompute_num_layers = 1
+
+ # --- vision FLOPs metadata ---
+ vision_flops_fn = registry.get("vision_flops_fn")
+ if vision_flops_fn is not None:
+ vision_flops_fn(args, language_config, vision_config)
+
+ # --- build model (fully delegated to the arch factory) ---
+ model = registry["model_factory_fn"](
+ args=args,
+ language_config=language_config,
+ vision_config=vision_config,
+ **kwargs,
+ )
+
+ return model
+
+
+def _resolve_provider_fn(provider_fn):
+ """Resolve a provider that may be a dotted import path string."""
+ if isinstance(provider_fn, str):
+ module_path, func_name = provider_fn.rsplit(".", 1)
+ provider_fn = getattr(
+ importlib.import_module(module_path), func_name,
+ )
+ return provider_fn
+
+
+def datasets_provider(train_val_test_num_samples):
+ """Dataset provider dispatcher.
+
+ Routes to the dataset factory registered for the current
+ ``(--model-arch, --dataset-provider)`` combination.
+ """
+ args = get_args()
+ model_arch = getattr(args, "model_arch", "qwen35_vl")
+ provider = getattr(args, "dataset_provider", "mock")
+
+ from examples.multimodal_dev.models import MODEL_REGISTRY
+
+ if model_arch not in MODEL_REGISTRY:
+ raise ValueError(
+ f"Unknown model arch '{model_arch}'. "
+ f"Available: {list(MODEL_REGISTRY.keys())}"
+ )
+
+ registry = MODEL_REGISTRY[model_arch]
+ available = registry.get("dataset_providers", {})
+
+ if provider not in available:
+ raise ValueError(
+ f"Unknown dataset provider '{provider}' for arch "
+ f"'{model_arch}'. Available: {list(available.keys())}"
+ )
+
+ provider_fn = _resolve_provider_fn(available[provider])
+ return provider_fn(train_val_test_num_samples)
+
+
+if __name__ == "__main__":
+ datasets_provider.is_distributed = True
+
+ args = parse_and_validate_args(
+ extra_args_provider=add_multimodal_args,
+ args_defaults={},
+ )
+ full_config = pretrain_cfg_container_from_args(args)
+ pretrain(
+ full_config,
+ datasets_provider,
+ model_provider,
+ ModelType.encoder_or_decoder,
+ forward_step,
+ )
diff --git a/examples/multimodal_dev/scripts/run_qwen35_vl.sh b/examples/multimodal_dev/scripts/run_qwen35_vl.sh
new file mode 100755
index 00000000000..974892a46ca
--- /dev/null
+++ b/examples/multimodal_dev/scripts/run_qwen35_vl.sh
@@ -0,0 +1,502 @@
+#!/bin/bash
+
+# Launch script for Qwen3.5-VL training via multimodal_dev (FSDP + EP).
+#
+# Usage (from the Megatron-LM repo root):
+# ./examples/multimodal_dev/scripts/run_qwen35_vl.sh
+#
+# Environment variables:
+# MODEL_VARIANT: proxy (default), 0.8b, 2b, 4b, 9b, 27b, 35b_a3b, 122b_a10b, 397b_a17b, 35b_a3b_light
+# CKPT_LOAD: path to a pre-converted checkpoint to load (enables --load + --finetune)
+# CKPT_FORMAT: checkpoint format override (e.g. torch_dist); auto-detected when empty
+# TP, EP, PP: parallelism sizes
+# MBS, GBS: micro/global batch sizes
+# NUM_LAYERS, NUM_EXPERTS: override for proxy testing
+# LAUNCHER: torchrun (default) or python
+# PROFILE: set to 1 to enable Nsight Systems profiling (default: 0)
+# PROFILE_STEP_START/PROFILE_STEP_END: profiled iteration window (default: 4-5)
+
+# example script:
+# WANDB_PROJECT=qwen35-cp-test WANDB_MODE=online CP=2 GPUS_PER_NODE=8 CKPT_LOAD=/lustre/fs1/portfolios/coreai/users/lit/workspace/dev-project/models/Qwen/Qwen3.5-0.8B-fsdp-0420/ USE_FSDP=1 EP=1 GBS=16 MODEL_VARIANT=0.8b SAVE_INTERVAL=10000 CKPT_RESUME=0 DRY_RUN=0 USE_PACKED_SEQUENCE=1 bash ./examples/multimodal_dev/scripts/run_qwen35_vl.sh
+
+# WANDB_PROJECT=qwen35-cp-test WANDB_MODE=online CP=1 GPUS_PER_NODE=4 CKPT_LOAD=/lustre/fs1/portfolios/coreai/users/lit/workspace/dev-project/models/Qwen/Qwen3.5-0.8B-fsdp-0420/ USE_FSDP=1 EP=1 GBS=16 MODEL_VARIANT=0.8b SAVE_INTERVAL=10000 CKPT_RESUME=0 DRY_RUN=0 USE_PACKED_SEQUENCE=1 bash ./examples/multimodal_dev/scripts/run_qwen35_vl.sh
+
+set -euo pipefail
+
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+export NCCL_IB_SL=1
+export NVTE_FUSED_ATTN=1
+export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+
+DRY_RUN=${DRY_RUN:-1}
+GPUS_PER_NODE=${GPUS_PER_NODE:-8}
+if [ -n "${SLURM_JOB_NUM_NODES:-}" ]; then
+ NUM_NODES="$SLURM_JOB_NUM_NODES"
+else
+ NUM_NODES=${NNODES:-1}
+fi
+PROFILE=${PROFILE:-0}
+PROFILE_STEP_START=${PROFILE_STEP_START:-4}
+PROFILE_STEP_END=${PROFILE_STEP_END:-5}
+PROFILE_RANKS=${PROFILE_RANKS:-0}
+LAUNCHER=${LAUNCHER:-torchrun}
+
+MODEL_VARIANT=${MODEL_VARIANT:-proxy}
+VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-}
+
+# Batch sizes
+MBS=${MBS:-2}
+GBS=${GBS:-16}
+
+# Parallelism
+TP=${TP:-1}
+EP=${EP:-2}
+PP=${PP:-1}
+CP=${CP:-1}
+
+# Variant-aware architecture defaults.
+# The model provider builds configs from the variant dict in
+# multimodal_dev/models/qwen35_vl/configuration.py, but Megatron also
+# uses these CLI args internally (PP splits, param counting).
+case "$MODEL_VARIANT" in
+ 0.8b)
+ NUM_LAYERS=${NUM_LAYERS:-24}
+ NUM_EXPERTS=${NUM_EXPERTS:-0}
+ HIDDEN_SIZE=1024
+ FFN_HIDDEN_SIZE=3584
+ NUM_ATTN_HEADS=8
+ NUM_QUERY_GROUPS=2
+ LINEAR_NUM_VALUE_HEADS=16
+ VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-12}
+ ;;
+ 2b)
+ NUM_LAYERS=${NUM_LAYERS:-24}
+ NUM_EXPERTS=${NUM_EXPERTS:-0}
+ HIDDEN_SIZE=2048
+ FFN_HIDDEN_SIZE=6144
+ NUM_ATTN_HEADS=8
+ NUM_QUERY_GROUPS=2
+ LINEAR_NUM_VALUE_HEADS=16
+ VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-24}
+ ;;
+ 4b)
+ NUM_LAYERS=${NUM_LAYERS:-32}
+ NUM_EXPERTS=${NUM_EXPERTS:-0}
+ HIDDEN_SIZE=2560
+ FFN_HIDDEN_SIZE=9216
+ NUM_ATTN_HEADS=16
+ NUM_QUERY_GROUPS=4
+ LINEAR_NUM_VALUE_HEADS=32
+ VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-24}
+ ;;
+ proxy)
+ NUM_LAYERS=${NUM_LAYERS:-4}
+ NUM_EXPERTS=${NUM_EXPERTS:-16}
+ HIDDEN_SIZE=4096
+ FFN_HIDDEN_SIZE=10240
+ NUM_ATTN_HEADS=32
+ NUM_QUERY_GROUPS=2
+ LINEAR_NUM_VALUE_HEADS=64
+ VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-2}
+ ;;
+ 9b)
+ NUM_LAYERS=${NUM_LAYERS:-32}
+ NUM_EXPERTS=${NUM_EXPERTS:-0}
+ HIDDEN_SIZE=4096
+ FFN_HIDDEN_SIZE=12288
+ NUM_ATTN_HEADS=16
+ NUM_QUERY_GROUPS=4
+ LINEAR_NUM_VALUE_HEADS=32
+ VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-27}
+ ;;
+ 27b)
+ NUM_LAYERS=${NUM_LAYERS:-64}
+ NUM_EXPERTS=${NUM_EXPERTS:-0}
+ HIDDEN_SIZE=5120
+ FFN_HIDDEN_SIZE=17408
+ NUM_ATTN_HEADS=24
+ NUM_QUERY_GROUPS=4
+ LINEAR_NUM_VALUE_HEADS=48
+ VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-27}
+ ;;
+ 35b_a3b)
+ NUM_LAYERS=${NUM_LAYERS:-40}
+ NUM_EXPERTS=${NUM_EXPERTS:-256}
+ HIDDEN_SIZE=2048
+ FFN_HIDDEN_SIZE=4096
+ NUM_ATTN_HEADS=16
+ NUM_QUERY_GROUPS=2
+ LINEAR_NUM_VALUE_HEADS=32
+ VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-27}
+ ;;
+ 35b_a3b_light)
+ NUM_LAYERS=${NUM_LAYERS:-12}
+ NUM_EXPERTS=${NUM_EXPERTS:-128}
+ HIDDEN_SIZE=2048
+ FFN_HIDDEN_SIZE=4096
+ NUM_ATTN_HEADS=16
+ NUM_QUERY_GROUPS=2
+ LINEAR_NUM_VALUE_HEADS=32
+ VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-7}
+ ;;
+ 122b_a10b)
+ NUM_LAYERS=${NUM_LAYERS:-48}
+ NUM_EXPERTS=${NUM_EXPERTS:-256}
+ HIDDEN_SIZE=3072
+ FFN_HIDDEN_SIZE=8192
+ NUM_ATTN_HEADS=32
+ NUM_QUERY_GROUPS=2
+ LINEAR_NUM_VALUE_HEADS=64
+ VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-27}
+ ;;
+ 397b_a17b)
+ NUM_LAYERS=${NUM_LAYERS:-60}
+ NUM_EXPERTS=${NUM_EXPERTS:-512}
+ HIDDEN_SIZE=4096
+ FFN_HIDDEN_SIZE=10240
+ NUM_ATTN_HEADS=32
+ NUM_QUERY_GROUPS=2
+ LINEAR_NUM_VALUE_HEADS=64
+ VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-27}
+ ;;
+ *)
+ : "${NUM_LAYERS:?NUM_LAYERS must be set for MODEL_VARIANT=$MODEL_VARIANT}"
+ : "${NUM_EXPERTS:?NUM_EXPERTS must be set for MODEL_VARIANT=$MODEL_VARIANT}"
+ : "${HIDDEN_SIZE:?HIDDEN_SIZE must be set for MODEL_VARIANT=$MODEL_VARIANT}"
+ : "${FFN_HIDDEN_SIZE:?FFN_HIDDEN_SIZE must be set for MODEL_VARIANT=$MODEL_VARIANT}"
+ : "${NUM_ATTN_HEADS:?NUM_ATTN_HEADS must be set for MODEL_VARIANT=$MODEL_VARIANT}"
+ : "${NUM_QUERY_GROUPS:?NUM_QUERY_GROUPS must be set for MODEL_VARIANT=$MODEL_VARIANT}"
+ : "${LINEAR_NUM_VALUE_HEADS:?LINEAR_NUM_VALUE_HEADS must be set for MODEL_VARIANT=$MODEL_VARIANT}"
+ VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-27}
+ ;;
+esac
+SEQ_LEN=${SEQ_LEN:-4096}
+
+WANDB_PROJECT=${WANDB_PROJECT:-'qwen35-vl-0524'}
+EXP_NAME="qwen35vl_${MODEL_VARIANT}_tp${TP}_ep${EP}_pp${PP}_cp${CP}"
+
+RECOMPUTE_VISION=${RECOMPUTE_VISION:-0}
+if [ "$RECOMPUTE_VISION" -eq 1 ]; then
+ EXP_NAME+="_recompute_encoder"
+fi
+RECOMPUTE=${RECOMPUTE:-0}
+if [ "$RECOMPUTE" -eq 1 ]; then
+ EXP_NAME+="_recompute_decoder"
+fi
+
+USE_PACKED_SEQUENCE=${USE_PACKED_SEQUENCE:-0}
+if [ "$USE_PACKED_SEQUENCE" -eq 1 ]; then
+ EXP_NAME+="_thd"
+fi
+
+MEGATRON_LM_PATH="${MEGATRON_LM_PATH:-$(cd "$(dirname "$0")/../../.." && pwd)}"
+ROOT_DIR="${ROOT_DIR:-${MEGATRON_LM_PATH}/local/}"
+CHECKPOINT_STORE_PATH="${ROOT_DIR}${EXP_NAME}"
+mkdir -p "$CHECKPOINT_STORE_PATH"
+
+TENSORBOARD_LOGS_PATH="${TENSORBOARD_LOGS_PATH:-${MEGATRON_LM_PATH}/logs}"
+mkdir -p "$TENSORBOARD_LOGS_PATH"
+
+DISTRIBUTED_ARGS=(
+ --nproc_per_node "$GPUS_PER_NODE"
+ --nnodes "$NUM_NODES"
+)
+
+if [ "$NUM_NODES" -gt 1 ]; then
+ DISTRIBUTED_ARGS+=(
+ --master_addr "${MASTER_ADDR:-localhost}"
+ --master_port "${MASTER_PORT:-6000}"
+ )
+fi
+
+# --- Parallelism ---
+MODEL_PARALLEL_ARGS=(
+ --tensor-model-parallel-size "$TP"
+ --pipeline-model-parallel-size "$PP"
+ --expert-model-parallel-size "$EP"
+ --context-parallel-size "$CP"
+ --cp-comm-type "a2a"
+ --expert-tensor-parallel-size 1
+ --use-distributed-optimizer
+ --sequence-parallel
+)
+
+# --- Training ---
+TRAINING_ARGS=(
+ --micro-batch-size "$MBS"
+ --global-batch-size "$GBS"
+ --train-iters "${TRAIN_ITERS:-500}"
+ --adam-beta1 0.9
+ --adam-beta2 0.95
+ --lr 1.2e-4
+ --min-lr 1.2e-5
+ --lr-decay-style cosine
+ --lr-warmup-iters 100
+ --lr-decay-iters 2000
+ --weight-decay 0.1
+ --clip-grad 1.0
+ --bf16
+ --use-mcore-models
+ --transformer-impl transformer_engine
+ --cross-entropy-loss-fusion
+ --cross-entropy-fusion-impl te
+ --enable-experimental
+ --manual-gc
+ --manual-gc-interval 5
+ --mtp-num-layers 1
+ --mtp-loss-scaling-factor 0.1
+ --sft
+ --use-flash-attn
+ # --attention-backend flash
+ --calculate-per-token-loss
+)
+
+PROFILE_ARGS=()
+NSYS_CMD=()
+if [ "$PROFILE" = "1" ]; then
+ PROFILE_ARGS=(
+ --profile
+ --profile-step-start "$PROFILE_STEP_START"
+ --profile-step-end "$PROFILE_STEP_END"
+ --profile-ranks "$PROFILE_RANKS"
+ )
+
+ NSYS_OUTPUT_DIR="${CHECKPOINT_STORE_PATH}/nsys"
+ mkdir -p "$NSYS_OUTPUT_DIR"
+ NSYS_CMD=(
+ nsys profile
+ --sample=none
+ --cpuctxsw=none
+ --trace=cuda,nvtx,cublas,cudnn
+ --force-overwrite=true
+ --capture-range=cudaProfilerApi
+ --capture-range-end=stop
+ -o "${NSYS_OUTPUT_DIR}/${EXP_NAME}_$(date +%Y%m%d_%H%M%S)"
+ )
+fi
+
+# --- Logging & Checkpointing ---
+SAVE_INTERVAL=${SAVE_INTERVAL:-500}
+EVAL_AND_LOGGING_ARGS=(
+ --log-interval 1
+ --save-interval "$SAVE_INTERVAL"
+ --eval-interval 500
+ --save "$CHECKPOINT_STORE_PATH"
+ --eval-iters 10
+ --tensorboard-dir "$TENSORBOARD_LOGS_PATH"
+ --wandb-project "$WANDB_PROJECT"
+ --wandb-exp-name "$EXP_NAME"
+ --wandb-save-dir "$CHECKPOINT_STORE_PATH"
+ --log-throughput
+)
+
+# --- Tokenizer ---
+TOKENIZER_MODEL=${TOKENIZER_MODEL:-Qwen/Qwen3.5-397B-A17B}
+TOKENIZER_ARGS=(
+ --tokenizer-type HuggingFaceTokenizer
+ --tokenizer-model "$TOKENIZER_MODEL"
+)
+
+# --- Multimodal-specific ---
+MULTIMODAL_ARGS=(
+ --model-arch qwen35_vl
+ --model-variant "$MODEL_VARIANT"
+ --dataset-provider cord_v2
+ --hf-processor-path Qwen/Qwen3.5-397B-A17B
+ --use-vanilla-collate-fn
+ --image-token-id 248056
+ --image-size 224
+ --total-seq-length "$SEQ_LEN"
+ --image-seq-length 256
+ --vision-num-layers "$VISION_NUM_LAYERS"
+)
+
+if [ "$USE_PACKED_SEQUENCE" -eq 1 ]; then
+ MULTIMODAL_ARGS+=( --use-packed-sequence )
+fi
+
+# --- Qwen3.5 Decoder Architecture (variant-specific dims set above) ---
+# These must match examples/multimodal_dev/models/qwen35_vl/configuration.py
+GPT_MODEL_ARGS=(
+ --num-layers "$NUM_LAYERS"
+ --hidden-size "$HIDDEN_SIZE"
+ --ffn-hidden-size "$FFN_HIDDEN_SIZE"
+ --num-attention-heads "$NUM_ATTN_HEADS"
+ --group-query-attention
+ --num-query-groups "$NUM_QUERY_GROUPS"
+ --kv-channels 256
+ --max-position-embeddings 262144
+ --seq-length "$SEQ_LEN"
+ --normalization RMSNorm
+ --apply-layernorm-1p
+ --norm-epsilon 1e-06
+ --swiglu
+ --disable-bias-linear
+ --position-embedding-type rope
+ --rotary-percent 0.25
+ --rotary-base 10000000
+ --rotary-seq-len-interpolation-factor 1
+ --qk-layernorm
+ --attention-output-gate
+ --attention-dropout 0.0
+ --hidden-dropout 0.0
+ --experimental-attention-variant gated_delta_net
+ --linear-attention-freq 4
+ --linear-conv-kernel-dim 4
+ --linear-key-head-dim 128
+ --linear-value-head-dim 128
+ --linear-num-key-heads 16
+ --linear-num-value-heads "$LINEAR_NUM_VALUE_HEADS"
+ --make-vocab-size-divisible-by 485
+ --moe-router-force-load-balancing
+)
+
+# --- Tied / untied embeddings ---
+# 0.8B, 2B, 4B use tied embeddings; all other variants untie them.
+case "$MODEL_VARIANT" in
+ 0.8b|2b|4b) ;;
+ *) GPT_MODEL_ARGS+=( --untie-embeddings-and-output-weights ) ;;
+esac
+
+# --- MoE args (MoE variants only) ---
+MOE_ARGS=()
+case "$MODEL_VARIANT" in
+ proxy)
+ MOE_TOPK=2; MOE_FFN_HIDDEN=1024; MOE_SHARED_HIDDEN=1024
+ ;;
+ 35b_a3b|35b_a3b_light)
+ MOE_TOPK=8; MOE_FFN_HIDDEN=512; MOE_SHARED_HIDDEN=512
+ ;;
+ 122b_a10b)
+ MOE_TOPK=8; MOE_FFN_HIDDEN=1024; MOE_SHARED_HIDDEN=1024
+ ;;
+ 397b_a17b)
+ MOE_TOPK=10; MOE_FFN_HIDDEN=1024; MOE_SHARED_HIDDEN=1024
+ ;;
+ 0.8b|2b|4b|9b|27b)
+ ;;
+esac
+if [ "${NUM_EXPERTS:-0}" -gt 0 ]; then
+ MOE_ARGS=(
+ --num-experts "$NUM_EXPERTS"
+ --moe-ffn-hidden-size "$MOE_FFN_HIDDEN"
+ --moe-shared-expert-intermediate-size "$MOE_SHARED_HIDDEN"
+ --moe-shared-expert-gate
+ --moe-router-load-balancing-type aux_loss
+ --moe-router-topk "$MOE_TOPK"
+ --moe-grouped-gemm
+ --moe-aux-loss-coeff 1e-3
+ --moe-token-dispatcher-type alltoall
+ --moe-router-dtype fp32
+ )
+fi
+
+# --- Recompute ---
+if [ "$RECOMPUTE" -eq 1 ]; then
+ RECOMPUTE_ARGS=(
+ --recompute-granularity full
+ --recompute-method uniform
+ --recompute-num-layers 1
+ )
+ # RECOMPUTE_ARGS=(
+ # --recompute-granularity selective
+ # --recompute-modules moe_act shared_experts layernorm moe
+ # )
+else
+ RECOMPUTE_ARGS=()
+fi
+if [ "$RECOMPUTE_VISION" -eq 1 ]; then
+ RECOMPUTE_ARGS+=( --recompute-vision )
+fi
+
+# --- Checkpoint loading ---
+# CKPT_LOAD: path to checkpoint directory
+# CKPT_FORMAT: override checkpoint format (default: auto-detect)
+# CKPT_RESUME: set to 1 to resume training (keep iteration, optimizer, rng);
+# default 0 = finetune mode (reset iteration, skip optim/rng)
+CKPT_LOAD=${CKPT_LOAD:-}
+CKPT_FORMAT=${CKPT_FORMAT:-}
+CKPT_RESUME=${CKPT_RESUME:-0}
+CKPT_OVERRIDE_SCHEDULER=${CKPT_OVERRIDE_SCHEDULER:-0}
+CKPT_ARGS=()
+if [ -n "$CKPT_LOAD" ]; then
+ CKPT_ARGS+=( --load "$CKPT_LOAD" )
+ if [ "$CKPT_RESUME" -eq 0 ]; then
+ CKPT_ARGS+=( --finetune --no-load-optim --no-load-rng )
+ fi
+ if [ -n "$CKPT_FORMAT" ]; then
+ CKPT_ARGS+=( --ckpt-format "$CKPT_FORMAT" )
+ fi
+ if [ "$CKPT_OVERRIDE_SCHEDULER" -eq 1 ]; then
+ CKPT_ARGS+=( --override-opt-param-scheduler )
+ fi
+fi
+
+# --- FSDP ---
+USE_FSDP=${USE_FSDP:-1}
+if [ "$USE_FSDP" -eq 1 ]; then
+ FSDP_ARGS=(
+ --use-megatron-fsdp
+ --data-parallel-sharding-strategy optim_grads_params
+ --no-gradient-accumulation-fusion
+ --init-model-with-meta-device
+ --use-distributed-optimizer
+ --ckpt-format fsdp_dtensor
+ )
+ export CUDA_DEVICE_MAX_CONNECTIONS=8
+else
+ FSDP_ARGS=()
+fi
+
+echo "================================================================"
+echo "Qwen3.5-VL Multimodal Training (multimodal_dev)"
+echo " Variant: $MODEL_VARIANT"
+echo " Vision layers: $VISION_NUM_LAYERS"
+echo " GPUs per node: $GPUS_PER_NODE"
+echo " Num nodes: $NUM_NODES"
+echo " TP=$TP EP=$EP PP=$PP CP=$CP"
+echo " MBS=$MBS GBS=$GBS"
+echo " Launcher: $LAUNCHER"
+echo " FSDP: $USE_FSDP"
+echo " PROFILE: $PROFILE"
+if [ -n "$CKPT_LOAD" ]; then
+ echo " CKPT_LOAD: $CKPT_LOAD"
+ echo " CKPT_FORMAT: ${CKPT_FORMAT:-auto}"
+ echo " CKPT_RESUME: $CKPT_RESUME"
+fi
+if [ "$PROFILE" = "1" ]; then
+ echo " Profile steps: ${PROFILE_STEP_START}-${PROFILE_STEP_END}"
+ echo " Profile ranks: $PROFILE_RANKS"
+fi
+echo "================================================================"
+
+if [ "$LAUNCHER" = "python" ]; then
+ LAUNCH_CMD=( python $MEGATRON_LM_PATH/examples/multimodal_dev/pretrain_multimodal.py )
+elif [ "$LAUNCHER" = "torchrun" ]; then
+ LAUNCH_CMD=( torchrun "${DISTRIBUTED_ARGS[@]}" $MEGATRON_LM_PATH/examples/multimodal_dev/pretrain_multimodal.py )
+else
+ echo "Unsupported LAUNCHER=$LAUNCHER (expected torchrun or python)" >&2
+ exit 1
+fi
+
+cmd=( "${NSYS_CMD[@]}" "${LAUNCH_CMD[@]}" \
+ "${TRAINING_ARGS[@]}" \
+ "${PROFILE_ARGS[@]}" \
+ "${MODEL_PARALLEL_ARGS[@]}" \
+ "${EVAL_AND_LOGGING_ARGS[@]}" \
+ "${TOKENIZER_ARGS[@]}" \
+ "${MULTIMODAL_ARGS[@]}" \
+ "${GPT_MODEL_ARGS[@]}" \
+ "${MOE_ARGS[@]}" \
+ "${RECOMPUTE_ARGS[@]}" \
+ "${FSDP_ARGS[@]}" \
+ "${CKPT_ARGS[@]}" )
+
+echo "${cmd[@]}"
+
+if [ "$DRY_RUN" -eq 1 ]; then
+ echo "=== DRY RUN ==="
+ exit 0
+else
+ "${cmd[@]}"
+fi
diff --git a/examples/multimodal_dev/tests/__init__.py b/examples/multimodal_dev/tests/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/examples/multimodal_dev/tests/test_cp_correctness.py b/examples/multimodal_dev/tests/test_cp_correctness.py
new file mode 100644
index 00000000000..8ee71f25eca
--- /dev/null
+++ b/examples/multimodal_dev/tests/test_cp_correctness.py
@@ -0,0 +1,313 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""Distributed correctness test for Context Parallelism (CP) support.
+
+Verifies that CP>1 produces the same (or numerically close) loss as CP=1
+for the Qwen3.5-VL multimodal model by running forward passes with
+deterministic data and comparing the per-rank reduced losses.
+
+Launch with torchrun (N must be >= 2*max_cp_size for zigzag splitting):
+
+ # Test CP=2 on 2 GPUs:
+ torchrun --nproc_per_node=2 examples/multimodal_dev/tests/test_cp_correctness.py --cp-size 2
+
+ # Test CP=4 on 4 GPUs:
+ torchrun --nproc_per_node=4 examples/multimodal_dev/tests/test_cp_correctness.py --cp-size 4
+
+The test:
+ 1. Builds a tiny proxy model (2 layers, no MoE, no vision encoder).
+ 2. Generates a deterministic batch (same seed on all ranks).
+ 3. Runs forward with CP=1 (each rank processes the full sequence independently).
+ 4. Re-initialises model-parallel groups with the target CP size.
+ 5. Runs forward with CP=target (sequence is split across ranks).
+ 6. Compares the all-reduced loss values.
+
+Exit code 0 = PASS, 1 = FAIL.
+"""
+
+import argparse
+import os
+import sys
+
+import torch
+import torch.distributed as dist
+
+# Ensure the repo root is on the path so that megatron and examples are importable.
+_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
+if _REPO_ROOT not in sys.path:
+ sys.path.insert(0, _REPO_ROOT)
+
+
+def _parse_args():
+ parser = argparse.ArgumentParser(description="CP correctness test")
+ parser.add_argument(
+ "--cp-size", type=int, default=2,
+ help="Target context-parallel size to compare against CP=1 baseline",
+ )
+ parser.add_argument(
+ "--seq-len", type=int, default=128,
+ help="Sequence length (must be divisible by 2*max(cp_size, tp_size*cp_size))",
+ )
+ parser.add_argument(
+ "--atol", type=float, default=1e-4,
+ help="Absolute tolerance for loss comparison",
+ )
+ parser.add_argument(
+ "--rtol", type=float, default=5e-2,
+ help="Relative tolerance for loss comparison (default 5%%)",
+ )
+ parser.add_argument(
+ "--seed", type=int, default=42,
+ help="Random seed for reproducibility",
+ )
+ # Megatron adds extra args; ignore them.
+ args, _ = parser.parse_known_args()
+ return args
+
+
+def _init_distributed():
+ """Initialise torch.distributed if not already done."""
+ if not dist.is_initialized():
+ dist.init_process_group(backend="nccl")
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
+ torch.cuda.set_device(local_rank)
+ return local_rank
+
+
+def _init_megatron_parallel(tp_size=1, pp_size=1, cp_size=1, seed=42):
+ """(Re-)initialise Megatron model-parallel groups and RNG tracker."""
+ from megatron.core import parallel_state as ps
+ from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
+ ps.destroy_model_parallel()
+ ps.initialize_model_parallel(
+ tensor_model_parallel_size=tp_size,
+ pipeline_model_parallel_size=pp_size,
+ context_parallel_size=cp_size,
+ )
+ model_parallel_cuda_manual_seed(seed)
+
+
+def _make_deterministic_batch(seed, batch_size, seq_len, vocab_size, device):
+ """Create a deterministic batch identical on all ranks."""
+ rng = torch.Generator(device="cpu")
+ rng.manual_seed(seed)
+
+ input_ids = torch.randint(
+ 0, vocab_size, (batch_size, seq_len), generator=rng,
+ ).to(device)
+ labels = torch.randint(
+ 0, vocab_size, (batch_size, seq_len), generator=rng,
+ ).to(device)
+ loss_mask = torch.ones(batch_size, seq_len, device=device)
+ # Standard position_ids [B, S]
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
+
+ return {
+ "input_ids": input_ids,
+ "labels": labels,
+ "loss_mask": loss_mask,
+ "position_ids": position_ids,
+ }
+
+
+def _build_tiny_model(cp_size, device):
+ """Build a minimal GPTModel for testing (no vision, no MoE)."""
+ from megatron.core.models.gpt import GPTModel
+ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
+ from megatron.core.transformer.spec_utils import ModuleSpec
+ from megatron.core.transformer.transformer_config import TransformerConfig
+
+ hidden_size = 256
+ num_heads = 4
+ config = TransformerConfig(
+ num_layers=2,
+ hidden_size=hidden_size,
+ ffn_hidden_size=hidden_size * 4,
+ num_attention_heads=num_heads,
+ kv_channels=hidden_size // num_heads,
+ normalization="RMSNorm",
+ layernorm_epsilon=1e-6,
+ gated_linear_unit=True,
+ activation_func=torch.nn.functional.silu,
+ bf16=True,
+ context_parallel_size=cp_size,
+ add_bias_linear=False,
+ attention_dropout=0.0,
+ hidden_dropout=0.0,
+ sequence_parallel=False,
+ )
+
+ spec = get_gpt_layer_with_transformer_engine_spec()
+
+ model = GPTModel(
+ config=config,
+ transformer_layer_spec=spec,
+ vocab_size=1024,
+ max_sequence_length=4096,
+ pre_process=True,
+ post_process=True,
+ parallel_output=False,
+ share_embeddings_and_output_weights=True,
+ position_embedding_type="rope",
+ rotary_percent=1.0,
+ rotary_base=10000,
+ )
+ model = model.to(device=device, dtype=torch.bfloat16)
+ return model, config
+
+
+def _forward_with_cp(model, batch, cp_size):
+ """Run forward pass, handling CP splitting of the batch.
+
+ When cp_size > 1, splits the batch tensors using the same zigzag
+ logic as multimodal_dev/models/base.py.
+ """
+ from examples.multimodal_dev.models.base import _cp_split_tensor
+ from megatron.core import parallel_state as ps
+
+ input_ids = batch["input_ids"].clone()
+ labels = batch["labels"].clone()
+ loss_mask = batch["loss_mask"].clone()
+ position_ids = batch["position_ids"].clone()
+
+ if cp_size > 1:
+ cp_rank = ps.get_context_parallel_rank()
+ input_ids = _cp_split_tensor(input_ids, seq_dim=1, cp_size=cp_size, cp_rank=cp_rank)
+ labels = _cp_split_tensor(labels, seq_dim=1, cp_size=cp_size, cp_rank=cp_rank)
+ loss_mask = _cp_split_tensor(loss_mask, seq_dim=1, cp_size=cp_size, cp_rank=cp_rank)
+ # position_ids are NOT split — the RoPE layer handles CP slicing internally.
+
+ with torch.no_grad():
+ output = model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ labels=labels,
+ attention_mask=None,
+ )
+
+ # output is the per-token loss [B, S/CP]
+ masked_loss = (output.float() * loss_mask.float()).sum()
+ num_tokens = loss_mask.sum()
+
+ # All-reduce across CP ranks to get global loss
+ if cp_size > 1:
+ cp_group = ps.get_context_parallel_group()
+ dist.all_reduce(masked_loss, group=cp_group)
+ dist.all_reduce(num_tokens, group=cp_group)
+
+ avg_loss = masked_loss / num_tokens.clamp(min=1)
+ return avg_loss.item()
+
+
+def main():
+ args = _parse_args()
+ local_rank = _init_distributed()
+ device = torch.device(f"cuda:{local_rank}")
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+
+ target_cp = args.cp_size
+ if world_size < target_cp:
+ if rank == 0:
+ print(
+ f"SKIP: world_size={world_size} < cp_size={target_cp}. "
+ f"Need at least {target_cp} GPUs.",
+ flush=True,
+ )
+ dist.destroy_process_group()
+ sys.exit(0)
+ if world_size % target_cp != 0:
+ if rank == 0:
+ print(
+ f"SKIP: world_size={world_size} is not divisible by cp_size={target_cp}.",
+ flush=True,
+ )
+ dist.destroy_process_group()
+ sys.exit(0)
+
+ vocab_size = 1024
+
+ # Ensure seq_len is divisible by 2 * target_cp
+ seq_len = args.seq_len
+ align = 2 * target_cp
+ if seq_len % align != 0:
+ seq_len = ((seq_len + align - 1) // align) * align
+ if rank == 0:
+ print(f"Adjusted seq_len to {seq_len} for alignment with CP={target_cp}", flush=True)
+
+ # --- Step 1: CP=1 baseline ---
+ if rank == 0:
+ print(f"=== CP=1 baseline (world_size={world_size}) ===", flush=True)
+
+ _init_megatron_parallel(cp_size=1)
+
+ # Set deterministic seed for model init
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+ model_cp1, _ = _build_tiny_model(cp_size=1, device=device)
+
+ batch = _make_deterministic_batch(
+ seed=args.seed + 1, batch_size=1, seq_len=seq_len,
+ vocab_size=vocab_size, device=device,
+ )
+
+ loss_cp1 = _forward_with_cp(model_cp1, batch, cp_size=1)
+
+ if rank == 0:
+ print(f" CP=1 loss: {loss_cp1:.6f}", flush=True)
+
+ # Save model state for reuse
+ state_dict = model_cp1.state_dict()
+ del model_cp1
+ torch.cuda.empty_cache()
+
+ # --- Step 2: CP=target ---
+ if rank == 0:
+ print(f"=== CP={target_cp} (world_size={world_size}) ===", flush=True)
+
+ _init_megatron_parallel(cp_size=target_cp)
+
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+ model_cpN, _ = _build_tiny_model(cp_size=target_cp, device=device)
+
+ # Load the same weights to ensure identical model
+ model_cpN.load_state_dict(state_dict, strict=True)
+ del state_dict
+
+ loss_cpN = _forward_with_cp(model_cpN, batch, cp_size=target_cp)
+
+ if rank == 0:
+ print(f" CP={target_cp} loss: {loss_cpN:.6f}", flush=True)
+
+ del model_cpN
+ torch.cuda.empty_cache()
+
+ # --- Step 3: Compare ---
+ if rank == 0:
+ diff = abs(loss_cpN - loss_cp1)
+ rel_diff = diff / max(abs(loss_cp1), 1e-10)
+
+ print(f"\n=== Comparison ===", flush=True)
+ print(f" CP=1 loss: {loss_cp1:.6f}", flush=True)
+ print(f" CP={target_cp} loss: {loss_cpN:.6f}", flush=True)
+ print(f" Absolute diff: {diff:.6e}", flush=True)
+ print(f" Relative diff: {rel_diff:.6e}", flush=True)
+ print(f" Tolerance (atol): {args.atol:.6e}", flush=True)
+ print(f" Tolerance (rtol): {args.rtol:.6e}", flush=True)
+
+ passed = diff <= args.atol + args.rtol * abs(loss_cp1)
+ if passed:
+ print(f"\nPASS: CP={target_cp} matches CP=1 baseline", flush=True)
+ else:
+ print(f"\nFAIL: CP={target_cp} loss differs from CP=1 beyond tolerance", flush=True)
+
+ dist.barrier()
+ dist.destroy_process_group()
+
+ if rank == 0 and not passed:
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/multimodal_dev/tests/test_cp_support.py b/examples/multimodal_dev/tests/test_cp_support.py
new file mode 100644
index 00000000000..a6f6b9b8d2c
--- /dev/null
+++ b/examples/multimodal_dev/tests/test_cp_support.py
@@ -0,0 +1,347 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""Unit tests for Context Parallelism (CP) support in multimodal_dev.
+
+Tests cover:
+ 1. _cp_split_tensor — zigzag split correctness, reconstruction, and edge cases
+ 2. _NoCPGroup — dummy process group behaviour
+ 3. _thd_cp_partition_index — TE-based per-sample THD CP partitioning
+ 4. Cross-validation against megatron.core.utils.get_batch_on_this_cp_rank
+
+Run with: pytest examples/multimodal_dev/tests/test_cp_support.py -v
+"""
+
+import pytest
+import torch
+
+from examples.multimodal_dev.models.base import _cp_split_tensor, _NoCPGroup
+
+
+class TestCpSplitTensor:
+ """Tests for zigzag CP splitting."""
+
+ def test_basic_2d_cp2(self):
+ """[B, S] tensor with CP=2 splits and reconstructs correctly."""
+ B, S = 2, 16
+ t = torch.arange(B * S).reshape(B, S)
+ cp_size = 2
+
+ chunks = []
+ for rank in range(cp_size):
+ chunks.append(_cp_split_tensor(t, seq_dim=1, cp_size=cp_size, cp_rank=rank))
+
+ # Each rank gets S / CP = 8 tokens
+ for c in chunks:
+ assert c.shape == (B, S // cp_size)
+
+ # Reconstruct: rank 0 gets chunks [0, 3], rank 1 gets chunks [1, 2]
+ # Original split into 4 chunks of size 4:
+ # chunk0=[0..3], chunk1=[4..7], chunk2=[8..11], chunk3=[12..15]
+ # rank0 = [chunk0, chunk3] = [0..3, 12..15]
+ # rank1 = [chunk1, chunk2] = [4..7, 8..11]
+ assert torch.equal(chunks[0][0], torch.tensor([0, 1, 2, 3, 12, 13, 14, 15]))
+ assert torch.equal(chunks[1][0], torch.tensor([4, 5, 6, 7, 8, 9, 10, 11]))
+
+ def test_3d_mrope_cp2(self):
+ """[3, B, S] MRoPE tensor with CP=2."""
+ B, S = 1, 8
+ cp_size = 2
+ t = torch.arange(3 * B * S).reshape(3, B, S)
+
+ chunk = _cp_split_tensor(t, seq_dim=2, cp_size=cp_size, cp_rank=0)
+ assert chunk.shape == (3, B, S // cp_size)
+
+ # All 3 MRoPE components should be split consistently
+ for d in range(3):
+ original_row = t[d, 0] # [S]
+ # With S=8, CP=2: 4 chunks of size 2
+ # rank0 gets chunks [0, 3] = positions [0,1, 6,7]
+ expected = torch.cat([original_row[0:2], original_row[6:8]])
+ assert torch.equal(chunk[d, 0], expected)
+
+ def test_sbh_decoder_input(self):
+ """[S, B, H] decoder input split along dim=0."""
+ S, B, H = 16, 2, 4
+ cp_size = 2
+ t = torch.randn(S, B, H)
+
+ chunk = _cp_split_tensor(t, seq_dim=0, cp_size=cp_size, cp_rank=0)
+ assert chunk.shape == (S // cp_size, B, H)
+
+ def test_cp4(self):
+ """CP=4 zigzag pattern."""
+ S = 32
+ cp_size = 4
+ t = torch.arange(S).unsqueeze(0) # [1, 32]
+
+ all_chunks = []
+ for rank in range(cp_size):
+ c = _cp_split_tensor(t, seq_dim=1, cp_size=cp_size, cp_rank=rank)
+ all_chunks.append(c)
+ assert c.shape == (1, S // cp_size)
+
+ # All tokens should appear exactly once across ranks
+ combined = torch.cat(all_chunks, dim=1)
+ assert torch.equal(combined.sort(dim=1).values, t.sort(dim=1).values)
+
+ def test_cp8(self):
+ """CP=8 zigzag pattern — all tokens appear exactly once."""
+ S = 64
+ cp_size = 8
+ t = torch.arange(S).unsqueeze(0) # [1, 64]
+
+ all_chunks = []
+ for rank in range(cp_size):
+ c = _cp_split_tensor(t, seq_dim=1, cp_size=cp_size, cp_rank=rank)
+ all_chunks.append(c)
+ assert c.shape == (1, S // cp_size)
+
+ combined = torch.cat(all_chunks, dim=1)
+ assert torch.equal(combined.sort(dim=1).values, t.sort(dim=1).values)
+
+ def test_not_divisible_raises(self):
+ """Should raise when seq_len not divisible by 2*cp_size."""
+ t = torch.randn(2, 10) # S=10, not divisible by 4
+ with pytest.raises(AssertionError):
+ _cp_split_tensor(t, seq_dim=1, cp_size=2, cp_rank=0)
+
+ def test_zigzag_symmetry(self):
+ """rank 0 and rank (cp_size-1) should get mirror chunks."""
+ S = 16
+ cp_size = 2
+ t = torch.arange(S).unsqueeze(0) # [1, 16]
+
+ c0 = _cp_split_tensor(t, seq_dim=1, cp_size=cp_size, cp_rank=0)
+ c1 = _cp_split_tensor(t, seq_dim=1, cp_size=cp_size, cp_rank=1)
+
+ # rank0 gets chunks [0, 3], rank1 gets chunks [1, 2]
+ # chunk0=[0..3], chunk3=[12..15] -> rank0 gets [0..3, 12..15]
+ # chunk1=[4..7], chunk2=[8..11] -> rank1 gets [4..7, 8..11]
+ # rank0's first half is earliest, rank1's first half is next
+ assert c0[0, 0].item() < c1[0, 0].item() # rank0 starts earlier
+
+ def test_matches_megatron_core(self):
+ """Cross-validate against megatron.core.utils.get_batch_on_this_cp_rank logic.
+
+ We simulate the core function's logic (seq_dim=1, attention_mask seq_dim=2)
+ and compare.
+ """
+ B, S = 2, 32
+ cp_size = 4
+
+ input_ids = torch.arange(B * S).reshape(B, S)
+ labels = torch.arange(B * S).reshape(B, S) + 1000
+
+ for cp_rank in range(cp_size):
+ # Our implementation
+ our_ids = _cp_split_tensor(input_ids, seq_dim=1, cp_size=cp_size, cp_rank=cp_rank)
+ our_labels = _cp_split_tensor(labels, seq_dim=1, cp_size=cp_size, cp_rank=cp_rank)
+
+ # Simulate megatron core logic inline
+ def core_split(val, seq_dim):
+ val = val.view(
+ *val.shape[0:seq_dim],
+ 2 * cp_size,
+ val.shape[seq_dim] // (2 * cp_size),
+ *val.shape[(seq_dim + 1):],
+ )
+ index = torch.zeros(2, dtype=torch.int64, device=val.device)
+ index[0].fill_(cp_rank)
+ index[1].fill_(2 * cp_size - cp_rank - 1)
+ val = val.index_select(seq_dim, index)
+ val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2):])
+ return val
+
+ ref_ids = core_split(input_ids.clone(), seq_dim=1)
+ ref_labels = core_split(labels.clone(), seq_dim=1)
+
+ assert torch.equal(our_ids, ref_ids), f"input_ids mismatch at rank {cp_rank}"
+ assert torch.equal(our_labels, ref_labels), f"labels mismatch at rank {cp_rank}"
+
+ def test_batch_dim_preserved(self):
+ """Batch dimension must be unchanged after split."""
+ B, S = 4, 32
+ cp_size = 4
+ t = torch.randn(B, S)
+
+ for rank in range(cp_size):
+ c = _cp_split_tensor(t, seq_dim=1, cp_size=cp_size, cp_rank=rank)
+ assert c.shape[0] == B
+
+
+class TestNoCPGroup:
+ """Tests for the dummy CP group used by the vision encoder."""
+
+ def test_size_is_one(self):
+ g = _NoCPGroup()
+ assert g.size() == 1
+
+ def test_rank_is_zero(self):
+ g = _NoCPGroup()
+ assert g.rank() == 0
+
+
+try:
+ from transformer_engine.pytorch import cpp_extensions as _tex # noqa: F401
+
+ _HAS_TE = True
+except Exception:
+ _HAS_TE = False
+
+
+@pytest.mark.skipif(not _HAS_TE, reason="TransformerEngine not installed")
+class TestThdCpPartition:
+ """Verify TE-based per-sample THD + CP partition matches THD semantics.
+
+ Each packed sub-sample of length ``s_i`` (where ``s_i % (2*cp_size) == 0``)
+ is split into ``2*cp_size`` zigzag chunks per sample; rank ``r`` gets
+ chunks ``[r, 2*cp_size - r - 1]`` of every sample. The union across
+ ranks must cover every token position exactly once.
+ """
+
+ @staticmethod
+ def _make_padded_packed(seqlens, divisor):
+ """Concatenate per-sample dummy tokens after padding each sample to a
+ multiple of *divisor*. Returns ``(input_ids[1, T], cu_seqlens_padded)``.
+ """
+ import math
+ padded = [math.ceil(s / divisor) * divisor for s in seqlens]
+ chunks = []
+ next_id = 1
+ for s, p in zip(seqlens, padded):
+ chunks.append(torch.arange(next_id, next_id + s, dtype=torch.int64))
+ chunks.append(torch.zeros(p - s, dtype=torch.int64)) # padding
+ next_id += s
+ input_ids = torch.cat(chunks, dim=0).unsqueeze(0) # [1, T]
+ cu_seqlens_padded = torch.tensor(
+ [0] + list(torch.tensor(padded).cumsum(0).tolist()),
+ dtype=torch.int32,
+ )
+ return input_ids, cu_seqlens_padded
+
+ def _ensure_cuda(self, x):
+ return x.cuda() if torch.cuda.is_available() else x
+
+ def test_partition_covers_all_positions_cp2(self):
+ from examples.multimodal_dev.models.base import _thd_cp_partition_index
+
+ cp_size = 2
+ seqlens = [5, 7, 3] # valid lengths
+ input_ids, cu_seqlens_padded = self._make_padded_packed(
+ seqlens, divisor=2 * cp_size,
+ )
+ input_ids = self._ensure_cuda(input_ids)
+ cu_seqlens_padded = self._ensure_cuda(cu_seqlens_padded)
+ T = input_ids.shape[1]
+
+ # Union of per-rank indices must be all positions exactly once.
+ seen = torch.zeros(T, dtype=torch.long, device=input_ids.device)
+ for cp_rank in range(cp_size):
+ idx = _thd_cp_partition_index(
+ cu_seqlens_padded, T, cp_size, cp_rank,
+ )
+ assert idx.numel() == T // cp_size, (
+ f"rank {cp_rank}: expected {T // cp_size} tokens, got {idx.numel()}"
+ )
+ seen.scatter_add_(
+ 0, idx.long(), torch.ones_like(idx, dtype=seen.dtype),
+ )
+ assert torch.all(seen == 1), (
+ f"Position coverage broken: counts={seen.tolist()}"
+ )
+
+ def test_index_select_aligns_inputs_and_position_ids_cp2(self):
+ """input_ids, loss_mask, and (3, 1, T) position_ids index_select with
+ the same partition index produce shape-consistent per-rank tensors."""
+ from examples.multimodal_dev.models.base import _thd_cp_partition_index
+
+ cp_size = 2
+ seqlens = [8, 4]
+ input_ids, cu_seqlens_padded = self._make_padded_packed(
+ seqlens, divisor=2 * cp_size,
+ )
+ input_ids = self._ensure_cuda(input_ids)
+ cu_seqlens_padded = self._ensure_cuda(cu_seqlens_padded)
+ T = input_ids.shape[1]
+ labels = input_ids + 1000
+ loss_mask = (input_ids != 0).float()
+ position_ids = (
+ torch.arange(T, device=input_ids.device)
+ .unsqueeze(0).unsqueeze(0).expand(3, 1, T).contiguous()
+ )
+ H = 4
+ decoder_input = (
+ torch.arange(T * H, dtype=torch.float32, device=input_ids.device)
+ .view(T, 1, H)
+ )
+
+ for cp_rank in range(cp_size):
+ idx = _thd_cp_partition_index(
+ cu_seqlens_padded, T, cp_size, cp_rank,
+ )
+ ii = input_ids.index_select(1, idx)
+ ll = labels.index_select(1, idx)
+ lm = loss_mask.index_select(1, idx)
+ pi = position_ids.index_select(2, idx)
+ di = decoder_input.index_select(0, idx)
+
+ assert ii.shape == (1, T // cp_size)
+ assert ll.shape == (1, T // cp_size)
+ assert lm.shape == (1, T // cp_size)
+ assert pi.shape == (3, 1, T // cp_size)
+ assert di.shape == (T // cp_size, 1, H)
+ # Sliced position_ids is just the partition index itself
+ # (since position_ids was arange(T) over all positions).
+ assert torch.equal(pi[0, 0], idx.to(pi.dtype))
+ # All MRoPE rows agree.
+ assert torch.equal(pi[1, 0], pi[0, 0])
+ assert torch.equal(pi[2, 0], pi[0, 0])
+
+ def test_partition_cp4_three_samples(self):
+ from examples.multimodal_dev.models.base import _thd_cp_partition_index
+
+ cp_size = 4
+ seqlens = [12, 4, 8]
+ input_ids, cu_seqlens_padded = self._make_padded_packed(
+ seqlens, divisor=2 * cp_size,
+ )
+ input_ids = self._ensure_cuda(input_ids)
+ cu_seqlens_padded = self._ensure_cuda(cu_seqlens_padded)
+ T = input_ids.shape[1]
+
+ seen = torch.zeros(T, dtype=torch.long, device=input_ids.device)
+ for cp_rank in range(cp_size):
+ idx = _thd_cp_partition_index(
+ cu_seqlens_padded, T, cp_size, cp_rank,
+ )
+ assert idx.numel() == T // cp_size
+ seen.scatter_add_(
+ 0, idx.long(), torch.ones_like(idx, dtype=seen.dtype),
+ )
+ assert torch.all(seen == 1)
+
+ def test_loss_mask_zero_kept_per_rank(self):
+ """Pad-token positions (loss_mask=0) survive as 0 on whichever rank
+ they land — sanity check that we don't accidentally discard them."""
+ from examples.multimodal_dev.models.base import _thd_cp_partition_index
+
+ cp_size = 2
+ seqlens = [5, 3]
+ input_ids, cu_seqlens_padded = self._make_padded_packed(
+ seqlens, divisor=2 * cp_size,
+ )
+ input_ids = self._ensure_cuda(input_ids)
+ cu_seqlens_padded = self._ensure_cuda(cu_seqlens_padded)
+ T = input_ids.shape[1]
+ loss_mask = (input_ids != 0).float()
+ total_zeros = (loss_mask == 0).sum().item()
+
+ zeros_seen = 0
+ for cp_rank in range(cp_size):
+ idx = _thd_cp_partition_index(
+ cu_seqlens_padded, T, cp_size, cp_rank,
+ )
+ zeros_seen += (
+ loss_mask.index_select(1, idx) == 0
+ ).sum().item()
+ assert zeros_seen == total_zeros
diff --git a/examples/multimodal_dev/tests/test_mrope_parity.py b/examples/multimodal_dev/tests/test_mrope_parity.py
new file mode 100644
index 00000000000..17cedb0536f
--- /dev/null
+++ b/examples/multimodal_dev/tests/test_mrope_parity.py
@@ -0,0 +1,663 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""Parity tests for ``get_rope_index`` (MRoPE position-ID computation).
+
+Two properties are verified:
+
+1. **BSHD backwards compatibility** — the refactored ``get_rope_index``
+ returns bit-identical ``(position_ids, mrope_position_deltas)`` to
+ the pre-refactor implementation on padded ``[B, S]`` batches.
+2. **THD == BSHD on the valid region** — when the same variable-length
+ samples are fed through both layouts (BSHD with right-padding; THD
+ packed with ``cu_seqlens_q`` / ``cu_seqlens_q_padded``), positions at
+ every valid slot agree.
+
+The pre-refactor function is pinned inline as ``_old_get_rope_index``
+so this test stays self-contained. Run with::
+
+ python -m pytest examples/multimodal_dev/tests/test_mrope_parity.py -v
+
+or directly::
+
+ python examples/multimodal_dev/tests/test_mrope_parity.py
+"""
+
+import math
+import os
+import sys
+from itertools import accumulate
+
+import torch
+import torch.nn.functional as F
+
+_REPO_ROOT = os.path.abspath(
+ os.path.join(os.path.dirname(__file__), "../../.."),
+)
+# Insert at position 0 unconditionally — other entries on sys.path
+# (e.g. a sibling Megatron-LM checkout) have their own ``examples``
+# package that would otherwise shadow ours.
+if _REPO_ROOT in sys.path:
+ sys.path.remove(_REPO_ROOT)
+sys.path.insert(0, _REPO_ROOT)
+
+from megatron.core.packed_seq_params import PackedSeqParams
+
+from examples.multimodal_dev.models.qwen35_vl.mrope import get_rope_index
+
+# -----------------------------------------------------------------------------
+# Token-ID constants (match Qwen3.5-VL, but values are arbitrary for this test)
+# -----------------------------------------------------------------------------
+
+IMAGE_TOKEN_ID = 248056
+VIDEO_TOKEN_ID = 248057
+VISION_START_TOKEN_ID = 248053
+SPATIAL_MERGE_SIZE = 2
+
+
+# -----------------------------------------------------------------------------
+# Pinned reference implementation (pre-refactor BSHD path)
+# -----------------------------------------------------------------------------
+
+def _old_get_rope_index(
+ spatial_merge_size,
+ image_token_id,
+ video_token_id,
+ vision_start_token_id,
+ input_ids=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ attention_mask=None,
+):
+ """Pre-refactor BSHD implementation of ``get_rope_index``.
+
+ Copied verbatim (modulo the broken cu_seqlens branch, which this
+ parity test does not exercise) so we can diff against the new
+ implementation on BSHD inputs.
+ """
+ if video_grid_thw is not None:
+ video_grid_thw = torch.repeat_interleave(
+ video_grid_thw, video_grid_thw[:, 0], dim=0,
+ )
+ video_grid_thw[:, 0] = 1
+
+ mrope_position_deltas = []
+
+ if input_ids is not None and (
+ image_grid_thw is not None or video_grid_thw is not None
+ ):
+ total_input_ids = input_ids
+ if attention_mask is None:
+ attention_mask = torch.ones_like(total_input_ids)
+
+ position_ids = torch.ones(
+ 3,
+ input_ids.shape[0],
+ input_ids.shape[1],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ image_index, video_index = 0, 0
+ attention_mask = attention_mask.to(total_input_ids.device)
+
+ for i, sample_input_ids in enumerate(total_input_ids):
+ sample_input_ids = sample_input_ids[attention_mask[i] == 1]
+ vision_start_indices = torch.argwhere(
+ sample_input_ids == vision_start_token_id,
+ ).squeeze(1)
+ vision_tokens = sample_input_ids[vision_start_indices + 1]
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = (vision_tokens == video_token_id).sum()
+ input_tokens = sample_input_ids.tolist()
+ llm_pos_ids_list = []
+ st = 0
+ remain_images, remain_videos = image_nums, video_nums
+
+ for _ in range(image_nums + video_nums):
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+
+ if ed_image < ed_video:
+ t, h, w = (
+ image_grid_thw[image_index][0],
+ image_grid_thw[image_index][1],
+ image_grid_thw[image_index][2],
+ )
+ image_index += 1
+ remain_images -= 1
+ ed = ed_image
+ else:
+ t, h, w = (
+ video_grid_thw[video_index][0],
+ video_grid_thw[video_index][1],
+ video_grid_thw[video_index][2],
+ )
+ video_index += 1
+ remain_videos -= 1
+ ed = ed_video
+
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t.item(),
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+ text_len = ed - st
+
+ st_idx = (
+ llm_pos_ids_list[-1].max() + 1
+ if llm_pos_ids_list
+ else 0
+ )
+ llm_pos_ids_list.append(
+ torch.arange(text_len).view(1, -1).expand(3, -1)
+ + st_idx
+ )
+
+ t_index = (
+ torch.arange(llm_grid_t)
+ .view(-1, 1)
+ .expand(-1, llm_grid_h * llm_grid_w)
+ .flatten()
+ )
+ h_index = (
+ torch.arange(llm_grid_h)
+ .view(1, -1, 1)
+ .expand(llm_grid_t, -1, llm_grid_w)
+ .flatten()
+ )
+ w_index = (
+ torch.arange(llm_grid_w)
+ .view(1, 1, -1)
+ .expand(llm_grid_t, llm_grid_h, -1)
+ .flatten()
+ )
+ llm_pos_ids_list.append(
+ torch.stack([t_index, h_index, w_index])
+ + text_len
+ + st_idx
+ )
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+
+ if st < len(input_tokens):
+ st_idx = (
+ llm_pos_ids_list[-1].max() + 1
+ if llm_pos_ids_list
+ else 0
+ )
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(
+ torch.arange(text_len).view(1, -1).expand(3, -1)
+ + st_idx
+ )
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ position_ids[
+ ..., i, attention_mask[i] == 1
+ ] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(
+ llm_positions.max() + 1 - len(total_input_ids[i]),
+ )
+
+ mrope_position_deltas = torch.tensor(
+ mrope_position_deltas, device=total_input_ids.device,
+ ).unsqueeze(1)
+ return position_ids, mrope_position_deltas
+
+ # Text-only fallback.
+ if attention_mask is not None:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = (
+ position_ids.unsqueeze(0)
+ .expand(3, -1, -1)
+ .to(attention_mask.device)
+ )
+ max_position_ids = (
+ position_ids.max(0, keepdim=False)[0]
+ .max(-1, keepdim=True)[0]
+ )
+ mrope_position_deltas = (
+ max_position_ids + 1 - attention_mask.shape[-1]
+ )
+ else:
+ position_ids = (
+ torch.arange(
+ input_ids.shape[1], device=input_ids.device,
+ )
+ .view(1, 1, -1)
+ .expand(3, input_ids.shape[0], -1)
+ )
+ mrope_position_deltas = torch.zeros(
+ [input_ids.shape[0], 1],
+ device=input_ids.device,
+ dtype=input_ids.dtype,
+ )
+ return position_ids, mrope_position_deltas
+
+
+# -----------------------------------------------------------------------------
+# Synthetic-sample builder
+# -----------------------------------------------------------------------------
+
+def _build_sample(
+ prefix_text_len,
+ grids,
+ suffix_text_len,
+ text_token=100,
+):
+ """Build one variable-length sample with ``len(grids)`` images.
+
+ Layout per image: ``vision_start_id`` then
+ ``llm_grid_t * llm_grid_h * llm_grid_w`` ``image_token_id`` slots
+ (where ``llm_grid_* = grid_* // spatial_merge_size`` for h/w).
+ Grids use ``t=1``.
+
+ Returns ``(input_ids [L], image_grid_thw [N, 3])``.
+ """
+ tokens = [text_token] * prefix_text_len
+ grid_rows = []
+ for t, h, w in grids:
+ n_image_tokens = (
+ t * (h // SPATIAL_MERGE_SIZE) * (w // SPATIAL_MERGE_SIZE)
+ )
+ tokens.append(VISION_START_TOKEN_ID)
+ tokens.extend([IMAGE_TOKEN_ID] * n_image_tokens)
+ grid_rows.append([t, h, w])
+ tokens.extend([text_token + 1] * suffix_text_len)
+ input_ids = torch.tensor(tokens, dtype=torch.int64)
+ image_grid_thw = torch.tensor(grid_rows, dtype=torch.int64)
+ return input_ids, image_grid_thw
+
+
+def _sample_bank():
+ """A small bank of samples covering text-only, single-image, multi-image."""
+ return [
+ _build_sample(
+ prefix_text_len=5,
+ grids=[(1, 4, 4)],
+ suffix_text_len=7,
+ ),
+ _build_sample(
+ prefix_text_len=3,
+ grids=[(1, 2, 2), (1, 4, 6)],
+ suffix_text_len=4,
+ ),
+ _build_sample(
+ prefix_text_len=10,
+ grids=[],
+ suffix_text_len=0,
+ ),
+ _build_sample(
+ prefix_text_len=0,
+ grids=[(1, 6, 4)],
+ suffix_text_len=2,
+ ),
+ ]
+
+
+# -----------------------------------------------------------------------------
+# Test 1: BSHD backwards compatibility
+# -----------------------------------------------------------------------------
+
+def test_bshd_matches_old_reference():
+ """New ``get_rope_index`` equals the pinned reference on BSHD inputs."""
+ samples = _sample_bank()
+ max_len = max(s.numel() for s, _ in samples)
+
+ input_ids_rows = []
+ mask_rows = []
+ grid_rows = []
+ for tokens, grids in samples:
+ L = tokens.numel()
+ padded = F.pad(tokens, (0, max_len - L), value=0)
+ mask = torch.zeros(max_len, dtype=torch.int64)
+ mask[:L] = 1
+ input_ids_rows.append(padded)
+ mask_rows.append(mask)
+ if grids.numel() > 0:
+ grid_rows.append(grids)
+
+ input_ids = torch.stack(input_ids_rows) # [B, S]
+ attention_mask = torch.stack(mask_rows) # [B, S]
+ image_grid_thw = (
+ torch.cat(grid_rows, dim=0) if grid_rows else None
+ )
+
+ old_pos, old_delta = _old_get_rope_index(
+ spatial_merge_size=SPATIAL_MERGE_SIZE,
+ image_token_id=IMAGE_TOKEN_ID,
+ video_token_id=VIDEO_TOKEN_ID,
+ vision_start_token_id=VISION_START_TOKEN_ID,
+ input_ids=input_ids,
+ image_grid_thw=image_grid_thw,
+ attention_mask=attention_mask,
+ )
+ new_pos, new_delta = get_rope_index(
+ spatial_merge_size=SPATIAL_MERGE_SIZE,
+ image_token_id=IMAGE_TOKEN_ID,
+ video_token_id=VIDEO_TOKEN_ID,
+ vision_start_token_id=VISION_START_TOKEN_ID,
+ input_ids=input_ids,
+ image_grid_thw=image_grid_thw,
+ attention_mask=attention_mask,
+ packed_seq_params=None,
+ )
+
+ assert torch.equal(old_pos, new_pos), (
+ f"BSHD position_ids differ.\nold:\n{old_pos}\nnew:\n{new_pos}"
+ )
+ assert torch.equal(old_delta, new_delta), (
+ f"BSHD mrope_position_deltas differ.\n"
+ f"old: {old_delta}\nnew: {new_delta}"
+ )
+
+
+# -----------------------------------------------------------------------------
+# Test 2: THD positions match BSHD positions on the valid region
+# -----------------------------------------------------------------------------
+
+def _pack_samples(samples, divisible_by=1):
+ """Pack ``samples`` into a single ``[1, T]`` tensor the same way
+ ``pack_or_pad_batch`` does, and build ``PackedSeqParams``.
+
+ Each per-sample tensor is right-padded to a multiple of
+ ``divisible_by`` before concatenation. ``cu_seqlens_q`` tracks
+ unpadded lengths; ``cu_seqlens_q_padded`` tracks the packed layout.
+ """
+ padded_chunks = []
+ seqlens = []
+ seqlens_padded = []
+ grid_rows = []
+ for tokens, grids in samples:
+ L = tokens.numel()
+ target_L = math.ceil(L / divisible_by) * divisible_by
+ padded_chunks.append(F.pad(tokens, (0, target_L - L), value=0))
+ seqlens.append(L)
+ seqlens_padded.append(target_L)
+ if grids.numel() > 0:
+ grid_rows.append(grids)
+
+ packed = torch.cat(padded_chunks, dim=0).unsqueeze(0) # [1, T]
+ cu_seqlens = torch.tensor(
+ list(accumulate(seqlens, initial=0)), dtype=torch.int32,
+ )
+ cu_seqlens_padded = torch.tensor(
+ list(accumulate(seqlens_padded, initial=0)), dtype=torch.int32,
+ )
+ psp = PackedSeqParams(
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_kv=cu_seqlens,
+ cu_seqlens_q_padded=cu_seqlens_padded,
+ cu_seqlens_kv_padded=cu_seqlens_padded,
+ max_seqlen_q=max(seqlens_padded),
+ max_seqlen_kv=max(seqlens_padded),
+ )
+ image_grid_thw = (
+ torch.cat(grid_rows, dim=0) if grid_rows else None
+ )
+ return packed, psp, image_grid_thw, seqlens, seqlens_padded
+
+
+def test_thd_matches_bshd_padded():
+ """THD positions at every valid slot equal BSHD positions on the
+ equivalent right-padded batch.
+ """
+ samples = _sample_bank()
+
+ # BSHD side: right-pad to common max_len.
+ max_len = max(s.numel() for s, _ in samples)
+ input_ids_rows = []
+ mask_rows = []
+ grid_rows = []
+ for tokens, grids in samples:
+ L = tokens.numel()
+ input_ids_rows.append(F.pad(tokens, (0, max_len - L), value=0))
+ m = torch.zeros(max_len, dtype=torch.int64)
+ m[:L] = 1
+ mask_rows.append(m)
+ if grids.numel() > 0:
+ grid_rows.append(grids)
+ bshd_input_ids = torch.stack(input_ids_rows)
+ bshd_mask = torch.stack(mask_rows)
+ bshd_grid = torch.cat(grid_rows, dim=0) if grid_rows else None
+
+ bshd_pos, _ = get_rope_index(
+ spatial_merge_size=SPATIAL_MERGE_SIZE,
+ image_token_id=IMAGE_TOKEN_ID,
+ video_token_id=VIDEO_TOKEN_ID,
+ vision_start_token_id=VISION_START_TOKEN_ID,
+ input_ids=bshd_input_ids,
+ image_grid_thw=bshd_grid,
+ attention_mask=bshd_mask,
+ )
+ # bshd_pos: [3, B, S_pad]
+
+ # THD side: pack with a non-trivial divisor so the padded and
+ # unpadded cu_seqlens diverge — this exercises the distinction.
+ for divisible_by in (1, 4):
+ packed_input_ids, psp, thd_grid, seqlens, seqlens_padded = (
+ _pack_samples(samples, divisible_by=divisible_by)
+ )
+ thd_pos, _ = get_rope_index(
+ spatial_merge_size=SPATIAL_MERGE_SIZE,
+ image_token_id=IMAGE_TOKEN_ID,
+ video_token_id=VIDEO_TOKEN_ID,
+ vision_start_token_id=VISION_START_TOKEN_ID,
+ input_ids=packed_input_ids,
+ image_grid_thw=thd_grid,
+ packed_seq_params=psp,
+ )
+ # thd_pos: [3, 1, T]
+ assert thd_pos.shape == (
+ 3, 1, packed_input_ids.shape[1],
+ ), f"bad THD shape {thd_pos.shape}"
+
+ seg_starts = list(accumulate(seqlens_padded, initial=0))
+ for k, (valid_len, seg_start) in enumerate(
+ zip(seqlens, seg_starts)
+ ):
+ thd_slice = thd_pos[:, 0, seg_start:seg_start + valid_len]
+ bshd_slice = bshd_pos[:, k, :valid_len]
+ assert torch.equal(thd_slice, bshd_slice), (
+ f"[divisible_by={divisible_by}] segment {k} "
+ f"(valid_len={valid_len}, seg_start={seg_start}) "
+ f"disagrees:\nTHD:\n{thd_slice}\nBSHD:\n{bshd_slice}"
+ )
+
+
+# -----------------------------------------------------------------------------
+# Test 3: THD with no images (text-only packed)
+# -----------------------------------------------------------------------------
+
+def test_thd_text_only_restarts_per_segment():
+ """Text-only THD: each segment gets a fresh ``[0..valid_len-1]`` range."""
+ samples = [
+ _build_sample(prefix_text_len=6, grids=[], suffix_text_len=0),
+ _build_sample(prefix_text_len=11, grids=[], suffix_text_len=0),
+ _build_sample(prefix_text_len=3, grids=[], suffix_text_len=0),
+ ]
+ packed_input_ids, psp, _, seqlens, seqlens_padded = _pack_samples(
+ samples, divisible_by=4,
+ )
+ thd_pos, _ = get_rope_index(
+ spatial_merge_size=SPATIAL_MERGE_SIZE,
+ image_token_id=IMAGE_TOKEN_ID,
+ video_token_id=VIDEO_TOKEN_ID,
+ vision_start_token_id=VISION_START_TOKEN_ID,
+ input_ids=packed_input_ids,
+ image_grid_thw=None,
+ packed_seq_params=psp,
+ )
+
+ seg_starts = list(accumulate(seqlens_padded, initial=0))
+ for valid_len, seg_start in zip(seqlens, seg_starts):
+ expected = (
+ torch.arange(valid_len, dtype=thd_pos.dtype)
+ .view(1, -1)
+ .expand(3, -1)
+ )
+ got = thd_pos[:, 0, seg_start:seg_start + valid_len]
+ assert torch.equal(got, expected), (
+ f"text-only segment mismatch at seg_start={seg_start}, "
+ f"valid_len={valid_len}:\n{got}\nexpected:\n{expected}"
+ )
+
+
+# -----------------------------------------------------------------------------
+# Test 4: Explicit two-sequence batch with vision, both in BSHD and THD
+# -----------------------------------------------------------------------------
+
+def _two_image_samples():
+ """Two samples, each with one image — the smallest case that can
+ expose a bug where segment k > 0 positions leak state from segment
+ k - 1 (e.g. non-restarted ``st_idx`` or a stale ``image_index``).
+ """
+ return [
+ _build_sample(
+ prefix_text_len=5,
+ grids=[(1, 4, 4)], # 4 image tokens after spatial merge
+ suffix_text_len=3,
+ ),
+ _build_sample(
+ prefix_text_len=4,
+ grids=[(1, 4, 4)],
+ suffix_text_len=6,
+ ),
+ ]
+
+
+def test_bshd_batch_size_2_with_vision():
+ """BSHD with ``B == 2``: both rows' positions restart at 0 and match
+ the pinned reference.
+ """
+ samples = _two_image_samples()
+ max_len = max(s.numel() for s, _ in samples)
+
+ input_ids_rows, mask_rows, grid_rows = [], [], []
+ for tokens, grids in samples:
+ L = tokens.numel()
+ input_ids_rows.append(F.pad(tokens, (0, max_len - L), value=0))
+ m = torch.zeros(max_len, dtype=torch.int64)
+ m[:L] = 1
+ mask_rows.append(m)
+ grid_rows.append(grids)
+
+ input_ids = torch.stack(input_ids_rows) # [2, S]
+ attention_mask = torch.stack(mask_rows)
+ image_grid_thw = torch.cat(grid_rows, dim=0)
+ assert input_ids.shape[0] == 2
+
+ old_pos, old_delta = _old_get_rope_index(
+ spatial_merge_size=SPATIAL_MERGE_SIZE,
+ image_token_id=IMAGE_TOKEN_ID,
+ video_token_id=VIDEO_TOKEN_ID,
+ vision_start_token_id=VISION_START_TOKEN_ID,
+ input_ids=input_ids,
+ image_grid_thw=image_grid_thw,
+ attention_mask=attention_mask,
+ )
+ new_pos, new_delta = get_rope_index(
+ spatial_merge_size=SPATIAL_MERGE_SIZE,
+ image_token_id=IMAGE_TOKEN_ID,
+ video_token_id=VIDEO_TOKEN_ID,
+ vision_start_token_id=VISION_START_TOKEN_ID,
+ input_ids=input_ids,
+ image_grid_thw=image_grid_thw,
+ attention_mask=attention_mask,
+ )
+ assert torch.equal(old_pos, new_pos), (
+ f"BSHD B=2 position mismatch vs reference.\n"
+ f"old:\n{old_pos}\nnew:\n{new_pos}"
+ )
+ assert torch.equal(old_delta, new_delta)
+
+ # Both rows must start at position 0.
+ for i in range(2):
+ valid_len = int(attention_mask[i].sum().item())
+ assert torch.all(new_pos[:, i, 0] == 0), (
+ f"row {i} does not start at 0: {new_pos[:, i, 0]}"
+ )
+ # Sanity: positions within the valid region are strictly < valid_len
+ # would be wrong (MRoPE can skip positions), so just check max.
+ assert new_pos[:, i, :valid_len].max() < valid_len
+
+
+def test_thd_batch_size_2_with_vision():
+ """THD with 2 packed sequences: seg 1 positions restart at 0 and
+ equal BSHD row 1 on the valid region (bit-identical).
+ """
+ samples = _two_image_samples()
+
+ # BSHD reference.
+ max_len = max(s.numel() for s, _ in samples)
+ rows, masks, grids_bshd = [], [], []
+ for tokens, grids in samples:
+ L = tokens.numel()
+ rows.append(F.pad(tokens, (0, max_len - L), value=0))
+ m = torch.zeros(max_len, dtype=torch.int64)
+ m[:L] = 1
+ masks.append(m)
+ grids_bshd.append(grids)
+ bshd_input_ids = torch.stack(rows)
+ bshd_mask = torch.stack(masks)
+ bshd_grid = torch.cat(grids_bshd, dim=0)
+ bshd_pos, _ = get_rope_index(
+ spatial_merge_size=SPATIAL_MERGE_SIZE,
+ image_token_id=IMAGE_TOKEN_ID,
+ video_token_id=VIDEO_TOKEN_ID,
+ vision_start_token_id=VISION_START_TOKEN_ID,
+ input_ids=bshd_input_ids,
+ image_grid_thw=bshd_grid,
+ attention_mask=bshd_mask,
+ )
+
+ # THD packed version with a non-trivial divisor so padded and
+ # unpadded cu_seqlens disagree.
+ packed_input_ids, psp, thd_grid, seqlens, seqlens_padded = (
+ _pack_samples(samples, divisible_by=4)
+ )
+ assert len(seqlens) == 2
+ thd_pos, _ = get_rope_index(
+ spatial_merge_size=SPATIAL_MERGE_SIZE,
+ image_token_id=IMAGE_TOKEN_ID,
+ video_token_id=VIDEO_TOKEN_ID,
+ vision_start_token_id=VISION_START_TOKEN_ID,
+ input_ids=packed_input_ids,
+ image_grid_thw=thd_grid,
+ packed_seq_params=psp,
+ )
+
+ seg_starts = list(accumulate(seqlens_padded, initial=0))
+ for k, (valid_len, seg_start) in enumerate(
+ zip(seqlens, seg_starts)
+ ):
+ thd_slice = thd_pos[:, 0, seg_start:seg_start + valid_len]
+ bshd_slice = bshd_pos[:, k, :valid_len]
+ assert torch.equal(thd_slice, bshd_slice), (
+ f"seg {k} THD vs BSHD row {k} mismatch.\n"
+ f"THD:\n{thd_slice}\nBSHD:\n{bshd_slice}"
+ )
+ # Critical: seg k must start at position 0 (bug 2 check).
+ assert torch.all(thd_slice[:, 0] == 0), (
+ f"seg {k} does not start at 0 — positions leaked from "
+ f"previous segment: first col = {thd_slice[:, 0]}"
+ )
+
+
+if __name__ == "__main__":
+ test_bshd_matches_old_reference()
+ print("[ok] test_bshd_matches_old_reference")
+ test_thd_matches_bshd_padded()
+ print("[ok] test_thd_matches_bshd_padded")
+ test_thd_text_only_restarts_per_segment()
+ print("[ok] test_thd_text_only_restarts_per_segment")
+ test_bshd_batch_size_2_with_vision()
+ print("[ok] test_bshd_batch_size_2_with_vision")
+ test_thd_batch_size_2_with_vision()
+ print("[ok] test_thd_batch_size_2_with_vision")
+ print("All parity tests passed.")
diff --git a/examples/multimodal_dev/tests/test_thd_correctness.py b/examples/multimodal_dev/tests/test_thd_correctness.py
new file mode 100644
index 00000000000..d9c731cf4f2
--- /dev/null
+++ b/examples/multimodal_dev/tests/test_thd_correctness.py
@@ -0,0 +1,388 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""BSHD vs THD correctness test for multimodal_dev packed sequence support.
+
+Validates that packing a [B, S] batch into [1, T] THD format produces
+numerically equivalent loss values and gradient norms through a GPTModel.
+
+The test uses equal-length sequences (no padding) so that BSHD causal
+attention and THD cu_seqlens-based causal attention are mathematically
+identical. This makes any numerical deviation a real bug rather than an
+expected consequence of different padding/masking strategies.
+
+Usage::
+
+ # Single GPU (flash attention):
+ torchrun --nproc_per_node=1 \\
+ examples/multimodal_dev/tests/test_thd_correctness.py
+
+ # Override model size:
+ torchrun --nproc_per_node=1 \\
+ examples/multimodal_dev/tests/test_thd_correctness.py \\
+ --num-layers 4 --hidden-size 512 --num-heads 8 --num-kv-heads 4
+"""
+
+import argparse
+import os
+import sys
+
+import torch
+
+_REPO_ROOT = os.path.abspath(
+ os.path.join(os.path.dirname(__file__), "../../.."),
+)
+if _REPO_ROOT not in sys.path:
+ sys.path.insert(0, _REPO_ROOT)
+
+from megatron.core import parallel_state
+from megatron.core.models.gpt import GPTModel
+from megatron.core.models.gpt.gpt_layer_specs import (
+ get_gpt_layer_with_transformer_engine_spec,
+)
+from megatron.core.tensor_parallel.random import (
+ model_parallel_cuda_manual_seed,
+)
+from megatron.core.transformer.transformer_config import TransformerConfig
+from tests.unit_tests.test_utilities import Utils
+
+from examples.multimodal_dev.forward_step import (
+ _build_packed_seq_params,
+ _pack_batch,
+)
+
+
+# ===================================================================
+# Helpers
+# ===================================================================
+
+
+def _grad_norm(model):
+ """L2 norm of all parameter gradients."""
+ total = 0.0
+ for p in model.parameters():
+ if p.grad is not None:
+ total += p.grad.data.float().norm(2).item() ** 2
+ return total ** 0.5
+
+
+def _mean_loss(per_token_loss, loss_mask):
+ """Mean loss over valid tokens."""
+ flat = per_token_loss.float().view(-1)
+ mask = loss_mask.float().view(-1)
+ return (flat * mask).sum() / mask.sum().clamp(min=1)
+
+
+def _build_model(cfg, vocab_size, max_seq_len):
+ """Build a small GPTModel for testing."""
+ spec = get_gpt_layer_with_transformer_engine_spec()
+ model = GPTModel(
+ config=cfg,
+ transformer_layer_spec=spec,
+ vocab_size=vocab_size,
+ max_sequence_length=max_seq_len,
+ pre_process=True,
+ post_process=True,
+ parallel_output=False,
+ position_embedding_type="rope",
+ )
+ model.cuda()
+ return model
+
+
+# ===================================================================
+# Core test logic
+# ===================================================================
+
+
+def run_equal_length_test(
+ model,
+ batch_size,
+ seq_len,
+ vocab_size,
+ seed,
+ atol_loss,
+ rtol_grad,
+):
+ """Compare BSHD and THD with equal-length sequences (no padding).
+
+ Returns a dict with test metrics for logging.
+ """
+ B, S = batch_size, seq_len
+
+ # Deterministic data generation.
+ torch.manual_seed(seed + 1)
+ input_ids = torch.randint(0, vocab_size, (B, S), device="cuda")
+ labels = torch.randint(0, vocab_size, (B, S), device="cuda")
+ loss_mask = torch.ones(B, S, device="cuda")
+ position_ids = (
+ torch.arange(S, device="cuda").unsqueeze(0).expand(B, -1).contiguous()
+ )
+
+ # ---- BSHD forward / backward ----
+ output_bshd = model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ attention_mask=None,
+ labels=labels,
+ loss_mask=loss_mask,
+ )
+ bshd_loss = _mean_loss(output_bshd, loss_mask)
+ bshd_loss.backward()
+ bshd_gn = _grad_norm(model)
+ bshd_lv = bshd_loss.item()
+ bshd_per_token = output_bshd.detach().float().view(-1).clone()
+
+ model.zero_grad()
+
+ # ---- THD forward / backward ----
+ batch = {
+ "input_ids": input_ids.clone(),
+ "labels": labels.clone(),
+ "loss_mask": loss_mask.clone(),
+ "position_ids": position_ids.clone(),
+ }
+ packed = _pack_batch(batch)
+ psp = packed.pop("packed_seq_params")
+
+ output_thd = model(
+ input_ids=packed["input_ids"],
+ position_ids=packed["position_ids"],
+ attention_mask=None,
+ labels=packed["labels"],
+ loss_mask=packed["loss_mask"],
+ packed_seq_params=psp,
+ )
+ thd_loss = _mean_loss(output_thd, packed["loss_mask"])
+ thd_loss.backward()
+ thd_gn = _grad_norm(model)
+ thd_lv = thd_loss.item()
+ thd_per_token = output_thd.detach().float().view(-1).clone()
+
+ model.zero_grad()
+
+ # ---- Comparison ----
+ loss_diff = abs(bshd_lv - thd_lv)
+ grad_diff = abs(bshd_gn - thd_gn)
+ grad_rel = grad_diff / max(bshd_gn, 1e-8)
+ token_max_diff = (bshd_per_token - thd_per_token).abs().max().item()
+ token_mean_diff = (bshd_per_token - thd_per_token).abs().mean().item()
+
+ loss_ok = loss_diff < atol_loss
+ grad_ok = grad_rel < rtol_grad
+
+ metrics = dict(
+ bshd_loss=bshd_lv,
+ thd_loss=thd_lv,
+ loss_diff=loss_diff,
+ bshd_grad_norm=bshd_gn,
+ thd_grad_norm=thd_gn,
+ grad_diff=grad_diff,
+ grad_rel=grad_rel,
+ token_max_diff=token_max_diff,
+ token_mean_diff=token_mean_diff,
+ loss_ok=loss_ok,
+ grad_ok=grad_ok,
+ )
+ return metrics
+
+
+def run_variable_length_smoke_test(model, vocab_size, seed):
+ """Smoke test: variable-length sequences packed to THD.
+
+ Does NOT compare against BSHD (padding in BSHD changes attention
+ context). Validates that:
+ - Packing produces correct shapes
+ - Forward + backward complete without error
+ - Loss is finite
+ - Gradients are finite and non-zero
+
+ Returns a dict with test metrics.
+ """
+ seq_lengths = [128, 96, 112, 80]
+ S = max(seq_lengths)
+ B = len(seq_lengths)
+
+ torch.manual_seed(seed + 2)
+ input_ids = torch.randint(0, vocab_size, (B, S), device="cuda")
+ labels = torch.randint(0, vocab_size, (B, S), device="cuda")
+ loss_mask = torch.ones(B, S, device="cuda")
+ position_ids = (
+ torch.arange(S, device="cuda").unsqueeze(0).expand(B, -1).contiguous()
+ )
+
+ # Build attention_mask to indicate valid tokens per sample.
+ attention_mask = torch.zeros(B, S, device="cuda")
+ for i, sl in enumerate(seq_lengths):
+ attention_mask[i, :sl] = 1.0
+ loss_mask[i, sl:] = 0.0
+
+ batch = {
+ "input_ids": input_ids,
+ "labels": labels,
+ "loss_mask": loss_mask,
+ "position_ids": position_ids,
+ "attention_mask": attention_mask,
+ }
+ packed = _pack_batch(batch)
+ psp = packed.pop("packed_seq_params")
+
+ T = sum(seq_lengths)
+ assert packed["input_ids"].shape == (1, T), (
+ f"Expected [1, {T}], got {packed['input_ids'].shape}"
+ )
+ assert packed["labels"].shape == (1, T)
+ assert packed["loss_mask"].shape == (1, T)
+ assert psp.cu_seqlens_q.tolist() == [
+ 0,
+ seq_lengths[0],
+ seq_lengths[0] + seq_lengths[1],
+ seq_lengths[0] + seq_lengths[1] + seq_lengths[2],
+ T,
+ ]
+
+ output = model(
+ input_ids=packed["input_ids"],
+ position_ids=packed["position_ids"],
+ attention_mask=None,
+ labels=packed["labels"],
+ loss_mask=packed["loss_mask"],
+ packed_seq_params=psp,
+ )
+ loss = _mean_loss(output, packed["loss_mask"])
+ loss.backward()
+ gn = _grad_norm(model)
+ loss_val = loss.item()
+
+ model.zero_grad()
+
+ loss_finite = torch.isfinite(torch.tensor(loss_val)).item()
+ grad_finite = torch.isfinite(torch.tensor(gn)).item()
+ grad_nonzero = gn > 0
+
+ return dict(
+ loss=loss_val,
+ grad_norm=gn,
+ total_tokens=T,
+ loss_finite=loss_finite,
+ grad_finite=grad_finite,
+ grad_nonzero=grad_nonzero,
+ passed=loss_finite and grad_finite and grad_nonzero,
+ )
+
+
+# ===================================================================
+# Main
+# ===================================================================
+
+
+def _print_banner(title):
+ print(f"\n{'='*60}")
+ print(f" {title}")
+ print(f"{'='*60}")
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="BSHD vs THD correctness test",
+ )
+ parser.add_argument("--batch-size", type=int, default=4)
+ parser.add_argument("--seq-len", type=int, default=128)
+ parser.add_argument("--vocab-size", type=int, default=1024)
+ parser.add_argument("--hidden-size", type=int, default=256)
+ parser.add_argument("--num-layers", type=int, default=2)
+ parser.add_argument("--num-heads", type=int, default=4)
+ parser.add_argument("--num-kv-heads", type=int, default=2)
+ parser.add_argument("--ffn-hidden-size", type=int, default=512)
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument("--atol-loss", type=float, default=1e-5,
+ help="Absolute tolerance for loss comparison")
+ parser.add_argument("--rtol-grad", type=float, default=1e-3,
+ help="Relative tolerance for grad norm comparison")
+ args = parser.parse_args()
+
+ Utils.initialize_model_parallel(tensor_model_parallel_size=1)
+ model_parallel_cuda_manual_seed(args.seed)
+
+ config = TransformerConfig(
+ num_layers=args.num_layers,
+ hidden_size=args.hidden_size,
+ ffn_hidden_size=args.ffn_hidden_size,
+ num_attention_heads=args.num_heads,
+ num_query_groups=args.num_kv_heads,
+ bf16=True,
+ params_dtype=torch.bfloat16,
+ pipeline_dtype=torch.bfloat16,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ tensor_model_parallel_size=1,
+ sequence_parallel=False,
+ )
+
+ model = _build_model(config, args.vocab_size, args.seq_len)
+
+ all_passed = True
+
+ # ----------------------------------------------------------------
+ # Test 1: equal-length correctness (BSHD vs THD)
+ # ----------------------------------------------------------------
+ _print_banner("Test 1: Equal-length BSHD vs THD correctness")
+ m = run_equal_length_test(
+ model=model,
+ batch_size=args.batch_size,
+ seq_len=args.seq_len,
+ vocab_size=args.vocab_size,
+ seed=args.seed,
+ atol_loss=args.atol_loss,
+ rtol_grad=args.rtol_grad,
+ )
+ print(f" Config: B={args.batch_size}, S={args.seq_len}, "
+ f"H={args.hidden_size}, L={args.num_layers}, "
+ f"heads={args.num_heads}/{args.num_kv_heads}")
+ print(f" BSHD loss: {m['bshd_loss']:.8f}")
+ print(f" THD loss: {m['thd_loss']:.8f}")
+ print(f" Loss abs diff: {m['loss_diff']:.2e}")
+ print(f" BSHD grad norm: {m['bshd_grad_norm']:.8f}")
+ print(f" THD grad norm: {m['thd_grad_norm']:.8f}")
+ print(f" Grad norm rel diff: {m['grad_rel']:.2e}")
+ print(f" Per-token max diff: {m['token_max_diff']:.2e}")
+ print(f" Per-token mean diff: {m['token_mean_diff']:.2e}")
+ print(f" Loss match: {'PASS' if m['loss_ok'] else 'FAIL'} "
+ f"(atol={args.atol_loss})")
+ print(f" Grad norm match: {'PASS' if m['grad_ok'] else 'FAIL'} "
+ f"(rtol={args.rtol_grad})")
+ if not (m["loss_ok"] and m["grad_ok"]):
+ all_passed = False
+
+ # ----------------------------------------------------------------
+ # Test 2: variable-length smoke test (THD only)
+ # ----------------------------------------------------------------
+ _print_banner("Test 2: Variable-length THD smoke test")
+ v = run_variable_length_smoke_test(model, args.vocab_size, args.seed)
+ print(f" Seq lengths: [128, 96, 112, 80]")
+ print(f" Total packed tokens: {v['total_tokens']}")
+ print(f" Loss: {v['loss']:.8f}")
+ print(f" Grad norm: {v['grad_norm']:.8f}")
+ print(f" Loss finite: {'PASS' if v['loss_finite'] else 'FAIL'}")
+ print(f" Grad finite: {'PASS' if v['grad_finite'] else 'FAIL'}")
+ print(f" Grad nonzero: {'PASS' if v['grad_nonzero'] else 'FAIL'}")
+ if not v["passed"]:
+ all_passed = False
+
+ # ----------------------------------------------------------------
+ # Summary
+ # ----------------------------------------------------------------
+ _print_banner("Summary")
+ if all_passed:
+ print(" ALL TESTS PASSED")
+ else:
+ print(" SOME TESTS FAILED")
+ print(f"{'='*60}\n")
+
+ Utils.destroy_model_parallel()
+
+ if not all_passed:
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/multimodal_dev/tests/test_thd_e2e.py b/examples/multimodal_dev/tests/test_thd_e2e.py
new file mode 100644
index 00000000000..eef320b0a37
--- /dev/null
+++ b/examples/multimodal_dev/tests/test_thd_e2e.py
@@ -0,0 +1,263 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
+"""Tests for THD (packed sequence) support in multimodal_dev."""
+
+import pytest
+import torch
+
+import sys
+import os
+
+# Ensure the repo root is on the path so that the examples package is importable.
+_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
+if _REPO_ROOT not in sys.path:
+ sys.path.insert(0, _REPO_ROOT)
+
+from examples.multimodal_dev.forward_step import (
+ _build_packed_seq_params,
+ _pack_batch,
+)
+
+
+# ===================================================================
+# Unit tests — CPU, no distributed / GPU required
+# ===================================================================
+
+
+class TestBuildPackedSeqParams:
+ """Tests for ``_build_packed_seq_params``."""
+
+ def test_basic(self):
+ params = _build_packed_seq_params(
+ torch.tensor([5, 3, 7], dtype=torch.int32), device="cpu",
+ )
+ assert params.qkv_format == "thd"
+ assert params.cu_seqlens_q.tolist() == [0, 5, 8, 15]
+ assert params.cu_seqlens_kv.tolist() == [0, 5, 8, 15]
+ assert params.max_seqlen_q == 7
+ assert params.max_seqlen_kv == 7
+ assert params.total_tokens == 15
+ # Padded cu_seqlens mirror actual cu_seqlens (no per-subseq padding).
+ assert params.cu_seqlens_q_padded is not None
+ assert params.cu_seqlens_q_padded.tolist() == [0, 5, 8, 15]
+ assert params.cu_seqlens_kv_padded is not None
+ assert params.cu_seqlens_kv_padded.tolist() == [0, 5, 8, 15]
+
+ def test_equal_lengths(self):
+ params = _build_packed_seq_params(
+ torch.tensor([4, 4, 4], dtype=torch.int32), device="cpu",
+ )
+ assert params.cu_seqlens_q.tolist() == [0, 4, 8, 12]
+ assert params.max_seqlen_q == 4
+ assert params.total_tokens == 12
+
+ def test_single_sample(self):
+ params = _build_packed_seq_params(
+ torch.tensor([10], dtype=torch.int32), device="cpu",
+ )
+ assert params.cu_seqlens_q.tolist() == [0, 10]
+ assert params.max_seqlen_q == 10
+ assert params.total_tokens == 10
+
+ def test_dtype_is_int32(self):
+ params = _build_packed_seq_params(
+ torch.tensor([3, 5], dtype=torch.int32), device="cpu",
+ )
+ assert params.cu_seqlens_q.dtype == torch.int32
+
+ def test_seq_idx_computed(self):
+ """Verify __post_init__ computes seq_idx for Mamba compatibility."""
+ params = _build_packed_seq_params(
+ torch.tensor([3, 2], dtype=torch.int32), device="cpu",
+ )
+ # seq_idx should be [0,0,0,1,1] (shape [1, 5])
+ assert params.seq_idx is not None
+ assert params.seq_idx.shape == (1, 5)
+ assert params.seq_idx[0].tolist() == [0, 0, 0, 1, 1]
+
+
+class TestPackBatch:
+ """Tests for ``_pack_batch``."""
+
+ def test_no_padding(self):
+ """All tokens valid — T == B*S."""
+ B, S = 2, 8
+ batch = {
+ "input_ids": torch.arange(B * S).reshape(B, S),
+ "labels": torch.arange(B * S).reshape(B, S) + 100,
+ "loss_mask": torch.ones(B, S),
+ "position_ids": torch.arange(S).unsqueeze(0).unsqueeze(0).expand(
+ 3, B, S,
+ ).clone(),
+ }
+ packed = _pack_batch(batch)
+
+ T = B * S
+ assert packed["input_ids"].shape == (1, T)
+ assert packed["labels"].shape == (1, T)
+ assert packed["loss_mask"].shape == (1, T)
+ assert packed["position_ids"].shape == (3, 1, T)
+ assert packed["attention_mask"] is None
+ assert packed["packed_seq_params"].total_tokens == T
+
+ def test_with_padding(self):
+ """attention_mask strips padding — T < B*S."""
+ B, S = 2, 8
+ batch = {
+ "input_ids": torch.arange(B * S).reshape(B, S),
+ "labels": torch.arange(B * S).reshape(B, S),
+ "loss_mask": torch.ones(B, S),
+ "position_ids": torch.zeros(3, B, S, dtype=torch.long),
+ "attention_mask": torch.tensor([
+ [1, 1, 1, 1, 1, 0, 0, 0], # 5 valid
+ [1, 1, 1, 0, 0, 0, 0, 0], # 3 valid
+ ]),
+ }
+ packed = _pack_batch(batch)
+
+ T = 5 + 3
+ assert packed["input_ids"].shape == (1, T)
+ assert packed["labels"].shape == (1, T)
+ assert packed["packed_seq_params"].cu_seqlens_q.tolist() == [0, 5, 8]
+ assert packed["packed_seq_params"].max_seqlen_q == 5
+ assert packed["packed_seq_params"].total_tokens == T
+
+ def test_token_order_preserved(self):
+ """Packed tokens appear in sample-0 then sample-1 order."""
+ batch = {
+ "input_ids": torch.tensor([[10, 20, 30], [40, 50, 60]]),
+ "position_ids": torch.zeros(3, 2, 3, dtype=torch.long),
+ }
+ packed = _pack_batch(batch)
+ assert packed["input_ids"].tolist() == [[10, 20, 30, 40, 50, 60]]
+
+ def test_position_ids_mrope(self):
+ """MRoPE [3, B, S] → [3, 1, T] with correct concatenation."""
+ B, S = 2, 4
+ pos = torch.zeros(3, B, S, dtype=torch.long)
+ # Sample 0: positions [0,1,2,3] on all 3 dims
+ # Sample 1: positions [10,11,12,13] on all 3 dims
+ for d in range(3):
+ pos[d, 0] = torch.tensor([0, 1, 2, 3])
+ pos[d, 1] = torch.tensor([10, 11, 12, 13])
+
+ batch = {
+ "input_ids": torch.zeros(B, S, dtype=torch.long),
+ "position_ids": pos,
+ }
+ packed = _pack_batch(batch)
+
+ assert packed["position_ids"].shape == (3, 1, 8)
+ # Each dim: [0,1,2,3,10,11,12,13]
+ for d in range(3):
+ assert packed["position_ids"][d, 0].tolist() == [
+ 0, 1, 2, 3, 10, 11, 12, 13,
+ ]
+
+ def test_standard_position_ids(self):
+ """Standard [B, S] position_ids → [1, T]."""
+ B, S = 2, 3
+ batch = {
+ "input_ids": torch.zeros(B, S, dtype=torch.long),
+ "position_ids": torch.tensor([[0, 1, 2], [0, 1, 2]]),
+ }
+ packed = _pack_batch(batch)
+ assert packed["position_ids"].shape == (1, 6)
+ assert packed["position_ids"].tolist() == [[0, 1, 2, 0, 1, 2]]
+
+ def test_no_labels_no_loss_mask(self):
+ """Gracefully handle missing labels and loss_mask."""
+ batch = {
+ "input_ids": torch.tensor([[1, 2], [3, 4]]),
+ "position_ids": torch.zeros(3, 2, 2, dtype=torch.long),
+ }
+ packed = _pack_batch(batch)
+ assert packed["input_ids"].shape == (1, 4)
+ assert packed.get("labels") is None
+ assert packed.get("loss_mask") is None
+
+ def test_data_provided_cu_seqlens_takes_priority(self):
+ """Prefer dataset-provided cu_seqlens over attention_mask."""
+ B, S = 2, 6
+ batch = {
+ "input_ids": torch.tensor([
+ [1, 2, 3, 4, 5, 6],
+ [7, 8, 9, 10, 11, 12],
+ ]),
+ "labels": torch.tensor([
+ [101, 102, 103, 104, 105, 106],
+ [107, 108, 109, 110, 111, 112],
+ ]),
+ "loss_mask": torch.ones(B, S),
+ "position_ids": torch.zeros(3, B, S, dtype=torch.long),
+ # Deliberately inconsistent with cu_seqlens.
+ "attention_mask": torch.ones(B, S),
+ # Per-sample cu_seqlens: sample0 len=4, sample1 len=2.
+ "cu_seqlens": torch.tensor([[0, 4], [0, 2]], dtype=torch.int32),
+ "cu_seqlens_padded": torch.tensor([[0, 4], [0, 2]], dtype=torch.int32),
+ "max_seqlen": torch.tensor([4, 2], dtype=torch.int32),
+ }
+ packed = _pack_batch(batch)
+
+ assert packed["input_ids"].shape == (1, 6)
+ assert packed["input_ids"].tolist() == [[1, 2, 3, 4, 7, 8]]
+ assert packed["labels"].tolist() == [[101, 102, 103, 104, 107, 108]]
+ assert packed["packed_seq_params"].cu_seqlens_q.tolist() == [0, 4, 6]
+ assert packed["packed_seq_params"].max_seqlen_q == 4
+ # Raw data-side metadata is no longer needed after packing.
+ assert "cu_seqlens" not in packed
+ assert "cu_seqlens_padded" not in packed
+ assert "max_seqlen" not in packed
+
+ def test_multisegment_cu_seqlens_rejected(self):
+ """[B, N] cu_seqlens with N>2 is explicitly unsupported."""
+ batch = {
+ "input_ids": torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]),
+ "position_ids": torch.zeros(3, 2, 4, dtype=torch.long),
+ "cu_seqlens": torch.tensor(
+ [[0, 2, 4], [0, 1, 4]], dtype=torch.int32,
+ ),
+ }
+ with pytest.raises(ValueError, match="expected \\[B, 2\\]"):
+ _pack_batch(batch)
+
+ def test_variable_length_with_attention_mask(self):
+ """Variable-length sequences: attention_mask strips padding."""
+ B, S = 3, 10
+ seq_lengths = [8, 5, 10] # valid tokens per sample
+ batch = {
+ "input_ids": torch.arange(B * S).reshape(B, S),
+ "labels": torch.arange(B * S).reshape(B, S) + 100,
+ "loss_mask": torch.ones(B, S),
+ "position_ids": torch.zeros(3, B, S, dtype=torch.long),
+ "attention_mask": torch.zeros(B, S),
+ }
+ for i, sl in enumerate(seq_lengths):
+ batch["attention_mask"][i, :sl] = 1.0
+
+ packed = _pack_batch(batch)
+
+ T = sum(seq_lengths) # 23
+ assert packed["input_ids"].shape == (1, T)
+ assert packed["labels"].shape == (1, T)
+ assert packed["loss_mask"].shape == (1, T)
+ assert packed["position_ids"].shape == (3, 1, T)
+ assert packed["packed_seq_params"].cu_seqlens_q.tolist() == [
+ 0, 8, 13, 23,
+ ]
+ assert packed["packed_seq_params"].total_tokens == T
+
+ # Verify correct tokens were kept (first sl tokens per sample).
+ ids = packed["input_ids"][0].tolist()
+ expected = list(range(0, 8)) + list(range(10, 15)) + list(range(20, 30))
+ assert ids == expected
+
+ def test_packed_seq_params_cumsum_matches_loop(self):
+ """Verify torch.cumsum produces the same cu_seqlens as a Python loop."""
+ lengths = torch.tensor([17, 31, 11, 42, 1], dtype=torch.int32)
+ params = _build_packed_seq_params(lengths, device="cpu")
+ # Manual cumulative sum
+ expected = [0]
+ for sl in lengths.tolist():
+ expected.append(expected[-1] + int(sl))
+ assert params.cu_seqlens_q.tolist() == expected
diff --git a/megatron/training/datasets/data_samplers.py b/megatron/training/datasets/data_samplers.py
index 430bd8b85da..8b14b975aba 100644
--- a/megatron/training/datasets/data_samplers.py
+++ b/megatron/training/datasets/data_samplers.py
@@ -106,7 +106,7 @@ def close_nvidia_fds():
maybe_worker_init_fn = worker_init_fn if args.num_workers > 0 else None
# Torch dataloader.
- if args.dynamic_context_parallel:
+ if args.dynamic_context_parallel or getattr(args, "use_vanilla_collate_fn", False):
extra_kwargs = {"collate_fn": lambda x: x}
else:
extra_kwargs = {}