diff --git a/examples/multimodal_dev/README.md b/examples/multimodal_dev/README.md new file mode 100644 index 00000000000..bdc34414da9 --- /dev/null +++ b/examples/multimodal_dev/README.md @@ -0,0 +1,162 @@ +# multimodal_dev — Standalone Multimodal Training + +Standalone, model-agnostic training entry point for multimodal +vision-language models built on Megatron-Core (FSDP + EP). + +## Directory Structure + +``` +multimodal_dev/ +├── pretrain_multimodal.py # Training entry point (model-agnostic) +├── forward_step.py # Forward step, TP broadcast, loss computation +├── arguments.py # Multimodal CLI arguments +├── data/ +│ └── mock.py # Mock dataset for end-to-end testing +├── models/ +│ ├── __init__.py # MODEL_REGISTRY — central model registry +│ ├── base.py # MultimodalModel base class (vision encoder + GPTModel) +│ └── qwen35_vl/ # Qwen3.5-VL architecture +│ ├── factory.py # Factory functions for pretrain entry point +│ ├── model.py # Qwen35VLModel (MRoPE, vision encoder wiring) +│ ├── configuration.py # TransformerConfig builders and constants +│ ├── specs.py # Layer spec builders (hybrid attention, ViT) +│ ├── mrope.py # 3D MRoPE position ID computation +│ └── vision_encoder.py# ViT encoder (patch embed, merger, RoPE) +└── scripts/ # Launch scripts (torchrun, Slurm) +``` + +## Quick Start + +```bash +torchrun --nproc_per_node=8 multimodal_dev/pretrain_multimodal.py \ + --model-arch qwen35_vl \ + --dataset-provider mock \ + ... # other Megatron args (--num-layers, --hidden-size, etc.) +``` + +## Architecture + +`pretrain_multimodal.py` is **model-agnostic**. All model-specific logic +is delegated to factory functions registered in `MODEL_REGISTRY` +(`models/__init__.py`). The entry point handles only generic concerns: + +- Building `language_config` from Megatron CLI args +- Constructing `vision_config` via the registry +- Applying vision recompute and dtype propagation +- Routing to model and dataset factories + +The `forward_step` is also model-agnostic — it uses the model's +`compute_position_ids()` method polymorphically and passes a standard +batch dict. + +## Adding a New Model Architecture + +Adding a new model (e.g. `llava_next`) requires **no changes** to +`pretrain_multimodal.py` or `forward_step.py`. Follow these steps: + +### Step 1 — Create the model package + +``` +multimodal_dev/models/llava_next/ +├── __init__.py +├── factory.py # Required: factory functions +├── configuration.py # Vision/language TransformerConfig builders +├── model.py # Model class (subclass MultimodalModel) +├── specs.py # Layer spec builders +└── vision_encoder.py # Vision encoder (if custom) +``` + +### Step 2 — Implement factory functions + +Create `factory.py` with up to three functions: + +```python +# models/llava_next/factory.py + +def post_language_config(language_config, args): + """(Optional) Mutate language_config with model-specific fields.""" + # e.g. language_config.some_field = value + pass + +def set_vision_flops_metadata(args, language_config, vision_config): + """(Optional) Set vision FLOPs metadata on args.""" + args.count_vision_model_flops = True + args.vision_flops_variant = "llava_next" + # ... set dimension fields for FLOPs calculation + +def build_model(args, language_config, vision_config, **kwargs): + """(Required) Build and return the complete model instance.""" + from .model import LlavaNextModel + from .specs import get_llava_next_language_spec + + language_spec = get_llava_next_language_spec( + config=language_config, + vp_stage=kwargs.get("vp_stage", None), + pp_rank=None, + ) + return LlavaNextModel( + language_config=language_config, + language_spec=language_spec, + vision_config=vision_config, + # ... model-specific args + ) +``` + +### Step 3 — Register in `MODEL_REGISTRY` + +Add an entry in `models/__init__.py`: + +```python +from multimodal_dev.models.llava_next.configuration import ( + get_llava_next_vision_config, +) +from multimodal_dev.models.llava_next.factory import ( + build_model as _build_llava_next_model, + post_language_config as _llava_next_post_language_config, + set_vision_flops_metadata as _llava_next_vision_flops, +) + +MODEL_REGISTRY["llava_next"] = { + "model_factory_fn": _build_llava_next_model, # required + "vision_config_fn": get_llava_next_vision_config, # required + "post_language_config_fn": _llava_next_post_language_config, # optional + "vision_flops_fn": _llava_next_vision_flops, # optional + "dataset_providers": { # optional + "mock": "multimodal_dev.data.llava_mock.train_valid_test_datasets_provider", + }, +} +``` + +### Step 4 — (Optional) Add a dataset provider + +Create a dataset module under `data/` if the model needs custom data +preprocessing. The provider function signature is: + +```python +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Return (train_dataset, val_dataset, test_dataset).""" + ... +``` + +Register it in the `dataset_providers` dict of the registry entry. +Providers can be either direct callables or dotted import path strings +(resolved lazily at runtime). + +### Step 5 — Launch + +```bash +torchrun --nproc_per_node=8 multimodal_dev/pretrain_multimodal.py \ + --model-arch llava_next \ + --dataset-provider mock \ + ... +``` + +## Registry Entry Reference + +| Field | Required | Signature | +|-------|----------|-----------| +| `model_factory_fn` | Yes | `(args, language_config, vision_config, **kwargs) -> MegatronModule` | +| `vision_config_fn` | Yes | `(num_layers_override=None) -> TransformerConfig` | +| `post_language_config_fn` | No | `(language_config, args) -> None` | +| `vision_flops_fn` | No | `(args, language_config, vision_config) -> None` | +| `dataset_providers` | No | `Dict[str, str \| callable]` | diff --git a/examples/multimodal_dev/__init__.py b/examples/multimodal_dev/__init__.py new file mode 100644 index 00000000000..e76ed74857b --- /dev/null +++ b/examples/multimodal_dev/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. diff --git a/examples/multimodal_dev/arguments.py b/examples/multimodal_dev/arguments.py new file mode 100644 index 00000000000..16d39c82857 --- /dev/null +++ b/examples/multimodal_dev/arguments.py @@ -0,0 +1,100 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Extra CLI arguments for multimodal_dev standalone training.""" + + +def add_multimodal_args(parser): + """Add multimodal-specific arguments to the Megatron argument parser.""" + group = parser.add_argument_group( + "Multimodal", "Multimodal model arguments", + ) + + group.add_argument( + "--model-arch", + type=str, + default="qwen35_vl", + help="Model architecture. Available: qwen35_vl", + ) + group.add_argument( + "--model-variant", + type=str, + default="proxy", + help="Model variant (size). E.g. proxy, 9b, 397b_a17b", + ) + group.add_argument( + "--dataset-provider", + type=str, + default="mock", + help="Dataset provider: mock", + ) + group.add_argument( + "--image-token-id", + type=int, + default=248056, + help="Token ID for image placeholder tokens", + ) + group.add_argument( + "--image-size", + type=int, + default=224, + help="Image size (height and width) for mock data", + ) + group.add_argument( + "--total-seq-length", + type=int, + default=1024, + help="Total sequence length for mock data", + ) + group.add_argument( + "--image-seq-length", + type=int, + default=256, + help="Number of image tokens in mock data", + ) + group.add_argument( + "--vision-num-layers", + type=int, + default=None, + help=( + "Override for vision backbone depth. " + "Useful for proxy perf runs." + ), + ) + group.add_argument( + "--hf-processor-path", + type=str, + default=None, + help=( + "HuggingFace processor path for real VLM datasets " + "(e.g. Qwen/Qwen2.5-VL-7B-Instruct)" + ), + ) + group.add_argument( + "--recompute-vision", + action="store_true", + default=False, + help=( + "Enable full activation recomputation for vision encoder layers. " + "Uses uniform method and recomputes every layer. " + "Independent of the decoder --recompute-* flags." + ), + ) + group.add_argument( + "--use-packed-sequence", + action="store_true", + default=False, + help=( + "Pack variable-length sequences into THD format to eliminate " + "padding waste." + ), + ) + group.add_argument( + "--use-vanilla-collate-fn", + action="store_true", + default=False, + help=( + "Use vanilla collate function to collate the data." + ), + ) + + return parser diff --git a/examples/multimodal_dev/data/__init__.py b/examples/multimodal_dev/data/__init__.py new file mode 100644 index 00000000000..e76ed74857b --- /dev/null +++ b/examples/multimodal_dev/data/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. diff --git a/examples/multimodal_dev/data/mock.py b/examples/multimodal_dev/data/mock.py new file mode 100644 index 00000000000..2daac8e4eba --- /dev/null +++ b/examples/multimodal_dev/data/mock.py @@ -0,0 +1,192 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Mock dataset for multimodal_dev end-to-end testing. + +Generates synthetic image + text data. Each sample has random text +tokens with image-token placeholders, random pixel values sized for the +vision encoder, 3D MRoPE position IDs, and shifted labels. +""" + +import torch +from torch.utils.data import Dataset + +from examples.multimodal_dev.models.qwen35_vl.configuration import ( + QWEN35_VL_IMAGE_TOKEN_ID, + QWEN35_VL_VIDEO_TOKEN_ID, + QWEN35_VL_VISION_START_TOKEN_ID, +) +from examples.multimodal_dev.models.qwen35_vl.mrope import get_rope_index + + +class MockQwen35VLDataset(Dataset): + """Synthetic Qwen3.5-VL training samples. + + Args: + num_samples: Number of samples. + seq_length: Total sequence length (text + image tokens). + image_seq_length: Number of image tokens per sample. + vocab_size: Vocabulary size for random text tokens. + image_token_id: Token ID for image placeholders. + video_token_id: Token ID for video placeholders. + vision_start_token_id: Token ID marking start of a vision region. + image_size: Image height and width in pixels. + patch_size: Spatial patch size. + temporal_patch_size: Temporal patch size. + spatial_merge_size: Spatial merge factor. + """ + + def __init__( + self, + num_samples: int = 1000, + seq_length: int = 1024, + image_seq_length: int = 256, + vocab_size: int = 248320, + image_token_id: int = QWEN35_VL_IMAGE_TOKEN_ID, + video_token_id: int = QWEN35_VL_VIDEO_TOKEN_ID, + vision_start_token_id: int = QWEN35_VL_VISION_START_TOKEN_ID, + image_size: int = 224, + patch_size: int = 16, + temporal_patch_size: int = 2, + spatial_merge_size: int = 2, + ): + self.num_samples = num_samples + self.seq_length = seq_length + self.vocab_size = vocab_size + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.image_size = image_size + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.spatial_merge_size = spatial_merge_size + + h_patches = image_size // patch_size + w_patches = image_size // patch_size + t_patches = temporal_patch_size + self.grid_thw = torch.tensor([[t_patches, h_patches, w_patches]]) + + self.num_merged_tokens = ( + t_patches + * (h_patches // spatial_merge_size) + * (w_patches // spatial_merge_size) + ) + self.image_seq_length = min( + image_seq_length, self.num_merged_tokens, + ) + self.total_patches = t_patches * h_patches * w_patches + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + # Reserve 1 slot for the vision_start sentinel before image tokens. + text_length = self.seq_length - self.image_seq_length - 1 + text_tokens = torch.randint( + 1, self.vocab_size, (text_length,), dtype=torch.long, + ) + special_ids = { + self.image_token_id, + self.video_token_id, + self.vision_start_token_id, + } + for sid in special_ids: + text_tokens[text_tokens == sid] = 1 + + prefix_len = text_length // 2 + suffix_len = text_length - prefix_len + input_ids = torch.cat([ + text_tokens[:prefix_len], + torch.tensor( + [self.vision_start_token_id], dtype=torch.long, + ), + torch.full( + (self.image_seq_length,), + self.image_token_id, + dtype=torch.long, + ), + text_tokens[prefix_len: prefix_len + suffix_len], + ]) + + labels = input_ids.clone() + labels[:-1] = input_ids[1:] + labels[-1] = 0 + + loss_mask = (input_ids != self.image_token_id).float() + loss_mask[-1] = 0 + + pixel_dim = ( + 3 + * self.temporal_patch_size + * self.patch_size + * self.patch_size + ) + pixel_values = torch.randn(self.total_patches, pixel_dim) + + image_grid_thw = self.grid_thw.clone() + + position_ids, _ = get_rope_index( + spatial_merge_size=self.spatial_merge_size, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + vision_start_token_id=self.vision_start_token_id, + input_ids=input_ids.unsqueeze(0), + image_grid_thw=image_grid_thw, + ) + position_ids = position_ids.squeeze(1) + + return { + "input_ids": input_ids, + "labels": labels, + "loss_mask": loss_mask, + "cu_seqlens": torch.tensor([0, self.seq_length], dtype=torch.int32), + "cu_seqlens_padded": torch.tensor( + [0, self.seq_length], dtype=torch.int32, + ), + "max_seqlen": torch.tensor(self.seq_length, dtype=torch.int32), + "position_ids": position_ids, + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thw, + } + + +def mock_collate_fn(batch): + """Collate: handles position_ids ``[3, S]`` stacking.""" + result = {} + keys = batch[0].keys() + for key in keys: + tensors = [sample[key] for sample in batch] + if key == "position_ids": + result[key] = torch.stack(tensors, dim=1) + elif key == "image_grid_thw": + result[key] = torch.cat(tensors, dim=0) + elif key == "pixel_values": + result[key] = torch.cat(tensors, dim=0) + else: + result[key] = torch.stack(tensors, dim=0) + return result + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Provide mock train / val / test datasets.""" + from megatron.training import get_args + + args = get_args() + kwargs = dict( + seq_length=getattr(args, "total_seq_length", 1024), + image_seq_length=getattr(args, "image_seq_length", 256), + vocab_size=getattr(args, "padded_vocab_size", 248320), + image_token_id=getattr(args, "image_token_id", 248056), + image_size=getattr(args, "image_size", 224), + ) + + train_ds = MockQwen35VLDataset( + num_samples=train_val_test_num_samples[0], **kwargs, + ) + val_ds = MockQwen35VLDataset( + num_samples=train_val_test_num_samples[1], **kwargs, + ) + test_ds = MockQwen35VLDataset( + num_samples=train_val_test_num_samples[2], **kwargs, + ) + + return train_ds, val_ds, test_ds diff --git a/examples/multimodal_dev/data/vlm_dataset.py b/examples/multimodal_dev/data/vlm_dataset.py new file mode 100644 index 00000000000..47a323d61ad --- /dev/null +++ b/examples/multimodal_dev/data/vlm_dataset.py @@ -0,0 +1,249 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Simple VLM dataset for multimodal_dev training. + +Single-turn image-text dataset using a HuggingFace ``AutoProcessor`` for +tokenization and image preprocessing. Currently supports CORD-V2 (receipt +OCR). No multi-turn support — each sample is one image + question → +answer pair. + +Because Megatron's DataLoader uses ``default_collate`` (no custom collate +function), all images are resized to a fixed resolution so that +``pixel_values`` has a consistent shape across samples. + +Usage:: + + torchrun ... pretrain_multimodal.py \\ + --model-arch qwen35_vl --dataset-provider cord_v2 \\ + --hf-processor-path Qwen/Qwen2.5-VL-7B-Instruct \\ + --image-size 448 --seq-length 2048 +""" + +import json +import logging +import random +from typing import Dict, List, Optional + +import torch +from torch.utils.data import Dataset + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# CORD-V2 helpers +# --------------------------------------------------------------------------- + +def _json2token(obj, sort_json_key=True): + """Convert a JSON object to a token-sequence string (Donut format).""" + if isinstance(obj, dict): + if len(obj) == 1 and "text_sequence" in obj: + return obj["text_sequence"] + output = "" + keys = sorted(obj.keys(), reverse=True) if sort_json_key else obj.keys() + for k in keys: + output += f"" + _json2token(obj[k], sort_json_key) + f"" + return output + if isinstance(obj, list): + return "".join(_json2token(item, sort_json_key) for item in obj) + return str(obj) + + +def load_cord_v2(split="train"): + """Load CORD-V2 and return a list of ``{image, question, answer}`` dicts.""" + from datasets import load_dataset + + ds = load_dataset("naver-clova-ix/cord-v2", split=split) + rng = random.Random(42) + examples = [] + for ex in ds: + gt = json.loads(ex["ground_truth"]) + gt_jsons = gt.get("gt_parses") or [gt["gt_parse"]] + text = rng.choice( + [_json2token(g, sort_json_key=True) for g in gt_jsons] + ) + examples.append( + {"image": ex["image"], "question": "Describe this image.", "answer": text} + ) + return examples + + +# --------------------------------------------------------------------------- +# Dataset +# --------------------------------------------------------------------------- + +class CordV2VLMDataset(Dataset): + """Single-turn VLM dataset backed by CORD-V2. + + Each sample is tokenized by the HF ``AutoProcessor`` and resized to a + fixed resolution so that ``pixel_values`` has consistent shape across + samples (required by ``default_collate``). + + Args: + examples: Output of :func:`load_cord_v2`. + processor: ``AutoProcessor`` instance. + seq_length: Pad / truncate ``input_ids`` to this length. + image_size: Resize all images to ``(image_size, image_size)``. + image_token_id: Token ID for image placeholders. + target_length: Virtual dataset length (repeats examples if needed). + + NOTE: + For qwen3.5 vl processor, an example is below. + - For images, the processor duplicates the frame so T=2, which makes the 3D conv behave exactly like a 2D conv on one image, to support both image and video inputs. + - image_array, 448 * 448 * 3 + - pixel_values.shape (28 * 28, 1536 = 16 * 16 * 3 * 2) + - image_grid_thw, 1, 28, 28 (H/16, W/16) + """ + + def __init__( + self, + examples: List[Dict], + processor, + seq_length: int = 2048, + image_size: int = 448, + image_token_id: Optional[int] = None, + target_length: Optional[int] = None, + ): + self.examples = examples + self.processor = processor + self.seq_length = seq_length + self.image_size = (image_size, image_size) + self._length = target_length if target_length else len(examples) + tok = processor.tokenizer + self.pad_token_id = tok.pad_token_id if tok.pad_token_id is not None else 0 + + # Resolve image token ID + if image_token_id is not None: + self.image_token_id = image_token_id + else: + vocab = tok.get_vocab() + for candidate in ("<|image_pad|>", "<|placeholder|>"): + if candidate in vocab: + self.image_token_id = vocab[candidate] + break + else: + self.image_token_id = None + + def __len__(self) -> int: + return self._length + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + example = self.examples[idx % len(self.examples)] + + # Fixed-resolution image + image = example["image"].convert("RGB").resize(self.image_size) + + # Build single-turn conversation + conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": example["question"]}, + ], + }, + { + "role": "assistant", + "content": example["answer"], + }, + ] + + # Tokenize + extract pixel values via HF processor + text = self.processor.apply_chat_template( + conversation, tokenize=False, add_generation_prompt=False, + ) + batch = self.processor( + text=[text], images=[image], return_tensors="pt", + ) + + input_ids = batch["input_ids"].squeeze(0) + pixel_values = batch["pixel_values"] # [total_patches, pixel_dim] + image_grid_thw = batch["image_grid_thw"] # [1, 3] + + # Labels: shifted next-token prediction targets + labels = input_ids.clone() + labels[:-1] = input_ids[1:] + labels[-1] = -100 + + # Loss mask: 1 for trainable positions, 0 for padding / image tokens + loss_mask = torch.ones_like(input_ids, dtype=torch.float32) + loss_mask[input_ids == self.pad_token_id] = 0.0 + if self.image_token_id is not None: + loss_mask[input_ids == self.image_token_id] = 0.0 + loss_mask[-1] = 0.0 + labels[loss_mask == 0] = -100 + + return { + "input_ids": input_ids, + "labels": labels, + "loss_mask": loss_mask, + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thw, + } + + +# --------------------------------------------------------------------------- +# Megatron dataset provider interface +# --------------------------------------------------------------------------- + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Provide CORD-V2 train / val / test datasets. + + Requires ``--hf-processor-path`` to point to a HuggingFace VL model + (e.g. ``Qwen/Qwen2.5-VL-7B-Instruct``) whose processor handles + tokenization and image preprocessing. + """ + from transformers import AutoProcessor + + from megatron.training import get_args + + args = get_args() + + processor_path = getattr(args, "hf_processor_path", None) + if processor_path is None: + raise ValueError( + "cord_v2 dataset requires --hf-processor-path " + "(e.g. Qwen/Qwen2.5-VL-7B-Instruct)" + ) + processor = AutoProcessor.from_pretrained( + processor_path, trust_remote_code=True, + ) + + seq_length = ( + getattr(args, "total_seq_length", None) + or getattr(args, "seq_length", 2048) + ) + image_size = getattr(args, "image_size", 448) + image_token_id = getattr(args, "image_token_id", None) + + # Load real data + train_examples = load_cord_v2(split="train") + val_examples = load_cord_v2(split="validation") + test_examples = load_cord_v2(split="test") + + def _make(examples, num_samples): + return CordV2VLMDataset( + examples=examples, + processor=processor, + seq_length=seq_length, + image_size=image_size, + image_token_id=image_token_id, + target_length=num_samples, + ) + + train_ds = _make(train_examples, train_val_test_num_samples[0]) + val_ds = _make(val_examples, max(train_val_test_num_samples[1], 1)) + test_ds = _make(test_examples, max(train_val_test_num_samples[2], 1)) + + return train_ds, val_ds, test_ds + + +if __name__ == "__main__": + from transformers import AutoProcessor + processor = AutoProcessor.from_pretrained("Qwen/Qwen3.5-35B-A3B", trust_remote_code=True) + examples = load_cord_v2(split="train") + dataset = CordV2VLMDataset( + examples=examples, + processor=processor, + ) + print(dataset[0]) diff --git a/examples/multimodal_dev/forward_step.py b/examples/multimodal_dev/forward_step.py new file mode 100644 index 00000000000..4ee7d688a2b --- /dev/null +++ b/examples/multimodal_dev/forward_step.py @@ -0,0 +1,446 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Forward step, TP broadcast, and loss for multimodal_dev training.""" + +import math +from functools import partial +from itertools import accumulate +from typing import Any, Dict, Iterator + +import torch +import torch.nn.functional as F + +from megatron.core import mpu +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.parallel_state import ( + get_context_parallel_world_size, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_src_rank, +) +from megatron.training import get_args + +# ------------------------------------------------------------------- +# dtype <-> int mapping for cross-rank broadcast +# ------------------------------------------------------------------- + +_DTYPE_MAP = { + torch.float32: 0, + torch.float16: 1, + torch.bfloat16: 2, + torch.int64: 3, + torch.int32: 4, + torch.bool: 5, +} +_ID_MAP = {v: k for k, v in _DTYPE_MAP.items()} + + +def _dtype_to_id(dtype): + return _DTYPE_MAP.get(dtype, 0) + + +def _id_to_dtype(id_val): + return _ID_MAP.get(id_val, torch.float32) + + +# ------------------------------------------------------------------- +# Tensor broadcast helper +# ------------------------------------------------------------------- + +def _broadcast_tensor(tensor, src, group, device): + """Broadcast a single tensor from *src* to all ranks in *group*.""" + ndim = torch.tensor( + [len(tensor.shape) if tensor is not None else 0], + dtype=torch.long, + device=device, + ) + torch.distributed.broadcast(ndim, src, group=group) + + if ndim.item() == 0: + return None + + if tensor is not None: + shape_tensor = torch.tensor( + list(tensor.shape), dtype=torch.long, device=device, + ) + dtype_id = torch.tensor( + [_dtype_to_id(tensor.dtype)], + dtype=torch.long, + device=device, + ) + else: + shape_tensor = torch.zeros( + ndim.item(), dtype=torch.long, device=device, + ) + dtype_id = torch.zeros(1, dtype=torch.long, device=device) + + torch.distributed.broadcast(shape_tensor, src, group=group) + torch.distributed.broadcast(dtype_id, src, group=group) + + dtype = _id_to_dtype(dtype_id.item()) + shape = tuple(shape_tensor.tolist()) + + if tensor is None: + tensor = torch.empty(shape, dtype=dtype, device=device) + torch.distributed.broadcast(tensor, src, group=group) + return tensor + + +# ------------------------------------------------------------------- +# Batch broadcast across TP ranks +# ------------------------------------------------------------------- + +def broadcast_data_batch(data, device="cuda"): + """Broadcast a data-batch dict from TP rank 0 to all TP ranks.""" + src = get_tensor_model_parallel_src_rank() + group = get_tensor_model_parallel_group() + + if data is None: + data = {} + + if get_tensor_model_parallel_rank() == 0: + keys = list(data.keys()) + key_str = ",".join(keys) + key_bytes = key_str.encode("utf-8") + key_len = torch.tensor( + [len(key_bytes)], dtype=torch.long, device=device, + ) + else: + key_len = torch.zeros(1, dtype=torch.long, device=device) + keys = [] + + torch.distributed.broadcast(key_len, src, group=group) + + if get_tensor_model_parallel_rank() == 0: + key_tensor = torch.tensor( + list(key_bytes), dtype=torch.uint8, device=device, + ) + else: + key_tensor = torch.zeros( + key_len.item(), dtype=torch.uint8, device=device, + ) + + torch.distributed.broadcast(key_tensor, src, group=group) + + if get_tensor_model_parallel_rank() != 0: + key_str = bytes(key_tensor.cpu().tolist()).decode("utf-8") + keys = key_str.split(",") if key_str else [] + + result = {} + for key in keys: + tensor = data.get(key, None) if data else None + if tensor is not None and isinstance(tensor, torch.Tensor): + tensor = tensor.to(device) + result[key] = _broadcast_tensor( + tensor if isinstance(tensor, torch.Tensor) else None, + src, group, device, + ) + + return result + + +# ------------------------------------------------------------------- +# THD (packed sequence) helpers +# ------------------------------------------------------------------- + +def _build_packed_seq_params( + seq_lengths: torch.Tensor, device: torch.device, +) -> PackedSeqParams: + """Build ``PackedSeqParams`` from per-sample valid sequence lengths. + + Args: + seq_lengths: ``[B]`` valid token counts per sample. + device: Target device for cu_seqlens tensors. + + Returns: + A ``PackedSeqParams`` instance with ``qkv_format='thd'``. + """ + if not isinstance(seq_lengths, torch.Tensor): + seq_lengths = torch.tensor(seq_lengths) + lengths_t = seq_lengths.to(device=device, dtype=torch.int32) + cu_seqlens = torch.zeros( + lengths_t.numel() + 1, dtype=torch.int32, device=device, + ) + torch.cumsum(lengths_t, dim=0, out=cu_seqlens[1:]) + max_seqlen = int(lengths_t.max().item()) + return _build_packed_seq_params_from_cu_seqlens( + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + + +def _build_packed_seq_params_from_cu_seqlens( + cu_seqlens: torch.Tensor, max_seqlen: int, +) -> PackedSeqParams: + """Build ``PackedSeqParams`` from packed cumulative sequence lengths. + + ``cu_seqlens`` must already be on the target compute device. + """ + cs = cu_seqlens.to(dtype=torch.int32) + total_tokens = int(cs[-1].item()) + return PackedSeqParams( + cu_seqlens_q=cs, + cu_seqlens_kv=cs, + cu_seqlens_q_padded=cs, + cu_seqlens_kv_padded=cs, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + qkv_format='thd', + total_tokens=total_tokens, + ) + + + + +def pack_or_pad_batch(batch: list[Dict[str, Any]], use_packed_sequence: bool=False, seq_length: int=None, device = "cuda") -> list[Dict[str, Any]]: + """Pack or pad a ``[B, S]`` batch into ``[1, T]`` THD format.""" + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + + # SP is an explicit runtime option; TP>1 does not imply SP is enabled. + try: + has_sp = bool(get_args().sequence_parallel) + except Exception: + has_sp = False + + if cp_size > 1: + divisible_by = (tp_size * cp_size * 2) if has_sp else (cp_size * 2) + else: + divisible_by = tp_size if has_sp else 1 + # NOTE: don't consider fp8 padding now + + if use_packed_sequence: + input_ids_list, labels_list, loss_mask_list, pixel_values_list, image_grid_thw_list = [], [], [], [], [] + seqlens_list, seqlens_padded_list = [], [] + + # NOTE: for attention_mask, we don't use attention mask + # for position_ids, let model handle it itself + # we don't cut input id, althrough it exceeds seq_length + + packed_batch = dict() + + for sample in batch: + seqlen = sample["input_ids"].shape[0] + assert sample["labels"].shape == sample["input_ids"].shape == sample["loss_mask"].shape, "labels, input_ids, and loss_mask must have the same shape" + target_len = math.ceil(seqlen / divisible_by) * divisible_by + input_ids = F.pad(sample["input_ids"], (0, target_len - seqlen), value=0) + labels = F.pad(sample["labels"], (0, target_len - seqlen), value=-100) + loss_mask = F.pad(sample["loss_mask"], (0, target_len - seqlen), value=0) + + input_ids_list.append(input_ids) + labels_list.append(labels) + loss_mask_list.append(loss_mask) + seqlens_list.append(seqlen) + seqlens_padded_list.append(target_len) + pixel_values_list.append(sample["pixel_values"]) + image_grid_thw_list.append(sample["image_grid_thw"]) + + cu_seqlens = list(accumulate(seqlens_list, initial=0)) + cu_seqlens_padded = list(accumulate(seqlens_padded_list, initial=0)) + + # padding_mask: True at collate-padded positions within each packed + # sample. Real tokens occupy [cu_seqlens_padded[i], +seqlens_list[i]); + # the tail up to cu_seqlens_padded[i+1] is padding. Consumed by MoE + # routing in megatron.core to exclude padded tokens from aux loss, + # z-loss, and expert-bias accumulation. + total_tokens_padded = cu_seqlens_padded[-1] + padding_mask_thd = torch.zeros(total_tokens_padded, dtype=torch.bool) + for i, real_seqlen in enumerate(seqlens_list): + pad_start = cu_seqlens_padded[i] + real_seqlen + pad_end = cu_seqlens_padded[i + 1] + if pad_end > pad_start: + padding_mask_thd[pad_start:pad_end] = True + + packed_batch["input_ids"] = torch.concat(input_ids_list, dim=0).unsqueeze(0) + packed_batch["labels"] = torch.concat(labels_list, dim=0).unsqueeze(0) + packed_batch["loss_mask"] = torch.concat(loss_mask_list, dim=0).unsqueeze(0) + packed_batch["padding_mask"] = padding_mask_thd.unsqueeze(0) + + # TODO, maybe pixel_values's seqlens needs to be recorded. + packed_batch["pixel_values"] = torch.concat(pixel_values_list) + packed_batch["image_grid_thw"] = torch.concat(image_grid_thw_list) + + # broadcast to all tp ranks + packed_batch = broadcast_data_batch(packed_batch, device=device) + + packed_batch["packed_seq_params"] = PackedSeqParams( + qkv_format='thd', + cu_seqlens_q=torch.tensor(cu_seqlens, dtype=torch.int32, device=device), + cu_seqlens_kv=torch.tensor(cu_seqlens, dtype=torch.int32, device=device), + cu_seqlens_q_padded=torch.tensor(cu_seqlens_padded, dtype=torch.int32, device=device), + cu_seqlens_kv_padded=torch.tensor(cu_seqlens_padded, dtype=torch.int32, device=device), + max_seqlen_q=max(seqlens_padded_list), + max_seqlen_kv=max(seqlens_padded_list), + total_tokens=cu_seqlens_padded[-1], + ) + return packed_batch + else: + assert seq_length is not None, "seq_length must be provided when use_packed_sequence is False" + max_seqlens = max([x["input_ids"].shape[0] for x in batch]) + target_seqlens = min(max_seqlens, seq_length) + # Round target seqlen up to the parallelism alignment factor so the + # batched tensor is divisible for CP (+SP) splitting downstream. + if divisible_by > 1: + target_seqlens = math.ceil(target_seqlens / divisible_by) * divisible_by + padded_batch = dict() + # Capture real lengths before in-place padding so we can build a + # padding_mask for MoE routing (True at collate-padded positions). + real_seqlens = [s["input_ids"].shape[0] for s in batch] + + for sample in batch: + sample["input_ids"] = F.pad(sample["input_ids"], (0, target_seqlens - sample["input_ids"].shape[0]), value=0) + sample["labels"] = F.pad(sample["labels"], (0, target_seqlens - sample["labels"].shape[0]), value=-100) + sample["loss_mask"] = F.pad(sample["loss_mask"], (0, target_seqlens - sample["loss_mask"].shape[0]), value=0) + + padded_batch["input_ids"] = torch.concat([x["input_ids"].unsqueeze(0) for x in batch], dim=0) + padded_batch["labels"] = torch.concat([x["labels"].unsqueeze(0) for x in batch], dim=0) + padded_batch["loss_mask"] = torch.concat([x["loss_mask"].unsqueeze(0) for x in batch], dim=0) + positions = torch.arange(target_seqlens).unsqueeze(0) + padded_batch["padding_mask"] = positions >= torch.tensor(real_seqlens).unsqueeze(1) + padded_batch["pixel_values"] = torch.concat([x["pixel_values"] for x in batch]) + padded_batch["image_grid_thw"] = torch.concat([x["image_grid_thw"] for x in batch]) + # broadcast to all tp ranks + padded_batch = broadcast_data_batch(padded_batch, device=device) + return padded_batch + + +# ------------------------------------------------------------------- +# get_batch +# ------------------------------------------------------------------- + +def get_batch(data_iterator: Iterator[Dict[str, Any]]): + """Get a batch from *data_iterator* and broadcast across TP ranks.""" + device = "cuda" + args = get_args() + + if get_tensor_model_parallel_rank() == 0: + try: + data = next(data_iterator) + has_data = torch.tensor( + [1], dtype=torch.uint8, device=device, + ) + except StopIteration: + has_data = torch.tensor( + [0], dtype=torch.uint8, device=device, + ) + data = None + else: + has_data = torch.empty(1, dtype=torch.uint8, device=device) + data = None + + src = get_tensor_model_parallel_src_rank() + group = get_tensor_model_parallel_group() + torch.distributed.broadcast(has_data, src, group=group) + + if has_data.item() == 0: + return None + + # Because broadcast will not broadcast packed_seq_params, we move it into pack_or_pad_batch + batch = pack_or_pad_batch(data, args.use_packed_sequence, args.seq_length, device=device) + + # Fix shapes produced by default_collate. + if "position_ids" in batch and batch["position_ids"] is not None: + p = batch["position_ids"] + if p.dim() == 3 and p.shape[1] == 3: + batch["position_ids"] = p.permute(1, 0, 2).contiguous() + + if "pixel_values" in batch and batch["pixel_values"] is not None: + pv = batch["pixel_values"] + if pv.dim() == 3: + B, P, D = pv.shape + batch["pixel_values"] = pv.reshape(B * P, D) + + if ( + "image_grid_thw" in batch + and batch["image_grid_thw"] is not None + ): + g = batch["image_grid_thw"] + if g.dim() == 3: + batch["image_grid_thw"] = g.squeeze(1) + + return batch + + +# ------------------------------------------------------------------- +# Loss +# ------------------------------------------------------------------- + +def loss_func(loss_mask, output_tensor): + """Compute masked language model loss.""" + losses = output_tensor.float() + loss_mask = loss_mask.contiguous().view(-1).float() + + total_tokens = loss_mask.sum().clone().detach().to(torch.int) + total_loss = torch.sum(losses.view(-1) * loss_mask) + reporting_loss = torch.cat( + [total_loss.clone().detach().view(1), total_tokens.view(1)], + ) + + return (total_loss, total_tokens, {"lm loss": reporting_loss}) + + +# ------------------------------------------------------------------- +# Forward step +# ------------------------------------------------------------------- + +def forward_step(data_iterator, model): + """Forward step for multimodal_dev training.""" + batch = get_batch(data_iterator) + + if batch is None: + return None, None + + pixel_values = batch.get("pixel_values", None) + if ( + pixel_values is not None + and pixel_values.is_floating_point() + and pixel_values.dtype == torch.float32 + ): + pixel_values = pixel_values.bfloat16() + + # We don't provide position_ids, now. Let model handle it itself. + output_tensor = model( + input_ids=batch["input_ids"], + position_ids=batch.get("position_ids"), + attention_mask=batch.get("attention_mask", None), + labels=batch.get("labels", None), + loss_mask=batch.get("loss_mask", None), + padding_mask=batch.get("padding_mask", None), + pixel_values=pixel_values, + image_grid_thw=batch.get("image_grid_thw", None), + packed_seq_params=batch.get("packed_seq_params", None), + ) + + loss_mask = batch.get("loss_mask", None) + if loss_mask is None: + loss_mask = torch.ones_like( + batch["input_ids"], dtype=torch.float, + ) + + # CP-split loss_mask to match the model output (which is CP-split + # inside MultimodalModel.forward / Qwen35VLModel.forward). + # THD: use the same TE-based per-sample partition index as the model. + # BSHD: use the matching zigzag split. + cp_size = get_context_parallel_world_size() + if cp_size > 1: + from megatron.core.parallel_state import get_context_parallel_rank + + from examples.multimodal_dev.models.base import ( + _cp_split_tensor, + _thd_cp_partition_index, + ) + + cp_rank = get_context_parallel_rank() + psp = batch.get("packed_seq_params", None) + if psp is not None: + idx = _thd_cp_partition_index( + psp.cu_seqlens_q_padded, + loss_mask.shape[1], cp_size, cp_rank, + ) + loss_mask = loss_mask.index_select(1, idx) + else: + loss_mask = _cp_split_tensor( + loss_mask, seq_dim=1, + cp_size=cp_size, cp_rank=cp_rank, + ) + + return output_tensor, partial(loss_func, loss_mask) diff --git a/examples/multimodal_dev/models/__init__.py b/examples/multimodal_dev/models/__init__.py new file mode 100644 index 00000000000..33788bb3619 --- /dev/null +++ b/examples/multimodal_dev/models/__init__.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Model registry for multimodal_dev training. + +Maps ``--model-arch`` to a set of factory functions that fully encapsulate +model-specific logic. The training entry point (``pretrain_multimodal.py``) +remains model-agnostic — adding a new architecture only requires a new +registry entry (and its backing module) without touching the entry point. + +Registry entry fields +--------------------- +``model_factory_fn`` *(required)* + ``(args, language_config, vision_config, **kwargs) -> MegatronModule`` + Builds and returns the complete model instance. + +``vision_config_fn`` *(required)* + ``(num_layers_override=None, variant=None) -> TransformerConfig`` + Returns the vision encoder TransformerConfig. + +``post_language_config_fn`` *(optional)* + ``(language_config, args) -> None`` + Mutates the language TransformerConfig in-place with model-specific + fields (e.g. ``mrope_section``). + +``vision_flops_fn`` *(optional)* + ``(args, language_config, vision_config) -> None`` + Sets vision FLOPs metadata on ``args`` for training throughput logging. + +``dataset_providers`` *(optional)* + ``Dict[str, str | callable]`` + Maps ``--dataset-provider`` names to callables (or dotted import paths + resolved lazily) with signature + ``(train_val_test_num_samples) -> (train_ds, val_ds, test_ds)``. +""" + +from examples.multimodal_dev.models.qwen35_vl.configuration import get_qwen35_vl_vision_config +from examples.multimodal_dev.models.qwen35_vl.factory import build_model as _build_qwen35_vl_model +from examples.multimodal_dev.models.qwen35_vl.factory import ( + post_language_config as _qwen35_vl_post_language_config, +) +from examples.multimodal_dev.models.qwen35_vl.factory import ( + set_vision_flops_metadata as _qwen35_vl_vision_flops, +) + +MODEL_REGISTRY = { + "qwen35_vl": { + "model_factory_fn": _build_qwen35_vl_model, + "vision_config_fn": get_qwen35_vl_vision_config, + "post_language_config_fn": _qwen35_vl_post_language_config, + "vision_flops_fn": _qwen35_vl_vision_flops, + "dataset_providers": { + "mock": ( + "examples.multimodal_dev.data.mock" + ".train_valid_test_datasets_provider" + ), + "cord_v2": ( + "examples.multimodal_dev.data.vlm_dataset" + ".train_valid_test_datasets_provider" + ), + }, + }, +} diff --git a/examples/multimodal_dev/models/base.py b/examples/multimodal_dev/models/base.py new file mode 100644 index 00000000000..4f1ed97b284 --- /dev/null +++ b/examples/multimodal_dev/models/base.py @@ -0,0 +1,404 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Base multimodal model for FSDP + EP training. + +Composes a vision encoder and a ``GPTModel`` language decoder. Designed +for FSDP + EP: always builds the **full** model on every rank (no PP +flags). PP support is only available through the MIMO ``MimoModel`` +assembly path. + +Subclasses override ``compute_position_ids()`` for model-specific +position encoding (e.g. MRoPE for Qwen3.5-VL). +""" + +import contextlib +from typing import Optional + +import torch +from torch import Tensor + +from megatron.core import parallel_state, tensor_parallel +from megatron.core.models.gpt import GPTModel +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig + + +def _cp_split_tensor(tensor, seq_dim, cp_size, cp_rank): + """Zigzag-split *tensor* along *seq_dim* for context parallelism (BSHD). + + Splits the sequence into ``2 * cp_size`` equal chunks, then selects + chunks ``[cp_rank, 2*cp_size - cp_rank - 1]`` and concatenates them. + This mirrors ``megatron.core.utils.get_batch_on_this_cp_rank``. + """ + S = tensor.shape[seq_dim] + assert S % (2 * cp_size) == 0, ( + f"seq_len {S} not divisible by 2*cp_size={2 * cp_size}" + ) + tensor = tensor.view( + *tensor.shape[:seq_dim], + 2 * cp_size, + S // (2 * cp_size), + *tensor.shape[seq_dim + 1 :], + ) + index = torch.zeros(2, dtype=torch.int64, device=tensor.device) + index[0] = cp_rank + index[1] = 2 * cp_size - cp_rank - 1 + tensor = tensor.index_select(seq_dim, index) + tensor = tensor.view( + *tensor.shape[:seq_dim], + -1, + *tensor.shape[seq_dim + 2 :], + ) + return tensor + + +class _NoCPGroup: + """Dummy size-1 process group used to bypass MRoPE's BSHD-style + zigzag of pre-computed THD freqs (Megatron-Core gap: + ``MultimodalRotaryEmbedding.forward`` lacks the ``not packed_seq`` + skip that plain ``RotaryEmbedding`` has). + """ + + def size(self): + return 1 + + def rank(self): + return 0 + + +_NO_CP_GROUP = _NoCPGroup() + +# Note: reported ``mtp_1 loss`` drifts ~1.3% from the CP=1 baseline under +# THD+CP. Megatron-Core's logging averages per-rank pre-divided ratios +# with op=AVG, and per-rank num_tokens are unequal after MTP rolling. +# Gradients are correct; only the *logged* value drifts. + + +def _thd_cp_partition_index(cu_seqlens_padded, total_tokens, cp_size, cp_rank): + """Per-rank token index for THD + CP via TE's + ``thd_get_partitioned_indices``. Cast to int64 so the result can be + used directly with ``index_select`` regardless of TE's return dtype. + """ + from transformer_engine.pytorch import cpp_extensions as tex + + idx = tex.thd_get_partitioned_indices( + cu_seqlens_padded, total_tokens, cp_size, cp_rank, + ) + return idx.long() + + +class MultimodalModel(MegatronModule): + """Base class for multimodal vision-language models. + + Composes a pre-constructed vision encoder and a ``GPTModel`` language + decoder. Designed for FSDP + EP; always builds the full model on + every rank. + + Args: + language_config: ``TransformerConfig`` for the language decoder. + language_spec: ``ModuleSpec`` for decoder transformer layers. + vision_encoder: Pre-constructed vision encoder module. + vocab_size: Language model vocabulary size. + max_sequence_length: Maximum sequence length. + image_token_id: Token ID for image placeholder tokens. + position_embedding_type: Position embedding type for the decoder. + rotary_percent: Fraction of hidden dim for RoPE. + rotary_base: Base frequency for RoPE. + mrope_section: MRoPE channel sections. + mtp_block_spec: Optional MTP block spec. + parallel_output: Keep outputs split across TP ranks. + share_embeddings_and_output_weights: Tie input/output embeddings. + """ + + def __init__( + self, + language_config: TransformerConfig, + language_spec: ModuleSpec, + vision_encoder: MegatronModule, + vocab_size: int, + max_sequence_length: int, + image_token_id: int, + position_embedding_type: str = "rope", + rotary_percent: float = 1.0, + rotary_base: int = 10000, + mrope_section: list = None, + mtp_block_spec: ModuleSpec = None, + parallel_output: bool = True, + share_embeddings_and_output_weights: bool = False, + ): + super().__init__(config=language_config) + + self.image_token_id = image_token_id + + self.vision_model = vision_encoder + self.language_model = GPTModel( + config=language_config, + transformer_layer_spec=language_spec, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + pre_process=True, + post_process=True, + parallel_output=parallel_output, + share_embeddings_and_output_weights=( + share_embeddings_and_output_weights + ), + position_embedding_type=position_embedding_type, + rotary_percent=rotary_percent, + rotary_base=rotary_base, + mtp_block_spec=mtp_block_spec, + ) + + def set_input_tensor(self, input_tensor): + """Route input tensors (simplified, no PP routing).""" + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + assert len(input_tensor) == 1 + self.language_model.set_input_tensor(input_tensor[0]) + + def _scatter_vision_embeddings( + self, + input_ids: Tensor, + text_embeddings: Tensor, + vision_embeddings: Tensor, + ) -> Tensor: + """Replace image-token positions with vision embeddings. + + Handles sequence parallelism (gather → scatter → re-scatter). + + Args: + input_ids: ``[B, S]`` token IDs. + text_embeddings: ``[S, B, D]`` (or ``[S/TP, B, D]`` with SP). + vision_embeddings: ``[num_visual_tokens, D]``. + + Returns: + Combined embeddings, same shape as *text_embeddings*. + """ + sp = ( + self.config.sequence_parallel + and parallel_state.get_tensor_model_parallel_world_size() + > 1 + ) + + if sp: + text_embeddings = ( + tensor_parallel.gather_from_sequence_parallel_region( + text_embeddings, tensor_parallel_output_grad=False, + ) + ) + + combined = text_embeddings.transpose(0, 1).contiguous() + image_mask = input_ids == self.image_token_id + mask_expanded = image_mask.unsqueeze(-1).expand_as(combined) + combined = combined.masked_scatter( + mask_expanded, vision_embeddings, + ) + combined = combined.transpose(0, 1).contiguous() + + if sp: + combined = ( + tensor_parallel.scatter_to_sequence_parallel_region( + combined, + ) + ) + + return combined + + def compute_position_ids( + self, + input_ids: Tensor, + image_grid_thw: Optional[Tensor] = None, + packed_seq_params=None, + ) -> Tensor: + """Compute position IDs. Override for MRoPE etc. + + Default: simple sequential positions. ``packed_seq_params`` is + accepted for subclass compatibility (e.g. MRoPE in THD mode). + """ + B, S = input_ids.shape + return ( + torch.arange(S, device=input_ids.device) + .unsqueeze(0) + .expand(B, -1) + ) + + def _cp_split_for_forward( + self, + *, + decoder_input, + input_ids, + labels, + loss_mask, + attention_mask, + position_ids, + packed_seq_params, + padding_mask=None, + ): + """Apply CP split to model-forward inputs. + + BSHD path zigzag-splits each tensor along its seq dim. THD path + partitions per-sample via ``tex.thd_get_partitioned_indices`` so + chunks line up with ``cu_seqlens_q_padded`` boundaries. + ``position_ids`` and ``attention_mask`` are NOT split in THD — + MRoPE returns full freqs and TE attention's + ``_apply_rotary_pos_emb_thd`` does the per-sample CP zigzag + itself via ``_get_thd_freqs_on_this_cp_rank``. + """ + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size <= 1: + return ( + decoder_input, input_ids, labels, loss_mask, + attention_mask, position_ids, padding_mask, + ) + cp_rank = parallel_state.get_context_parallel_rank() + + if packed_seq_params is not None: + total_tokens = ( + decoder_input.shape[0] + if decoder_input is not None + else input_ids.shape[1] + ) + idx = _thd_cp_partition_index( + packed_seq_params.cu_seqlens_q_padded, + total_tokens, cp_size, cp_rank, + ) + if decoder_input is not None: + decoder_input = decoder_input.index_select(0, idx) + if input_ids is not None: + input_ids = input_ids.index_select(1, idx) + if labels is not None: + labels = labels.index_select(1, idx) + if loss_mask is not None: + loss_mask = loss_mask.index_select(1, idx) + if padding_mask is not None: + padding_mask = padding_mask.index_select(1, idx) + else: + def _split(t, seq_dim): + return None if t is None else _cp_split_tensor( + t, seq_dim=seq_dim, cp_size=cp_size, cp_rank=cp_rank, + ) + decoder_input = _split(decoder_input, 0) + input_ids = _split(input_ids, 1) + labels = _split(labels, 1) + loss_mask = _split(loss_mask, 1) + attention_mask = _split(attention_mask, 1) + padding_mask = _split(padding_mask, 1) + + return ( + decoder_input, input_ids, labels, loss_mask, + attention_mask, position_ids, padding_mask, + ) + + @contextlib.contextmanager + def _thd_mrope_no_cp_override(self, packed_seq_params): + """Force ``rotary_pos_emb.cp_group`` to size 1 for the wrapped + forward call so MRoPE returns full-length freqs in THD mode. + Attention then applies per-sample CP zigzag itself via + ``_apply_rotary_pos_emb_thd``. Done by direct mutation rather + than via ``packed_seq_params.cp_group`` so MTP's CP-aware roll + (which reads that field) still sees the real CP group. + """ + mrope = ( + getattr(self.language_model, "rotary_pos_emb", None) + if packed_seq_params is not None + and parallel_state.get_context_parallel_world_size() > 1 + else None + ) + saved = getattr(mrope, "cp_group", None) if mrope is not None else None + if mrope is not None: + mrope.cp_group = _NO_CP_GROUP + try: + yield + finally: + if mrope is not None: + mrope.cp_group = saved + + def forward( + self, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor = None, + labels: Tensor = None, + loss_mask: Tensor = None, + padding_mask: Tensor = None, + pixel_values: Tensor = None, + image_grid_thw: Tensor = None, + decoder_input: Tensor = None, + packed_seq_params=None, + **kwargs, + ): + """Forward pass. + + Args: + input_ids: ``[B, S]`` token IDs (or ``[1, T]`` in THD mode). + position_ids: ``[3, B, S]`` for MRoPE or ``[B, S]`` + (``[3, 1, T]`` / ``[1, T]`` in THD mode). + attention_mask: ``[B, S]`` attention mask (None in THD). + labels: ``[B, S]`` target token IDs (``[1, T]`` in THD). + loss_mask: ``[B, S]`` mask for loss (``[1, T]`` in THD). + padding_mask: ``[B, S]`` bool mask, True at collate-padded + positions (``[1, T]`` in THD). Forwarded to the language + decoder so MoE routing excludes padded tokens from aux + loss / z-loss / expert-bias accumulation. Distinct from + ``loss_mask``: only true padding, not SFT prompt tokens. + pixel_values: Preprocessed image pixels. + image_grid_thw: ``[num_images, 3]`` grid dimensions. + decoder_input: Pre-computed decoder input (skip embed). + packed_seq_params: ``PackedSeqParams`` for THD attention. + + Returns: + Loss tensor (post_process=True) or hidden states. + """ + if position_ids is None: + position_ids = self.compute_position_ids( + input_ids=input_ids, + image_grid_thw=image_grid_thw, + packed_seq_params=packed_seq_params, + ) + + vision_embeddings = None + if ( + self.vision_model is not None + and pixel_values is not None + ): + vision_embeddings = self.vision_model( + pixel_values, image_grid_thw, + ) + + if decoder_input is None and self.language_model is not None: + text_embeddings = self.language_model.embedding( + input_ids=input_ids, position_ids=None, + ) + + if vision_embeddings is not None: + decoder_input = self._scatter_vision_embeddings( + input_ids, text_embeddings, vision_embeddings, + ) + else: + decoder_input = text_embeddings + + ( + decoder_input, input_ids, labels, loss_mask, + attention_mask, position_ids, padding_mask, + ) = self._cp_split_for_forward( + decoder_input=decoder_input, + input_ids=input_ids, + labels=labels, + loss_mask=loss_mask, + attention_mask=attention_mask, + position_ids=position_ids, + packed_seq_params=packed_seq_params, + padding_mask=padding_mask, + ) + + with self._thd_mrope_no_cp_override(packed_seq_params): + return self.language_model( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + decoder_input=decoder_input, + labels=labels, + loss_mask=loss_mask, + padding_mask=padding_mask, + packed_seq_params=packed_seq_params, + ) diff --git a/examples/multimodal_dev/models/qwen35_vl/__init__.py b/examples/multimodal_dev/models/qwen35_vl/__init__.py new file mode 100644 index 00000000000..44e82fa14bb --- /dev/null +++ b/examples/multimodal_dev/models/qwen35_vl/__init__.py @@ -0,0 +1,70 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Qwen3.5-VL model components — the single source of truth. + +Both the standalone ``multimodal_dev`` training path and the MIMO path +import from here. +""" + +from examples.multimodal_dev.models.qwen35_vl.configuration import ( + MROPE_SECTION, + QWEN35_VL_IMAGE_TOKEN_ID, + QWEN35_VL_VIDEO_TOKEN_ID, + QWEN35_VL_VISION_END_TOKEN_ID, + QWEN35_VL_VISION_START_TOKEN_ID, + QWEN35_VL_VOCAB_SIZE, + ROTARY_BASE, + ROTARY_PERCENT, + VISION_KWARGS, + get_qwen35_vl_language_config, + get_qwen35_vl_vision_config, +) +from examples.multimodal_dev.models.qwen35_vl.factory import ( + build_model, + post_language_config, + set_vision_flops_metadata, +) +from examples.multimodal_dev.models.qwen35_vl.model import Qwen35VLModel +from examples.multimodal_dev.models.qwen35_vl.mrope import get_rope_index +from examples.multimodal_dev.models.qwen35_vl.specs import ( + get_qwen35_vl_language_spec, + get_qwen35_vl_vision_spec, +) +from examples.multimodal_dev.models.qwen35_vl.vision_encoder import ( + Qwen35VLPatchEmbed, + Qwen35VLPatchMerger, + Qwen35VLVisionEncoder, + Qwen35VLVisionRotaryEmbedding, +) + +__all__ = [ + # Model class + "Qwen35VLModel", + # Factory functions + "build_model", + "post_language_config", + "set_vision_flops_metadata", + # Vision encoder + "Qwen35VLVisionEncoder", + "Qwen35VLPatchEmbed", + "Qwen35VLPatchMerger", + "Qwen35VLVisionRotaryEmbedding", + # Config helpers + "get_qwen35_vl_vision_config", + "get_qwen35_vl_language_config", + # Spec helpers + "get_qwen35_vl_language_spec", + "get_qwen35_vl_vision_spec", + # MRoPE + "get_rope_index", + # Constants + "QWEN35_VL_IMAGE_TOKEN_ID", + "QWEN35_VL_VIDEO_TOKEN_ID", + "QWEN35_VL_VISION_START_TOKEN_ID", + "QWEN35_VL_VISION_END_TOKEN_ID", + "QWEN35_VL_VOCAB_SIZE", + "ROTARY_BASE", + "ROTARY_PERCENT", + "MROPE_SECTION", + "VISION_KWARGS", +] diff --git a/examples/multimodal_dev/models/qwen35_vl/configuration.py b/examples/multimodal_dev/models/qwen35_vl/configuration.py new file mode 100644 index 00000000000..41f6dda2524 --- /dev/null +++ b/examples/multimodal_dev/models/qwen35_vl/configuration.py @@ -0,0 +1,355 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Configuration helpers for Qwen3.5-VL vision-language model. + +Provides TransformerConfig builders for the vision encoder and all language +decoder variants. Both the standalone ``multimodal_dev`` training path and the +MIMO path import from here — this is the single source of truth. + +Supported language variants (HuggingFace Qwen3.5 series): + ``0.8b`` Dense 0.8B + ``2b`` Dense 2B + ``4b`` Dense 4B + ``9b`` Dense 9B + ``27b`` Dense 27B + ``35b_a3b`` MoE 35B-A3B (256 experts, top-8) + ``122b_a10b`` MoE 122B-A10B (256 experts, top-8) + ``397b_a17b`` MoE 397B-A17B (512 experts, top-10) + ``35b_a3b_light`` Reduced 35B-A3B for testing + ``proxy`` Reduced proxy based on 397B for single-node testing +""" + +from typing import Optional + +import torch + +from megatron.core.transformer.transformer_config import TransformerConfig + +# --------------------------------------------------------------------------- +# Public constants +# --------------------------------------------------------------------------- + +QWEN35_VL_IMAGE_TOKEN_ID: int = 248056 +QWEN35_VL_VIDEO_TOKEN_ID: int = 248057 +QWEN35_VL_VISION_START_TOKEN_ID: int = 248053 +QWEN35_VL_VISION_END_TOKEN_ID: int = 248054 +QWEN35_VL_VOCAB_SIZE: int = 248320 + +ROTARY_BASE: int = 10_000_000 +ROTARY_PERCENT: float = 0.25 +MROPE_SECTION: list = [11, 11, 10] + +# --------------------------------------------------------------------------- +# Vision config +# --------------------------------------------------------------------------- + +VISION_KWARGS = { + "in_channels": 3, + "patch_size": 16, + "temporal_patch_size": 2, + "spatial_merge_size": 2, + "out_hidden_size": 3584, + "max_num_positions": 2304, +} + +# Three distinct vision encoder architectures in the Qwen3.5 family. +_VISION_SMALL = { + "num_layers": 12, "hidden_size": 768, "num_attention_heads": 12, + "kv_channels": 64, "ffn_hidden_size": 3072, +} +_VISION_MEDIUM = { + "num_layers": 24, "hidden_size": 1024, "num_attention_heads": 16, + "kv_channels": 64, "ffn_hidden_size": 4096, +} +_VISION_LARGE = { + "num_layers": 27, "hidden_size": 1152, "num_attention_heads": 16, + "kv_channels": 72, "ffn_hidden_size": 4304, +} + +# Per-variant vision config. ``out_hidden_size`` equals the language model's +# hidden_size and controls the merger projection output dimension. +_VISION_VARIANT_CONFIGS = { + "0.8b": {**_VISION_SMALL, "out_hidden_size": 1024}, + "2b": {**_VISION_MEDIUM, "out_hidden_size": 2048}, + "4b": {**_VISION_MEDIUM, "out_hidden_size": 2560}, + "9b": {**_VISION_LARGE, "out_hidden_size": 4096}, + "27b": {**_VISION_LARGE, "out_hidden_size": 5120}, + "35b_a3b": {**_VISION_LARGE, "out_hidden_size": 2048}, + "122b_a10b": {**_VISION_LARGE, "out_hidden_size": 3072}, + "397b_a17b": {**_VISION_LARGE, "out_hidden_size": 4096}, +} + +# Fallback for proxy/unknown variants (large ViT, generic out_hidden_size). +_VISION_DEFAULT = {**_VISION_LARGE, "out_hidden_size": 3584} + + +def get_qwen35_vl_vision_config( + num_layers_override: Optional[int] = None, + variant: Optional[str] = None, +) -> TransformerConfig: + """TransformerConfig for the Qwen3.5-VL vision encoder. + + Three ViT architectures are used across the family: + - Small (0.8b): depth 12, 768-dim, 12 heads + - Medium (2b, 4b): depth 24, 1024-dim, 16 heads + - Large (9b, 27b, MoE variants): depth 27, 1152-dim, 16 heads + + Args: + num_layers_override: Override vision backbone depth for proxy runs. + variant: Language model variant name. When set, selects the + matching vision config from ``_VISION_VARIANT_CONFIGS`` if one + exists; otherwise the default large-ViT config is used. + """ + vcfg = _VISION_VARIANT_CONFIGS.get(variant, _VISION_DEFAULT) + num_layers = vcfg["num_layers"] + if num_layers_override is not None: + num_layers = num_layers_override + + return TransformerConfig( + num_layers=num_layers, + hidden_size=vcfg["hidden_size"], + num_attention_heads=vcfg["num_attention_heads"], + kv_channels=vcfg["kv_channels"], + ffn_hidden_size=vcfg["ffn_hidden_size"], + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-6, + normalization="LayerNorm", + gated_linear_unit=False, + activation_func=lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + bias_activation_fusion=False, + apply_query_key_layer_scaling=False, + apply_rope_fusion=False, + bf16=False, + ) + + +# --------------------------------------------------------------------------- +# Language config variants +# --------------------------------------------------------------------------- + +_VARIANT_CONFIGS = { + "0.8b": { + "num_layers": 24, + "hidden_size": 1024, + "ffn_hidden_size": 3584, + "num_attention_heads": 8, + "num_query_groups": 2, + "kv_channels": 256, + "linear_num_value_heads": 16, + "num_moe_experts": None, + "moe_router_topk": None, + "moe_ffn_hidden_size": None, + "moe_shared_expert_intermediate_size": None, + }, + "2b": { + "num_layers": 24, + "hidden_size": 2048, + "ffn_hidden_size": 6144, + "num_attention_heads": 8, + "num_query_groups": 2, + "kv_channels": 256, + "linear_num_value_heads": 16, + "num_moe_experts": None, + "moe_router_topk": None, + "moe_ffn_hidden_size": None, + "moe_shared_expert_intermediate_size": None, + }, + "4b": { + "num_layers": 32, + "hidden_size": 2560, + "ffn_hidden_size": 9216, + "num_attention_heads": 16, + "num_query_groups": 4, + "kv_channels": 256, + "linear_num_value_heads": 32, + "num_moe_experts": None, + "moe_router_topk": None, + "moe_ffn_hidden_size": None, + "moe_shared_expert_intermediate_size": None, + }, + "9b": { + "num_layers": 32, + "hidden_size": 4096, + "ffn_hidden_size": 12288, + "num_attention_heads": 16, + "num_query_groups": 4, + "kv_channels": 256, + "linear_num_value_heads": 32, + "num_moe_experts": None, + "moe_router_topk": None, + "moe_ffn_hidden_size": None, + "moe_shared_expert_intermediate_size": None, + }, + "27b": { + "num_layers": 64, + "hidden_size": 5120, + "ffn_hidden_size": 17408, + "num_attention_heads": 24, + "num_query_groups": 4, + "kv_channels": 256, + "linear_num_value_heads": 48, + "num_moe_experts": None, + "moe_router_topk": None, + "moe_ffn_hidden_size": None, + "moe_shared_expert_intermediate_size": None, + }, + "35b_a3b": { + "num_layers": 40, + "hidden_size": 2048, + "ffn_hidden_size": 4096, + "num_attention_heads": 16, + "num_query_groups": 2, + "kv_channels": 256, + "linear_num_value_heads": 32, + "num_moe_experts": 256, + "moe_router_topk": 8, + "moe_ffn_hidden_size": 512, + "moe_shared_expert_intermediate_size": 512, + }, + "35b_a3b_light": { + "num_layers": 20, + "hidden_size": 2048, + "ffn_hidden_size": 4096, + "num_attention_heads": 16, + "num_query_groups": 2, + "kv_channels": 256, + "linear_num_value_heads": 32, + "num_moe_experts": 256, + "moe_router_topk": 8, + "moe_ffn_hidden_size": 512, + "moe_shared_expert_intermediate_size": 512, + }, + "122b_a10b": { + "num_layers": 48, + "hidden_size": 3072, + "ffn_hidden_size": 8192, + "num_attention_heads": 32, + "num_query_groups": 2, + "kv_channels": 256, + "linear_num_value_heads": 64, + "num_moe_experts": 256, + "moe_router_topk": 8, + "moe_ffn_hidden_size": 1024, + "moe_shared_expert_intermediate_size": 1024, + }, + "397b_a17b": { + "num_layers": 60, + "hidden_size": 4096, + "ffn_hidden_size": 10240, + "num_attention_heads": 32, + "num_query_groups": 2, + "kv_channels": 256, + "linear_num_value_heads": 64, + "num_moe_experts": 512, + "moe_router_topk": 10, + "moe_ffn_hidden_size": 1024, + "moe_shared_expert_intermediate_size": 1024, + }, + "proxy": { + "num_layers": 4, + "hidden_size": 4096, + "ffn_hidden_size": 10240, + "num_attention_heads": 32, + "num_query_groups": 2, + "kv_channels": 256, + "linear_num_value_heads": 64, + "num_moe_experts": 16, + "moe_router_topk": 2, + "moe_ffn_hidden_size": 1024, + "moe_shared_expert_intermediate_size": 1024, + }, +} + + +def get_qwen35_vl_language_config( + variant: str = "proxy", + **overrides, +) -> TransformerConfig: + """TransformerConfig for the Qwen3.5-VL language decoder. + + The ``397b_a17b`` variant reproduces the MIMO + ``get_qwen35_language_model_config()`` output exactly. + + Args: + variant: One of ``0.8b``, ``2b``, ``4b``, ``9b``, ``27b``, + ``35b_a3b``, ``122b_a10b``, ``397b_a17b``, + ``35b_a3b_light``, ``proxy``. + **overrides: Override any TransformerConfig field. + + Returns: + Fully-populated TransformerConfig. + """ + if variant not in _VARIANT_CONFIGS: + raise ValueError( + f"Unknown variant '{variant}'. " + f"Choose from {list(_VARIANT_CONFIGS.keys())}" + ) + + v = _VARIANT_CONFIGS[variant] + + kwargs = dict( + # Architecture + num_layers=v["num_layers"], + hidden_size=v["hidden_size"], + ffn_hidden_size=v["ffn_hidden_size"], + num_attention_heads=v["num_attention_heads"], + num_query_groups=v["num_query_groups"], + kv_channels=v["kv_channels"], + # Normalization & activation + normalization="RMSNorm", + layernorm_epsilon=1e-6, + layernorm_zero_centered_gamma=True, + apply_residual_connection_post_layernorm=False, + gated_linear_unit=True, + activation_func=torch.nn.functional.silu, + # MRoPE section (interleaved T/H/W layout, Qwen3.5-VL style) + mrope_section=list(MROPE_SECTION), + mrope_interleaved=True, + rotary_interleaved=False, + # Attention + qk_layernorm=True, + attention_output_gate=True, + attention_dropout=0.0, + hidden_dropout=0.0, + add_bias_linear=False, + # Hybrid attention (GatedDeltaNet) + experimental_attention_variant="gated_delta_net", + linear_attention_freq=4, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=v["linear_num_value_heads"], + # Kernel / TE fusions + bias_activation_fusion=True, + masked_softmax_fusion=True, + persist_layer_norm=True, + bias_dropout_fusion=True, + apply_rope_fusion=False, + # Precision + bf16=True, + ) + + # MoE config (only for MoE variants) + if v["num_moe_experts"] is not None: + kwargs.update( + num_moe_experts=v["num_moe_experts"], + moe_router_topk=v["moe_router_topk"], + moe_ffn_hidden_size=v["moe_ffn_hidden_size"], + moe_shared_expert_intermediate_size=v[ + "moe_shared_expert_intermediate_size" + ], + moe_shared_expert_gate=True, + moe_layer_freq=1, + moe_router_pre_softmax=False, + moe_router_load_balancing_type="global_aux_loss", + moe_permute_fusion=True, + moe_aux_loss_coeff=1e-3, + moe_grouped_gemm=True, + moe_token_dispatcher_type="alltoall", + moe_router_dtype="fp32", + ) + + kwargs.update(overrides) + return TransformerConfig(**kwargs) diff --git a/examples/multimodal_dev/models/qwen35_vl/factory.py b/examples/multimodal_dev/models/qwen35_vl/factory.py new file mode 100644 index 00000000000..1b9076845e7 --- /dev/null +++ b/examples/multimodal_dev/models/qwen35_vl/factory.py @@ -0,0 +1,101 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Factory functions for Qwen3.5-VL model construction. + +Encapsulates all Qwen3.5-VL-specific logic needed by ``pretrain_multimodal.py`` +so that the training entry point remains model-agnostic. +""" + +from examples.multimodal_dev.models.qwen35_vl.configuration import ( + MROPE_SECTION, + VISION_KWARGS, +) + + +def post_language_config(language_config, args): + """Apply Qwen3.5-VL-specific settings to the language TransformerConfig. + + Called after ``core_transformer_config_from_args`` to inject model-specific + fields that cannot be expressed via CLI args alone. + """ + language_config.mrope_section = list(MROPE_SECTION) + language_config.mrope_interleaved = True + + +def set_vision_flops_metadata(args, language_config, vision_config): + """Expose Qwen3.5-VL vision-model dimensions for FLOPs estimation.""" + args.count_vision_model_flops = True + args.vision_flops_variant = "qwen35_vl_v2" + args.vision_num_layers = vision_config.num_layers + args.vision_hidden_size = vision_config.hidden_size + args.vision_ffn_hidden_size = vision_config.ffn_hidden_size + args.vision_num_attention_heads = vision_config.num_attention_heads + args.vision_kv_channels = vision_config.kv_channels + args.vision_in_channels = VISION_KWARGS["in_channels"] + args.vision_patch_size = VISION_KWARGS["patch_size"] + args.vision_temporal_patch_size = VISION_KWARGS["temporal_patch_size"] + args.vision_spatial_merge_size = VISION_KWARGS["spatial_merge_size"] + args.vision_out_hidden_size = language_config.hidden_size + + +def build_model(args, language_config, vision_config, **kwargs): + """Build a complete Qwen3.5-VL model instance. + + Handles language spec construction, optional MTP block spec, and + model instantiation with Qwen3.5-VL-specific parameters. + + Args: + args: Megatron parsed arguments. + language_config: ``TransformerConfig`` for the language decoder + (already post-processed by :func:`post_language_config`). + vision_config: ``TransformerConfig`` for the vision encoder. + **kwargs: Extra keyword arguments (e.g. ``vp_stage``). + + Returns: + A :class:`Qwen35VLModel` instance. + """ + from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_mtp_block_spec, + ) + + from examples.multimodal_dev.models.qwen35_vl.model import Qwen35VLModel + from examples.multimodal_dev.models.qwen35_vl.specs import ( + get_qwen35_vl_language_spec, + ) + + language_spec = get_qwen35_vl_language_spec( + config=language_config, + vp_stage=kwargs.get("vp_stage", None), + pp_rank=None, + ) + + mtp_block_spec = None + if getattr(args, "mtp_num_layers", None): + mtp_block_spec = get_gpt_mtp_block_spec( + config=language_config, + spec=language_spec, + use_transformer_engine=( + args.transformer_impl == "transformer_engine" + ), + vp_stage=kwargs.get("vp_stage", None), + pp_rank=None, + ) + + # When --untie-embeddings-and-output-weights is NOT passed, Megatron + # defaults to tied embeddings (share_embeddings_and_output_weights=True). + # The 0.8B variant uses tied embeddings, while larger variants untie them. + share_embeddings = not getattr( + args, "untie_embeddings_and_output_weights", False + ) + + return Qwen35VLModel( + language_config=language_config, + language_spec=language_spec, + vision_config=vision_config, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + image_token_id=getattr(args, "image_token_id", 248056), + mtp_block_spec=mtp_block_spec, + parallel_output=True, + share_embeddings_and_output_weights=share_embeddings, + ) diff --git a/examples/multimodal_dev/models/qwen35_vl/model.py b/examples/multimodal_dev/models/qwen35_vl/model.py new file mode 100644 index 00000000000..2e2d1455af8 --- /dev/null +++ b/examples/multimodal_dev/models/qwen35_vl/model.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Qwen3.5-VL multimodal model for standalone FSDP + EP training. + +Composes a Megatron-native Qwen3.5 vision encoder with a ``GPTModel`` +language decoder using MRoPE and hybrid GatedDeltaNet / full-attention +layers. +""" + +from typing import Optional + +from torch import Tensor + +from examples.multimodal_dev.models.base import MultimodalModel +from examples.multimodal_dev.models.qwen35_vl.configuration import ( + QWEN35_VL_IMAGE_TOKEN_ID, + QWEN35_VL_VIDEO_TOKEN_ID, + QWEN35_VL_VISION_START_TOKEN_ID, + QWEN35_VL_VOCAB_SIZE, + ROTARY_BASE, + ROTARY_PERCENT, + VISION_KWARGS, +) +from examples.multimodal_dev.models.qwen35_vl.mrope import get_rope_index +from examples.multimodal_dev.models.qwen35_vl.specs import get_qwen35_vl_vision_spec +from examples.multimodal_dev.models.qwen35_vl.vision_encoder import Qwen35VLVisionEncoder +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig + + +class Qwen35VLModel(MultimodalModel): + """Qwen3.5-VL multimodal model. + + Args: + language_config: ``TransformerConfig`` for the language decoder. + language_spec: ``ModuleSpec`` for language decoder layers. + vision_config: ``TransformerConfig`` for the vision encoder. + vision_spec: ``ModuleSpec`` for vision encoder layers. + vocab_size: Vocabulary size. + max_sequence_length: Maximum sequence length. + image_token_id: Token ID for image placeholders. + spatial_merge_size: Vision encoder spatial merge factor. + mtp_block_spec: Optional MTP block spec. + parallel_output: Keep outputs split across TP. + share_embeddings_and_output_weights: Tie embeddings. + """ + + def __init__( + self, + language_config: TransformerConfig, + language_spec: ModuleSpec, + vision_config: TransformerConfig, + vision_spec: ModuleSpec = None, + vocab_size: int = QWEN35_VL_VOCAB_SIZE, + max_sequence_length: int = 262144, + image_token_id: int = QWEN35_VL_IMAGE_TOKEN_ID, + video_token_id: int = QWEN35_VL_VIDEO_TOKEN_ID, + vision_start_token_id: int = QWEN35_VL_VISION_START_TOKEN_ID, + spatial_merge_size: int = 2, + mtp_block_spec: ModuleSpec = None, + parallel_output: bool = True, + share_embeddings_and_output_weights: bool = False, + ): + if vision_spec is None: + vision_spec = get_qwen35_vl_vision_spec() + + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.spatial_merge_size = spatial_merge_size + + vkw = dict(VISION_KWARGS) + vkw["spatial_merge_size"] = spatial_merge_size + vkw["out_hidden_size"] = language_config.hidden_size + + vision_encoder = Qwen35VLVisionEncoder( + config=vision_config, + transformer_layer_spec=vision_spec, + in_channels=vkw["in_channels"], + patch_size=vkw["patch_size"], + temporal_patch_size=vkw["temporal_patch_size"], + spatial_merge_size=vkw["spatial_merge_size"], + out_hidden_size=vkw["out_hidden_size"], + max_num_positions=vkw["max_num_positions"], + ) + + super().__init__( + language_config=language_config, + language_spec=language_spec, + vision_encoder=vision_encoder, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + image_token_id=image_token_id, + position_embedding_type="mrope", + rotary_percent=ROTARY_PERCENT, + rotary_base=ROTARY_BASE, + mrope_section=language_config.mrope_section, + mtp_block_spec=mtp_block_spec, + parallel_output=parallel_output, + share_embeddings_and_output_weights=( + share_embeddings_and_output_weights + ), + ) + + def compute_position_ids( + self, + input_ids: Tensor, + image_grid_thw: Optional[Tensor] = None, + packed_seq_params=None, + ) -> Tensor: + """Compute 3D MRoPE position IDs for Qwen3.5-VL. + + In THD mode ``input_ids`` is ``[1, T]`` and ``packed_seq_params`` + supplies per-segment boundaries; positions restart at 0 per + segment. In BSHD mode ``input_ids`` is ``[B, S]`` and + ``packed_seq_params`` should be ``None``. + + Returns: + ``[3, B, S]`` position IDs for MRoPE (``[3, 1, T]`` in THD). + """ + position_ids, _ = get_rope_index( + spatial_merge_size=self.spatial_merge_size, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + vision_start_token_id=self.vision_start_token_id, + input_ids=input_ids, + image_grid_thw=image_grid_thw, + packed_seq_params=packed_seq_params, + ) + return position_ids diff --git a/examples/multimodal_dev/models/qwen35_vl/mrope.py b/examples/multimodal_dev/models/qwen35_vl/mrope.py new file mode 100644 index 00000000000..44806a9bdf0 --- /dev/null +++ b/examples/multimodal_dev/models/qwen35_vl/mrope.py @@ -0,0 +1,378 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""MRoPE (Multimodal Rotary Position Embedding) position ID computation. + +Computes 3D position IDs for Qwen3.5-VL: for text tokens all three +dimensions share sequential positions; for image/video tokens the three +dimensions encode (temporal, height, width) in the merged spatial grid. + +Supports two input layouts: + +* **BSHD** — ``input_ids`` is ``[B, S]``; each row is an independent + sample (possibly padded) and ``attention_mask`` marks valid tokens. +* **THD** — ``input_ids`` is ``[1, T]``, a concatenation of ``N`` + sub-sequences. ``packed_seq_params.cu_seqlens_q_padded`` gives the + physical segment boundaries in the packed tensor and + ``cu_seqlens_q`` gives the valid (unpadded) token count inside each + segment. Position IDs restart at 0 at every segment boundary; image + / video grid rows are consumed in packed order across segments. + +Ported from Megatron-Bridge ``get_rope_index`` (which itself is adapted +from HF ``Qwen3VLForConditionalGeneration.get_rope_index``). The inner +loop iterates over vision occurrences, not individual tokens. +""" + +from typing import Optional + +import torch +from torch import Tensor + +from megatron.core.packed_seq_params import PackedSeqParams + + +def _build_sample_mrope_positions( + sample_input_ids: Tensor, + image_grid_thw: Optional[Tensor], + video_grid_thw: Optional[Tensor], + image_index: int, + video_index: int, + spatial_merge_size: int, + image_token_id: int, + video_token_id: int, + vision_start_token_id: int, +) -> tuple[Tensor, int, int]: + """Compute MRoPE position IDs for a single sub-sequence. + + Walks vision occurrences in ``sample_input_ids`` and produces a + ``[3, L]`` position tensor whose values start at 0. Advances + ``image_index`` / ``video_index`` through ``image_grid_thw`` / + ``video_grid_thw`` so callers can keep a running cursor across + multiple sub-sequences. + """ + vision_start_indices = torch.argwhere( + sample_input_ids == vision_start_token_id, + ).squeeze(1) + vision_tokens = sample_input_ids[vision_start_indices + 1] + image_nums = int((vision_tokens == image_token_id).sum()) + video_nums = int((vision_tokens == video_token_id).sum()) + input_tokens = sample_input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if llm_pos_ids_list + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + + st_idx + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + + text_len + + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if llm_pos_ids_list + else 0 + ) + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + + st_idx + ) + + if llm_pos_ids_list: + positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + else: + positions = torch.zeros( + 3, 0, + dtype=sample_input_ids.dtype, + device=sample_input_ids.device, + ) + return positions, image_index, video_index + + +def get_rope_index( + spatial_merge_size: int, + image_token_id: int, + video_token_id: int, + vision_start_token_id: int, + input_ids: Optional[Tensor] = None, + image_grid_thw: Optional[Tensor] = None, + video_grid_thw: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, +) -> tuple[Tensor, Tensor]: + """Compute 3D MRoPE position IDs for Qwen3-VL / Qwen3.5-VL. + + Qwen3-VL uses timestamps rather than absolute time position IDs. + + For text tokens all three dimensions share sequential positions. + For vision tokens the three dimensions encode (temporal, height, + width) in the merged spatial grid. + + Args: + spatial_merge_size: Merge factor for spatial dimensions. + image_token_id: Token ID for image placeholders. + video_token_id: Token ID for video placeholders. + vision_start_token_id: Token ID marking start of a vision region. + input_ids: ``[B, S]`` in BSHD or ``[1, T]`` in THD. + image_grid_thw: ``[num_images, 3]`` per-image + ``(temporal, height, width)`` in patch-grid units. Rows are + consumed in the order their image tokens appear in + ``input_ids`` (packed order across segments in THD). + video_grid_thw: ``[num_videos, 3]`` per-video grid dimensions. + attention_mask: ``[B, S]`` mask (1 = keep, 0 = pad). BSHD only. + packed_seq_params: When provided, selects the THD branch and + supplies segment boundaries via ``cu_seqlens_q`` (valid + lengths) and ``cu_seqlens_q_padded`` (packed layout). + + Returns: + ``(position_ids, mrope_position_deltas)`` where *position_ids* + has shape ``[3, B, S]`` (``[3, 1, T]`` in THD). + """ + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave( + video_grid_thw, video_grid_thw[:, 0], dim=0, + ) + video_grid_thw[:, 0] = 1 + + # ----------------------------------------------------------------- + # THD (packed) branch + # ----------------------------------------------------------------- + if packed_seq_params is not None and input_ids is not None: + cu_seqlens = packed_seq_params.cu_seqlens_q + cu_seqlens_padded = getattr( + packed_seq_params, "cu_seqlens_q_padded", None, + ) + if cu_seqlens_padded is None: + cu_seqlens_padded = cu_seqlens + + assert ( + input_ids.dim() == 2 and input_ids.shape[0] == 1 + ), "THD get_rope_index expects input_ids shape [1, T]" + + total_tokens = input_ids.shape[1] + device = input_ids.device + + # Padding slots default to 1 (matches BSHD convention where + # masked positions get filled with 1). + position_ids = torch.ones( + 3, 1, total_tokens, + dtype=input_ids.dtype, device=device, + ) + deltas: list = [] + image_index = 0 + video_index = 0 + num_segs = cu_seqlens.numel() - 1 + + for k in range(num_segs): + seg_start = int(cu_seqlens_padded[k].item()) + valid_len = int( + cu_seqlens[k + 1].item() - cu_seqlens[k].item() + ) + valid_end = seg_start + valid_len + + if valid_len == 0: + deltas.append(0) + continue + + sample_input_ids = input_ids[0, seg_start:valid_end] + + if ( + image_grid_thw is not None + or video_grid_thw is not None + ): + ( + positions, + image_index, + video_index, + ) = _build_sample_mrope_positions( + sample_input_ids=sample_input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_index=image_index, + video_index=video_index, + spatial_merge_size=spatial_merge_size, + image_token_id=image_token_id, + video_token_id=video_token_id, + vision_start_token_id=vision_start_token_id, + ) + else: + positions = ( + torch.arange(valid_len, device=device) + .view(1, -1) + .expand(3, -1) + ) + + position_ids[:, 0, seg_start:valid_end] = positions.to( + device=device, dtype=position_ids.dtype, + ) + + if positions.numel() > 0: + deltas.append( + int(positions.max().item()) + 1 - valid_len + ) + else: + deltas.append(0) + + mrope_position_deltas = torch.tensor( + deltas, device=device, + ).unsqueeze(1) + return position_ids, mrope_position_deltas + + # ----------------------------------------------------------------- + # BSHD branch with vision + # ----------------------------------------------------------------- + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + elif attention_mask.dim() > 2: + attention_mask = attention_mask.any(dim=-1) + if attention_mask.dim() == 3: + attention_mask = attention_mask.squeeze(1) + attention_mask = attention_mask.to(dtype=total_input_ids.dtype) + + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + mrope_position_deltas = [] + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + + for i, sample_input_ids in enumerate(total_input_ids): + sample_input_ids = sample_input_ids[attention_mask[i] == 1] + ( + llm_positions, + image_index, + video_index, + ) = _build_sample_mrope_positions( + sample_input_ids=sample_input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_index=image_index, + video_index=video_index, + spatial_merge_size=spatial_merge_size, + image_token_id=image_token_id, + video_token_id=video_token_id, + vision_start_token_id=vision_start_token_id, + ) + position_ids[ + ..., i, attention_mask[i] == 1 + ] = llm_positions.to(position_ids.device) + mrope_position_deltas.append( + llm_positions.max() + 1 - len(total_input_ids[i]), + ) + + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=total_input_ids.device, + ).unsqueeze(1) + return position_ids, mrope_position_deltas + + # ----------------------------------------------------------------- + # Text-only fallback + # ----------------------------------------------------------------- + if attention_mask is not None: + if attention_mask.dim() > 2: + attention_mask = attention_mask.any(dim=-1) + if attention_mask.dim() == 3: + attention_mask = attention_mask.squeeze(1) + attention_mask = attention_mask.to(dtype=torch.long) + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = ( + position_ids.unsqueeze(0) + .expand(3, -1, -1) + .to(attention_mask.device) + ) + max_position_ids = ( + position_ids.max(0, keepdim=False)[0] + .max(-1, keepdim=True)[0] + ) + mrope_position_deltas = ( + max_position_ids + 1 - attention_mask.shape[-1] + ) + else: + position_ids = ( + torch.arange( + input_ids.shape[1], device=input_ids.device, + ) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas diff --git a/examples/multimodal_dev/models/qwen35_vl/specs.py b/examples/multimodal_dev/models/qwen35_vl/specs.py new file mode 100644 index 00000000000..64b864e223e --- /dev/null +++ b/examples/multimodal_dev/models/qwen35_vl/specs.py @@ -0,0 +1,131 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Layer spec helpers for Qwen3.5-VL vision encoder and language decoder. + +Provides ModuleSpec builders that define the transformer layer composition. +Both the standalone and MIMO training paths import from here. +""" + +from typing import Optional + +from examples.multimodal_dev.models.base import _NO_CP_GROUP +from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( + get_transformer_block_with_experimental_attention_variant_spec, +) +from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec +from megatron.core.transformer.attention import SelfAttention +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_block import TransformerBlockSubmodules +from megatron.core.transformer.transformer_config import TransformerConfig + + +def _apply_rope_fp32(t, freqs, config, cu_seqlens=None, mscale=1.0, cp_group=None): + """Apply rotary positional embedding in fp32, then cast back to original dtype. + + Mirrors ``Qwen3VLSelfAttention.apply_rotary_pos_emb_absolute`` in Megatron-Bridge + with ``apply_rotary_pos_emb_in_fp32=True``. + """ + from megatron.core import parallel_state + from megatron.core.models.common.embeddings.rope_utils import ( + _apply_rotary_pos_emb_bshd, + _apply_rotary_pos_emb_thd, + ) + + orig_dtype = t.dtype + t_fp32 = t.float() + + if cu_seqlens is None: + out = _apply_rotary_pos_emb_bshd( + t_fp32, + freqs, + rotary_interleaved=config.rotary_interleaved, + multi_latent_attention=getattr(config, 'multi_latent_attention', False), + mscale=mscale, + ) + else: + if cp_group is None: + cp_group = parallel_state.get_context_parallel_group() + out = _apply_rotary_pos_emb_thd( + t_fp32, + cu_seqlens, + freqs, + rotary_interleaved=config.rotary_interleaved, + multi_latent_attention=getattr(config, 'multi_latent_attention', False), + mscale=mscale, + cp_group=cp_group, + ) + return out.to(orig_dtype) + + +def _apply_rope_fp32_no_cp(t, freqs, config, cu_seqlens=None, mscale=1.0, cp_group=None): + """Same as ``_apply_rope_fp32`` but forces CP-size=1. + + The vision encoder uses THD packed sequences for variable-resolution + images. When the language model uses CP>1, the global CP group would + incorrectly split the vision seqlens. This wrapper substitutes a + trivial group so the vision RoPE sees the full packed sequence. + """ + return _apply_rope_fp32( + t, freqs, config, cu_seqlens, mscale, cp_group=_NO_CP_GROUP, + ) + + +class Qwen35VLVisionSelfAttention(SelfAttention): + """ViT self-attention with RoPE applied in fp32. + + Matches Bridge's ``Qwen3VLSelfAttention`` behaviour when + ``apply_rotary_pos_emb_in_fp32=True``: query and key are cast to float32 + before the rotary multiply and cast back to bf16 afterwards. The + monkey-patch approach avoids duplicating the 300-line ``SelfAttention.forward`` + while keeping the change local to this class. + """ + + def forward(self, *args, **kwargs): + import megatron.core.transformer.attention as _attn_mod + + _orig = _attn_mod.apply_rotary_pos_emb + _attn_mod.apply_rotary_pos_emb = _apply_rope_fp32_no_cp + try: + return super().forward(*args, **kwargs) + finally: + _attn_mod.apply_rotary_pos_emb = _orig + + +def get_qwen35_vl_language_spec( + config: TransformerConfig, + vp_stage: Optional[int] = None, + pp_rank: Optional[int] = None, +) -> TransformerBlockSubmodules: + """Transformer block spec for the Qwen3.5-VL language decoder. + + Uses the experimental attention variant infrastructure to build hybrid + GatedDeltaNet + full-attention layers with optional MoE interleaving. + + Args: + config: Language decoder TransformerConfig. + vp_stage: Virtual pipeline stage. + pp_rank: Pipeline parallel rank. + + Returns: + TransformerBlockSubmodules with per-layer specs. + """ + return get_transformer_block_with_experimental_attention_variant_spec( + config=config, + vp_stage=vp_stage, + pp_rank=pp_rank, + ) + + +def get_qwen35_vl_vision_spec() -> ModuleSpec: + """ModuleSpec for vision encoder transformer layers. + + Uses ``TEDotProductAttention`` which supports packed-sequence (THD) + attention via ``PackedSeqParams`` for variable-length images. + + ``Qwen35VLVisionSelfAttention`` replaces the default ``SelfAttention`` so + that RoPE is applied in fp32, matching Bridge's + ``apply_rotary_pos_emb_in_fp32=True`` behaviour. + """ + spec = get_vit_layer_with_transformer_engine_spec() + spec.submodules.self_attention.module = Qwen35VLVisionSelfAttention + return spec diff --git a/examples/multimodal_dev/models/qwen35_vl/vision_encoder.py b/examples/multimodal_dev/models/qwen35_vl/vision_encoder.py new file mode 100644 index 00000000000..8e16417937d --- /dev/null +++ b/examples/multimodal_dev/models/qwen35_vl/vision_encoder.py @@ -0,0 +1,593 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Megatron-native Qwen3.5-VL vision encoder. + +Architecture (matches HF ``Qwen3VLVisionModel`` exactly): + + PatchEmbed (Conv3d) + → learned position embedding (bilinear interpolation) + → 2D Vision RoPE + → TransformerBlock × N (with PackedSeqParams / THD attention) + → PatchMerger (per-token LN → spatial merge → MLP) + +Key design choices: + * ``Conv3d`` patch embedding is replicated across TP ranks (no MCore + equivalent for 3D convolutions). + * ``PatchMerger`` MLP uses ``ColumnParallelLinear`` / ``RowParallelLinear`` + for TP sharding. + * Inherits from ``VisionModule``. + * Expects pixel values in block-merge order (as produced by the HF + processor) so the merger's simple reshape is correct. +""" + +from typing import List, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor + +from megatron.core.models.common.vision_module.vision_module import ( + VisionModule, +) +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.tensor_parallel.layers import ( + ColumnParallelLinear, + RowParallelLinear, +) +from megatron.core.extensions.transformer_engine import TENorm +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig + +# ------------------------------------------------------------------- +# PatchEmbed — Conv3d (replicated, no TP sharding) +# ------------------------------------------------------------------- + +class Qwen35VLPatchEmbed(MegatronModule): + """3D convolution patch embedding matching HF ``Qwen3VLVisionPatchEmbed``. + + Uses ``nn.Conv3d`` with kernel = stride = ``[temporal_patch_size, + patch_size, patch_size]`` and ``bias=True``. The module is replicated + across TP ranks (no MCore equivalent for 3D conv). + + Args: + config: TransformerConfig (used by MegatronModule base). + in_channels: Number of input channels (3 for RGB). + hidden_size: Output embedding dimension. + patch_size: Spatial patch size. + temporal_patch_size: Temporal patch size. + """ + + def __init__( + self, + config: TransformerConfig, + in_channels: int = 3, + hidden_size: int = 1152, + patch_size: int = 16, + temporal_patch_size: int = 2, + ): + super().__init__(config=config) + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.hidden_size = hidden_size + + kernel = [temporal_patch_size, patch_size, patch_size] + self.proj = torch.nn.Conv3d( + in_channels, + hidden_size, + kernel_size=kernel, + stride=kernel, + bias=True, + ) + + def forward(self, pixel_values: Tensor) -> Tensor: + """Forward pass. + + Args: + pixel_values: ``[total_patches, C * T * pH * pW]`` + pre-extracted flat patches. + + Returns: + Patch embeddings ``[total_patches, hidden_size]``. + """ + target_dtype = self.proj.weight.dtype + pixel_values = pixel_values.view( + -1, + self.in_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + ) + return self.proj(pixel_values.to(dtype=target_dtype)).view( + -1, self.hidden_size + ) + + +# ------------------------------------------------------------------- +# VisionRotaryEmbedding — 1D frequency table +# ------------------------------------------------------------------- + +class Qwen35VLVisionRotaryEmbedding(MegatronModule): + """1D rotary position frequency table for the vision transformer. + + Generates RoPE frequencies for integer positions ``0 .. seqlen-1``. + The encoder maps 2D (row, col) positions to embeddings via table + lookup. Matches HF ``Qwen3VLVisionRotaryEmbedding``. + + Args: + dim: Frequency dimension (``head_dim // 2``). + theta: RoPE base frequency. + config: Optional TransformerConfig for MegatronModule base. + """ + + def __init__( + self, + dim: int, + theta: float = 10000.0, + config: Optional[TransformerConfig] = None, + ): + super().__init__(config=config) + self.dim = dim + self.theta = theta + inv_freq = 1.0 / ( + theta + ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def _get_inv_freq(self, device: torch.device) -> Tensor: + """Return ``inv_freq`` in float32 on *device*. + + Always recomputes in float32 regardless of the buffer's stored dtype. + This matches Bridge's lazy-init behaviour where ``inv_freq`` is + constructed fresh (in float32) on the first forward call, after any + ``model.bfloat16()`` cast has already occurred. + """ + return 1.0 / ( + self.theta + ** ( + torch.arange( + 0, self.dim, 2, + dtype=torch.float32, device=device, + ) + / self.dim + ) + ) + + def forward( + self, + seqlen: int, + device: Optional[torch.device] = None, + ) -> Tensor: + """Frequency lookup table for positions ``0 .. seqlen-1``. + + Args: + seqlen: Number of positions. + device: Runtime device (required for meta-init safety). + + Returns: + ``[seqlen, dim // 2]`` frequencies. + """ + if device is None: + if self.inv_freq.device.type != "meta": + device = self.inv_freq.device + else: + device = torch.device( + "cuda", torch.cuda.current_device() + ) + inv_freq = self._get_inv_freq(device) + seq = torch.arange(seqlen, device=device, dtype=inv_freq.dtype) + return torch.outer(seq, inv_freq) + + +# ------------------------------------------------------------------- +# PatchMerger — per-token LN, spatial merge, TP-sharded MLP +# ------------------------------------------------------------------- + +class Qwen35VLPatchMerger(MegatronModule): + """Spatial patch merger matching HF ``Qwen3VLVisionPatchMerger``. + + Per-token ``LayerNorm`` on ``hidden_size`` → reshape to merge + ``spatial_merge_size ** 2`` adjacent patches → two-layer MLP + (``ColumnParallelLinear`` → GELU → ``RowParallelLinear``). + + MLP dimensions: ``merge_dim → merge_dim → out_hidden_size`` + where ``merge_dim = hidden_size * spatial_merge_size ** 2``. + + Args: + config: TransformerConfig (provides TP settings, init_method). + hidden_size: Per-token hidden size from the ViT. + out_hidden_size: Output dimension (language model hidden_size). + spatial_merge_size: Merge factor per spatial dimension. + """ + + def __init__( + self, + config: TransformerConfig, + hidden_size: int = 1152, + out_hidden_size: int = 3584, + spatial_merge_size: int = 2, + ): + super().__init__(config=config) + self.spatial_merge_size = spatial_merge_size + self.merge_dim = hidden_size * (spatial_merge_size ** 2) + merge_dim = self.merge_dim + + self.patch_norm = TENorm(config=config, hidden_size=hidden_size, eps=1e-6) + self.linear_fc1 = build_module( + ColumnParallelLinear, + merge_dim, + merge_dim, + config=config, + init_method=config.init_method, + bias=True, + gather_output=False, + ) + self.linear_fc2 = build_module( + RowParallelLinear, + merge_dim, + out_hidden_size, + config=config, + init_method=config.output_layer_init_method, + bias=True, + input_is_parallel=True, + skip_bias_add=False, + ) + + def forward(self, hidden_states: Tensor) -> Tensor: + """Merge patches spatially. + + Args: + hidden_states: ``[total_patches, hidden_size]`` in block-merge + order from the ViT transformer blocks. + + Returns: + ``[total_merged_patches, out_hidden_size]``. + """ + hidden_states = self.patch_norm(hidden_states) + merged = hidden_states.view(-1, self.merge_dim) + merged, _ = self.linear_fc1(merged) + # NOTE: Official HuggingFace uses default approximate='none' in Qwen3VLVisionPatchMerger. + merged = torch.nn.functional.gelu(merged, approximate="tanh") + merged, _ = self.linear_fc2(merged) + return merged + + +# ------------------------------------------------------------------- +# Qwen35VLVisionEncoder — top-level encoder module +# ------------------------------------------------------------------- + +class Qwen35VLVisionEncoder(VisionModule): + """Megatron-native Qwen3.5-VL vision encoder. + + Processes image / video inputs through: + + 1. ``Qwen35VLPatchEmbed`` (Conv3d) + 2. Learned ``nn.Embedding`` position table with bilinear interpolation + 3. 2D Vision RoPE from ``(row, col)`` patch positions + 4. ``TransformerBlock`` × N with ``PackedSeqParams`` (THD attention) + 5. ``Qwen35VLPatchMerger`` + + Output dimension matches the language model ``hidden_size``. + + Args: + config: Vision ``TransformerConfig``. + transformer_layer_spec: ``ModuleSpec`` for ViT layers. + in_channels: Image channels (3 for RGB). + patch_size: Spatial patch size. + temporal_patch_size: Temporal patch size. + spatial_merge_size: Spatial merge factor. + out_hidden_size: Output dim (language decoder hidden_size). + max_num_positions: Size of the learned position table. + """ + + def __init__( + self, + config: TransformerConfig, + transformer_layer_spec: ModuleSpec = None, + in_channels: int = 3, + patch_size: int = 16, + temporal_patch_size: int = 2, + spatial_merge_size: int = 2, + out_hidden_size: int = 3584, + max_num_positions: int = 2304, + ): + super().__init__(config=config) + + self.hidden_size = config.hidden_size + self.spatial_merge_size = spatial_merge_size + + # --- Patch embedding (Conv3d) --- + self.patch_embed = Qwen35VLPatchEmbed( + config=config, + in_channels=in_channels, + hidden_size=config.hidden_size, + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + ) + + # --- Learned position embedding with bilinear interpolation --- + self.pos_embed = torch.nn.Embedding( + max_num_positions, config.hidden_size, + ) + self.num_grid_per_side = int(max_num_positions ** 0.5) + + # --- Vision rotary embeddings --- + head_dim = config.hidden_size // config.num_attention_heads + self.rot_pos_emb = Qwen35VLVisionRotaryEmbedding( + head_dim // 2, config=config, + ) + + # --- Transformer blocks --- + if transformer_layer_spec is None: + from examples.multimodal_dev.models.qwen35_vl.specs import ( + get_qwen35_vl_vision_spec, + ) + transformer_layer_spec = get_qwen35_vl_vision_spec() + + self.decoder = TransformerBlock( + config=config, + spec=transformer_layer_spec, + pre_process=True, + post_process=True, + post_layer_norm=False, + ) + + # --- Patch merger --- + self.merger = Qwen35VLPatchMerger( + config=config, + hidden_size=config.hidden_size, + out_hidden_size=out_hidden_size, + spatial_merge_size=spatial_merge_size, + ) + + # --------------------------------------------------------------- + # Learned position embedding with bilinear interpolation + # --------------------------------------------------------------- + + def _fast_pos_embed_interpolate( + self, grid_thw: Tensor, + ) -> Tensor: + """Bilinear interpolation of the learned 2D position table. + + Matches HF ``Qwen3VLVisionModel.fast_pos_embed_interpolate``. + + Args: + grid_thw: ``[num_images, 3]`` (T, H, W) in patch-grid units. + + Returns: + ``[total_patches, hidden_size]`` position embeddings in + block-merge order. + """ + grid_thw_list = grid_thw.tolist() + grid_ts = [int(row[0]) for row in grid_thw_list] + grid_hs = [int(row[1]) for row in grid_thw_list] + grid_ws = [int(row[2]) for row in grid_thw_list] + device = self.pos_embed.weight.device + n = self.num_grid_per_side + + idx_list: List[List[int]] = [[] for _ in range(4)] + weight_list: List[List[float]] = [[] for _ in range(4)] + + for t, h, w in grid_thw_list: + t, h, w = int(t), int(h), int(w) + h_idxs = torch.linspace(0, n - 1, h) + w_idxs = torch.linspace(0, n - 1, w) + + h_floor = h_idxs.int() + w_floor = w_idxs.int() + h_ceil = (h_floor + 1).clip(max=n - 1) + w_ceil = (w_floor + 1).clip(max=n - 1) + + dh = h_idxs - h_floor.float() + dw = w_idxs - w_floor.float() + + base_h = h_floor * n + base_h_ceil = h_ceil * n + + indices = [ + (base_h[None].T + w_floor[None]).flatten(), + (base_h[None].T + w_ceil[None]).flatten(), + (base_h_ceil[None].T + w_floor[None]).flatten(), + (base_h_ceil[None].T + w_ceil[None]).flatten(), + ] + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor( + idx_list, dtype=torch.long, device=device, + ) + weight_tensor = torch.tensor( + weight_list, + dtype=self.pos_embed.weight.dtype, + device=device, + ) + pos_embeds = ( + self.pos_embed(idx_tensor).to(device) + * weight_tensor[:, :, None] + ) + patch_pos_embeds = ( + pos_embeds[0] + pos_embeds[1] + + pos_embeds[2] + pos_embeds[3] + ) + + patch_pos_embeds = patch_pos_embeds.split( + [h * w for h, w in zip(grid_hs, grid_ws)] + ) + + merge = self.spatial_merge_size + result = [] + for pe, t, h, w in zip( + patch_pos_embeds, grid_ts, grid_hs, grid_ws, + ): + pe = pe.repeat(t, 1) + pe = ( + pe.view( + t, h // merge, merge, w // merge, merge, -1, + ) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + result.append(pe) + + return torch.cat(result) + + # --------------------------------------------------------------- + # 2D Vision RoPE + # --------------------------------------------------------------- + + def _compute_rotary_pos_emb(self, grid_thw: Tensor) -> Tensor: + """Compute 2D Vision RoPE for all patches in block-merge order. + + Matches HF ``Qwen3VLVisionModel.rot_pos_emb``. + + Args: + grid_thw: ``[num_images, 3]`` (T, H, W) per image. + + Returns: + ``[total_patches, head_dim // 2]`` raw RoPE frequencies. + """ + merge = self.spatial_merge_size + grid_thw_list = grid_thw.tolist() + + max_hw = max(max(int(h), int(w)) for _, h, w in grid_thw_list) + freq_table = self.rot_pos_emb( + max_hw, device=grid_thw.device, + ) + device = freq_table.device + + total_tokens = sum( + int(t) * int(h) * int(w) for t, h, w in grid_thw_list + ) + pos_ids = torch.empty( + (total_tokens, 2), dtype=torch.long, device=device, + ) + + offset = 0 + for num_frames, height, width in grid_thw_list: + num_frames = int(num_frames) + height = int(height) + width = int(width) + merged_h = height // merge + merged_w = width // merge + + block_rows = torch.arange(merged_h, device=device) + block_cols = torch.arange(merged_w, device=device) + intra_row = torch.arange(merge, device=device) + intra_col = torch.arange(merge, device=device) + + row_idx = ( + block_rows[:, None, None, None] * merge + + intra_row[None, None, :, None] + ) + col_idx = ( + block_cols[None, :, None, None] * merge + + intra_col[None, None, None, :] + ) + + row_idx = row_idx.expand( + merged_h, merged_w, merge, merge, + ).reshape(-1) + col_idx = col_idx.expand( + merged_h, merged_w, merge, merge, + ).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + n_tokens = coords.shape[0] + pos_ids[offset: offset + n_tokens] = coords + offset += n_tokens + + embeddings = freq_table[pos_ids] + embeddings = embeddings.flatten(1) + return embeddings + + # --------------------------------------------------------------- + # PackedSeqParams for variable-length attention + # --------------------------------------------------------------- + + @staticmethod + def _build_packed_seq_params(grid_thw: Tensor) -> PackedSeqParams: + """Build ``PackedSeqParams`` from grid dimensions. + + Each temporal frame of each image forms a separate sub-sequence + in the packed THD layout, matching HF's ``cu_seqlens`` computation. + + Args: + grid_thw: ``[num_images, 3]``. + + Returns: + ``PackedSeqParams`` for ``TransformerBlock``. + """ + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0], + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + max_seqlen = int( + (grid_thw[:, 1] * grid_thw[:, 2]).max().item() + ) + + return PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + ) + + # --------------------------------------------------------------- + # Forward + # --------------------------------------------------------------- + + def forward( + self, + pixel_values: Tensor, + grid_thw: Tensor, + ) -> Tensor: + """Encode images / video frames. + + Args: + pixel_values: ``[total_patches, C * T * pH * pW]`` + pre-extracted flat patches in block-merge order. + grid_thw: ``[num_images, 3]`` (T, H, W) in patch-grid units. + + Returns: + ``[total_merged_patches, out_hidden_size]`` visual embeddings. + """ + # 1. Patch embedding (Conv3d) + hidden_states = self.patch_embed(pixel_values) + + # 2. Learned position embedding (bilinear interpolation) + pos_embeds = self._fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + # 3. 2D Vision RoPE + rot_freqs = self._compute_rotary_pos_emb(grid_thw) + emb = torch.cat((rot_freqs, rot_freqs), dim=-1) + rot_freqs_expanded = emb.unsqueeze(1).unsqueeze(1) + + # 4. Transformer blocks with PackedSeqParams + packed_seq_params = self._build_packed_seq_params(grid_thw) + hidden_states = hidden_states.unsqueeze(1) + hidden_states = self.decoder( + hidden_states=hidden_states, + attention_mask=None, + rotary_pos_emb=rot_freqs_expanded, + packed_seq_params=packed_seq_params, + ) + hidden_states = hidden_states.squeeze(1) + + # 5. Patch merger + return self.merger(hidden_states) diff --git a/examples/multimodal_dev/pretrain_multimodal.py b/examples/multimodal_dev/pretrain_multimodal.py new file mode 100644 index 00000000000..7a54bb93271 --- /dev/null +++ b/examples/multimodal_dev/pretrain_multimodal.py @@ -0,0 +1,158 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Standalone entry point for multimodal_dev model training (FSDP + EP). + +This entry point is **model-agnostic**. All model-specific logic (layer +specs, model construction, FLOPs metadata, dataset generation) is +delegated to factory functions registered in +:data:`multimodal_dev.models.MODEL_REGISTRY`. + +Adding a new architecture only requires: + +1. Creating a new model package under ``multimodal_dev/models//`` + with the appropriate factory functions. +2. Registering an entry in ``MODEL_REGISTRY``. + +No changes to this file are necessary. + +Usage:: + + torchrun --nproc_per_node=8 multimodal_dev/pretrain_multimodal.py \\ + --model-arch qwen35_vl \\ + --dataset-provider mock \\ + ... (other megatron args) +""" + +import importlib +import os +import sys + +sys.path.insert( + 0, + os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")), +) + +from examples.multimodal_dev.arguments import add_multimodal_args +from examples.multimodal_dev.forward_step import forward_step +from megatron.core.enums import ModelType +from megatron.training import get_args, pretrain +from megatron.training.argument_utils import pretrain_cfg_container_from_args +from megatron.training.arguments import core_transformer_config_from_args, parse_and_validate_args + + +def model_provider( + pre_process: bool = True, + post_process: bool = True, + **kwargs, +): + """Build a multimodal model from ``--model-arch``. + + The language ``TransformerConfig`` is built from CLI args so that + parallelism settings, precision, and fusion flags are inherited. + Model-specific post-processing and construction are delegated to the + registry factory functions. + """ + args = get_args() + model_arch = getattr(args, "model_arch", "qwen35_vl") + + from examples.multimodal_dev.models import MODEL_REGISTRY + + if model_arch not in MODEL_REGISTRY: + raise ValueError( + f"Unknown model arch '{model_arch}'. " + f"Available: {list(MODEL_REGISTRY.keys())}" + ) + + registry = MODEL_REGISTRY[model_arch] + + # --- language config (generic + model-specific post-processing) --- + language_config = core_transformer_config_from_args(args) + post_language_config_fn = registry.get("post_language_config_fn") + if post_language_config_fn is not None: + post_language_config_fn(language_config, args) + + # --- vision config --- + vision_config = registry["vision_config_fn"]( + num_layers_override=getattr(args, "vision_num_layers", None), + variant=getattr(args, "model_variant", None), + ) + vision_config.bf16 = language_config.bf16 + vision_config.fp16 = language_config.fp16 + + if getattr(args, "recompute_vision", False): + vision_config.recompute_granularity = "full" + vision_config.recompute_method = "uniform" + vision_config.recompute_num_layers = 1 + + # --- vision FLOPs metadata --- + vision_flops_fn = registry.get("vision_flops_fn") + if vision_flops_fn is not None: + vision_flops_fn(args, language_config, vision_config) + + # --- build model (fully delegated to the arch factory) --- + model = registry["model_factory_fn"]( + args=args, + language_config=language_config, + vision_config=vision_config, + **kwargs, + ) + + return model + + +def _resolve_provider_fn(provider_fn): + """Resolve a provider that may be a dotted import path string.""" + if isinstance(provider_fn, str): + module_path, func_name = provider_fn.rsplit(".", 1) + provider_fn = getattr( + importlib.import_module(module_path), func_name, + ) + return provider_fn + + +def datasets_provider(train_val_test_num_samples): + """Dataset provider dispatcher. + + Routes to the dataset factory registered for the current + ``(--model-arch, --dataset-provider)`` combination. + """ + args = get_args() + model_arch = getattr(args, "model_arch", "qwen35_vl") + provider = getattr(args, "dataset_provider", "mock") + + from examples.multimodal_dev.models import MODEL_REGISTRY + + if model_arch not in MODEL_REGISTRY: + raise ValueError( + f"Unknown model arch '{model_arch}'. " + f"Available: {list(MODEL_REGISTRY.keys())}" + ) + + registry = MODEL_REGISTRY[model_arch] + available = registry.get("dataset_providers", {}) + + if provider not in available: + raise ValueError( + f"Unknown dataset provider '{provider}' for arch " + f"'{model_arch}'. Available: {list(available.keys())}" + ) + + provider_fn = _resolve_provider_fn(available[provider]) + return provider_fn(train_val_test_num_samples) + + +if __name__ == "__main__": + datasets_provider.is_distributed = True + + args = parse_and_validate_args( + extra_args_provider=add_multimodal_args, + args_defaults={}, + ) + full_config = pretrain_cfg_container_from_args(args) + pretrain( + full_config, + datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + ) diff --git a/examples/multimodal_dev/scripts/run_qwen35_vl.sh b/examples/multimodal_dev/scripts/run_qwen35_vl.sh new file mode 100755 index 00000000000..974892a46ca --- /dev/null +++ b/examples/multimodal_dev/scripts/run_qwen35_vl.sh @@ -0,0 +1,502 @@ +#!/bin/bash + +# Launch script for Qwen3.5-VL training via multimodal_dev (FSDP + EP). +# +# Usage (from the Megatron-LM repo root): +# ./examples/multimodal_dev/scripts/run_qwen35_vl.sh +# +# Environment variables: +# MODEL_VARIANT: proxy (default), 0.8b, 2b, 4b, 9b, 27b, 35b_a3b, 122b_a10b, 397b_a17b, 35b_a3b_light +# CKPT_LOAD: path to a pre-converted checkpoint to load (enables --load + --finetune) +# CKPT_FORMAT: checkpoint format override (e.g. torch_dist); auto-detected when empty +# TP, EP, PP: parallelism sizes +# MBS, GBS: micro/global batch sizes +# NUM_LAYERS, NUM_EXPERTS: override for proxy testing +# LAUNCHER: torchrun (default) or python +# PROFILE: set to 1 to enable Nsight Systems profiling (default: 0) +# PROFILE_STEP_START/PROFILE_STEP_END: profiled iteration window (default: 4-5) + +# example script: +# WANDB_PROJECT=qwen35-cp-test WANDB_MODE=online CP=2 GPUS_PER_NODE=8 CKPT_LOAD=/lustre/fs1/portfolios/coreai/users/lit/workspace/dev-project/models/Qwen/Qwen3.5-0.8B-fsdp-0420/ USE_FSDP=1 EP=1 GBS=16 MODEL_VARIANT=0.8b SAVE_INTERVAL=10000 CKPT_RESUME=0 DRY_RUN=0 USE_PACKED_SEQUENCE=1 bash ./examples/multimodal_dev/scripts/run_qwen35_vl.sh + +# WANDB_PROJECT=qwen35-cp-test WANDB_MODE=online CP=1 GPUS_PER_NODE=4 CKPT_LOAD=/lustre/fs1/portfolios/coreai/users/lit/workspace/dev-project/models/Qwen/Qwen3.5-0.8B-fsdp-0420/ USE_FSDP=1 EP=1 GBS=16 MODEL_VARIANT=0.8b SAVE_INTERVAL=10000 CKPT_RESUME=0 DRY_RUN=0 USE_PACKED_SEQUENCE=1 bash ./examples/multimodal_dev/scripts/run_qwen35_vl.sh + +set -euo pipefail + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_IB_SL=1 +export NVTE_FUSED_ATTN=1 +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +DRY_RUN=${DRY_RUN:-1} +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +if [ -n "${SLURM_JOB_NUM_NODES:-}" ]; then + NUM_NODES="$SLURM_JOB_NUM_NODES" +else + NUM_NODES=${NNODES:-1} +fi +PROFILE=${PROFILE:-0} +PROFILE_STEP_START=${PROFILE_STEP_START:-4} +PROFILE_STEP_END=${PROFILE_STEP_END:-5} +PROFILE_RANKS=${PROFILE_RANKS:-0} +LAUNCHER=${LAUNCHER:-torchrun} + +MODEL_VARIANT=${MODEL_VARIANT:-proxy} +VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-} + +# Batch sizes +MBS=${MBS:-2} +GBS=${GBS:-16} + +# Parallelism +TP=${TP:-1} +EP=${EP:-2} +PP=${PP:-1} +CP=${CP:-1} + +# Variant-aware architecture defaults. +# The model provider builds configs from the variant dict in +# multimodal_dev/models/qwen35_vl/configuration.py, but Megatron also +# uses these CLI args internally (PP splits, param counting). +case "$MODEL_VARIANT" in + 0.8b) + NUM_LAYERS=${NUM_LAYERS:-24} + NUM_EXPERTS=${NUM_EXPERTS:-0} + HIDDEN_SIZE=1024 + FFN_HIDDEN_SIZE=3584 + NUM_ATTN_HEADS=8 + NUM_QUERY_GROUPS=2 + LINEAR_NUM_VALUE_HEADS=16 + VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-12} + ;; + 2b) + NUM_LAYERS=${NUM_LAYERS:-24} + NUM_EXPERTS=${NUM_EXPERTS:-0} + HIDDEN_SIZE=2048 + FFN_HIDDEN_SIZE=6144 + NUM_ATTN_HEADS=8 + NUM_QUERY_GROUPS=2 + LINEAR_NUM_VALUE_HEADS=16 + VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-24} + ;; + 4b) + NUM_LAYERS=${NUM_LAYERS:-32} + NUM_EXPERTS=${NUM_EXPERTS:-0} + HIDDEN_SIZE=2560 + FFN_HIDDEN_SIZE=9216 + NUM_ATTN_HEADS=16 + NUM_QUERY_GROUPS=4 + LINEAR_NUM_VALUE_HEADS=32 + VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-24} + ;; + proxy) + NUM_LAYERS=${NUM_LAYERS:-4} + NUM_EXPERTS=${NUM_EXPERTS:-16} + HIDDEN_SIZE=4096 + FFN_HIDDEN_SIZE=10240 + NUM_ATTN_HEADS=32 + NUM_QUERY_GROUPS=2 + LINEAR_NUM_VALUE_HEADS=64 + VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-2} + ;; + 9b) + NUM_LAYERS=${NUM_LAYERS:-32} + NUM_EXPERTS=${NUM_EXPERTS:-0} + HIDDEN_SIZE=4096 + FFN_HIDDEN_SIZE=12288 + NUM_ATTN_HEADS=16 + NUM_QUERY_GROUPS=4 + LINEAR_NUM_VALUE_HEADS=32 + VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-27} + ;; + 27b) + NUM_LAYERS=${NUM_LAYERS:-64} + NUM_EXPERTS=${NUM_EXPERTS:-0} + HIDDEN_SIZE=5120 + FFN_HIDDEN_SIZE=17408 + NUM_ATTN_HEADS=24 + NUM_QUERY_GROUPS=4 + LINEAR_NUM_VALUE_HEADS=48 + VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-27} + ;; + 35b_a3b) + NUM_LAYERS=${NUM_LAYERS:-40} + NUM_EXPERTS=${NUM_EXPERTS:-256} + HIDDEN_SIZE=2048 + FFN_HIDDEN_SIZE=4096 + NUM_ATTN_HEADS=16 + NUM_QUERY_GROUPS=2 + LINEAR_NUM_VALUE_HEADS=32 + VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-27} + ;; + 35b_a3b_light) + NUM_LAYERS=${NUM_LAYERS:-12} + NUM_EXPERTS=${NUM_EXPERTS:-128} + HIDDEN_SIZE=2048 + FFN_HIDDEN_SIZE=4096 + NUM_ATTN_HEADS=16 + NUM_QUERY_GROUPS=2 + LINEAR_NUM_VALUE_HEADS=32 + VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-7} + ;; + 122b_a10b) + NUM_LAYERS=${NUM_LAYERS:-48} + NUM_EXPERTS=${NUM_EXPERTS:-256} + HIDDEN_SIZE=3072 + FFN_HIDDEN_SIZE=8192 + NUM_ATTN_HEADS=32 + NUM_QUERY_GROUPS=2 + LINEAR_NUM_VALUE_HEADS=64 + VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-27} + ;; + 397b_a17b) + NUM_LAYERS=${NUM_LAYERS:-60} + NUM_EXPERTS=${NUM_EXPERTS:-512} + HIDDEN_SIZE=4096 + FFN_HIDDEN_SIZE=10240 + NUM_ATTN_HEADS=32 + NUM_QUERY_GROUPS=2 + LINEAR_NUM_VALUE_HEADS=64 + VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-27} + ;; + *) + : "${NUM_LAYERS:?NUM_LAYERS must be set for MODEL_VARIANT=$MODEL_VARIANT}" + : "${NUM_EXPERTS:?NUM_EXPERTS must be set for MODEL_VARIANT=$MODEL_VARIANT}" + : "${HIDDEN_SIZE:?HIDDEN_SIZE must be set for MODEL_VARIANT=$MODEL_VARIANT}" + : "${FFN_HIDDEN_SIZE:?FFN_HIDDEN_SIZE must be set for MODEL_VARIANT=$MODEL_VARIANT}" + : "${NUM_ATTN_HEADS:?NUM_ATTN_HEADS must be set for MODEL_VARIANT=$MODEL_VARIANT}" + : "${NUM_QUERY_GROUPS:?NUM_QUERY_GROUPS must be set for MODEL_VARIANT=$MODEL_VARIANT}" + : "${LINEAR_NUM_VALUE_HEADS:?LINEAR_NUM_VALUE_HEADS must be set for MODEL_VARIANT=$MODEL_VARIANT}" + VISION_NUM_LAYERS=${VISION_NUM_LAYERS:-27} + ;; +esac +SEQ_LEN=${SEQ_LEN:-4096} + +WANDB_PROJECT=${WANDB_PROJECT:-'qwen35-vl-0524'} +EXP_NAME="qwen35vl_${MODEL_VARIANT}_tp${TP}_ep${EP}_pp${PP}_cp${CP}" + +RECOMPUTE_VISION=${RECOMPUTE_VISION:-0} +if [ "$RECOMPUTE_VISION" -eq 1 ]; then + EXP_NAME+="_recompute_encoder" +fi +RECOMPUTE=${RECOMPUTE:-0} +if [ "$RECOMPUTE" -eq 1 ]; then + EXP_NAME+="_recompute_decoder" +fi + +USE_PACKED_SEQUENCE=${USE_PACKED_SEQUENCE:-0} +if [ "$USE_PACKED_SEQUENCE" -eq 1 ]; then + EXP_NAME+="_thd" +fi + +MEGATRON_LM_PATH="${MEGATRON_LM_PATH:-$(cd "$(dirname "$0")/../../.." && pwd)}" +ROOT_DIR="${ROOT_DIR:-${MEGATRON_LM_PATH}/local/}" +CHECKPOINT_STORE_PATH="${ROOT_DIR}${EXP_NAME}" +mkdir -p "$CHECKPOINT_STORE_PATH" + +TENSORBOARD_LOGS_PATH="${TENSORBOARD_LOGS_PATH:-${MEGATRON_LM_PATH}/logs}" +mkdir -p "$TENSORBOARD_LOGS_PATH" + +DISTRIBUTED_ARGS=( + --nproc_per_node "$GPUS_PER_NODE" + --nnodes "$NUM_NODES" +) + +if [ "$NUM_NODES" -gt 1 ]; then + DISTRIBUTED_ARGS+=( + --master_addr "${MASTER_ADDR:-localhost}" + --master_port "${MASTER_PORT:-6000}" + ) +fi + +# --- Parallelism --- +MODEL_PARALLEL_ARGS=( + --tensor-model-parallel-size "$TP" + --pipeline-model-parallel-size "$PP" + --expert-model-parallel-size "$EP" + --context-parallel-size "$CP" + --cp-comm-type "a2a" + --expert-tensor-parallel-size 1 + --use-distributed-optimizer + --sequence-parallel +) + +# --- Training --- +TRAINING_ARGS=( + --micro-batch-size "$MBS" + --global-batch-size "$GBS" + --train-iters "${TRAIN_ITERS:-500}" + --adam-beta1 0.9 + --adam-beta2 0.95 + --lr 1.2e-4 + --min-lr 1.2e-5 + --lr-decay-style cosine + --lr-warmup-iters 100 + --lr-decay-iters 2000 + --weight-decay 0.1 + --clip-grad 1.0 + --bf16 + --use-mcore-models + --transformer-impl transformer_engine + --cross-entropy-loss-fusion + --cross-entropy-fusion-impl te + --enable-experimental + --manual-gc + --manual-gc-interval 5 + --mtp-num-layers 1 + --mtp-loss-scaling-factor 0.1 + --sft + --use-flash-attn + # --attention-backend flash + --calculate-per-token-loss +) + +PROFILE_ARGS=() +NSYS_CMD=() +if [ "$PROFILE" = "1" ]; then + PROFILE_ARGS=( + --profile + --profile-step-start "$PROFILE_STEP_START" + --profile-step-end "$PROFILE_STEP_END" + --profile-ranks "$PROFILE_RANKS" + ) + + NSYS_OUTPUT_DIR="${CHECKPOINT_STORE_PATH}/nsys" + mkdir -p "$NSYS_OUTPUT_DIR" + NSYS_CMD=( + nsys profile + --sample=none + --cpuctxsw=none + --trace=cuda,nvtx,cublas,cudnn + --force-overwrite=true + --capture-range=cudaProfilerApi + --capture-range-end=stop + -o "${NSYS_OUTPUT_DIR}/${EXP_NAME}_$(date +%Y%m%d_%H%M%S)" + ) +fi + +# --- Logging & Checkpointing --- +SAVE_INTERVAL=${SAVE_INTERVAL:-500} +EVAL_AND_LOGGING_ARGS=( + --log-interval 1 + --save-interval "$SAVE_INTERVAL" + --eval-interval 500 + --save "$CHECKPOINT_STORE_PATH" + --eval-iters 10 + --tensorboard-dir "$TENSORBOARD_LOGS_PATH" + --wandb-project "$WANDB_PROJECT" + --wandb-exp-name "$EXP_NAME" + --wandb-save-dir "$CHECKPOINT_STORE_PATH" + --log-throughput +) + +# --- Tokenizer --- +TOKENIZER_MODEL=${TOKENIZER_MODEL:-Qwen/Qwen3.5-397B-A17B} +TOKENIZER_ARGS=( + --tokenizer-type HuggingFaceTokenizer + --tokenizer-model "$TOKENIZER_MODEL" +) + +# --- Multimodal-specific --- +MULTIMODAL_ARGS=( + --model-arch qwen35_vl + --model-variant "$MODEL_VARIANT" + --dataset-provider cord_v2 + --hf-processor-path Qwen/Qwen3.5-397B-A17B + --use-vanilla-collate-fn + --image-token-id 248056 + --image-size 224 + --total-seq-length "$SEQ_LEN" + --image-seq-length 256 + --vision-num-layers "$VISION_NUM_LAYERS" +) + +if [ "$USE_PACKED_SEQUENCE" -eq 1 ]; then + MULTIMODAL_ARGS+=( --use-packed-sequence ) +fi + +# --- Qwen3.5 Decoder Architecture (variant-specific dims set above) --- +# These must match examples/multimodal_dev/models/qwen35_vl/configuration.py +GPT_MODEL_ARGS=( + --num-layers "$NUM_LAYERS" + --hidden-size "$HIDDEN_SIZE" + --ffn-hidden-size "$FFN_HIDDEN_SIZE" + --num-attention-heads "$NUM_ATTN_HEADS" + --group-query-attention + --num-query-groups "$NUM_QUERY_GROUPS" + --kv-channels 256 + --max-position-embeddings 262144 + --seq-length "$SEQ_LEN" + --normalization RMSNorm + --apply-layernorm-1p + --norm-epsilon 1e-06 + --swiglu + --disable-bias-linear + --position-embedding-type rope + --rotary-percent 0.25 + --rotary-base 10000000 + --rotary-seq-len-interpolation-factor 1 + --qk-layernorm + --attention-output-gate + --attention-dropout 0.0 + --hidden-dropout 0.0 + --experimental-attention-variant gated_delta_net + --linear-attention-freq 4 + --linear-conv-kernel-dim 4 + --linear-key-head-dim 128 + --linear-value-head-dim 128 + --linear-num-key-heads 16 + --linear-num-value-heads "$LINEAR_NUM_VALUE_HEADS" + --make-vocab-size-divisible-by 485 + --moe-router-force-load-balancing +) + +# --- Tied / untied embeddings --- +# 0.8B, 2B, 4B use tied embeddings; all other variants untie them. +case "$MODEL_VARIANT" in + 0.8b|2b|4b) ;; + *) GPT_MODEL_ARGS+=( --untie-embeddings-and-output-weights ) ;; +esac + +# --- MoE args (MoE variants only) --- +MOE_ARGS=() +case "$MODEL_VARIANT" in + proxy) + MOE_TOPK=2; MOE_FFN_HIDDEN=1024; MOE_SHARED_HIDDEN=1024 + ;; + 35b_a3b|35b_a3b_light) + MOE_TOPK=8; MOE_FFN_HIDDEN=512; MOE_SHARED_HIDDEN=512 + ;; + 122b_a10b) + MOE_TOPK=8; MOE_FFN_HIDDEN=1024; MOE_SHARED_HIDDEN=1024 + ;; + 397b_a17b) + MOE_TOPK=10; MOE_FFN_HIDDEN=1024; MOE_SHARED_HIDDEN=1024 + ;; + 0.8b|2b|4b|9b|27b) + ;; +esac +if [ "${NUM_EXPERTS:-0}" -gt 0 ]; then + MOE_ARGS=( + --num-experts "$NUM_EXPERTS" + --moe-ffn-hidden-size "$MOE_FFN_HIDDEN" + --moe-shared-expert-intermediate-size "$MOE_SHARED_HIDDEN" + --moe-shared-expert-gate + --moe-router-load-balancing-type aux_loss + --moe-router-topk "$MOE_TOPK" + --moe-grouped-gemm + --moe-aux-loss-coeff 1e-3 + --moe-token-dispatcher-type alltoall + --moe-router-dtype fp32 + ) +fi + +# --- Recompute --- +if [ "$RECOMPUTE" -eq 1 ]; then + RECOMPUTE_ARGS=( + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + ) + # RECOMPUTE_ARGS=( + # --recompute-granularity selective + # --recompute-modules moe_act shared_experts layernorm moe + # ) +else + RECOMPUTE_ARGS=() +fi +if [ "$RECOMPUTE_VISION" -eq 1 ]; then + RECOMPUTE_ARGS+=( --recompute-vision ) +fi + +# --- Checkpoint loading --- +# CKPT_LOAD: path to checkpoint directory +# CKPT_FORMAT: override checkpoint format (default: auto-detect) +# CKPT_RESUME: set to 1 to resume training (keep iteration, optimizer, rng); +# default 0 = finetune mode (reset iteration, skip optim/rng) +CKPT_LOAD=${CKPT_LOAD:-} +CKPT_FORMAT=${CKPT_FORMAT:-} +CKPT_RESUME=${CKPT_RESUME:-0} +CKPT_OVERRIDE_SCHEDULER=${CKPT_OVERRIDE_SCHEDULER:-0} +CKPT_ARGS=() +if [ -n "$CKPT_LOAD" ]; then + CKPT_ARGS+=( --load "$CKPT_LOAD" ) + if [ "$CKPT_RESUME" -eq 0 ]; then + CKPT_ARGS+=( --finetune --no-load-optim --no-load-rng ) + fi + if [ -n "$CKPT_FORMAT" ]; then + CKPT_ARGS+=( --ckpt-format "$CKPT_FORMAT" ) + fi + if [ "$CKPT_OVERRIDE_SCHEDULER" -eq 1 ]; then + CKPT_ARGS+=( --override-opt-param-scheduler ) + fi +fi + +# --- FSDP --- +USE_FSDP=${USE_FSDP:-1} +if [ "$USE_FSDP" -eq 1 ]; then + FSDP_ARGS=( + --use-megatron-fsdp + --data-parallel-sharding-strategy optim_grads_params + --no-gradient-accumulation-fusion + --init-model-with-meta-device + --use-distributed-optimizer + --ckpt-format fsdp_dtensor + ) + export CUDA_DEVICE_MAX_CONNECTIONS=8 +else + FSDP_ARGS=() +fi + +echo "================================================================" +echo "Qwen3.5-VL Multimodal Training (multimodal_dev)" +echo " Variant: $MODEL_VARIANT" +echo " Vision layers: $VISION_NUM_LAYERS" +echo " GPUs per node: $GPUS_PER_NODE" +echo " Num nodes: $NUM_NODES" +echo " TP=$TP EP=$EP PP=$PP CP=$CP" +echo " MBS=$MBS GBS=$GBS" +echo " Launcher: $LAUNCHER" +echo " FSDP: $USE_FSDP" +echo " PROFILE: $PROFILE" +if [ -n "$CKPT_LOAD" ]; then + echo " CKPT_LOAD: $CKPT_LOAD" + echo " CKPT_FORMAT: ${CKPT_FORMAT:-auto}" + echo " CKPT_RESUME: $CKPT_RESUME" +fi +if [ "$PROFILE" = "1" ]; then + echo " Profile steps: ${PROFILE_STEP_START}-${PROFILE_STEP_END}" + echo " Profile ranks: $PROFILE_RANKS" +fi +echo "================================================================" + +if [ "$LAUNCHER" = "python" ]; then + LAUNCH_CMD=( python $MEGATRON_LM_PATH/examples/multimodal_dev/pretrain_multimodal.py ) +elif [ "$LAUNCHER" = "torchrun" ]; then + LAUNCH_CMD=( torchrun "${DISTRIBUTED_ARGS[@]}" $MEGATRON_LM_PATH/examples/multimodal_dev/pretrain_multimodal.py ) +else + echo "Unsupported LAUNCHER=$LAUNCHER (expected torchrun or python)" >&2 + exit 1 +fi + +cmd=( "${NSYS_CMD[@]}" "${LAUNCH_CMD[@]}" \ + "${TRAINING_ARGS[@]}" \ + "${PROFILE_ARGS[@]}" \ + "${MODEL_PARALLEL_ARGS[@]}" \ + "${EVAL_AND_LOGGING_ARGS[@]}" \ + "${TOKENIZER_ARGS[@]}" \ + "${MULTIMODAL_ARGS[@]}" \ + "${GPT_MODEL_ARGS[@]}" \ + "${MOE_ARGS[@]}" \ + "${RECOMPUTE_ARGS[@]}" \ + "${FSDP_ARGS[@]}" \ + "${CKPT_ARGS[@]}" ) + +echo "${cmd[@]}" + +if [ "$DRY_RUN" -eq 1 ]; then + echo "=== DRY RUN ===" + exit 0 +else + "${cmd[@]}" +fi diff --git a/examples/multimodal_dev/tests/__init__.py b/examples/multimodal_dev/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/multimodal_dev/tests/test_cp_correctness.py b/examples/multimodal_dev/tests/test_cp_correctness.py new file mode 100644 index 00000000000..8ee71f25eca --- /dev/null +++ b/examples/multimodal_dev/tests/test_cp_correctness.py @@ -0,0 +1,313 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Distributed correctness test for Context Parallelism (CP) support. + +Verifies that CP>1 produces the same (or numerically close) loss as CP=1 +for the Qwen3.5-VL multimodal model by running forward passes with +deterministic data and comparing the per-rank reduced losses. + +Launch with torchrun (N must be >= 2*max_cp_size for zigzag splitting): + + # Test CP=2 on 2 GPUs: + torchrun --nproc_per_node=2 examples/multimodal_dev/tests/test_cp_correctness.py --cp-size 2 + + # Test CP=4 on 4 GPUs: + torchrun --nproc_per_node=4 examples/multimodal_dev/tests/test_cp_correctness.py --cp-size 4 + +The test: + 1. Builds a tiny proxy model (2 layers, no MoE, no vision encoder). + 2. Generates a deterministic batch (same seed on all ranks). + 3. Runs forward with CP=1 (each rank processes the full sequence independently). + 4. Re-initialises model-parallel groups with the target CP size. + 5. Runs forward with CP=target (sequence is split across ranks). + 6. Compares the all-reduced loss values. + +Exit code 0 = PASS, 1 = FAIL. +""" + +import argparse +import os +import sys + +import torch +import torch.distributed as dist + +# Ensure the repo root is on the path so that megatron and examples are importable. +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + + +def _parse_args(): + parser = argparse.ArgumentParser(description="CP correctness test") + parser.add_argument( + "--cp-size", type=int, default=2, + help="Target context-parallel size to compare against CP=1 baseline", + ) + parser.add_argument( + "--seq-len", type=int, default=128, + help="Sequence length (must be divisible by 2*max(cp_size, tp_size*cp_size))", + ) + parser.add_argument( + "--atol", type=float, default=1e-4, + help="Absolute tolerance for loss comparison", + ) + parser.add_argument( + "--rtol", type=float, default=5e-2, + help="Relative tolerance for loss comparison (default 5%%)", + ) + parser.add_argument( + "--seed", type=int, default=42, + help="Random seed for reproducibility", + ) + # Megatron adds extra args; ignore them. + args, _ = parser.parse_known_args() + return args + + +def _init_distributed(): + """Initialise torch.distributed if not already done.""" + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + return local_rank + + +def _init_megatron_parallel(tp_size=1, pp_size=1, cp_size=1, seed=42): + """(Re-)initialise Megatron model-parallel groups and RNG tracker.""" + from megatron.core import parallel_state as ps + from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed + ps.destroy_model_parallel() + ps.initialize_model_parallel( + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + context_parallel_size=cp_size, + ) + model_parallel_cuda_manual_seed(seed) + + +def _make_deterministic_batch(seed, batch_size, seq_len, vocab_size, device): + """Create a deterministic batch identical on all ranks.""" + rng = torch.Generator(device="cpu") + rng.manual_seed(seed) + + input_ids = torch.randint( + 0, vocab_size, (batch_size, seq_len), generator=rng, + ).to(device) + labels = torch.randint( + 0, vocab_size, (batch_size, seq_len), generator=rng, + ).to(device) + loss_mask = torch.ones(batch_size, seq_len, device=device) + # Standard position_ids [B, S] + position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) + + return { + "input_ids": input_ids, + "labels": labels, + "loss_mask": loss_mask, + "position_ids": position_ids, + } + + +def _build_tiny_model(cp_size, device): + """Build a minimal GPTModel for testing (no vision, no MoE).""" + from megatron.core.models.gpt import GPTModel + from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec + from megatron.core.transformer.spec_utils import ModuleSpec + from megatron.core.transformer.transformer_config import TransformerConfig + + hidden_size = 256 + num_heads = 4 + config = TransformerConfig( + num_layers=2, + hidden_size=hidden_size, + ffn_hidden_size=hidden_size * 4, + num_attention_heads=num_heads, + kv_channels=hidden_size // num_heads, + normalization="RMSNorm", + layernorm_epsilon=1e-6, + gated_linear_unit=True, + activation_func=torch.nn.functional.silu, + bf16=True, + context_parallel_size=cp_size, + add_bias_linear=False, + attention_dropout=0.0, + hidden_dropout=0.0, + sequence_parallel=False, + ) + + spec = get_gpt_layer_with_transformer_engine_spec() + + model = GPTModel( + config=config, + transformer_layer_spec=spec, + vocab_size=1024, + max_sequence_length=4096, + pre_process=True, + post_process=True, + parallel_output=False, + share_embeddings_and_output_weights=True, + position_embedding_type="rope", + rotary_percent=1.0, + rotary_base=10000, + ) + model = model.to(device=device, dtype=torch.bfloat16) + return model, config + + +def _forward_with_cp(model, batch, cp_size): + """Run forward pass, handling CP splitting of the batch. + + When cp_size > 1, splits the batch tensors using the same zigzag + logic as multimodal_dev/models/base.py. + """ + from examples.multimodal_dev.models.base import _cp_split_tensor + from megatron.core import parallel_state as ps + + input_ids = batch["input_ids"].clone() + labels = batch["labels"].clone() + loss_mask = batch["loss_mask"].clone() + position_ids = batch["position_ids"].clone() + + if cp_size > 1: + cp_rank = ps.get_context_parallel_rank() + input_ids = _cp_split_tensor(input_ids, seq_dim=1, cp_size=cp_size, cp_rank=cp_rank) + labels = _cp_split_tensor(labels, seq_dim=1, cp_size=cp_size, cp_rank=cp_rank) + loss_mask = _cp_split_tensor(loss_mask, seq_dim=1, cp_size=cp_size, cp_rank=cp_rank) + # position_ids are NOT split — the RoPE layer handles CP slicing internally. + + with torch.no_grad(): + output = model( + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + attention_mask=None, + ) + + # output is the per-token loss [B, S/CP] + masked_loss = (output.float() * loss_mask.float()).sum() + num_tokens = loss_mask.sum() + + # All-reduce across CP ranks to get global loss + if cp_size > 1: + cp_group = ps.get_context_parallel_group() + dist.all_reduce(masked_loss, group=cp_group) + dist.all_reduce(num_tokens, group=cp_group) + + avg_loss = masked_loss / num_tokens.clamp(min=1) + return avg_loss.item() + + +def main(): + args = _parse_args() + local_rank = _init_distributed() + device = torch.device(f"cuda:{local_rank}") + world_size = dist.get_world_size() + rank = dist.get_rank() + + target_cp = args.cp_size + if world_size < target_cp: + if rank == 0: + print( + f"SKIP: world_size={world_size} < cp_size={target_cp}. " + f"Need at least {target_cp} GPUs.", + flush=True, + ) + dist.destroy_process_group() + sys.exit(0) + if world_size % target_cp != 0: + if rank == 0: + print( + f"SKIP: world_size={world_size} is not divisible by cp_size={target_cp}.", + flush=True, + ) + dist.destroy_process_group() + sys.exit(0) + + vocab_size = 1024 + + # Ensure seq_len is divisible by 2 * target_cp + seq_len = args.seq_len + align = 2 * target_cp + if seq_len % align != 0: + seq_len = ((seq_len + align - 1) // align) * align + if rank == 0: + print(f"Adjusted seq_len to {seq_len} for alignment with CP={target_cp}", flush=True) + + # --- Step 1: CP=1 baseline --- + if rank == 0: + print(f"=== CP=1 baseline (world_size={world_size}) ===", flush=True) + + _init_megatron_parallel(cp_size=1) + + # Set deterministic seed for model init + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + model_cp1, _ = _build_tiny_model(cp_size=1, device=device) + + batch = _make_deterministic_batch( + seed=args.seed + 1, batch_size=1, seq_len=seq_len, + vocab_size=vocab_size, device=device, + ) + + loss_cp1 = _forward_with_cp(model_cp1, batch, cp_size=1) + + if rank == 0: + print(f" CP=1 loss: {loss_cp1:.6f}", flush=True) + + # Save model state for reuse + state_dict = model_cp1.state_dict() + del model_cp1 + torch.cuda.empty_cache() + + # --- Step 2: CP=target --- + if rank == 0: + print(f"=== CP={target_cp} (world_size={world_size}) ===", flush=True) + + _init_megatron_parallel(cp_size=target_cp) + + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + model_cpN, _ = _build_tiny_model(cp_size=target_cp, device=device) + + # Load the same weights to ensure identical model + model_cpN.load_state_dict(state_dict, strict=True) + del state_dict + + loss_cpN = _forward_with_cp(model_cpN, batch, cp_size=target_cp) + + if rank == 0: + print(f" CP={target_cp} loss: {loss_cpN:.6f}", flush=True) + + del model_cpN + torch.cuda.empty_cache() + + # --- Step 3: Compare --- + if rank == 0: + diff = abs(loss_cpN - loss_cp1) + rel_diff = diff / max(abs(loss_cp1), 1e-10) + + print(f"\n=== Comparison ===", flush=True) + print(f" CP=1 loss: {loss_cp1:.6f}", flush=True) + print(f" CP={target_cp} loss: {loss_cpN:.6f}", flush=True) + print(f" Absolute diff: {diff:.6e}", flush=True) + print(f" Relative diff: {rel_diff:.6e}", flush=True) + print(f" Tolerance (atol): {args.atol:.6e}", flush=True) + print(f" Tolerance (rtol): {args.rtol:.6e}", flush=True) + + passed = diff <= args.atol + args.rtol * abs(loss_cp1) + if passed: + print(f"\nPASS: CP={target_cp} matches CP=1 baseline", flush=True) + else: + print(f"\nFAIL: CP={target_cp} loss differs from CP=1 beyond tolerance", flush=True) + + dist.barrier() + dist.destroy_process_group() + + if rank == 0 and not passed: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal_dev/tests/test_cp_support.py b/examples/multimodal_dev/tests/test_cp_support.py new file mode 100644 index 00000000000..a6f6b9b8d2c --- /dev/null +++ b/examples/multimodal_dev/tests/test_cp_support.py @@ -0,0 +1,347 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for Context Parallelism (CP) support in multimodal_dev. + +Tests cover: + 1. _cp_split_tensor — zigzag split correctness, reconstruction, and edge cases + 2. _NoCPGroup — dummy process group behaviour + 3. _thd_cp_partition_index — TE-based per-sample THD CP partitioning + 4. Cross-validation against megatron.core.utils.get_batch_on_this_cp_rank + +Run with: pytest examples/multimodal_dev/tests/test_cp_support.py -v +""" + +import pytest +import torch + +from examples.multimodal_dev.models.base import _cp_split_tensor, _NoCPGroup + + +class TestCpSplitTensor: + """Tests for zigzag CP splitting.""" + + def test_basic_2d_cp2(self): + """[B, S] tensor with CP=2 splits and reconstructs correctly.""" + B, S = 2, 16 + t = torch.arange(B * S).reshape(B, S) + cp_size = 2 + + chunks = [] + for rank in range(cp_size): + chunks.append(_cp_split_tensor(t, seq_dim=1, cp_size=cp_size, cp_rank=rank)) + + # Each rank gets S / CP = 8 tokens + for c in chunks: + assert c.shape == (B, S // cp_size) + + # Reconstruct: rank 0 gets chunks [0, 3], rank 1 gets chunks [1, 2] + # Original split into 4 chunks of size 4: + # chunk0=[0..3], chunk1=[4..7], chunk2=[8..11], chunk3=[12..15] + # rank0 = [chunk0, chunk3] = [0..3, 12..15] + # rank1 = [chunk1, chunk2] = [4..7, 8..11] + assert torch.equal(chunks[0][0], torch.tensor([0, 1, 2, 3, 12, 13, 14, 15])) + assert torch.equal(chunks[1][0], torch.tensor([4, 5, 6, 7, 8, 9, 10, 11])) + + def test_3d_mrope_cp2(self): + """[3, B, S] MRoPE tensor with CP=2.""" + B, S = 1, 8 + cp_size = 2 + t = torch.arange(3 * B * S).reshape(3, B, S) + + chunk = _cp_split_tensor(t, seq_dim=2, cp_size=cp_size, cp_rank=0) + assert chunk.shape == (3, B, S // cp_size) + + # All 3 MRoPE components should be split consistently + for d in range(3): + original_row = t[d, 0] # [S] + # With S=8, CP=2: 4 chunks of size 2 + # rank0 gets chunks [0, 3] = positions [0,1, 6,7] + expected = torch.cat([original_row[0:2], original_row[6:8]]) + assert torch.equal(chunk[d, 0], expected) + + def test_sbh_decoder_input(self): + """[S, B, H] decoder input split along dim=0.""" + S, B, H = 16, 2, 4 + cp_size = 2 + t = torch.randn(S, B, H) + + chunk = _cp_split_tensor(t, seq_dim=0, cp_size=cp_size, cp_rank=0) + assert chunk.shape == (S // cp_size, B, H) + + def test_cp4(self): + """CP=4 zigzag pattern.""" + S = 32 + cp_size = 4 + t = torch.arange(S).unsqueeze(0) # [1, 32] + + all_chunks = [] + for rank in range(cp_size): + c = _cp_split_tensor(t, seq_dim=1, cp_size=cp_size, cp_rank=rank) + all_chunks.append(c) + assert c.shape == (1, S // cp_size) + + # All tokens should appear exactly once across ranks + combined = torch.cat(all_chunks, dim=1) + assert torch.equal(combined.sort(dim=1).values, t.sort(dim=1).values) + + def test_cp8(self): + """CP=8 zigzag pattern — all tokens appear exactly once.""" + S = 64 + cp_size = 8 + t = torch.arange(S).unsqueeze(0) # [1, 64] + + all_chunks = [] + for rank in range(cp_size): + c = _cp_split_tensor(t, seq_dim=1, cp_size=cp_size, cp_rank=rank) + all_chunks.append(c) + assert c.shape == (1, S // cp_size) + + combined = torch.cat(all_chunks, dim=1) + assert torch.equal(combined.sort(dim=1).values, t.sort(dim=1).values) + + def test_not_divisible_raises(self): + """Should raise when seq_len not divisible by 2*cp_size.""" + t = torch.randn(2, 10) # S=10, not divisible by 4 + with pytest.raises(AssertionError): + _cp_split_tensor(t, seq_dim=1, cp_size=2, cp_rank=0) + + def test_zigzag_symmetry(self): + """rank 0 and rank (cp_size-1) should get mirror chunks.""" + S = 16 + cp_size = 2 + t = torch.arange(S).unsqueeze(0) # [1, 16] + + c0 = _cp_split_tensor(t, seq_dim=1, cp_size=cp_size, cp_rank=0) + c1 = _cp_split_tensor(t, seq_dim=1, cp_size=cp_size, cp_rank=1) + + # rank0 gets chunks [0, 3], rank1 gets chunks [1, 2] + # chunk0=[0..3], chunk3=[12..15] -> rank0 gets [0..3, 12..15] + # chunk1=[4..7], chunk2=[8..11] -> rank1 gets [4..7, 8..11] + # rank0's first half is earliest, rank1's first half is next + assert c0[0, 0].item() < c1[0, 0].item() # rank0 starts earlier + + def test_matches_megatron_core(self): + """Cross-validate against megatron.core.utils.get_batch_on_this_cp_rank logic. + + We simulate the core function's logic (seq_dim=1, attention_mask seq_dim=2) + and compare. + """ + B, S = 2, 32 + cp_size = 4 + + input_ids = torch.arange(B * S).reshape(B, S) + labels = torch.arange(B * S).reshape(B, S) + 1000 + + for cp_rank in range(cp_size): + # Our implementation + our_ids = _cp_split_tensor(input_ids, seq_dim=1, cp_size=cp_size, cp_rank=cp_rank) + our_labels = _cp_split_tensor(labels, seq_dim=1, cp_size=cp_size, cp_rank=cp_rank) + + # Simulate megatron core logic inline + def core_split(val, seq_dim): + val = val.view( + *val.shape[0:seq_dim], + 2 * cp_size, + val.shape[seq_dim] // (2 * cp_size), + *val.shape[(seq_dim + 1):], + ) + index = torch.zeros(2, dtype=torch.int64, device=val.device) + index[0].fill_(cp_rank) + index[1].fill_(2 * cp_size - cp_rank - 1) + val = val.index_select(seq_dim, index) + val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2):]) + return val + + ref_ids = core_split(input_ids.clone(), seq_dim=1) + ref_labels = core_split(labels.clone(), seq_dim=1) + + assert torch.equal(our_ids, ref_ids), f"input_ids mismatch at rank {cp_rank}" + assert torch.equal(our_labels, ref_labels), f"labels mismatch at rank {cp_rank}" + + def test_batch_dim_preserved(self): + """Batch dimension must be unchanged after split.""" + B, S = 4, 32 + cp_size = 4 + t = torch.randn(B, S) + + for rank in range(cp_size): + c = _cp_split_tensor(t, seq_dim=1, cp_size=cp_size, cp_rank=rank) + assert c.shape[0] == B + + +class TestNoCPGroup: + """Tests for the dummy CP group used by the vision encoder.""" + + def test_size_is_one(self): + g = _NoCPGroup() + assert g.size() == 1 + + def test_rank_is_zero(self): + g = _NoCPGroup() + assert g.rank() == 0 + + +try: + from transformer_engine.pytorch import cpp_extensions as _tex # noqa: F401 + + _HAS_TE = True +except Exception: + _HAS_TE = False + + +@pytest.mark.skipif(not _HAS_TE, reason="TransformerEngine not installed") +class TestThdCpPartition: + """Verify TE-based per-sample THD + CP partition matches THD semantics. + + Each packed sub-sample of length ``s_i`` (where ``s_i % (2*cp_size) == 0``) + is split into ``2*cp_size`` zigzag chunks per sample; rank ``r`` gets + chunks ``[r, 2*cp_size - r - 1]`` of every sample. The union across + ranks must cover every token position exactly once. + """ + + @staticmethod + def _make_padded_packed(seqlens, divisor): + """Concatenate per-sample dummy tokens after padding each sample to a + multiple of *divisor*. Returns ``(input_ids[1, T], cu_seqlens_padded)``. + """ + import math + padded = [math.ceil(s / divisor) * divisor for s in seqlens] + chunks = [] + next_id = 1 + for s, p in zip(seqlens, padded): + chunks.append(torch.arange(next_id, next_id + s, dtype=torch.int64)) + chunks.append(torch.zeros(p - s, dtype=torch.int64)) # padding + next_id += s + input_ids = torch.cat(chunks, dim=0).unsqueeze(0) # [1, T] + cu_seqlens_padded = torch.tensor( + [0] + list(torch.tensor(padded).cumsum(0).tolist()), + dtype=torch.int32, + ) + return input_ids, cu_seqlens_padded + + def _ensure_cuda(self, x): + return x.cuda() if torch.cuda.is_available() else x + + def test_partition_covers_all_positions_cp2(self): + from examples.multimodal_dev.models.base import _thd_cp_partition_index + + cp_size = 2 + seqlens = [5, 7, 3] # valid lengths + input_ids, cu_seqlens_padded = self._make_padded_packed( + seqlens, divisor=2 * cp_size, + ) + input_ids = self._ensure_cuda(input_ids) + cu_seqlens_padded = self._ensure_cuda(cu_seqlens_padded) + T = input_ids.shape[1] + + # Union of per-rank indices must be all positions exactly once. + seen = torch.zeros(T, dtype=torch.long, device=input_ids.device) + for cp_rank in range(cp_size): + idx = _thd_cp_partition_index( + cu_seqlens_padded, T, cp_size, cp_rank, + ) + assert idx.numel() == T // cp_size, ( + f"rank {cp_rank}: expected {T // cp_size} tokens, got {idx.numel()}" + ) + seen.scatter_add_( + 0, idx.long(), torch.ones_like(idx, dtype=seen.dtype), + ) + assert torch.all(seen == 1), ( + f"Position coverage broken: counts={seen.tolist()}" + ) + + def test_index_select_aligns_inputs_and_position_ids_cp2(self): + """input_ids, loss_mask, and (3, 1, T) position_ids index_select with + the same partition index produce shape-consistent per-rank tensors.""" + from examples.multimodal_dev.models.base import _thd_cp_partition_index + + cp_size = 2 + seqlens = [8, 4] + input_ids, cu_seqlens_padded = self._make_padded_packed( + seqlens, divisor=2 * cp_size, + ) + input_ids = self._ensure_cuda(input_ids) + cu_seqlens_padded = self._ensure_cuda(cu_seqlens_padded) + T = input_ids.shape[1] + labels = input_ids + 1000 + loss_mask = (input_ids != 0).float() + position_ids = ( + torch.arange(T, device=input_ids.device) + .unsqueeze(0).unsqueeze(0).expand(3, 1, T).contiguous() + ) + H = 4 + decoder_input = ( + torch.arange(T * H, dtype=torch.float32, device=input_ids.device) + .view(T, 1, H) + ) + + for cp_rank in range(cp_size): + idx = _thd_cp_partition_index( + cu_seqlens_padded, T, cp_size, cp_rank, + ) + ii = input_ids.index_select(1, idx) + ll = labels.index_select(1, idx) + lm = loss_mask.index_select(1, idx) + pi = position_ids.index_select(2, idx) + di = decoder_input.index_select(0, idx) + + assert ii.shape == (1, T // cp_size) + assert ll.shape == (1, T // cp_size) + assert lm.shape == (1, T // cp_size) + assert pi.shape == (3, 1, T // cp_size) + assert di.shape == (T // cp_size, 1, H) + # Sliced position_ids is just the partition index itself + # (since position_ids was arange(T) over all positions). + assert torch.equal(pi[0, 0], idx.to(pi.dtype)) + # All MRoPE rows agree. + assert torch.equal(pi[1, 0], pi[0, 0]) + assert torch.equal(pi[2, 0], pi[0, 0]) + + def test_partition_cp4_three_samples(self): + from examples.multimodal_dev.models.base import _thd_cp_partition_index + + cp_size = 4 + seqlens = [12, 4, 8] + input_ids, cu_seqlens_padded = self._make_padded_packed( + seqlens, divisor=2 * cp_size, + ) + input_ids = self._ensure_cuda(input_ids) + cu_seqlens_padded = self._ensure_cuda(cu_seqlens_padded) + T = input_ids.shape[1] + + seen = torch.zeros(T, dtype=torch.long, device=input_ids.device) + for cp_rank in range(cp_size): + idx = _thd_cp_partition_index( + cu_seqlens_padded, T, cp_size, cp_rank, + ) + assert idx.numel() == T // cp_size + seen.scatter_add_( + 0, idx.long(), torch.ones_like(idx, dtype=seen.dtype), + ) + assert torch.all(seen == 1) + + def test_loss_mask_zero_kept_per_rank(self): + """Pad-token positions (loss_mask=0) survive as 0 on whichever rank + they land — sanity check that we don't accidentally discard them.""" + from examples.multimodal_dev.models.base import _thd_cp_partition_index + + cp_size = 2 + seqlens = [5, 3] + input_ids, cu_seqlens_padded = self._make_padded_packed( + seqlens, divisor=2 * cp_size, + ) + input_ids = self._ensure_cuda(input_ids) + cu_seqlens_padded = self._ensure_cuda(cu_seqlens_padded) + T = input_ids.shape[1] + loss_mask = (input_ids != 0).float() + total_zeros = (loss_mask == 0).sum().item() + + zeros_seen = 0 + for cp_rank in range(cp_size): + idx = _thd_cp_partition_index( + cu_seqlens_padded, T, cp_size, cp_rank, + ) + zeros_seen += ( + loss_mask.index_select(1, idx) == 0 + ).sum().item() + assert zeros_seen == total_zeros diff --git a/examples/multimodal_dev/tests/test_mrope_parity.py b/examples/multimodal_dev/tests/test_mrope_parity.py new file mode 100644 index 00000000000..17cedb0536f --- /dev/null +++ b/examples/multimodal_dev/tests/test_mrope_parity.py @@ -0,0 +1,663 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Parity tests for ``get_rope_index`` (MRoPE position-ID computation). + +Two properties are verified: + +1. **BSHD backwards compatibility** — the refactored ``get_rope_index`` + returns bit-identical ``(position_ids, mrope_position_deltas)`` to + the pre-refactor implementation on padded ``[B, S]`` batches. +2. **THD == BSHD on the valid region** — when the same variable-length + samples are fed through both layouts (BSHD with right-padding; THD + packed with ``cu_seqlens_q`` / ``cu_seqlens_q_padded``), positions at + every valid slot agree. + +The pre-refactor function is pinned inline as ``_old_get_rope_index`` +so this test stays self-contained. Run with:: + + python -m pytest examples/multimodal_dev/tests/test_mrope_parity.py -v + +or directly:: + + python examples/multimodal_dev/tests/test_mrope_parity.py +""" + +import math +import os +import sys +from itertools import accumulate + +import torch +import torch.nn.functional as F + +_REPO_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../.."), +) +# Insert at position 0 unconditionally — other entries on sys.path +# (e.g. a sibling Megatron-LM checkout) have their own ``examples`` +# package that would otherwise shadow ours. +if _REPO_ROOT in sys.path: + sys.path.remove(_REPO_ROOT) +sys.path.insert(0, _REPO_ROOT) + +from megatron.core.packed_seq_params import PackedSeqParams + +from examples.multimodal_dev.models.qwen35_vl.mrope import get_rope_index + +# ----------------------------------------------------------------------------- +# Token-ID constants (match Qwen3.5-VL, but values are arbitrary for this test) +# ----------------------------------------------------------------------------- + +IMAGE_TOKEN_ID = 248056 +VIDEO_TOKEN_ID = 248057 +VISION_START_TOKEN_ID = 248053 +SPATIAL_MERGE_SIZE = 2 + + +# ----------------------------------------------------------------------------- +# Pinned reference implementation (pre-refactor BSHD path) +# ----------------------------------------------------------------------------- + +def _old_get_rope_index( + spatial_merge_size, + image_token_id, + video_token_id, + vision_start_token_id, + input_ids=None, + image_grid_thw=None, + video_grid_thw=None, + attention_mask=None, +): + """Pre-refactor BSHD implementation of ``get_rope_index``. + + Copied verbatim (modulo the broken cu_seqlens branch, which this + parity test does not exercise) so we can diff against the new + implementation on BSHD inputs. + """ + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave( + video_grid_thw, video_grid_thw[:, 0], dim=0, + ) + video_grid_thw[:, 0] = 1 + + mrope_position_deltas = [] + + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + + for i, sample_input_ids in enumerate(total_input_ids): + sample_input_ids = sample_input_ids[attention_mask[i] == 1] + vision_start_indices = torch.argwhere( + sample_input_ids == vision_start_token_id, + ).squeeze(1) + vision_tokens = sample_input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = sample_input_ids.tolist() + llm_pos_ids_list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if llm_pos_ids_list + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + + st_idx + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + + text_len + + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if llm_pos_ids_list + else 0 + ) + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[ + ..., i, attention_mask[i] == 1 + ] = llm_positions.to(position_ids.device) + mrope_position_deltas.append( + llm_positions.max() + 1 - len(total_input_ids[i]), + ) + + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=total_input_ids.device, + ).unsqueeze(1) + return position_ids, mrope_position_deltas + + # Text-only fallback. + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = ( + position_ids.unsqueeze(0) + .expand(3, -1, -1) + .to(attention_mask.device) + ) + max_position_ids = ( + position_ids.max(0, keepdim=False)[0] + .max(-1, keepdim=True)[0] + ) + mrope_position_deltas = ( + max_position_ids + 1 - attention_mask.shape[-1] + ) + else: + position_ids = ( + torch.arange( + input_ids.shape[1], device=input_ids.device, + ) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + return position_ids, mrope_position_deltas + + +# ----------------------------------------------------------------------------- +# Synthetic-sample builder +# ----------------------------------------------------------------------------- + +def _build_sample( + prefix_text_len, + grids, + suffix_text_len, + text_token=100, +): + """Build one variable-length sample with ``len(grids)`` images. + + Layout per image: ``vision_start_id`` then + ``llm_grid_t * llm_grid_h * llm_grid_w`` ``image_token_id`` slots + (where ``llm_grid_* = grid_* // spatial_merge_size`` for h/w). + Grids use ``t=1``. + + Returns ``(input_ids [L], image_grid_thw [N, 3])``. + """ + tokens = [text_token] * prefix_text_len + grid_rows = [] + for t, h, w in grids: + n_image_tokens = ( + t * (h // SPATIAL_MERGE_SIZE) * (w // SPATIAL_MERGE_SIZE) + ) + tokens.append(VISION_START_TOKEN_ID) + tokens.extend([IMAGE_TOKEN_ID] * n_image_tokens) + grid_rows.append([t, h, w]) + tokens.extend([text_token + 1] * suffix_text_len) + input_ids = torch.tensor(tokens, dtype=torch.int64) + image_grid_thw = torch.tensor(grid_rows, dtype=torch.int64) + return input_ids, image_grid_thw + + +def _sample_bank(): + """A small bank of samples covering text-only, single-image, multi-image.""" + return [ + _build_sample( + prefix_text_len=5, + grids=[(1, 4, 4)], + suffix_text_len=7, + ), + _build_sample( + prefix_text_len=3, + grids=[(1, 2, 2), (1, 4, 6)], + suffix_text_len=4, + ), + _build_sample( + prefix_text_len=10, + grids=[], + suffix_text_len=0, + ), + _build_sample( + prefix_text_len=0, + grids=[(1, 6, 4)], + suffix_text_len=2, + ), + ] + + +# ----------------------------------------------------------------------------- +# Test 1: BSHD backwards compatibility +# ----------------------------------------------------------------------------- + +def test_bshd_matches_old_reference(): + """New ``get_rope_index`` equals the pinned reference on BSHD inputs.""" + samples = _sample_bank() + max_len = max(s.numel() for s, _ in samples) + + input_ids_rows = [] + mask_rows = [] + grid_rows = [] + for tokens, grids in samples: + L = tokens.numel() + padded = F.pad(tokens, (0, max_len - L), value=0) + mask = torch.zeros(max_len, dtype=torch.int64) + mask[:L] = 1 + input_ids_rows.append(padded) + mask_rows.append(mask) + if grids.numel() > 0: + grid_rows.append(grids) + + input_ids = torch.stack(input_ids_rows) # [B, S] + attention_mask = torch.stack(mask_rows) # [B, S] + image_grid_thw = ( + torch.cat(grid_rows, dim=0) if grid_rows else None + ) + + old_pos, old_delta = _old_get_rope_index( + spatial_merge_size=SPATIAL_MERGE_SIZE, + image_token_id=IMAGE_TOKEN_ID, + video_token_id=VIDEO_TOKEN_ID, + vision_start_token_id=VISION_START_TOKEN_ID, + input_ids=input_ids, + image_grid_thw=image_grid_thw, + attention_mask=attention_mask, + ) + new_pos, new_delta = get_rope_index( + spatial_merge_size=SPATIAL_MERGE_SIZE, + image_token_id=IMAGE_TOKEN_ID, + video_token_id=VIDEO_TOKEN_ID, + vision_start_token_id=VISION_START_TOKEN_ID, + input_ids=input_ids, + image_grid_thw=image_grid_thw, + attention_mask=attention_mask, + packed_seq_params=None, + ) + + assert torch.equal(old_pos, new_pos), ( + f"BSHD position_ids differ.\nold:\n{old_pos}\nnew:\n{new_pos}" + ) + assert torch.equal(old_delta, new_delta), ( + f"BSHD mrope_position_deltas differ.\n" + f"old: {old_delta}\nnew: {new_delta}" + ) + + +# ----------------------------------------------------------------------------- +# Test 2: THD positions match BSHD positions on the valid region +# ----------------------------------------------------------------------------- + +def _pack_samples(samples, divisible_by=1): + """Pack ``samples`` into a single ``[1, T]`` tensor the same way + ``pack_or_pad_batch`` does, and build ``PackedSeqParams``. + + Each per-sample tensor is right-padded to a multiple of + ``divisible_by`` before concatenation. ``cu_seqlens_q`` tracks + unpadded lengths; ``cu_seqlens_q_padded`` tracks the packed layout. + """ + padded_chunks = [] + seqlens = [] + seqlens_padded = [] + grid_rows = [] + for tokens, grids in samples: + L = tokens.numel() + target_L = math.ceil(L / divisible_by) * divisible_by + padded_chunks.append(F.pad(tokens, (0, target_L - L), value=0)) + seqlens.append(L) + seqlens_padded.append(target_L) + if grids.numel() > 0: + grid_rows.append(grids) + + packed = torch.cat(padded_chunks, dim=0).unsqueeze(0) # [1, T] + cu_seqlens = torch.tensor( + list(accumulate(seqlens, initial=0)), dtype=torch.int32, + ) + cu_seqlens_padded = torch.tensor( + list(accumulate(seqlens_padded, initial=0)), dtype=torch.int32, + ) + psp = PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + max_seqlen_q=max(seqlens_padded), + max_seqlen_kv=max(seqlens_padded), + ) + image_grid_thw = ( + torch.cat(grid_rows, dim=0) if grid_rows else None + ) + return packed, psp, image_grid_thw, seqlens, seqlens_padded + + +def test_thd_matches_bshd_padded(): + """THD positions at every valid slot equal BSHD positions on the + equivalent right-padded batch. + """ + samples = _sample_bank() + + # BSHD side: right-pad to common max_len. + max_len = max(s.numel() for s, _ in samples) + input_ids_rows = [] + mask_rows = [] + grid_rows = [] + for tokens, grids in samples: + L = tokens.numel() + input_ids_rows.append(F.pad(tokens, (0, max_len - L), value=0)) + m = torch.zeros(max_len, dtype=torch.int64) + m[:L] = 1 + mask_rows.append(m) + if grids.numel() > 0: + grid_rows.append(grids) + bshd_input_ids = torch.stack(input_ids_rows) + bshd_mask = torch.stack(mask_rows) + bshd_grid = torch.cat(grid_rows, dim=0) if grid_rows else None + + bshd_pos, _ = get_rope_index( + spatial_merge_size=SPATIAL_MERGE_SIZE, + image_token_id=IMAGE_TOKEN_ID, + video_token_id=VIDEO_TOKEN_ID, + vision_start_token_id=VISION_START_TOKEN_ID, + input_ids=bshd_input_ids, + image_grid_thw=bshd_grid, + attention_mask=bshd_mask, + ) + # bshd_pos: [3, B, S_pad] + + # THD side: pack with a non-trivial divisor so the padded and + # unpadded cu_seqlens diverge — this exercises the distinction. + for divisible_by in (1, 4): + packed_input_ids, psp, thd_grid, seqlens, seqlens_padded = ( + _pack_samples(samples, divisible_by=divisible_by) + ) + thd_pos, _ = get_rope_index( + spatial_merge_size=SPATIAL_MERGE_SIZE, + image_token_id=IMAGE_TOKEN_ID, + video_token_id=VIDEO_TOKEN_ID, + vision_start_token_id=VISION_START_TOKEN_ID, + input_ids=packed_input_ids, + image_grid_thw=thd_grid, + packed_seq_params=psp, + ) + # thd_pos: [3, 1, T] + assert thd_pos.shape == ( + 3, 1, packed_input_ids.shape[1], + ), f"bad THD shape {thd_pos.shape}" + + seg_starts = list(accumulate(seqlens_padded, initial=0)) + for k, (valid_len, seg_start) in enumerate( + zip(seqlens, seg_starts) + ): + thd_slice = thd_pos[:, 0, seg_start:seg_start + valid_len] + bshd_slice = bshd_pos[:, k, :valid_len] + assert torch.equal(thd_slice, bshd_slice), ( + f"[divisible_by={divisible_by}] segment {k} " + f"(valid_len={valid_len}, seg_start={seg_start}) " + f"disagrees:\nTHD:\n{thd_slice}\nBSHD:\n{bshd_slice}" + ) + + +# ----------------------------------------------------------------------------- +# Test 3: THD with no images (text-only packed) +# ----------------------------------------------------------------------------- + +def test_thd_text_only_restarts_per_segment(): + """Text-only THD: each segment gets a fresh ``[0..valid_len-1]`` range.""" + samples = [ + _build_sample(prefix_text_len=6, grids=[], suffix_text_len=0), + _build_sample(prefix_text_len=11, grids=[], suffix_text_len=0), + _build_sample(prefix_text_len=3, grids=[], suffix_text_len=0), + ] + packed_input_ids, psp, _, seqlens, seqlens_padded = _pack_samples( + samples, divisible_by=4, + ) + thd_pos, _ = get_rope_index( + spatial_merge_size=SPATIAL_MERGE_SIZE, + image_token_id=IMAGE_TOKEN_ID, + video_token_id=VIDEO_TOKEN_ID, + vision_start_token_id=VISION_START_TOKEN_ID, + input_ids=packed_input_ids, + image_grid_thw=None, + packed_seq_params=psp, + ) + + seg_starts = list(accumulate(seqlens_padded, initial=0)) + for valid_len, seg_start in zip(seqlens, seg_starts): + expected = ( + torch.arange(valid_len, dtype=thd_pos.dtype) + .view(1, -1) + .expand(3, -1) + ) + got = thd_pos[:, 0, seg_start:seg_start + valid_len] + assert torch.equal(got, expected), ( + f"text-only segment mismatch at seg_start={seg_start}, " + f"valid_len={valid_len}:\n{got}\nexpected:\n{expected}" + ) + + +# ----------------------------------------------------------------------------- +# Test 4: Explicit two-sequence batch with vision, both in BSHD and THD +# ----------------------------------------------------------------------------- + +def _two_image_samples(): + """Two samples, each with one image — the smallest case that can + expose a bug where segment k > 0 positions leak state from segment + k - 1 (e.g. non-restarted ``st_idx`` or a stale ``image_index``). + """ + return [ + _build_sample( + prefix_text_len=5, + grids=[(1, 4, 4)], # 4 image tokens after spatial merge + suffix_text_len=3, + ), + _build_sample( + prefix_text_len=4, + grids=[(1, 4, 4)], + suffix_text_len=6, + ), + ] + + +def test_bshd_batch_size_2_with_vision(): + """BSHD with ``B == 2``: both rows' positions restart at 0 and match + the pinned reference. + """ + samples = _two_image_samples() + max_len = max(s.numel() for s, _ in samples) + + input_ids_rows, mask_rows, grid_rows = [], [], [] + for tokens, grids in samples: + L = tokens.numel() + input_ids_rows.append(F.pad(tokens, (0, max_len - L), value=0)) + m = torch.zeros(max_len, dtype=torch.int64) + m[:L] = 1 + mask_rows.append(m) + grid_rows.append(grids) + + input_ids = torch.stack(input_ids_rows) # [2, S] + attention_mask = torch.stack(mask_rows) + image_grid_thw = torch.cat(grid_rows, dim=0) + assert input_ids.shape[0] == 2 + + old_pos, old_delta = _old_get_rope_index( + spatial_merge_size=SPATIAL_MERGE_SIZE, + image_token_id=IMAGE_TOKEN_ID, + video_token_id=VIDEO_TOKEN_ID, + vision_start_token_id=VISION_START_TOKEN_ID, + input_ids=input_ids, + image_grid_thw=image_grid_thw, + attention_mask=attention_mask, + ) + new_pos, new_delta = get_rope_index( + spatial_merge_size=SPATIAL_MERGE_SIZE, + image_token_id=IMAGE_TOKEN_ID, + video_token_id=VIDEO_TOKEN_ID, + vision_start_token_id=VISION_START_TOKEN_ID, + input_ids=input_ids, + image_grid_thw=image_grid_thw, + attention_mask=attention_mask, + ) + assert torch.equal(old_pos, new_pos), ( + f"BSHD B=2 position mismatch vs reference.\n" + f"old:\n{old_pos}\nnew:\n{new_pos}" + ) + assert torch.equal(old_delta, new_delta) + + # Both rows must start at position 0. + for i in range(2): + valid_len = int(attention_mask[i].sum().item()) + assert torch.all(new_pos[:, i, 0] == 0), ( + f"row {i} does not start at 0: {new_pos[:, i, 0]}" + ) + # Sanity: positions within the valid region are strictly < valid_len + # would be wrong (MRoPE can skip positions), so just check max. + assert new_pos[:, i, :valid_len].max() < valid_len + + +def test_thd_batch_size_2_with_vision(): + """THD with 2 packed sequences: seg 1 positions restart at 0 and + equal BSHD row 1 on the valid region (bit-identical). + """ + samples = _two_image_samples() + + # BSHD reference. + max_len = max(s.numel() for s, _ in samples) + rows, masks, grids_bshd = [], [], [] + for tokens, grids in samples: + L = tokens.numel() + rows.append(F.pad(tokens, (0, max_len - L), value=0)) + m = torch.zeros(max_len, dtype=torch.int64) + m[:L] = 1 + masks.append(m) + grids_bshd.append(grids) + bshd_input_ids = torch.stack(rows) + bshd_mask = torch.stack(masks) + bshd_grid = torch.cat(grids_bshd, dim=0) + bshd_pos, _ = get_rope_index( + spatial_merge_size=SPATIAL_MERGE_SIZE, + image_token_id=IMAGE_TOKEN_ID, + video_token_id=VIDEO_TOKEN_ID, + vision_start_token_id=VISION_START_TOKEN_ID, + input_ids=bshd_input_ids, + image_grid_thw=bshd_grid, + attention_mask=bshd_mask, + ) + + # THD packed version with a non-trivial divisor so padded and + # unpadded cu_seqlens disagree. + packed_input_ids, psp, thd_grid, seqlens, seqlens_padded = ( + _pack_samples(samples, divisible_by=4) + ) + assert len(seqlens) == 2 + thd_pos, _ = get_rope_index( + spatial_merge_size=SPATIAL_MERGE_SIZE, + image_token_id=IMAGE_TOKEN_ID, + video_token_id=VIDEO_TOKEN_ID, + vision_start_token_id=VISION_START_TOKEN_ID, + input_ids=packed_input_ids, + image_grid_thw=thd_grid, + packed_seq_params=psp, + ) + + seg_starts = list(accumulate(seqlens_padded, initial=0)) + for k, (valid_len, seg_start) in enumerate( + zip(seqlens, seg_starts) + ): + thd_slice = thd_pos[:, 0, seg_start:seg_start + valid_len] + bshd_slice = bshd_pos[:, k, :valid_len] + assert torch.equal(thd_slice, bshd_slice), ( + f"seg {k} THD vs BSHD row {k} mismatch.\n" + f"THD:\n{thd_slice}\nBSHD:\n{bshd_slice}" + ) + # Critical: seg k must start at position 0 (bug 2 check). + assert torch.all(thd_slice[:, 0] == 0), ( + f"seg {k} does not start at 0 — positions leaked from " + f"previous segment: first col = {thd_slice[:, 0]}" + ) + + +if __name__ == "__main__": + test_bshd_matches_old_reference() + print("[ok] test_bshd_matches_old_reference") + test_thd_matches_bshd_padded() + print("[ok] test_thd_matches_bshd_padded") + test_thd_text_only_restarts_per_segment() + print("[ok] test_thd_text_only_restarts_per_segment") + test_bshd_batch_size_2_with_vision() + print("[ok] test_bshd_batch_size_2_with_vision") + test_thd_batch_size_2_with_vision() + print("[ok] test_thd_batch_size_2_with_vision") + print("All parity tests passed.") diff --git a/examples/multimodal_dev/tests/test_thd_correctness.py b/examples/multimodal_dev/tests/test_thd_correctness.py new file mode 100644 index 00000000000..d9c731cf4f2 --- /dev/null +++ b/examples/multimodal_dev/tests/test_thd_correctness.py @@ -0,0 +1,388 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""BSHD vs THD correctness test for multimodal_dev packed sequence support. + +Validates that packing a [B, S] batch into [1, T] THD format produces +numerically equivalent loss values and gradient norms through a GPTModel. + +The test uses equal-length sequences (no padding) so that BSHD causal +attention and THD cu_seqlens-based causal attention are mathematically +identical. This makes any numerical deviation a real bug rather than an +expected consequence of different padding/masking strategies. + +Usage:: + + # Single GPU (flash attention): + torchrun --nproc_per_node=1 \\ + examples/multimodal_dev/tests/test_thd_correctness.py + + # Override model size: + torchrun --nproc_per_node=1 \\ + examples/multimodal_dev/tests/test_thd_correctness.py \\ + --num-layers 4 --hidden-size 512 --num-heads 8 --num-kv-heads 4 +""" + +import argparse +import os +import sys + +import torch + +_REPO_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../.."), +) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +from megatron.core import parallel_state +from megatron.core.models.gpt import GPTModel +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_with_transformer_engine_spec, +) +from megatron.core.tensor_parallel.random import ( + model_parallel_cuda_manual_seed, +) +from megatron.core.transformer.transformer_config import TransformerConfig +from tests.unit_tests.test_utilities import Utils + +from examples.multimodal_dev.forward_step import ( + _build_packed_seq_params, + _pack_batch, +) + + +# =================================================================== +# Helpers +# =================================================================== + + +def _grad_norm(model): + """L2 norm of all parameter gradients.""" + total = 0.0 + for p in model.parameters(): + if p.grad is not None: + total += p.grad.data.float().norm(2).item() ** 2 + return total ** 0.5 + + +def _mean_loss(per_token_loss, loss_mask): + """Mean loss over valid tokens.""" + flat = per_token_loss.float().view(-1) + mask = loss_mask.float().view(-1) + return (flat * mask).sum() / mask.sum().clamp(min=1) + + +def _build_model(cfg, vocab_size, max_seq_len): + """Build a small GPTModel for testing.""" + spec = get_gpt_layer_with_transformer_engine_spec() + model = GPTModel( + config=cfg, + transformer_layer_spec=spec, + vocab_size=vocab_size, + max_sequence_length=max_seq_len, + pre_process=True, + post_process=True, + parallel_output=False, + position_embedding_type="rope", + ) + model.cuda() + return model + + +# =================================================================== +# Core test logic +# =================================================================== + + +def run_equal_length_test( + model, + batch_size, + seq_len, + vocab_size, + seed, + atol_loss, + rtol_grad, +): + """Compare BSHD and THD with equal-length sequences (no padding). + + Returns a dict with test metrics for logging. + """ + B, S = batch_size, seq_len + + # Deterministic data generation. + torch.manual_seed(seed + 1) + input_ids = torch.randint(0, vocab_size, (B, S), device="cuda") + labels = torch.randint(0, vocab_size, (B, S), device="cuda") + loss_mask = torch.ones(B, S, device="cuda") + position_ids = ( + torch.arange(S, device="cuda").unsqueeze(0).expand(B, -1).contiguous() + ) + + # ---- BSHD forward / backward ---- + output_bshd = model( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=None, + labels=labels, + loss_mask=loss_mask, + ) + bshd_loss = _mean_loss(output_bshd, loss_mask) + bshd_loss.backward() + bshd_gn = _grad_norm(model) + bshd_lv = bshd_loss.item() + bshd_per_token = output_bshd.detach().float().view(-1).clone() + + model.zero_grad() + + # ---- THD forward / backward ---- + batch = { + "input_ids": input_ids.clone(), + "labels": labels.clone(), + "loss_mask": loss_mask.clone(), + "position_ids": position_ids.clone(), + } + packed = _pack_batch(batch) + psp = packed.pop("packed_seq_params") + + output_thd = model( + input_ids=packed["input_ids"], + position_ids=packed["position_ids"], + attention_mask=None, + labels=packed["labels"], + loss_mask=packed["loss_mask"], + packed_seq_params=psp, + ) + thd_loss = _mean_loss(output_thd, packed["loss_mask"]) + thd_loss.backward() + thd_gn = _grad_norm(model) + thd_lv = thd_loss.item() + thd_per_token = output_thd.detach().float().view(-1).clone() + + model.zero_grad() + + # ---- Comparison ---- + loss_diff = abs(bshd_lv - thd_lv) + grad_diff = abs(bshd_gn - thd_gn) + grad_rel = grad_diff / max(bshd_gn, 1e-8) + token_max_diff = (bshd_per_token - thd_per_token).abs().max().item() + token_mean_diff = (bshd_per_token - thd_per_token).abs().mean().item() + + loss_ok = loss_diff < atol_loss + grad_ok = grad_rel < rtol_grad + + metrics = dict( + bshd_loss=bshd_lv, + thd_loss=thd_lv, + loss_diff=loss_diff, + bshd_grad_norm=bshd_gn, + thd_grad_norm=thd_gn, + grad_diff=grad_diff, + grad_rel=grad_rel, + token_max_diff=token_max_diff, + token_mean_diff=token_mean_diff, + loss_ok=loss_ok, + grad_ok=grad_ok, + ) + return metrics + + +def run_variable_length_smoke_test(model, vocab_size, seed): + """Smoke test: variable-length sequences packed to THD. + + Does NOT compare against BSHD (padding in BSHD changes attention + context). Validates that: + - Packing produces correct shapes + - Forward + backward complete without error + - Loss is finite + - Gradients are finite and non-zero + + Returns a dict with test metrics. + """ + seq_lengths = [128, 96, 112, 80] + S = max(seq_lengths) + B = len(seq_lengths) + + torch.manual_seed(seed + 2) + input_ids = torch.randint(0, vocab_size, (B, S), device="cuda") + labels = torch.randint(0, vocab_size, (B, S), device="cuda") + loss_mask = torch.ones(B, S, device="cuda") + position_ids = ( + torch.arange(S, device="cuda").unsqueeze(0).expand(B, -1).contiguous() + ) + + # Build attention_mask to indicate valid tokens per sample. + attention_mask = torch.zeros(B, S, device="cuda") + for i, sl in enumerate(seq_lengths): + attention_mask[i, :sl] = 1.0 + loss_mask[i, sl:] = 0.0 + + batch = { + "input_ids": input_ids, + "labels": labels, + "loss_mask": loss_mask, + "position_ids": position_ids, + "attention_mask": attention_mask, + } + packed = _pack_batch(batch) + psp = packed.pop("packed_seq_params") + + T = sum(seq_lengths) + assert packed["input_ids"].shape == (1, T), ( + f"Expected [1, {T}], got {packed['input_ids'].shape}" + ) + assert packed["labels"].shape == (1, T) + assert packed["loss_mask"].shape == (1, T) + assert psp.cu_seqlens_q.tolist() == [ + 0, + seq_lengths[0], + seq_lengths[0] + seq_lengths[1], + seq_lengths[0] + seq_lengths[1] + seq_lengths[2], + T, + ] + + output = model( + input_ids=packed["input_ids"], + position_ids=packed["position_ids"], + attention_mask=None, + labels=packed["labels"], + loss_mask=packed["loss_mask"], + packed_seq_params=psp, + ) + loss = _mean_loss(output, packed["loss_mask"]) + loss.backward() + gn = _grad_norm(model) + loss_val = loss.item() + + model.zero_grad() + + loss_finite = torch.isfinite(torch.tensor(loss_val)).item() + grad_finite = torch.isfinite(torch.tensor(gn)).item() + grad_nonzero = gn > 0 + + return dict( + loss=loss_val, + grad_norm=gn, + total_tokens=T, + loss_finite=loss_finite, + grad_finite=grad_finite, + grad_nonzero=grad_nonzero, + passed=loss_finite and grad_finite and grad_nonzero, + ) + + +# =================================================================== +# Main +# =================================================================== + + +def _print_banner(title): + print(f"\n{'='*60}") + print(f" {title}") + print(f"{'='*60}") + + +def main(): + parser = argparse.ArgumentParser( + description="BSHD vs THD correctness test", + ) + parser.add_argument("--batch-size", type=int, default=4) + parser.add_argument("--seq-len", type=int, default=128) + parser.add_argument("--vocab-size", type=int, default=1024) + parser.add_argument("--hidden-size", type=int, default=256) + parser.add_argument("--num-layers", type=int, default=2) + parser.add_argument("--num-heads", type=int, default=4) + parser.add_argument("--num-kv-heads", type=int, default=2) + parser.add_argument("--ffn-hidden-size", type=int, default=512) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--atol-loss", type=float, default=1e-5, + help="Absolute tolerance for loss comparison") + parser.add_argument("--rtol-grad", type=float, default=1e-3, + help="Relative tolerance for grad norm comparison") + args = parser.parse_args() + + Utils.initialize_model_parallel(tensor_model_parallel_size=1) + model_parallel_cuda_manual_seed(args.seed) + + config = TransformerConfig( + num_layers=args.num_layers, + hidden_size=args.hidden_size, + ffn_hidden_size=args.ffn_hidden_size, + num_attention_heads=args.num_heads, + num_query_groups=args.num_kv_heads, + bf16=True, + params_dtype=torch.bfloat16, + pipeline_dtype=torch.bfloat16, + hidden_dropout=0.0, + attention_dropout=0.0, + tensor_model_parallel_size=1, + sequence_parallel=False, + ) + + model = _build_model(config, args.vocab_size, args.seq_len) + + all_passed = True + + # ---------------------------------------------------------------- + # Test 1: equal-length correctness (BSHD vs THD) + # ---------------------------------------------------------------- + _print_banner("Test 1: Equal-length BSHD vs THD correctness") + m = run_equal_length_test( + model=model, + batch_size=args.batch_size, + seq_len=args.seq_len, + vocab_size=args.vocab_size, + seed=args.seed, + atol_loss=args.atol_loss, + rtol_grad=args.rtol_grad, + ) + print(f" Config: B={args.batch_size}, S={args.seq_len}, " + f"H={args.hidden_size}, L={args.num_layers}, " + f"heads={args.num_heads}/{args.num_kv_heads}") + print(f" BSHD loss: {m['bshd_loss']:.8f}") + print(f" THD loss: {m['thd_loss']:.8f}") + print(f" Loss abs diff: {m['loss_diff']:.2e}") + print(f" BSHD grad norm: {m['bshd_grad_norm']:.8f}") + print(f" THD grad norm: {m['thd_grad_norm']:.8f}") + print(f" Grad norm rel diff: {m['grad_rel']:.2e}") + print(f" Per-token max diff: {m['token_max_diff']:.2e}") + print(f" Per-token mean diff: {m['token_mean_diff']:.2e}") + print(f" Loss match: {'PASS' if m['loss_ok'] else 'FAIL'} " + f"(atol={args.atol_loss})") + print(f" Grad norm match: {'PASS' if m['grad_ok'] else 'FAIL'} " + f"(rtol={args.rtol_grad})") + if not (m["loss_ok"] and m["grad_ok"]): + all_passed = False + + # ---------------------------------------------------------------- + # Test 2: variable-length smoke test (THD only) + # ---------------------------------------------------------------- + _print_banner("Test 2: Variable-length THD smoke test") + v = run_variable_length_smoke_test(model, args.vocab_size, args.seed) + print(f" Seq lengths: [128, 96, 112, 80]") + print(f" Total packed tokens: {v['total_tokens']}") + print(f" Loss: {v['loss']:.8f}") + print(f" Grad norm: {v['grad_norm']:.8f}") + print(f" Loss finite: {'PASS' if v['loss_finite'] else 'FAIL'}") + print(f" Grad finite: {'PASS' if v['grad_finite'] else 'FAIL'}") + print(f" Grad nonzero: {'PASS' if v['grad_nonzero'] else 'FAIL'}") + if not v["passed"]: + all_passed = False + + # ---------------------------------------------------------------- + # Summary + # ---------------------------------------------------------------- + _print_banner("Summary") + if all_passed: + print(" ALL TESTS PASSED") + else: + print(" SOME TESTS FAILED") + print(f"{'='*60}\n") + + Utils.destroy_model_parallel() + + if not all_passed: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal_dev/tests/test_thd_e2e.py b/examples/multimodal_dev/tests/test_thd_e2e.py new file mode 100644 index 00000000000..eef320b0a37 --- /dev/null +++ b/examples/multimodal_dev/tests/test_thd_e2e.py @@ -0,0 +1,263 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Tests for THD (packed sequence) support in multimodal_dev.""" + +import pytest +import torch + +import sys +import os + +# Ensure the repo root is on the path so that the examples package is importable. +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +from examples.multimodal_dev.forward_step import ( + _build_packed_seq_params, + _pack_batch, +) + + +# =================================================================== +# Unit tests — CPU, no distributed / GPU required +# =================================================================== + + +class TestBuildPackedSeqParams: + """Tests for ``_build_packed_seq_params``.""" + + def test_basic(self): + params = _build_packed_seq_params( + torch.tensor([5, 3, 7], dtype=torch.int32), device="cpu", + ) + assert params.qkv_format == "thd" + assert params.cu_seqlens_q.tolist() == [0, 5, 8, 15] + assert params.cu_seqlens_kv.tolist() == [0, 5, 8, 15] + assert params.max_seqlen_q == 7 + assert params.max_seqlen_kv == 7 + assert params.total_tokens == 15 + # Padded cu_seqlens mirror actual cu_seqlens (no per-subseq padding). + assert params.cu_seqlens_q_padded is not None + assert params.cu_seqlens_q_padded.tolist() == [0, 5, 8, 15] + assert params.cu_seqlens_kv_padded is not None + assert params.cu_seqlens_kv_padded.tolist() == [0, 5, 8, 15] + + def test_equal_lengths(self): + params = _build_packed_seq_params( + torch.tensor([4, 4, 4], dtype=torch.int32), device="cpu", + ) + assert params.cu_seqlens_q.tolist() == [0, 4, 8, 12] + assert params.max_seqlen_q == 4 + assert params.total_tokens == 12 + + def test_single_sample(self): + params = _build_packed_seq_params( + torch.tensor([10], dtype=torch.int32), device="cpu", + ) + assert params.cu_seqlens_q.tolist() == [0, 10] + assert params.max_seqlen_q == 10 + assert params.total_tokens == 10 + + def test_dtype_is_int32(self): + params = _build_packed_seq_params( + torch.tensor([3, 5], dtype=torch.int32), device="cpu", + ) + assert params.cu_seqlens_q.dtype == torch.int32 + + def test_seq_idx_computed(self): + """Verify __post_init__ computes seq_idx for Mamba compatibility.""" + params = _build_packed_seq_params( + torch.tensor([3, 2], dtype=torch.int32), device="cpu", + ) + # seq_idx should be [0,0,0,1,1] (shape [1, 5]) + assert params.seq_idx is not None + assert params.seq_idx.shape == (1, 5) + assert params.seq_idx[0].tolist() == [0, 0, 0, 1, 1] + + +class TestPackBatch: + """Tests for ``_pack_batch``.""" + + def test_no_padding(self): + """All tokens valid — T == B*S.""" + B, S = 2, 8 + batch = { + "input_ids": torch.arange(B * S).reshape(B, S), + "labels": torch.arange(B * S).reshape(B, S) + 100, + "loss_mask": torch.ones(B, S), + "position_ids": torch.arange(S).unsqueeze(0).unsqueeze(0).expand( + 3, B, S, + ).clone(), + } + packed = _pack_batch(batch) + + T = B * S + assert packed["input_ids"].shape == (1, T) + assert packed["labels"].shape == (1, T) + assert packed["loss_mask"].shape == (1, T) + assert packed["position_ids"].shape == (3, 1, T) + assert packed["attention_mask"] is None + assert packed["packed_seq_params"].total_tokens == T + + def test_with_padding(self): + """attention_mask strips padding — T < B*S.""" + B, S = 2, 8 + batch = { + "input_ids": torch.arange(B * S).reshape(B, S), + "labels": torch.arange(B * S).reshape(B, S), + "loss_mask": torch.ones(B, S), + "position_ids": torch.zeros(3, B, S, dtype=torch.long), + "attention_mask": torch.tensor([ + [1, 1, 1, 1, 1, 0, 0, 0], # 5 valid + [1, 1, 1, 0, 0, 0, 0, 0], # 3 valid + ]), + } + packed = _pack_batch(batch) + + T = 5 + 3 + assert packed["input_ids"].shape == (1, T) + assert packed["labels"].shape == (1, T) + assert packed["packed_seq_params"].cu_seqlens_q.tolist() == [0, 5, 8] + assert packed["packed_seq_params"].max_seqlen_q == 5 + assert packed["packed_seq_params"].total_tokens == T + + def test_token_order_preserved(self): + """Packed tokens appear in sample-0 then sample-1 order.""" + batch = { + "input_ids": torch.tensor([[10, 20, 30], [40, 50, 60]]), + "position_ids": torch.zeros(3, 2, 3, dtype=torch.long), + } + packed = _pack_batch(batch) + assert packed["input_ids"].tolist() == [[10, 20, 30, 40, 50, 60]] + + def test_position_ids_mrope(self): + """MRoPE [3, B, S] → [3, 1, T] with correct concatenation.""" + B, S = 2, 4 + pos = torch.zeros(3, B, S, dtype=torch.long) + # Sample 0: positions [0,1,2,3] on all 3 dims + # Sample 1: positions [10,11,12,13] on all 3 dims + for d in range(3): + pos[d, 0] = torch.tensor([0, 1, 2, 3]) + pos[d, 1] = torch.tensor([10, 11, 12, 13]) + + batch = { + "input_ids": torch.zeros(B, S, dtype=torch.long), + "position_ids": pos, + } + packed = _pack_batch(batch) + + assert packed["position_ids"].shape == (3, 1, 8) + # Each dim: [0,1,2,3,10,11,12,13] + for d in range(3): + assert packed["position_ids"][d, 0].tolist() == [ + 0, 1, 2, 3, 10, 11, 12, 13, + ] + + def test_standard_position_ids(self): + """Standard [B, S] position_ids → [1, T].""" + B, S = 2, 3 + batch = { + "input_ids": torch.zeros(B, S, dtype=torch.long), + "position_ids": torch.tensor([[0, 1, 2], [0, 1, 2]]), + } + packed = _pack_batch(batch) + assert packed["position_ids"].shape == (1, 6) + assert packed["position_ids"].tolist() == [[0, 1, 2, 0, 1, 2]] + + def test_no_labels_no_loss_mask(self): + """Gracefully handle missing labels and loss_mask.""" + batch = { + "input_ids": torch.tensor([[1, 2], [3, 4]]), + "position_ids": torch.zeros(3, 2, 2, dtype=torch.long), + } + packed = _pack_batch(batch) + assert packed["input_ids"].shape == (1, 4) + assert packed.get("labels") is None + assert packed.get("loss_mask") is None + + def test_data_provided_cu_seqlens_takes_priority(self): + """Prefer dataset-provided cu_seqlens over attention_mask.""" + B, S = 2, 6 + batch = { + "input_ids": torch.tensor([ + [1, 2, 3, 4, 5, 6], + [7, 8, 9, 10, 11, 12], + ]), + "labels": torch.tensor([ + [101, 102, 103, 104, 105, 106], + [107, 108, 109, 110, 111, 112], + ]), + "loss_mask": torch.ones(B, S), + "position_ids": torch.zeros(3, B, S, dtype=torch.long), + # Deliberately inconsistent with cu_seqlens. + "attention_mask": torch.ones(B, S), + # Per-sample cu_seqlens: sample0 len=4, sample1 len=2. + "cu_seqlens": torch.tensor([[0, 4], [0, 2]], dtype=torch.int32), + "cu_seqlens_padded": torch.tensor([[0, 4], [0, 2]], dtype=torch.int32), + "max_seqlen": torch.tensor([4, 2], dtype=torch.int32), + } + packed = _pack_batch(batch) + + assert packed["input_ids"].shape == (1, 6) + assert packed["input_ids"].tolist() == [[1, 2, 3, 4, 7, 8]] + assert packed["labels"].tolist() == [[101, 102, 103, 104, 107, 108]] + assert packed["packed_seq_params"].cu_seqlens_q.tolist() == [0, 4, 6] + assert packed["packed_seq_params"].max_seqlen_q == 4 + # Raw data-side metadata is no longer needed after packing. + assert "cu_seqlens" not in packed + assert "cu_seqlens_padded" not in packed + assert "max_seqlen" not in packed + + def test_multisegment_cu_seqlens_rejected(self): + """[B, N] cu_seqlens with N>2 is explicitly unsupported.""" + batch = { + "input_ids": torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]), + "position_ids": torch.zeros(3, 2, 4, dtype=torch.long), + "cu_seqlens": torch.tensor( + [[0, 2, 4], [0, 1, 4]], dtype=torch.int32, + ), + } + with pytest.raises(ValueError, match="expected \\[B, 2\\]"): + _pack_batch(batch) + + def test_variable_length_with_attention_mask(self): + """Variable-length sequences: attention_mask strips padding.""" + B, S = 3, 10 + seq_lengths = [8, 5, 10] # valid tokens per sample + batch = { + "input_ids": torch.arange(B * S).reshape(B, S), + "labels": torch.arange(B * S).reshape(B, S) + 100, + "loss_mask": torch.ones(B, S), + "position_ids": torch.zeros(3, B, S, dtype=torch.long), + "attention_mask": torch.zeros(B, S), + } + for i, sl in enumerate(seq_lengths): + batch["attention_mask"][i, :sl] = 1.0 + + packed = _pack_batch(batch) + + T = sum(seq_lengths) # 23 + assert packed["input_ids"].shape == (1, T) + assert packed["labels"].shape == (1, T) + assert packed["loss_mask"].shape == (1, T) + assert packed["position_ids"].shape == (3, 1, T) + assert packed["packed_seq_params"].cu_seqlens_q.tolist() == [ + 0, 8, 13, 23, + ] + assert packed["packed_seq_params"].total_tokens == T + + # Verify correct tokens were kept (first sl tokens per sample). + ids = packed["input_ids"][0].tolist() + expected = list(range(0, 8)) + list(range(10, 15)) + list(range(20, 30)) + assert ids == expected + + def test_packed_seq_params_cumsum_matches_loop(self): + """Verify torch.cumsum produces the same cu_seqlens as a Python loop.""" + lengths = torch.tensor([17, 31, 11, 42, 1], dtype=torch.int32) + params = _build_packed_seq_params(lengths, device="cpu") + # Manual cumulative sum + expected = [0] + for sl in lengths.tolist(): + expected.append(expected[-1] + int(sl)) + assert params.cu_seqlens_q.tolist() == expected diff --git a/megatron/training/datasets/data_samplers.py b/megatron/training/datasets/data_samplers.py index 430bd8b85da..8b14b975aba 100644 --- a/megatron/training/datasets/data_samplers.py +++ b/megatron/training/datasets/data_samplers.py @@ -106,7 +106,7 @@ def close_nvidia_fds(): maybe_worker_init_fn = worker_init_fn if args.num_workers > 0 else None # Torch dataloader. - if args.dynamic_context_parallel: + if args.dynamic_context_parallel or getattr(args, "use_vanilla_collate_fn", False): extra_kwargs = {"collate_fn": lambda x: x} else: extra_kwargs = {}