diff --git a/src/olmo_core/data/multimodal/__init__.py b/src/olmo_core/data/multimodal/__init__.py index 741a9ae23..b7440f7a9 100644 --- a/src/olmo_core/data/multimodal/__init__.py +++ b/src/olmo_core/data/multimodal/__init__.py @@ -34,3 +34,20 @@ "MultiCropPreprocessor", "MultiCropPreprocessorConfig", ] + +# Training pipeline (added in this PR) +from .collator import MultimodalCollator, MultimodalCollatorConfig +from .data_loader import MultimodalDataLoader, MultimodalDataLoaderConfig +from .pixmo_cap import PixmoCapDataset, PixmoCapDatasetConfig +from .preprocessor import MultimodalPreprocessor, MultimodalPreprocessorConfig + +__all__ += [ + "MultimodalCollator", + "MultimodalCollatorConfig", + "MultimodalDataLoader", + "MultimodalDataLoaderConfig", + "PixmoCapDataset", + "PixmoCapDatasetConfig", + "MultimodalPreprocessor", + "MultimodalPreprocessorConfig", +] diff --git a/src/olmo_core/data/multimodal/collator.py b/src/olmo_core/data/multimodal/collator.py new file mode 100644 index 000000000..33e840542 --- /dev/null +++ b/src/olmo_core/data/multimodal/collator.py @@ -0,0 +1,139 @@ +""" +Multimodal collator. + +Stacks a list of per-example dicts (produced by +:class:`~olmo_core.data.multimodal.preprocessor.MultimodalPreprocessor`) into +the batched tensor dict that :class:`~olmo_core.nn.vision.MultimodalTransformer.forward` +consumes. + +The tricky part is variable image-layout across the batch (different +``n_crops`` and ``n_pooled`` per example, which happens in +``overlap_and_resize`` mode). The model's splice asserts that the total +number of ```` tokens in ``input_ids`` equals the total number of +pooled image features, so we: + +1. Pad ``pooled_patches_idx`` to the max ``n_pooled`` in the batch with all-``-1`` + rows. These rows produce zero features (since every patch index is masked + to zero in the connector). +2. Append ``(max_n_pooled - n_pooled)`` dummy ```` tokens to each + example's ``input_tokens`` (with ``loss_mask = 0``) so the count matches. + These dummy positions receive ``+= 0`` from the splice, so they're + functionally invisible to training. +3. Pad ``images`` to the max ``n_crops`` with zeros. These extra crops aren't + referenced by any non-``-1`` pool index, so the vision tower processes them + but the connector never gathers them. + +The base tokenizer's ``pad_token_id`` is used to pad ``input_tokens`` and +``loss_masks`` to the max sequence length in the batch. +""" + +from dataclasses import dataclass +from typing import Dict, List + +import numpy as np +import torch + +from ...config import Config +from .tokens import MultimodalTokenizerConfig + +__all__ = [ + "MultimodalCollatorConfig", + "MultimodalCollator", +] + + +@dataclass +class MultimodalCollatorConfig(Config): + """Configuration for :class:`MultimodalCollator`.""" + + tokenizer: MultimodalTokenizerConfig + """Tokenizer providing ``image_patch_id`` and ``base.pad_token_id``.""" + + pad_to_multiple_of: int = 1 + """If > 1, pad the sequence length up to the next multiple. Useful for + hardware alignment (e.g. flash-attn) but optional.""" + + def build(self) -> "MultimodalCollator": + return MultimodalCollator(self) + + +class MultimodalCollator: + """Stack per-example dicts into a single batched dict of ``torch.Tensor``.""" + + def __init__(self, cfg: MultimodalCollatorConfig): + self.cfg = cfg + self.image_patch_id = cfg.tokenizer.image_patch_id + self.pad_id = cfg.tokenizer.base.pad_token_id + + def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, torch.Tensor]: + """Batch ``examples`` into tensors. + + :param examples: List of dicts with keys ``input_tokens``, ``loss_masks``, + ``images``, ``pooled_patches_idx`` (the output of + :class:`MultimodalPreprocessor`). + :returns: Dict with keys ``input_ids``, ``loss_masks``, ``images``, + ``pooled_patches_idx`` — all ``torch.Tensor``. + """ + if not examples: + raise ValueError("collator received an empty batch") + + B = len(examples) + cfg = self.cfg + + # Determine output shapes. + max_n_pooled = max(e["pooled_patches_idx"].shape[0] for e in examples) + pool_size = examples[0]["pooled_patches_idx"].shape[1] + max_n_crops = max(e["images"].shape[0] for e in examples) + # Per-crop dimensions come from any example's `images` shape — they're + # fixed by the preprocessor config and present even when n_crops=0. + n_patches_per_crop = examples[0]["images"].shape[1] + patch_dim = examples[0]["images"].shape[2] + for e in examples[1:]: + if e["images"].shape[1:] != (n_patches_per_crop, patch_dim): + raise ValueError( + "All examples must share patch shape; got " + f"{e['images'].shape[1:]} vs ({n_patches_per_crop}, {patch_dim})" + ) + + # Sequence length includes any dummy tokens we'll append. + raw_lengths = [ + e["input_tokens"].shape[0] + (max_n_pooled - e["pooled_patches_idx"].shape[0]) + for e in examples + ] + max_seq = max(raw_lengths) + if cfg.pad_to_multiple_of > 1: + mult = cfg.pad_to_multiple_of + max_seq = mult * ((max_seq + mult - 1) // mult) + + # Allocate batched tensors. + input_ids = torch.full((B, max_seq), self.pad_id, dtype=torch.long) + loss_masks = torch.zeros((B, max_seq), dtype=torch.float32) + images = torch.zeros((B, max_n_crops, n_patches_per_crop, patch_dim), dtype=torch.float32) + pooled_patches_idx = torch.full((B, max_n_pooled, pool_size), -1, dtype=torch.long) + + # Fill. + for i, e in enumerate(examples): + tokens = e["input_tokens"] + n_pooled = e["pooled_patches_idx"].shape[0] + n_dummy = max_n_pooled - n_pooled + + seq_len = tokens.shape[0] + input_ids[i, :seq_len] = torch.from_numpy(tokens) + loss_masks[i, :seq_len] = torch.from_numpy(e["loss_masks"]) + # Dummy tokens after the real content, before pad tokens. + if n_dummy > 0: + input_ids[i, seq_len : seq_len + n_dummy] = self.image_patch_id + # loss_mask stays 0 for these positions. + + n_crops = e["images"].shape[0] + if n_crops > 0: + images[i, :n_crops, : e["images"].shape[1]] = torch.from_numpy(e["images"]) + if n_pooled > 0: + pooled_patches_idx[i, :n_pooled] = torch.from_numpy(e["pooled_patches_idx"]) + + return { + "input_ids": input_ids, + "loss_masks": loss_masks, + "images": images, + "pooled_patches_idx": pooled_patches_idx, + } diff --git a/src/olmo_core/data/multimodal/data_loader.py b/src/olmo_core/data/multimodal/data_loader.py new file mode 100644 index 000000000..35e119e45 --- /dev/null +++ b/src/olmo_core/data/multimodal/data_loader.py @@ -0,0 +1,292 @@ +""" +Multimodal data loader. + +Wraps a source iterable of ``(prompt, response, image)`` triples (e.g. +:class:`~olmo_core.data.multimodal.pixmo_cap.PixmoCapDataset`) into a +:class:`~olmo_core.data.data_loader.DataLoaderBase` that the +:class:`~olmo_core.train.Trainer` can consume directly. + +The pipeline at iteration time is: + + source → MultimodalPreprocessor (per example) → collator (per batch) + +Rank sharding and PyTorch ``DataLoader`` worker sharding are handled here +so each rank/worker sees a disjoint slice of the source. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, Iterable, Iterator, Optional + +import numpy as np +import torch +import torch.utils.data + +from ...config import Config +from ..data_loader import DataLoaderBase +from .collator import MultimodalCollator, MultimodalCollatorConfig +from .preprocessor import MultimodalPreprocessor, MultimodalPreprocessorConfig + +__all__ = [ + "MultimodalDataLoaderConfig", + "MultimodalDataLoader", +] + + +class _MultimodalIterableWrapper(torch.utils.data.IterableDataset): + """A :class:`torch.utils.data.IterableDataset` that runs the preprocessor + inline and shards the source across (rank × worker). + + The source is iterated **once per epoch**; this wrapper handles the + rank/worker stride. The source's ``set_epoch(epoch)`` method (if present) + is called when :meth:`set_epoch` is called externally. + """ + + def __init__( + self, + source, + preprocessor: MultimodalPreprocessor, + dp_rank: int, + dp_world_size: int, + ): + super().__init__() + self.source = source + self.preprocessor = preprocessor + self.dp_rank = dp_rank + self.dp_world_size = dp_world_size + self._epoch = 0 + + def set_epoch(self, epoch: int): + self._epoch = epoch + # If the source supports epoch-based shuffling, propagate. + if hasattr(self.source, "set_epoch"): + self.source.set_epoch(epoch) + + def __iter__(self) -> Iterator[Dict[str, np.ndarray]]: + worker_info = torch.utils.data.get_worker_info() + worker_id = worker_info.id if worker_info is not None else 0 + num_workers = worker_info.num_workers if worker_info is not None else 1 + + # Global shard ID across DP ranks AND DataLoader workers. + shard_id = self.dp_rank * num_workers + worker_id + total_shards = self.dp_world_size * num_workers + + for i, ex in enumerate(self.source): + if i % total_shards != shard_id: + continue + prompt, response, image = ex + yield self.preprocessor(prompt, response, image) + + +@dataclass +class MultimodalDataLoaderConfig(Config): + """Configuration for :class:`MultimodalDataLoader`.""" + + preprocessor: MultimodalPreprocessorConfig = field(default_factory=MultimodalPreprocessorConfig) + """Preprocessor settings (tokenizer + multicrop + sequence length).""" + + collator: Optional[MultimodalCollatorConfig] = None + """Collator settings. If ``None``, a default collator is built from + ``preprocessor.tokenizer``.""" + + global_batch_size: int = 8 + """Total examples per batch across all DP ranks.""" + + num_workers: int = 0 + """Number of background workers per rank for the wrapped torch DataLoader. + Set to 0 for in-process iteration (simplest, fine for tests). Multi-worker + requires the source dataset and preprocessor to be picklable.""" + + prefetch_factor: int = 2 + """Number of batches each worker prefetches. Ignored when num_workers=0.""" + + seed: int = 0 + """Base seed; the epoch number is added to it for per-epoch shuffling.""" + + work_dir: str = "/tmp/olmo-core-mm-data" + """Local working directory; required by :class:`DataLoaderBase`.""" + + def build( + self, + source, + hf_tokenizer, + *, + dp_world_size: int = 1, + dp_rank: int = 0, + fs_local_rank: Optional[int] = None, + ) -> "MultimodalDataLoader": + """Instantiate a :class:`MultimodalDataLoader`. + + :param source: An iterable of ``(prompt, response, image)`` triples. + Should be re-iterable across epochs and ideally have a + ``set_epoch(epoch)`` method for per-epoch shuffling. + :param hf_tokenizer: A HuggingFace tokenizer with the image special + tokens added (see :meth:`MultimodalTokenizerConfig.load_hf_tokenizer`). + :param dp_world_size: Data-parallel world size. + :param dp_rank: This process's data-parallel rank. + :param fs_local_rank: Filesystem-local rank (defaults to the global one). + """ + preprocessor = MultimodalPreprocessor(self.preprocessor, hf_tokenizer) + coll_cfg = self.collator or MultimodalCollatorConfig(tokenizer=self.preprocessor.tokenizer) + collator = MultimodalCollator(coll_cfg) + return MultimodalDataLoader( + source=source, + preprocessor=preprocessor, + collator=collator, + global_batch_size=self.global_batch_size, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + seed=self.seed, + work_dir=self.work_dir, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + fs_local_rank=fs_local_rank, + ) + + +class MultimodalDataLoader(DataLoaderBase): + """Per-example multimodal data loader. + + ``global_batch_size`` is in **examples** (not tokens), so each rank yields + batches of size ``global_batch_size // dp_world_size``. Each batch is a + dict of tensors with the keys + :class:`~olmo_core.nn.vision.MultimodalTransformer.forward` expects. + """ + + def __init__( + self, + *, + source, + preprocessor: MultimodalPreprocessor, + collator: MultimodalCollator, + global_batch_size: int, + num_workers: int = 0, + prefetch_factor: int = 2, + seed: int = 0, + work_dir: str = "/tmp/olmo-core-mm-data", + dp_world_size: int = 1, + dp_rank: int = 0, + fs_local_rank: Optional[int] = None, + ): + super().__init__( + work_dir=work_dir, + global_batch_size=global_batch_size, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + fs_local_rank=fs_local_rank, + ) + if global_batch_size % dp_world_size != 0: + raise ValueError( + f"global_batch_size ({global_batch_size}) must be divisible by " + f"dp_world_size ({dp_world_size})" + ) + self.source = source + self.preprocessor = preprocessor + self.collator = collator + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor + self.seed = seed + + # ------------------------------------------------------------------ + # DataLoaderBase API + # ------------------------------------------------------------------ + + @property + def total_batches(self) -> Optional[int]: + """If the source has a known length, return the number of batches in an epoch.""" + n = getattr(self.source, "__len__", None) + if n is None: + return None + try: + n_examples = len(self.source) + except TypeError: + return None + # Each rank gets a fraction of the total. + return n_examples // self.global_batch_size + + def state_dict(self) -> Dict[str, Any]: + return { + "epoch": self._epoch, + "batches_processed": self.batches_processed, + "seed": self.seed, + } + + def load_state_dict(self, state_dict: Dict[str, Any]): + self._epoch = state_dict["epoch"] + self.batches_processed = state_dict["batches_processed"] + self.seed = state_dict["seed"] + + def reshuffle(self, epoch: Optional[int] = None, **kwargs): + del kwargs + if epoch is None: + epoch = 1 if self._epoch is None else self._epoch + 1 + self._epoch = epoch + # Propagate per-epoch shuffle seed to the source if supported. + if hasattr(self.source, "set_epoch"): + self.source.set_epoch(self.seed + epoch) + + def _iter_batches(self) -> Iterable[Dict[str, Any]]: + # Skip building the iterator if the epoch is already exhausted. + if self.total_batches is not None and self.batches_processed >= self.total_batches: + return + + wrapper = _MultimodalIterableWrapper( + self.source, self.preprocessor, self.dp_rank, self.dp_world_size + ) + wrapper.set_epoch(self._epoch or 0) + + torch_loader = torch.utils.data.DataLoader( + wrapper, + batch_size=self.rank_batch_size, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None, + persistent_workers=False, + collate_fn=self._collate, + ) + + # Resume from where state_dict left off in this epoch. + skip = self.batches_processed + for i, batch in enumerate(torch_loader): + if i < skip: + continue + # Drop the trailing incomplete batch (the collator received < rank_batch_size). + if batch["input_ids"].shape[0] < self.rank_batch_size: + break + yield batch + + def _collate(self, examples) -> Dict[str, torch.Tensor]: + return self.collator(examples) + + def get_mock_batch(self) -> Dict[str, torch.Tensor]: + """Return a small fake batch for the trainer's dry-run pass. + + Uses the configured patch / pool dimensions from the preprocessor so + the shapes match what the model expects. + """ + cfg = self.preprocessor.cfg + mc_cfg = cfg.multicrop + patch_dim = 3 * mc_cfg.image_preprocessor.patch_size**2 + n_patches_per_crop = ( + mc_cfg.base_image_input_size[0] // mc_cfg.image_preprocessor.patch_size + ) * (mc_cfg.base_image_input_size[1] // mc_cfg.image_preprocessor.patch_size) + pool_size = mc_cfg.pool_h * mc_cfg.pool_w + n_pooled = n_patches_per_crop // pool_size + + B = self.rank_batch_size + # Build a sequence of tokens followed by a few text tokens. + patch_id = cfg.tokenizer.image_patch_id + pad_id = cfg.tokenizer.base.pad_token_id + S = n_pooled + 4 + input_ids = torch.full((B, S), pad_id, dtype=torch.long) + input_ids[:, :n_pooled] = patch_id + return { + "input_ids": input_ids, + "loss_masks": torch.ones(B, S, dtype=torch.float32), + "images": torch.zeros(B, 1, n_patches_per_crop, patch_dim, dtype=torch.float32), + "pooled_patches_idx": ( + torch.arange(n_patches_per_crop) + .view(n_pooled, pool_size) + .unsqueeze(0) + .expand(B, -1, -1) + .contiguous() + ), + } diff --git a/src/olmo_core/data/multimodal/pixmo_cap.py b/src/olmo_core/data/multimodal/pixmo_cap.py new file mode 100644 index 000000000..c5b85048a --- /dev/null +++ b/src/olmo_core/data/multimodal/pixmo_cap.py @@ -0,0 +1,167 @@ +""" +PixMo-Cap dataset adapter. + +PixMo-Cap (``allenai/pixmo-cap``) is the caption-pretraining set used in Molmo +Stage 1, so it's the natural first data source for a from-scratch VLM run. It +yields ``(prompt, response, image)`` triples compatible with +:class:`~olmo_core.data.multimodal.preprocessor.MultimodalPreprocessor`. + +Two loading modes, selected by :attr:`PixmoCapDatasetConfig.source`: + +- ``"hub"`` (default): stream ``allenai/pixmo-cap`` directly from the + HuggingFace Hub and download each image from its ``image_url`` at iteration + time. No local copy required, but needs network access; rows whose URL is + dead or unreachable are silently skipped. The dataset on the Hub stores only + image URLs (columns ``image_url`` / ``caption`` / ``transcripts``), not image + bytes, which is why a download step is required. + +- ``"local"``: read a pre-downloaded HF ``Dataset`` saved under + ``$MOLMO_DATA_DIR/torch_datasets/pixmo_datasets/cap`` (the layout Molmo2 + itself uses), where the ``image`` column already holds local file paths. This + is the path used for real large-scale training on AI2 infrastructure — no + network at iteration time. +""" + +import io +import logging +import os +from dataclasses import dataclass +from os.path import join +from typing import Iterator, Optional, Tuple + +from ...config import Config + +__all__ = [ + "PixmoCapDatasetConfig", + "PixmoCapDataset", + "DEFAULT_MOLMO_DATA_DIR", + "HF_DATASET_ID", +] + +log = logging.getLogger(__name__) + +#: Default location of the shared Molmo data tree on AI2 infrastructure. +DEFAULT_MOLMO_DATA_DIR: str = "/weka/oe-training-default/mm-olmo" + +#: HuggingFace Hub dataset id for the ``"hub"`` source. +HF_DATASET_ID: str = "allenai/pixmo-cap" + + +@dataclass +class PixmoCapDatasetConfig(Config): + """Configuration for :class:`PixmoCapDataset`.""" + + source: str = "hub" + """Where to load from: ``"hub"`` (stream from the HuggingFace Hub and + download images per-URL) or ``"local"`` (read a pre-downloaded on-disk HF + ``Dataset``).""" + + split: str = "train" + """HuggingFace split name. The Hub dataset only has ``train``.""" + + prompt: str = "Describe this image in detail." + """Constant prompt prefix paired with each caption.""" + + limit: Optional[int] = None + """If set, stop after this many successfully-loaded examples (handy for + tests / dry runs). Counts examples *yielded*, not rows scanned, so in + ``"hub"`` mode skipped (dead-URL) rows don't count against it.""" + + shuffle: bool = False + """If ``True``, shuffle before iteration.""" + + shuffle_seed: int = 0 + """Seed for the shuffle. Ignored when :attr:`shuffle` is ``False``.""" + + # ---- hub-mode options ---- + hub_dataset_id: str = HF_DATASET_ID + """(``"hub"`` only) HuggingFace Hub dataset id.""" + + image_timeout: float = 10.0 + """(``"hub"`` only) Per-image HTTP download timeout, in seconds.""" + + shuffle_buffer_size: int = 10_000 + """(``"hub"`` only) Buffer size for streaming shuffle.""" + + # ---- local-mode options ---- + data_dir: Optional[str] = None + """(``"local"`` only) Path to the saved HF ``Dataset`` directory. When + ``None``, defaults to + ``${MOLMO_DATA_DIR or DEFAULT_MOLMO_DATA_DIR}/torch_datasets/pixmo_datasets/cap``. + """ + + def resolve_data_dir(self) -> str: + """(``"local"`` only) Return the data dir, falling back to the + env-var-driven default.""" + if self.data_dir is not None: + return self.data_dir + root = os.environ.get("MOLMO_DATA_DIR", DEFAULT_MOLMO_DATA_DIR) + return join(root, "torch_datasets", "pixmo_datasets", "cap") + + def build(self) -> "PixmoCapDataset": + return PixmoCapDataset(self) + + +class PixmoCapDataset: + """Iterable over PixMo-Cap ``(prompt, response, image)`` triples. + + See the module docstring for the two loading modes. + """ + + def __init__(self, cfg: PixmoCapDatasetConfig): + if cfg.source not in ("hub", "local"): + raise ValueError(f"source must be 'hub' or 'local', got {cfg.source!r}") + self.cfg = cfg + + def __iter__(self) -> Iterator[Tuple[str, str, object]]: + if self.cfg.source == "hub": + return self._iter_hub() + return self._iter_local() + + def _iter_hub(self) -> Iterator[Tuple[str, str, object]]: + import requests + from datasets import load_dataset + from PIL import Image + + cfg = self.cfg + ds = load_dataset(cfg.hub_dataset_id, split=cfg.split, streaming=True) + if cfg.shuffle: + ds = ds.shuffle(seed=cfg.shuffle_seed, buffer_size=cfg.shuffle_buffer_size) + + n = 0 + for row in ds: + if cfg.limit is not None and n >= cfg.limit: + return + url = row["image_url"] + try: + resp = requests.get(url, timeout=cfg.image_timeout) + resp.raise_for_status() + image = Image.open(io.BytesIO(resp.content)).convert("RGB") + except Exception: # noqa: BLE001 + # Dead / blocked / corrupt URL — skip without aborting iteration. + continue + n += 1 + yield cfg.prompt, row["caption"], image + + def _iter_local(self) -> Iterator[Tuple[str, str, object]]: + from datasets import load_from_disk + from PIL import Image + + cfg = self.cfg + path = cfg.resolve_data_dir() + ds = load_from_disk(path)[cfg.split] + if cfg.shuffle: + ds = ds.shuffle(seed=cfg.shuffle_seed) + + n = 0 + for row in ds: + if cfg.limit is not None and n >= cfg.limit: + return + image_path = row["image"] + try: + image = Image.open(image_path).convert("RGB") + except Exception: # noqa: BLE001 + # Skip rows whose local file is missing / corrupt; don't abort. + continue + n += 1 + yield cfg.prompt, row["caption"], image diff --git a/src/olmo_core/data/multimodal/preprocessor.py b/src/olmo_core/data/multimodal/preprocessor.py new file mode 100644 index 000000000..f0087f71c --- /dev/null +++ b/src/olmo_core/data/multimodal/preprocessor.py @@ -0,0 +1,169 @@ +""" +Top-level multimodal preprocessor. + +Combines :class:`MultiCropPreprocessor` with a text tokenizer to produce the +full per-example dict expected by +:class:`~olmo_core.nn.vision.MultimodalTransformer.forward`. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Tuple + +import numpy as np + +from ...config import Config +from .multicrop import MultiCropPreprocessor, MultiCropPreprocessorConfig +from .tokens import MultimodalTokenizerConfig + +__all__ = [ + "MultimodalPreprocessorConfig", + "MultimodalPreprocessor", +] + + +@dataclass +class MultimodalPreprocessorConfig(Config): + """Configuration for :class:`MultimodalPreprocessor`.""" + + tokenizer: MultimodalTokenizerConfig = field(default_factory=MultimodalTokenizerConfig.dolma2) + """Tokenizer config defining image special token IDs.""" + + multicrop: MultiCropPreprocessorConfig = field(default_factory=MultiCropPreprocessorConfig) + """Multi-crop / image preprocessing settings.""" + + max_sequence_length: int = 2048 + """Hard upper bound on the produced token sequence length. Sequences are + truncated from the right (response tail dropped).""" + + prompt_template: str = "{image}\n{prompt}" + """Template for the prompt portion. Must contain a single ``{image}`` placeholder + where image tokens are inserted, and may contain ``{prompt}``. When *image* + is ``None`` the ``{image}`` placeholder is dropped (the surrounding newline + is removed as well to avoid a dangling line break).""" + + response_template: str = " {response}" + """Template for the response portion. May contain ``{response}``. The + leading space matches Molmo's convention where the response starts after + a delimiter.""" + + add_eos: bool = True + """If ``True``, append the base tokenizer's ``eos_token_id`` to the response + (with ``loss_mask=1``) so the model learns to stop.""" + + +class MultimodalPreprocessor: + """Build per-example training dicts from raw ``(prompt, response, image)``. + + The HuggingFace tokenizer is supplied at construction time so callers can + decide when to pay the (potentially network-bound) load cost — typically + once per dataset. + """ + + def __init__(self, cfg: MultimodalPreprocessorConfig, hf_tokenizer): + self.cfg = cfg + self.tokenizer = hf_tokenizer # PreTrainedTokenizerBase + self.multicrop: MultiCropPreprocessor = cfg.multicrop.build(cfg.tokenizer) + + # Cache the base EOS id; we'll use it for loss-masking when present. + self._eos_id: Optional[int] = cfg.tokenizer.base.eos_token_id + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _encode(self, text: str) -> np.ndarray: + if not text: + return np.empty((0,), dtype=np.int64) + ids = self.tokenizer.encode(text, add_special_tokens=False) + return np.asarray(ids, dtype=np.int64) + + def _split_prompt_template(self) -> Tuple[str, str]: + """Split ``prompt_template`` at the ``{image}`` placeholder.""" + tmpl = self.cfg.prompt_template + if "{image}" not in tmpl: + raise ValueError( + "prompt_template must contain a single '{image}' placeholder; got: " f"{tmpl!r}" + ) + pre, post = tmpl.split("{image}", 1) + return pre, post + + # ------------------------------------------------------------------ + # Main entry point + # ------------------------------------------------------------------ + + def __call__( + self, + prompt: str, + response: str, + image: Optional[Any] = None, + ) -> Dict[str, np.ndarray]: + """Build the per-example dict. + + :param prompt: User-side text. + :param response: Assistant-side text (the supervised target). + :param image: A PIL image or HWC ``np.ndarray`` (uint8 or float32). + Pass ``None`` for text-only examples — the ``{image}`` placeholder + is stripped from the prompt template. + :returns: Dict with keys ``input_tokens``, ``loss_masks``, ``images``, + ``pooled_patches_idx``. For text-only examples the two image fields + are zero-sized arrays so a collator can still stack them. + """ + cfg = self.cfg + + # Multi-crop preprocessing (if any image). + if image is not None: + mc_out = self.multicrop(image) + image_tokens = mc_out.image_tokens + images = mc_out.images + pooled_patches_idx = mc_out.pooled_patches_idx + else: + image_tokens = np.empty((0,), dtype=np.int64) + # Zero-sized arrays so the collator can still concat across the batch. + patch_size = cfg.multicrop.image_preprocessor.patch_size + patch_dim = 3 * patch_size * patch_size + pool_size = cfg.multicrop.pool_h * cfg.multicrop.pool_w + images = np.zeros((0, 0, patch_dim), dtype=np.float32) + pooled_patches_idx = np.zeros((0, pool_size), dtype=np.int64) + + # Compose the text portion. We split the prompt template at {image} so + # the image-token sequence (which is already token IDs) is spliced in + # between two tokenized text chunks. + pre_tmpl, post_tmpl = self._split_prompt_template() + try: + pre_text = pre_tmpl.format(prompt=prompt) + post_text = post_tmpl.format(prompt=prompt) + except KeyError as e: + raise ValueError(f"prompt_template references unknown field: {e}") from e + if image is None: + # Strip the dangling newline that the {image} placeholder used to anchor. + post_text = post_text.lstrip("\n") + + resp_text = cfg.response_template.format(response=response) + + pre_ids = self._encode(pre_text) + post_ids = self._encode(post_text) + resp_ids = self._encode(resp_text) + + if cfg.add_eos and self._eos_id is not None: + resp_ids = np.concatenate([resp_ids, np.asarray([self._eos_id], dtype=np.int64)]) + + # Concatenate everything. + input_tokens = np.concatenate([pre_ids, image_tokens, post_ids, resp_ids]) + loss_masks = np.concatenate( + [ + np.zeros(len(pre_ids) + len(image_tokens) + len(post_ids), dtype=np.float32), + np.ones(len(resp_ids), dtype=np.float32), + ] + ) + + # Truncate from the right; the prompt + image tokens are load-bearing. + if len(input_tokens) > cfg.max_sequence_length: + input_tokens = input_tokens[: cfg.max_sequence_length] + loss_masks = loss_masks[: cfg.max_sequence_length] + + return { + "input_tokens": input_tokens, + "loss_masks": loss_masks, + "images": images, + "pooled_patches_idx": pooled_patches_idx, + } diff --git a/src/olmo_core/nn/vision/multimodal.py b/src/olmo_core/nn/vision/multimodal.py index 149905f05..8dd8e8d2a 100644 --- a/src/olmo_core/nn/vision/multimodal.py +++ b/src/olmo_core/nn/vision/multimodal.py @@ -3,10 +3,15 @@ import torch import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard from ...config import Config from ..lm_head import LMOutputWithLoss -from ..transformer.config import TransformerConfig +from ..transformer.config import ( + TransformerConfig, + TransformerDataParallelWrappingStrategy, +) from .config import VisionBackboneConfig from .connector import VisionConnectorConfig @@ -100,6 +105,164 @@ def __init__(self, cfg: MultimodalTransformerConfig, init_device: str = "cpu"): self.vision = cfg.vision.build(init_device=init_device) self.connector = cfg.connector.build(init_device=init_device) + # ------------------------------------------------------------------ + # TrainModule interface (delegated to the wrapped LM) + # ------------------------------------------------------------------ + + def post_batch(self, dry_run: bool = False) -> None: + """Hook called by the train module after each batch's backward pass.""" + self.lm.post_batch(dry_run=dry_run) + + def post_optim_step(self) -> None: + """Hook called by the train module after each optimizer step.""" + self.lm.post_optim_step() + + def reset_auxiliary_metrics(self) -> None: + """Reset the LM's auxiliary metrics (MoE load-balancing etc.).""" + self.lm.reset_auxiliary_metrics() + + def compute_auxiliary_metrics(self, reset: bool = True): + """Return the LM's auxiliary metrics; vision/connector contribute none.""" + return self.lm.compute_auxiliary_metrics(reset=reset) + + def num_flops_per_token(self, seq_len: int) -> int: + """Approximate FLOPs/token. Counts the LM only — vision FLOPs depend on + ``n_crops × n_patches`` which the LM seq_len doesn't capture.""" + return self.lm.num_flops_per_token(seq_len) + + @property + def num_params(self) -> int: + """Total parameter count across LM, vision, and connector.""" + return sum(p.numel() for p in self.parameters()) + + @property + def num_trainable_params(self) -> int: + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + @property + def num_non_embedding_params(self) -> int: + """Total params excluding the LM token embedding table. + + Used by speed-monitor callbacks. Vision patch-embedding and pooling + attention are kept (they're not "vocabulary embeddings").""" + return self.num_params - self.lm.embeddings.weight.numel() + + # ------------------------------------------------------------------ + # Distributed: materialize weights + apply FSDP/DDP + # ------------------------------------------------------------------ + + @torch.no_grad() + def init_weights( + self, + *, + max_seq_len: Optional[int] = None, + max_local_microbatch_size: Optional[int] = None, + device: Optional[torch.device] = None, + world_mesh: Optional[DeviceMesh] = None, + model_part_idx: int = 0, + ) -> torch.Generator: + """Materialise parameters on ``device`` and initialise them. + + Matches :meth:`~olmo_core.nn.transformer.Transformer.init_weights`'s + signature so the same trainer-level call site works for both. Each + sub-module (LM, vision, connector) is materialised separately to + avoid double-``to_empty`` issues under FSDP — each ``to_empty`` on a + FSDP-wrapped param is a collective, and overlapping/redundant calls + across ranks deadlock. + """ + target_device = device or next(iter(self.parameters())).device + + # LM handles its own materialisation + InitMethod + RoPE cache. + gen = self.lm.init_weights( + max_seq_len=max_seq_len, + max_local_microbatch_size=max_local_microbatch_size, + device=target_device, + world_mesh=world_mesh, + model_part_idx=model_part_idx, + ) + + # Materialise vision and connector separately, then init their params. + self.vision.to_empty(device=target_device) + self.vision.reset_parameters() + self.connector.to_empty(device=target_device) + self.connector.reset_parameters() + + return gen + + def apply_fsdp( + self, + dp_mesh: Optional[DeviceMesh] = None, + param_dtype: Optional[torch.dtype] = None, + reduce_dtype: torch.dtype = torch.float32, + pp_enabled: bool = False, + prefetch_factor: int = 0, + wrapping_strategy: TransformerDataParallelWrappingStrategy = TransformerDataParallelWrappingStrategy.full, + ) -> None: + """Apply FSDP2 (``fully_shard``) to vision, connector, LM, and the + composite model. + + The LM is wrapped via its own :meth:`Transformer.apply_fsdp`, which + handles block-level sharding, the embedding table, and the LM head. + Vision blocks are individually sharded; the connector is sharded as + a single unit; the whole :class:`MultimodalTransformer` gets a final + outer ``fully_shard`` so cross-submodule unsharding remains cheap. + + :param dp_mesh: The data-parallel device mesh. + :param param_dtype: Mixed-precision parameter dtype. + :param reduce_dtype: Gradient reduction dtype. + :param pp_enabled: Whether pipeline parallelism is also enabled. + Currently unsupported for multimodal models; passed through to + the LM only. + :param prefetch_factor: Forwarded to LM block prefetching. + :param wrapping_strategy: Forwarded to LM FSDP wrapping. + """ + # 1. Delegate LM wrapping to Transformer.apply_fsdp. + self.lm.apply_fsdp( + dp_mesh=dp_mesh, + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + pp_enabled=pp_enabled, + prefetch_factor=prefetch_factor, + wrapping_strategy=wrapping_strategy, + ) + + mp_policy = MixedPrecisionPolicy( + param_dtype=param_dtype or self.lm.dtype, reduce_dtype=reduce_dtype + ) + fsdp_kwargs = dict(mesh=dp_mesh, mp_policy=mp_policy) + reshard_after_forward = not pp_enabled + + # 2. Each vision block gets its own FSDP unit. + for block in self.vision.blocks: + fully_shard(block, reshard_after_forward=reshard_after_forward, **fsdp_kwargs) + + # 3. Connector is small enough to wrap as a single unit. + fully_shard(self.connector, reshard_after_forward=reshard_after_forward, **fsdp_kwargs) + + # 4. Top-level wrap so the composite all-gather happens in one shot. + fully_shard(self, reshard_after_forward=reshard_after_forward, **fsdp_kwargs) + + # Match Transformer's behaviour: don't unshard the (large) text + # embedding table during backward, since it isn't needed there. + if isinstance(self.lm.embeddings, FSDPModule): + self.lm.embeddings.set_unshard_in_backward(False) + + def apply_ddp( + self, + dp_mesh: Optional[DeviceMesh] = None, + param_dtype: Optional[torch.dtype] = None, + ) -> None: + """Apply DDP to the composite model. + + Cheap to implement because DDP doesn't need per-submodule wrapping — + we just replicate the whole :class:`MultimodalTransformer`. + """ + from torch.distributed._composable.replicate import replicate + + if param_dtype is not None and param_dtype != self.lm.dtype: + self.to(dtype=param_dtype) + replicate(self, device_mesh=dp_mesh, bucket_cap_mb=100) + def _encode_images( self, images: torch.Tensor, @@ -163,6 +326,16 @@ def forward( self.lm.embeddings is not None ), "MultimodalTransformer requires the LM to have an embedding table" + # The base Transformer's _prepare_inputs would move input_ids to device, + # but we lookup embeddings before delegating to it. Move proactively. + emb_device = self.lm.embeddings.weight.device + if input_ids.device != emb_device: + input_ids = input_ids.to(emb_device) + if images is not None and images.device != emb_device: + images = images.to(emb_device) + if pooled_patches_idx is not None and pooled_patches_idx.device != emb_device: + pooled_patches_idx = pooled_patches_idx.to(emb_device) + # Compute LM token embeddings with any configured scale / norm. h = self.lm.embeddings(input_ids) if self.lm.embed_scale is not None: diff --git a/src/olmo_core/train/train_module/multimodal/__init__.py b/src/olmo_core/train/train_module/multimodal/__init__.py new file mode 100644 index 000000000..2bf67b7e7 --- /dev/null +++ b/src/olmo_core/train/train_module/multimodal/__init__.py @@ -0,0 +1,13 @@ +""" +Train module for multimodal vision-language transformers. +""" + +from .train_module import ( + MultimodalTransformerTrainModule, + MultimodalTransformerTrainModuleConfig, +) + +__all__ = [ + "MultimodalTransformerTrainModule", + "MultimodalTransformerTrainModuleConfig", +] diff --git a/src/olmo_core/train/train_module/multimodal/train_module.py b/src/olmo_core/train/train_module/multimodal/train_module.py new file mode 100644 index 000000000..fad291a64 --- /dev/null +++ b/src/olmo_core/train/train_module/multimodal/train_module.py @@ -0,0 +1,281 @@ +""" +Train module for :class:`~olmo_core.nn.vision.MultimodalTransformer`. + +Reuses :class:`TransformerTrainModule`'s machinery (microbatching, optimizer +step, scheduler, autocast, state dict, loss accounting) and adds: + +- FSDP / DDP support tailored to the composite model — each vision block, + the connector, and the LM (via its own ``apply_fsdp``) become FSDP units, + with a final outer wrap on the whole :class:`MultimodalTransformer`. +- The batch carries ``loss_masks: (B, S) float32``. Before computing + autoregressive labels we convert it to the boolean ``label_mask`` the base + class understands — that's what implements response-only loss. +- The batch also carries ``images`` and ``pooled_patches_idx`` which flow + through to ``MultimodalTransformer.forward`` as ``**model_kwargs``. + +TP / CP / PP / EP are still out of scope and silently ignored with a warning. +""" + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, cast + +import torch +import torch.distributed.checkpoint.state_dict as dist_cp_sd +from torch.optim import Optimizer + +from olmo_core.config import DType +from olmo_core.distributed.parallel import build_world_mesh, get_dp_model_mesh +from olmo_core.distributed.utils import is_distributed +from olmo_core.exceptions import OLMoConfigurationError +from olmo_core.nn.vision import MultimodalTransformer +from olmo_core.optim import OptimConfig +from olmo_core.optim.scheduler import Scheduler +from olmo_core.train.train_module.train_module import TrainModule +from olmo_core.train.train_module.transformer.config import ( + TransformerDataParallelConfig, + TransformerTrainModuleConfig, +) +from olmo_core.train.train_module.transformer.train_module import TransformerTrainModule +from olmo_core.utils import get_default_device + +if TYPE_CHECKING: + pass + +log = logging.getLogger(__name__) + +__all__ = [ + "MultimodalTransformerTrainModule", + "MultimodalTransformerTrainModuleConfig", +] + + +class MultimodalTransformerTrainModule(TransformerTrainModule): + """A :class:`TrainModule` for :class:`MultimodalTransformer`. + + Reuses the full :class:`TransformerTrainModule` pipeline (microbatching, + state dict, autocast, scheduler) and runs the multimodal-specific + parallelization in ``__init__``. The + :func:`~olmo_core.train.train_module.transformer.common.parallelize_model` + helper isn't called because it's :class:`Transformer`-specific; we call + :meth:`MultimodalTransformer.apply_fsdp` / :meth:`apply_ddp` directly. + """ + + def __init__( + self, + model: MultimodalTransformer, + optim: OptimConfig, + rank_microbatch_size: int, + max_sequence_length: int, + dp_config: Optional[TransformerDataParallelConfig] = None, + z_loss_multiplier: Optional[float] = None, + autocast_precision: Optional[torch.dtype] = None, + max_grad_norm: Optional[float] = None, + scheduler: Optional[Scheduler] = None, + device: Optional[torch.device] = None, + state_dict_save_opts: Optional[dist_cp_sd.StateDictOptions] = None, + state_dict_load_opts: Optional[dist_cp_sd.StateDictOptions] = None, + load_key_mapping: Optional[Dict[str, str]] = None, + label_ignore_index: int = -100, + ): + # Skip TransformerTrainModule.__init__ (which calls parallelize_model + # on a Transformer) and run the base TrainModule init manually. + TrainModule.__init__(self) + + if rank_microbatch_size % max_sequence_length != 0: + raise OLMoConfigurationError( + f"'rank_microbatch_size' ({rank_microbatch_size:,d} tokens) must be divisible by " + f"'max_sequence_length' ({max_sequence_length:,d} tokens)" + ) + + self.device = device or get_default_device() + + # Build the world mesh if distributed and a DP config was given. + self.world_mesh = None + if dp_config is not None: + if not is_distributed(): + raise OLMoConfigurationError("dp_config is only valid in a distributed setting") + self.world_mesh = build_world_mesh(dp=dp_config, device_type=self.device.type) + + # Parallelize, then materialize and initialize weights. + if dp_config is None: + self.model: MultimodalTransformer = model.to(self.device) # type: ignore[assignment] + else: + assert self.world_mesh is not None + dp_mesh = get_dp_model_mesh(self.world_mesh) + param_dtype = ( + dp_config.param_dtype.as_pt() if dp_config.param_dtype is not None else None + ) + reduce_dtype = dp_config.reduce_dtype.as_pt() + from olmo_core.distributed.parallel import DataParallelType + + if dp_config.name in (DataParallelType.fsdp, DataParallelType.hsdp): + model.apply_fsdp( + dp_mesh=dp_mesh, + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + wrapping_strategy=dp_config.wrapping_strategy, + prefetch_factor=dp_config.prefetch_factor, + ) + elif dp_config.name == DataParallelType.ddp: + model.apply_ddp(dp_mesh=dp_mesh, param_dtype=param_dtype) + else: + raise NotImplementedError(dp_config.name) + self.model = model + self.model.init_weights( + max_seq_len=max_sequence_length, + max_local_microbatch_size=rank_microbatch_size, + device=self.device, + world_mesh=self.world_mesh, + ) + + self._model_mode = None + + self._dp_config = dp_config + self._cp_config = None + self._tp_config = None + self._ep_config = None + self.label_ignore_index = label_ignore_index + self.z_loss_multiplier = z_loss_multiplier + self.rank_microbatch_size = rank_microbatch_size + self.max_sequence_length = max_sequence_length + self.autocast_precision = autocast_precision + self.max_grad_norm = max_grad_norm + self.scheduler = scheduler + self.state_dict_save_opts = state_dict_save_opts or dist_cp_sd.StateDictOptions( + flatten_optimizer_state_dict=True, cpu_offload=True + ) + self.state_dict_load_opts = state_dict_load_opts or dist_cp_sd.StateDictOptions( + flatten_optimizer_state_dict=True, strict=True + ) + self.load_key_mapping = load_key_mapping + + log.info("Building optimizer for multimodal model...") + self.optim: Optimizer = optim.build(self.model, strict=True) + + # ------------------------------------------------------------------ + # Batch handling: loss_masks → label_mask + image kwargs + # ------------------------------------------------------------------ + + def _convert_loss_masks(self, batch: Dict[str, Any]) -> None: + """In-place: convert ``loss_masks`` to ``label_mask`` so the base + class's :func:`get_labels` masks non-response positions.""" + if "loss_masks" in batch and "label_mask" not in batch: + loss_masks = batch.pop("loss_masks") + if not isinstance(loss_masks, torch.Tensor): + loss_masks = torch.as_tensor(loss_masks) + batch["label_mask"] = loss_masks.bool() + + def _move_image_kwargs_to_device(self, batch: Dict[str, Any]) -> None: + """Move multimodal-specific tensors to ``self.device``. + + The base ``Transformer._prepare_inputs`` moves ``input_ids`` and + ``labels`` but doesn't know about our ``images`` / + ``pooled_patches_idx``. Moving them here keeps the train module + agnostic to whatever device the data loader produced batches on.""" + from olmo_core.utils import move_to_device + + for key in ("images", "pooled_patches_idx"): + if key in batch and isinstance(batch[key], torch.Tensor): + batch[key] = move_to_device(batch[key], self.device) + + def pre_train(self): + """Validate sizing in examples, not tokens. + + ``TransformerTrainModule.pre_train`` checks + ``global_batch_size % (rank_microbatch_size * dp_ws) == 0``, assuming + both are in tokens (text loaders' convention). Our multimodal data + loader counts ``global_batch_size`` in **examples**, while + ``rank_microbatch_size`` is still in tokens. We translate to a + per-rank-instance check instead. + """ + from olmo_core.distributed.utils import get_world_size + + dp_ws = get_world_size(self.trainer.dp_process_group) + rank_examples = self.trainer.global_batch_size // dp_ws + microbatch_instances = max(1, self.rank_microbatch_size // self.max_sequence_length) + if rank_examples % microbatch_instances != 0: + raise OLMoConfigurationError( + f"global batch size ({self.trainer.global_batch_size} examples) divided by " + f"DP world size ({dp_ws}) gives {rank_examples} examples per rank, which is " + f"not divisible by microbatch instances ({microbatch_instances} = " + f"{self.rank_microbatch_size} tokens / {self.max_sequence_length} max_seq_len)" + ) + + def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): + self._convert_loss_masks(batch) + self._move_image_kwargs_to_device(batch) + super().train_batch(batch, dry_run=dry_run) + + def eval_batch(self, batch, labels=None): + self._convert_loss_masks(batch) + self._move_image_kwargs_to_device(batch) + return super().eval_batch(batch, labels) + + def _prepare_batch( + self, + batch: Dict[str, Any], + labels: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Dict[str, Any]]: + """Pop ``label_mask`` from the kwargs that flow through to the model. + + ``MultimodalTransformer.forward`` doesn't accept ``label_mask`` and + we've already used it to compute ``labels`` upstream. + """ + input_ids, labels, model_kwargs = super()._prepare_batch(batch, labels) + model_kwargs.pop("label_mask", None) + return input_ids, labels, model_kwargs + + +@dataclass +class MultimodalTransformerTrainModuleConfig(TransformerTrainModuleConfig): + """Configuration for :class:`MultimodalTransformerTrainModule`. + + Inherits every field from :class:`TransformerTrainModuleConfig` for API + symmetry. :attr:`dp_config` is honored (FSDP / HSDP / DDP); other + parallelism configs (``tp_config``, ``cp_config``, ``pp_config``, + ``ep_config``) are silently ignored with a warning. ``compile_model`` + is a no-op for now. + """ + + def build( # type: ignore[override] + self, + model: MultimodalTransformer, + device: Optional[torch.device] = None, + ) -> MultimodalTransformerTrainModule: + """Instantiate the train module.""" + if self.pp_config is not None: + raise NotImplementedError( + "Pipeline parallelism is not yet supported for MultimodalTransformer" + ) + if any(cfg is not None for cfg in (self.tp_config, self.cp_config, self.ep_config)): + log.warning( + "TP/CP/EP configs are not yet honored for MultimodalTransformer; " + "proceeding with DP only." + ) + if self.compile_model: + log.warning( + "compile_model is not yet supported for MultimodalTransformer; " + "proceeding without torch.compile()." + ) + kwargs = self.as_dict(exclude_none=True, recurse=False) + # Strip fields that don't apply to the multimodal train module. + for unsupported in ( + "compile_model", + "float8_config", + "tp_config", + "cp_config", + "ep_config", + "ac_config", + "pp_config", + ): + kwargs.pop(unsupported, None) + # dp_config goes through as a nested object, not flattened by as_dict. + kwargs["dp_config"] = self.dp_config + if (autocast_precision := kwargs.pop("autocast_precision", None)) is not None: + kwargs["autocast_precision"] = cast(DType, autocast_precision).as_pt() + if (state_dict_save_opts := kwargs.pop("state_dict_save_opts", None)) is not None: + kwargs["state_dict_save_opts"] = dist_cp_sd.StateDictOptions(**state_dict_save_opts) + if (state_dict_load_opts := kwargs.pop("state_dict_load_opts", None)) is not None: + kwargs["state_dict_load_opts"] = dist_cp_sd.StateDictOptions(**state_dict_load_opts) + return MultimodalTransformerTrainModule(model=model, device=device, **kwargs) diff --git a/src/test/data/multimodal/collator_test.py b/src/test/data/multimodal/collator_test.py new file mode 100644 index 000000000..1446c3f26 --- /dev/null +++ b/src/test/data/multimodal/collator_test.py @@ -0,0 +1,251 @@ +"""Tests for the multimodal collator.""" + +import numpy as np +import torch + +from olmo_core.data.multimodal.collator import ( + MultimodalCollator, + MultimodalCollatorConfig, +) +from olmo_core.data.multimodal.tokens import MultimodalTokenizerConfig + + +def _example( + seq_len: int, + n_crops: int, + n_patches_per_crop: int, + patch_dim: int, + n_pooled: int, + pool_size: int, + rng: np.random.Generator, +) -> dict: + return { + "input_tokens": rng.integers(0, 100, size=seq_len, dtype=np.int64), + "loss_masks": rng.random(seq_len, dtype=np.float32), + "images": rng.random((n_crops, n_patches_per_crop, patch_dim), dtype=np.float32), + "pooled_patches_idx": rng.integers( + 0, n_crops * n_patches_per_crop, size=(n_pooled, pool_size), dtype=np.int64 + ), + } + + +def _collator() -> MultimodalCollator: + tok = MultimodalTokenizerConfig.dolma2() + return MultimodalCollatorConfig(tokenizer=tok).build() + + +# --------------------------------------------------------------------------- +# Uniform batches +# --------------------------------------------------------------------------- + + +def test_uniform_batch_basic_shapes(): + rng = np.random.default_rng(0) + coll = _collator() + batch = [_example(16, 1, 4, 12, 1, 4, rng) for _ in range(3)] + out = coll(batch) + assert out["input_ids"].shape == (3, 16) + assert out["loss_masks"].shape == (3, 16) + assert out["images"].shape == (3, 1, 4, 12) + assert out["pooled_patches_idx"].shape == (3, 1, 4) + + +def test_dtypes_match_model_contract(): + rng = np.random.default_rng(0) + coll = _collator() + batch = [_example(8, 1, 4, 12, 1, 4, rng)] + out = coll(batch) + assert out["input_ids"].dtype == torch.long + assert out["loss_masks"].dtype == torch.float32 + assert out["images"].dtype == torch.float32 + assert out["pooled_patches_idx"].dtype == torch.long + + +# --------------------------------------------------------------------------- +# Variable sequence length — text padding +# --------------------------------------------------------------------------- + + +def test_variable_seq_len_pads_to_max(): + rng = np.random.default_rng(0) + coll = _collator() + batch = [ + _example( + seq_len=8, + n_crops=1, + n_patches_per_crop=4, + patch_dim=12, + n_pooled=1, + pool_size=4, + rng=rng, + ), + _example( + seq_len=12, + n_crops=1, + n_patches_per_crop=4, + patch_dim=12, + n_pooled=1, + pool_size=4, + rng=rng, + ), + ] + out = coll(batch) + assert out["input_ids"].shape == (2, 12) + # First example's trailing positions should be pad_id. + tok = MultimodalTokenizerConfig.dolma2() + assert (out["input_ids"][0, 8:] == tok.base.pad_token_id).all() + # Loss mask is zero at padded positions. + assert (out["loss_masks"][0, 8:] == 0.0).all() + + +def test_pad_to_multiple_of(): + rng = np.random.default_rng(0) + tok = MultimodalTokenizerConfig.dolma2() + coll = MultimodalCollatorConfig(tokenizer=tok, pad_to_multiple_of=8).build() + batch = [ + _example( + seq_len=5, + n_crops=1, + n_patches_per_crop=4, + patch_dim=12, + n_pooled=1, + pool_size=4, + rng=rng, + ) + ] + out = coll(batch) + assert out["input_ids"].shape[1] % 8 == 0 + assert out["input_ids"].shape[1] >= 5 + + +# --------------------------------------------------------------------------- +# Variable image layout — crop padding + dummy patch tokens +# --------------------------------------------------------------------------- + + +def test_variable_n_crops_pads_to_max(): + rng = np.random.default_rng(0) + coll = _collator() + batch = [ + _example( + seq_len=16, + n_crops=1, + n_patches_per_crop=4, + patch_dim=12, + n_pooled=1, + pool_size=4, + rng=rng, + ), + _example( + seq_len=16, + n_crops=3, + n_patches_per_crop=4, + patch_dim=12, + n_pooled=1, + pool_size=4, + rng=rng, + ), + ] + out = coll(batch) + assert out["images"].shape == (2, 3, 4, 12) + # First example's last two crops are zero-padded. + assert (out["images"][0, 1:].sum() == 0).item() + + +def test_variable_n_pooled_pads_and_adds_dummy_patches(): + """Variable n_pooled: pad pooled_patches_idx with -1 rows AND append + matching dummy tokens to input_ids so the model contract + (count of tokens == B * max_n_pooled) holds.""" + rng = np.random.default_rng(0) + tok = MultimodalTokenizerConfig.dolma2() + coll = _collator() + batch = [ + _example( + seq_len=8, + n_crops=1, + n_patches_per_crop=4, + patch_dim=12, + n_pooled=1, + pool_size=4, + rng=rng, + ), + _example( + seq_len=8, + n_crops=1, + n_patches_per_crop=4, + patch_dim=12, + n_pooled=3, + pool_size=4, + rng=rng, + ), + ] + # Force input_tokens to contain the right number of real tokens. + batch[0]["input_tokens"] = np.array([tok.image_patch_id, 5, 6, 7, 0, 0, 0, 0], dtype=np.int64) + batch[1]["input_tokens"] = np.array([tok.image_patch_id] * 3 + [5, 6, 7, 0, 0], dtype=np.int64) + out = coll(batch) + # pooled_patches_idx padded to max=3. + assert out["pooled_patches_idx"].shape == (2, 3, 4) + # Padding rows are all -1. + assert (out["pooled_patches_idx"][0, 1:] == -1).all() + # Each example's count should equal max_n_pooled=3. + for i in range(2): + n_image_patch = (out["input_ids"][i] == tok.image_patch_id).sum().item() + assert n_image_patch == 3, f"example {i} has {n_image_patch} tokens, expected 3" + # Loss mask is zero at dummy positions for the example that received them. + # Example 0: original n_pooled=1; needs +2 dummy. The 2 extra + # tokens land at positions [seq_len, seq_len+1] = [8, 9] (no original loss + # information there because seq_len=8). Loss mask at those positions is 0. + assert out["loss_masks"][0, 8] == 0.0 + assert out["loss_masks"][0, 9] == 0.0 + + +def test_model_contract_total_image_patches_matches_pooled_size(): + """Most important: total across batch == B * max_n_pooled.""" + rng = np.random.default_rng(0) + tok = MultimodalTokenizerConfig.dolma2() + coll = _collator() + batch = [ + _example( + seq_len=8, + n_crops=1, + n_patches_per_crop=4, + patch_dim=12, + n_pooled=k, + pool_size=4, + rng=rng, + ) + for k in (1, 2, 3) + ] + # Each example needs exactly k real tokens. + for i, k in enumerate((1, 2, 3)): + toks = np.full(8, 99, dtype=np.int64) + toks[:k] = tok.image_patch_id + batch[i]["input_tokens"] = toks + out = coll(batch) + total_image_patch = (out["input_ids"] == tok.image_patch_id).sum().item() + B, max_n_pooled = out["pooled_patches_idx"].shape[:2] + assert total_image_patch == B * max_n_pooled + + +# --------------------------------------------------------------------------- +# Empty / text-only batch +# --------------------------------------------------------------------------- + + +def test_text_only_batch(): + """A batch with no images at all should still produce valid tensors.""" + rng = np.random.default_rng(0) + coll = _collator() + batch = [ + { + "input_tokens": rng.integers(0, 100, size=8, dtype=np.int64), + "loss_masks": np.ones(8, dtype=np.float32), + "images": np.zeros((0, 4, 12), dtype=np.float32), + "pooled_patches_idx": np.zeros((0, 4), dtype=np.int64), + } + for _ in range(2) + ] + out = coll(batch) + assert out["input_ids"].shape == (2, 8) + assert out["images"].shape == (2, 0, 4, 12) + assert out["pooled_patches_idx"].shape == (2, 0, 4) diff --git a/src/test/data/multimodal/data_loader_test.py b/src/test/data/multimodal/data_loader_test.py new file mode 100644 index 000000000..babd9f455 --- /dev/null +++ b/src/test/data/multimodal/data_loader_test.py @@ -0,0 +1,276 @@ +"""Tests for the multimodal data loader.""" + +import tempfile +from test.data.multimodal.synthetic_source import ( + SyntheticMultimodalDataset, + SyntheticMultimodalDatasetConfig, +) + +import pytest +import torch + +from olmo_core.data.multimodal import ( + CropMode, + ImagePreprocessorConfig, + MultiCropPreprocessorConfig, + MultimodalDataLoaderConfig, + MultimodalPreprocessorConfig, + MultimodalTokenizerConfig, +) + +transformers = pytest.importorskip("transformers") + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def hf_tokenizer(): + tok_cfg = MultimodalTokenizerConfig.dolma2() + try: + return tok_cfg.load_hf_tokenizer() + except Exception as e: # noqa: BLE001 + pytest.skip(f"Could not load dolma2 HF tokenizer: {e}") + + +def _loader_cfg(global_batch_size: int = 4, num_workers: int = 0) -> MultimodalDataLoaderConfig: + return MultimodalDataLoaderConfig( + preprocessor=MultimodalPreprocessorConfig( + max_sequence_length=128, + multicrop=MultiCropPreprocessorConfig( + base_image_input_size=(28, 28), + crop_mode=CropMode.resize, + pool_h=2, + pool_w=2, + image_preprocessor=ImagePreprocessorConfig(patch_size=14), + ), + ), + global_batch_size=global_batch_size, + num_workers=num_workers, + work_dir=tempfile.mkdtemp(prefix="mm-data-test-"), + ) + + +def _source(n_examples: int = 32, seed: int = 0) -> SyntheticMultimodalDataset: + return SyntheticMultimodalDataset( + SyntheticMultimodalDatasetConfig(n_examples=n_examples, image_size=(56, 56), seed=seed) + ) + + +# --------------------------------------------------------------------------- +# Basic batch shape / dtype +# --------------------------------------------------------------------------- + + +def test_yields_batches_of_expected_shape(hf_tokenizer): + cfg = _loader_cfg(global_batch_size=4) + loader = cfg.build(_source(n_examples=32), hf_tokenizer) + loader.reshuffle(epoch=1) + + batches = list(loader) + loader.reset() + assert len(batches) == 32 // 4 + b = batches[0] + assert b["input_ids"].shape[0] == 4 # rank_batch_size = 4 with dp_world_size=1 + assert b["images"].shape[0] == 4 + assert b["pooled_patches_idx"].shape[0] == 4 + + +def test_batch_dtypes_match_model_contract(hf_tokenizer): + cfg = _loader_cfg(global_batch_size=2) + loader = cfg.build(_source(n_examples=8), hf_tokenizer) + loader.reshuffle(epoch=1) + b = next(iter(loader)) + loader.reset() + assert b["input_ids"].dtype == torch.long + assert b["loss_masks"].dtype == torch.float32 + assert b["images"].dtype == torch.float32 + assert b["pooled_patches_idx"].dtype == torch.long + + +def test_total_batches_matches_iter_count(hf_tokenizer): + cfg = _loader_cfg(global_batch_size=4) + loader = cfg.build(_source(n_examples=20), hf_tokenizer) + loader.reshuffle(epoch=1) + expected = loader.total_batches + assert expected == 5 + actual = sum(1 for _ in loader) + assert actual == expected + + +# --------------------------------------------------------------------------- +# Rank sharding +# --------------------------------------------------------------------------- + + +def test_rank_sharding_yields_disjoint_examples(hf_tokenizer): + """Each rank sees a different stride of the source — no overlap.""" + cfg = _loader_cfg(global_batch_size=4) + rank0 = cfg.build(_source(n_examples=16), hf_tokenizer, dp_world_size=2, dp_rank=0) + rank1 = cfg.build(_source(n_examples=16), hf_tokenizer, dp_world_size=2, dp_rank=1) + rank0.reshuffle(epoch=1) + rank1.reshuffle(epoch=1) + + # rank_batch_size = global_batch_size / dp_world_size = 2. + assert rank0.rank_batch_size == 2 + assert rank1.rank_batch_size == 2 + + r0 = next(iter(rank0))["input_ids"] + r1 = next(iter(rank1))["input_ids"] + rank0.reset() + rank1.reset() + # The two ranks should see different examples → token sequences differ. + assert not torch.equal(r0, r1) + + +# --------------------------------------------------------------------------- +# Reshuffle / epochs +# --------------------------------------------------------------------------- + + +def test_reshuffle_changes_examples_across_epochs(hf_tokenizer): + cfg = _loader_cfg(global_batch_size=2) + loader = cfg.build(_source(n_examples=8), hf_tokenizer) + loader.reshuffle(epoch=1) + epoch1 = next(iter(loader))["input_ids"] + loader.reset() + + loader.reshuffle(epoch=2) + epoch2 = next(iter(loader))["input_ids"] + loader.reset() + assert not torch.equal(epoch1, epoch2) + + +def test_reshuffle_same_epoch_same_order(hf_tokenizer): + """Calling reshuffle with the same epoch twice produces the same iteration.""" + cfg = _loader_cfg(global_batch_size=2) + loader = cfg.build(_source(n_examples=8), hf_tokenizer) + loader.reshuffle(epoch=3) + a = next(iter(loader))["input_ids"] + loader.reset() + loader.reshuffle(epoch=3) + b = next(iter(loader))["input_ids"] + loader.reset() + assert torch.equal(a, b) + + +# --------------------------------------------------------------------------- +# state_dict round-trip +# --------------------------------------------------------------------------- + + +def test_state_dict_round_trip_resumes(hf_tokenizer): + cfg = _loader_cfg(global_batch_size=2) + loader = cfg.build(_source(n_examples=8), hf_tokenizer) + loader.reshuffle(epoch=1) + + # Consume two batches. + it = iter(loader) + next(it) + next(it) + state = loader.state_dict() + assert state["batches_processed"] == 2 + + # Build a fresh loader and load state. + loader2 = cfg.build(_source(n_examples=8), hf_tokenizer) + loader2.load_state_dict(state) + assert loader2.batches_processed == 2 + assert loader2.epoch == state["epoch"] + + +# --------------------------------------------------------------------------- +# Mock batch +# --------------------------------------------------------------------------- + + +def test_mock_batch_shape_matches_model_contract(hf_tokenizer): + cfg = _loader_cfg(global_batch_size=2) + loader = cfg.build(_source(n_examples=8), hf_tokenizer) + mock = loader.get_mock_batch() + # Same keys as a real batch. + real_keys = {"input_ids", "loss_masks", "images", "pooled_patches_idx"} + assert set(mock.keys()) == real_keys + assert mock["input_ids"].shape[0] == loader.rank_batch_size + # Total count must equal B * n_pooled (model contract). + patch_id = cfg.preprocessor.tokenizer.image_patch_id + n_image_patches = (mock["input_ids"] == patch_id).sum().item() + B, n_pooled, _ = mock["pooled_patches_idx"].shape + assert n_image_patches == B * n_pooled + + +# --------------------------------------------------------------------------- +# Multi-worker +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("num_workers", [0, 2]) +def test_works_with_workers(hf_tokenizer, num_workers): + cfg = _loader_cfg(global_batch_size=4, num_workers=num_workers) + loader = cfg.build(_source(n_examples=16), hf_tokenizer) + loader.reshuffle(epoch=1) + n = 0 + for batch in loader: + n += 1 + assert batch["input_ids"].shape[0] == 4 + loader.reset() + assert n == 4 # 16 examples / batch 4 + + +# --------------------------------------------------------------------------- +# End-to-end with the model +# --------------------------------------------------------------------------- + + +def test_end_to_end_with_model(hf_tokenizer): + """Drive the loader through MultimodalTransformer.forward.""" + from olmo_core.nn.transformer.config import TransformerConfig + from olmo_core.nn.vision import ( + MultimodalTransformer, + MultimodalTransformerConfig, + VisionBackboneConfig, + VisionBackboneType, + VisionConnectorConfig, + ) + + mm_tok = MultimodalTokenizerConfig.dolma2() + lm_cfg = TransformerConfig.olmo2_1M(vocab_size=mm_tok.padded_vocab_size(128)) + vis_cfg = VisionBackboneConfig( + name=VisionBackboneType.openai, + image_default_input_size=(28, 28), + image_patch_size=14, + image_emb_dim=32, + image_num_heads=2, + image_num_key_value_heads=2, + image_num_layers=2, + image_head_dim=16, + image_mlp_dim=64, + image_num_pos=5, + image_norm_eps=1e-5, + ) + conn_cfg = VisionConnectorConfig.from_vision_backbone( + vis_cfg, output_dim=lm_cfg.d_model, mlp_hidden_size=32 + ) + model = MultimodalTransformer( + MultimodalTransformerConfig( + lm=lm_cfg, + vision=vis_cfg, + connector=conn_cfg, + image_patch_token_id=mm_tok.image_patch_id, + ), + init_device="cpu", + ) + model.eval() + + loader = _loader_cfg(global_batch_size=2).build(_source(n_examples=4), hf_tokenizer) + loader.reshuffle(epoch=1) + for batch in loader: + with torch.inference_mode(): + out = model( + input_ids=batch["input_ids"], + images=batch["images"], + pooled_patches_idx=batch["pooled_patches_idx"], + ) + assert torch.isfinite(out).all() + loader.reset() diff --git a/src/test/data/multimodal/end_to_end_test.py b/src/test/data/multimodal/end_to_end_test.py new file mode 100644 index 000000000..0b58bf9ff --- /dev/null +++ b/src/test/data/multimodal/end_to_end_test.py @@ -0,0 +1,217 @@ +""" +End-to-end test: synthetic dataset → preprocessor → collator → model.forward. + +This is the load-bearing test for PR 4 — it verifies that the data pipeline +produces tensors with the exact shapes / dtypes / token-count invariants that +:class:`~olmo_core.nn.vision.MultimodalTransformer.forward` requires. +""" + +from test.data.multimodal.synthetic_source import ( + SyntheticMultimodalDataset, + SyntheticMultimodalDatasetConfig, +) + +import pytest +import torch + +from olmo_core.data.multimodal import ( + CropMode, + ImagePreprocessorConfig, + MultiCropPreprocessorConfig, + MultimodalCollator, + MultimodalCollatorConfig, + MultimodalPreprocessor, + MultimodalPreprocessorConfig, + MultimodalTokenizerConfig, +) +from olmo_core.nn.transformer.config import TransformerConfig +from olmo_core.nn.vision import ( + MultimodalTransformer, + MultimodalTransformerConfig, + VisionBackboneConfig, + VisionBackboneType, + VisionConnectorConfig, +) + +transformers = pytest.importorskip("transformers") + + +# --------------------------------------------------------------------------- +# Build a tiny end-to-end stack +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def hf_tokenizer(): + try: + return MultimodalTokenizerConfig.dolma2().load_hf_tokenizer() + except Exception as e: # noqa: BLE001 + pytest.skip(f"Could not load dolma2 HF tokenizer: {e}") + + +def _build_stack(crop_mode: CropMode): + """Returns (preprocessor, collator, model, multimodal_tokenizer).""" + mm_tok = MultimodalTokenizerConfig.dolma2() + + multicrop_cfg = MultiCropPreprocessorConfig( + base_image_input_size=(28, 28), + crop_mode=crop_mode, + max_crops=4, + overlap_margins=(0, 0), + pool_h=2, + pool_w=2, + image_preprocessor=ImagePreprocessorConfig(patch_size=14), + ) + prep_cfg = MultimodalPreprocessorConfig( + tokenizer=mm_tok, + multicrop=multicrop_cfg, + max_sequence_length=256, + ) + coll_cfg = MultimodalCollatorConfig(tokenizer=mm_tok) + + # Build the tiny multimodal model whose vocab is sized to fit our extended tokens. + lm_vocab = mm_tok.padded_vocab_size(128) + lm_cfg = TransformerConfig.olmo2_1M(vocab_size=lm_vocab) + vis_cfg = VisionBackboneConfig( + name=VisionBackboneType.openai, + image_default_input_size=(28, 28), + image_patch_size=14, + image_emb_dim=32, + image_num_heads=2, + image_num_key_value_heads=2, + image_num_layers=2, + image_head_dim=16, + image_mlp_dim=64, + image_num_pos=5, + image_norm_eps=1e-5, + ) + conn_cfg = VisionConnectorConfig.from_vision_backbone( + vis_cfg, output_dim=lm_cfg.d_model, mlp_hidden_size=32 + ) + model_cfg = MultimodalTransformerConfig( + lm=lm_cfg, + vision=vis_cfg, + connector=conn_cfg, + image_patch_token_id=mm_tok.image_patch_id, + ) + return prep_cfg, coll_cfg, model_cfg, mm_tok + + +# --------------------------------------------------------------------------- +# End-to-end shape / dtype check +# --------------------------------------------------------------------------- + + +class TestEndToEnd: + @pytest.fixture(autouse=True) + def _setup(self, hf_tokenizer): + self.tok = hf_tokenizer + + def _run_pipeline(self, crop_mode: CropMode, batch_size: int = 3, n_examples: int = 6): + prep_cfg, coll_cfg, model_cfg, mm_tok = _build_stack(crop_mode) + prep = MultimodalPreprocessor(prep_cfg, self.tok) + coll = MultimodalCollator(coll_cfg) + model = MultimodalTransformer(model_cfg, init_device="cpu") + model.eval() + + ds = SyntheticMultimodalDataset( + SyntheticMultimodalDatasetConfig(n_examples=n_examples, image_size=(56, 56), seed=0) + ) + + # Pull one batch. + examples = [] + for ex in ds: + prompt, response, image = ex + examples.append(prep(prompt, response, image)) + if len(examples) == batch_size: + break + + batch = coll(examples) + with torch.inference_mode(): + out = model( + input_ids=batch["input_ids"], + images=batch["images"], + pooled_patches_idx=batch["pooled_patches_idx"], + ) + return out, batch, model_cfg + + def test_resize_mode_forward(self): + out, batch, cfg = self._run_pipeline(CropMode.resize) + B, S = batch["input_ids"].shape + assert out.shape == (B, S, cfg.lm.vocab_size) + assert torch.isfinite(out).all() + + def test_overlap_mode_forward(self): + out, batch, cfg = self._run_pipeline(CropMode.overlap_and_resize) + B, S = batch["input_ids"].shape + assert out.shape == (B, S, cfg.lm.vocab_size) + assert torch.isfinite(out).all() + + def test_with_labels_returns_loss(self): + prep_cfg, coll_cfg, model_cfg, mm_tok = _build_stack(CropMode.resize) + prep = MultimodalPreprocessor(prep_cfg, self.tok) + coll = MultimodalCollator(coll_cfg) + model = MultimodalTransformer(model_cfg, init_device="cpu") + model.eval() + + ds = SyntheticMultimodalDataset(SyntheticMultimodalDatasetConfig(n_examples=4, seed=0)) + examples = [prep(p, r, i) for p, r, i in ds] + batch = coll(examples) + + # Use input_ids as labels (typical autoregressive setup). + from olmo_core.nn.lm_head import LMOutputWithLoss + + with torch.inference_mode(): + out = model( + input_ids=batch["input_ids"], + images=batch["images"], + pooled_patches_idx=batch["pooled_patches_idx"], + labels=batch["input_ids"], + ) + assert isinstance(out, LMOutputWithLoss) + assert out.loss.shape == () + assert torch.isfinite(out.loss) + + def test_text_only_batch_forward(self): + """When all examples are text-only, the pipeline still produces a working batch.""" + prep_cfg, coll_cfg, model_cfg, mm_tok = _build_stack(CropMode.resize) + prep = MultimodalPreprocessor(prep_cfg, self.tok) + coll = MultimodalCollator(coll_cfg) + model = MultimodalTransformer(model_cfg, init_device="cpu") + model.eval() + + ds = SyntheticMultimodalDataset( + SyntheticMultimodalDatasetConfig(n_examples=3, seed=0, text_only_fraction=1.0) + ) + examples = [prep(p, r, i) for p, r, i in ds] + batch = coll(examples) + with torch.inference_mode(): + # No image kwargs → text-only forward. + out = model(input_ids=batch["input_ids"]) + assert out.shape[0] == 3 + assert torch.isfinite(out).all() + + def test_variable_image_layout(self): + """In overlap mode with jittered image sizes, the collator must pad + n_pooled and the model splice must still satisfy its contract.""" + prep_cfg, coll_cfg, model_cfg, mm_tok = _build_stack(CropMode.overlap_and_resize) + prep = MultimodalPreprocessor(prep_cfg, self.tok) + coll = MultimodalCollator(coll_cfg) + model = MultimodalTransformer(model_cfg, init_device="cpu") + model.eval() + + ds = SyntheticMultimodalDataset( + SyntheticMultimodalDatasetConfig( + n_examples=4, image_size=(56, 56), image_size_jitter=28, seed=0 + ) + ) + examples = [prep(p, r, i) for p, r, i in ds] + batch = coll(examples) + with torch.inference_mode(): + out = model( + input_ids=batch["input_ids"], + images=batch["images"], + pooled_patches_idx=batch["pooled_patches_idx"], + ) + assert out.shape[0] == 4 + assert torch.isfinite(out).all() diff --git a/src/test/data/multimodal/pixmo_cap_test.py b/src/test/data/multimodal/pixmo_cap_test.py new file mode 100644 index 000000000..900b29895 --- /dev/null +++ b/src/test/data/multimodal/pixmo_cap_test.py @@ -0,0 +1,140 @@ +""" +Tests for the PixMo-Cap adapter. + +Two loading modes are covered: + +- ``local``: reads a pre-downloaded on-disk HF ``Dataset``. Auto-skips if the + local dataset path isn't present (e.g. running outside the AI2 cluster). +- ``hub``: streams ``allenai/pixmo-cap`` and downloads images per-URL. Tested + with ``datasets.load_dataset`` and ``requests.get`` monkeypatched, so it runs + deterministically without network access. +""" + +import io +import os + +import numpy as np +import pytest + +from olmo_core.data.multimodal.pixmo_cap import PixmoCapDataset, PixmoCapDatasetConfig + +# --------------------------------------------------------------------------- +# local mode (skipped without the on-disk dataset) +# --------------------------------------------------------------------------- + + +def _local_path_present() -> bool: + return os.path.isdir(PixmoCapDatasetConfig(source="local").resolve_data_dir()) + + +local_only = pytest.mark.skipif( + not _local_path_present(), + reason="local PixMo-Cap dataset directory not present", +) + + +@local_only +def test_local_yields_one_example(): + pytest.importorskip("datasets") + pytest.importorskip("PIL") + ds = PixmoCapDatasetConfig(source="local", limit=1).build() + prompt, caption, image = next(iter(ds)) + assert isinstance(prompt, str) and len(prompt) > 0 + assert isinstance(caption, str) and len(caption) > 0 + assert hasattr(image, "size") and len(image.size) == 2 + + +@local_only +def test_local_iteration_stops_at_limit(): + pytest.importorskip("datasets") + pytest.importorskip("PIL") + items = list(PixmoCapDatasetConfig(source="local", limit=3).build()) + assert len(items) == 3 + + +# --------------------------------------------------------------------------- +# hub mode (monkeypatched — no network) +# --------------------------------------------------------------------------- + + +def _png_bytes(color=(123, 222, 64)) -> bytes: + from PIL import Image + + buf = io.BytesIO() + Image.fromarray(np.full((8, 8, 3), color, dtype=np.uint8)).save(buf, format="PNG") + return buf.getvalue() + + +class _FakeResponse: + def __init__(self, content: bytes, ok: bool = True): + self.content = content + self._ok = ok + + def raise_for_status(self): + if not self._ok: + raise RuntimeError("HTTP error") + + +def _patch_hub(monkeypatch, rows, *, dead_urls=()): + """Patch datasets.load_dataset → rows, requests.get → fake PNG (or error + for URLs in ``dead_urls``).""" + pytest.importorskip("datasets") + pytest.importorskip("requests") + import datasets + import requests + + def fake_load_dataset(dataset_id, split=None, streaming=False): + assert streaming is True + return list(rows) + + def fake_get(url, timeout=None): + if url in dead_urls: + return _FakeResponse(b"", ok=False) + return _FakeResponse(_png_bytes()) + + monkeypatch.setattr(datasets, "load_dataset", fake_load_dataset) + monkeypatch.setattr(requests, "get", fake_get) + + +def test_hub_yields_triples(monkeypatch): + rows = [ + {"image_url": "http://x/0.jpg", "caption": "a cat", "transcripts": []}, + {"image_url": "http://x/1.jpg", "caption": "a dog", "transcripts": []}, + ] + _patch_hub(monkeypatch, rows) + items = list(PixmoCapDataset(PixmoCapDatasetConfig(source="hub"))) + assert len(items) == 2 + prompt, caption, image = items[0] + assert prompt == PixmoCapDatasetConfig().prompt + assert caption == "a cat" + assert image.size == (8, 8) + + +def test_hub_skips_dead_urls(monkeypatch): + rows = [ + {"image_url": "http://x/ok.jpg", "caption": "good", "transcripts": []}, + {"image_url": "http://x/dead.jpg", "caption": "bad", "transcripts": []}, + {"image_url": "http://x/ok2.jpg", "caption": "good2", "transcripts": []}, + ] + _patch_hub(monkeypatch, rows, dead_urls={"http://x/dead.jpg"}) + captions = [c for _, c, _ in PixmoCapDataset(PixmoCapDatasetConfig(source="hub"))] + assert captions == ["good", "good2"] + + +def test_hub_limit_counts_yielded_not_scanned(monkeypatch): + """limit should count successfully-loaded examples, so dead URLs in front + of the limit don't shrink the yielded count.""" + rows = [ + {"image_url": "http://x/dead.jpg", "caption": "skip", "transcripts": []}, + {"image_url": "http://x/a.jpg", "caption": "a", "transcripts": []}, + {"image_url": "http://x/b.jpg", "caption": "b", "transcripts": []}, + {"image_url": "http://x/c.jpg", "caption": "c", "transcripts": []}, + ] + _patch_hub(monkeypatch, rows, dead_urls={"http://x/dead.jpg"}) + captions = [c for _, c, _ in PixmoCapDataset(PixmoCapDatasetConfig(source="hub", limit=2))] + assert captions == ["a", "b"] + + +def test_invalid_source_raises(): + with pytest.raises(ValueError, match="source must be"): + PixmoCapDataset(PixmoCapDatasetConfig(source="nope")) diff --git a/src/test/data/multimodal/preprocessor_test.py b/src/test/data/multimodal/preprocessor_test.py new file mode 100644 index 000000000..43d056e39 --- /dev/null +++ b/src/test/data/multimodal/preprocessor_test.py @@ -0,0 +1,205 @@ +"""Tests for the top-level multimodal preprocessor.""" + +import numpy as np +import pytest + +from olmo_core.data.multimodal.image_preprocessor import ImagePreprocessorConfig +from olmo_core.data.multimodal.multicrop import CropMode, MultiCropPreprocessorConfig +from olmo_core.data.multimodal.preprocessor import ( + MultimodalPreprocessor, + MultimodalPreprocessorConfig, +) +from olmo_core.data.multimodal.tokens import MultimodalTokenizerConfig + +transformers = pytest.importorskip("transformers") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def hf_tokenizer(): + tok_cfg = MultimodalTokenizerConfig.dolma2() + try: + return tok_cfg.load_hf_tokenizer() + except Exception as e: # noqa: BLE001 + pytest.skip(f"Could not load dolma2 HF tokenizer: {e}") + + +def _resize_cfg(seq_len: int = 256) -> MultimodalPreprocessorConfig: + return MultimodalPreprocessorConfig( + max_sequence_length=seq_len, + multicrop=MultiCropPreprocessorConfig( + base_image_input_size=(28, 28), + crop_mode=CropMode.resize, + pool_h=2, + pool_w=2, + image_preprocessor=ImagePreprocessorConfig(patch_size=14), + ), + ) + + +def _overlap_cfg(seq_len: int = 512) -> MultimodalPreprocessorConfig: + return MultimodalPreprocessorConfig( + max_sequence_length=seq_len, + multicrop=MultiCropPreprocessorConfig( + base_image_input_size=(28, 28), + crop_mode=CropMode.overlap_and_resize, + max_crops=4, + overlap_margins=(0, 0), + pool_h=1, + pool_w=1, + image_preprocessor=ImagePreprocessorConfig(patch_size=14), + ), + ) + + +def _random_image(h: int = 56, w: int = 56): + rng = np.random.default_rng(0) + return rng.integers(0, 256, size=(h, w, 3), dtype=np.uint8) + + +# --------------------------------------------------------------------------- +# Shapes / dtypes +# --------------------------------------------------------------------------- + + +def test_text_with_image_returns_expected_keys(hf_tokenizer): + cfg = _resize_cfg() + pp = MultimodalPreprocessor(cfg, hf_tokenizer) + out = pp(prompt="What is in this image?", response="A cat.", image=_random_image()) + assert set(out.keys()) == { + "input_tokens", + "loss_masks", + "images", + "pooled_patches_idx", + } + + +def test_text_with_image_dtypes(hf_tokenizer): + pp = MultimodalPreprocessor(_resize_cfg(), hf_tokenizer) + out = pp("Caption:", "A cat.", _random_image()) + assert out["input_tokens"].dtype == np.int64 + assert out["loss_masks"].dtype == np.float32 + assert out["images"].dtype == np.float32 + assert out["pooled_patches_idx"].dtype == np.int64 + + +def test_input_tokens_and_loss_masks_same_length(hf_tokenizer): + pp = MultimodalPreprocessor(_resize_cfg(), hf_tokenizer) + out = pp("Question.", "Answer.", _random_image()) + assert out["input_tokens"].shape == out["loss_masks"].shape + + +def test_image_tensors_match_multicrop(hf_tokenizer): + """The image tensors in the preprocessor output should match running + multicrop directly on the same image.""" + cfg = _resize_cfg() + pp = MultimodalPreprocessor(cfg, hf_tokenizer) + img = _random_image() + out = pp("Q", "A", img) + mc_out = pp.multicrop(img) + np.testing.assert_array_equal(out["images"], mc_out.images) + np.testing.assert_array_equal(out["pooled_patches_idx"], mc_out.pooled_patches_idx) + + +# --------------------------------------------------------------------------- +# Loss masks +# --------------------------------------------------------------------------- + + +def test_loss_mask_only_on_response(hf_tokenizer): + """Tokens that come from prompt/image text must have loss_mask = 0.""" + cfg = _resize_cfg() + pp = MultimodalPreprocessor(cfg, hf_tokenizer) + out = pp(prompt="prompt text", response="response text", image=_random_image()) + # Find the response section by looking at where loss_mask flips to 1. + assert (out["loss_masks"] == 1.0).any() + assert (out["loss_masks"] == 0.0).any() + # Loss mask should be contiguous: zeros first, then ones. + flip_points = np.where(np.diff(out["loss_masks"]) != 0)[0] + assert len(flip_points) == 1, f"loss_masks should flip exactly once, got {len(flip_points)}" + + +def test_eos_added_to_response(hf_tokenizer): + """When add_eos=True, the response should end with the base EOS token.""" + cfg = _resize_cfg() + pp = MultimodalPreprocessor(cfg, hf_tokenizer) + out = pp(prompt="Q", response="A", image=_random_image()) + base_eos = cfg.tokenizer.base.eos_token_id + assert out["input_tokens"][-1] == base_eos + assert out["loss_masks"][-1] == 1.0 + + +def test_no_eos_when_disabled(hf_tokenizer): + cfg = _resize_cfg() + cfg.add_eos = False + pp = MultimodalPreprocessor(cfg, hf_tokenizer) + out = pp(prompt="Q", response="A", image=_random_image()) + # The response template prepends a space, so encode " A" to get the + # expected trailing tokens. + expected_resp = hf_tokenizer.encode( + cfg.response_template.format(response="A"), add_special_tokens=False + ) + assert out["input_tokens"][-len(expected_resp) :].tolist() == expected_resp + # And the base EOS should not be the final token. + assert out["input_tokens"][-1] != cfg.tokenizer.base.eos_token_id + + +# --------------------------------------------------------------------------- +# Image-patch token count contract (load-bearing for the model splice) +# --------------------------------------------------------------------------- + + +def test_n_image_patch_tokens_equals_n_pooled(hf_tokenizer): + """The number of tokens in input_tokens must equal + pooled_patches_idx.shape[0] — model splice contract.""" + cfg = _resize_cfg() + pp = MultimodalPreprocessor(cfg, hf_tokenizer) + out = pp("Q", "A", _random_image()) + tok = cfg.tokenizer + n_image_tokens = int((out["input_tokens"] == tok.image_patch_id).sum()) + assert n_image_tokens == out["pooled_patches_idx"].shape[0] + + +def test_n_image_patch_tokens_equals_n_pooled_overlap(hf_tokenizer): + """Same contract holds in overlap-and-resize mode (with Molmo2 defaults + the global view also uses ).""" + cfg = _overlap_cfg() + pp = MultimodalPreprocessor(cfg, hf_tokenizer) + out = pp("Describe this.", "It is wide.", _random_image(28, 84)) + tok = cfg.tokenizer + n_image_tokens = int((out["input_tokens"] == tok.image_patch_id).sum()) + assert n_image_tokens == out["pooled_patches_idx"].shape[0] + + +# --------------------------------------------------------------------------- +# Truncation +# --------------------------------------------------------------------------- + + +def test_truncation_caps_to_max_sequence_length(hf_tokenizer): + cfg = _resize_cfg(seq_len=16) # very short + pp = MultimodalPreprocessor(cfg, hf_tokenizer) + out = pp(prompt="Q", response="A" * 200, image=_random_image()) + assert out["input_tokens"].shape[0] <= 16 + assert out["loss_masks"].shape[0] == out["input_tokens"].shape[0] + + +# --------------------------------------------------------------------------- +# Text-only path +# --------------------------------------------------------------------------- + + +def test_text_only_example(hf_tokenizer): + cfg = _resize_cfg() + pp = MultimodalPreprocessor(cfg, hf_tokenizer) + out = pp(prompt="Just text.", response="More text.", image=None) + assert out["images"].shape[0] == 0 + assert out["pooled_patches_idx"].shape[0] == 0 + # input_tokens must still be non-empty. + assert out["input_tokens"].shape[0] > 0 + # Loss must still cover the response. + assert (out["loss_masks"] == 1.0).any() diff --git a/src/test/data/multimodal/synthetic_source.py b/src/test/data/multimodal/synthetic_source.py new file mode 100644 index 000000000..3ee54fb52 --- /dev/null +++ b/src/test/data/multimodal/synthetic_source.py @@ -0,0 +1,77 @@ +""" +Test-only synthetic multimodal data source. + +A lightweight, dependency-free iterable that yields ``(prompt, response, image)`` +triples drawn from a fixed RNG. Used by the data-pipeline tests (collator / +loader / preprocessor / train module) to exercise the plumbing without network +or real data. This deliberately lives in the test tree, not in ``olmo_core`` — +the only shipped training data source is +:class:`~olmo_core.data.multimodal.pixmo_cap.PixmoCapDataset`. +""" + +from dataclasses import dataclass +from typing import Iterator, Optional, Tuple + +import numpy as np + +__all__ = [ + "SyntheticMultimodalDatasetConfig", + "SyntheticMultimodalDataset", +] + + +@dataclass +class SyntheticMultimodalDatasetConfig: + """Configuration for :class:`SyntheticMultimodalDataset`.""" + + n_examples: int = 64 + image_size: Tuple[int, int] = (56, 56) + image_size_jitter: int = 0 + prompt_words: int = 6 + response_words: int = 4 + seed: int = 0 + text_only_fraction: float = 0.0 + + +class SyntheticMultimodalDataset: + """An iterable of ``(prompt, response, image_or_None)`` triples, deterministic + given :attr:`SyntheticMultimodalDatasetConfig.seed`.""" + + def __init__(self, cfg: SyntheticMultimodalDatasetConfig): + self.cfg = cfg + self._epoch_seed = cfg.seed + self._words = ( + "alpha bravo charlie delta echo foxtrot golf hotel india " + "juliet kilo lima mike november oscar papa quebec romeo " + "sierra tango uniform victor whiskey xray yankee zulu" + ).split() + + def set_epoch(self, seed: int) -> None: + """Re-seed for a new epoch. Called by the data loader.""" + self._epoch_seed = seed + + def _random_text(self, rng: np.random.Generator, n_words: int) -> str: + return " ".join(rng.choice(self._words, size=n_words).tolist()) + + def _random_image(self, rng: np.random.Generator) -> np.ndarray: + cfg = self.cfg + h, w = cfg.image_size + if cfg.image_size_jitter > 0: + j = cfg.image_size_jitter + h = max(1, int(rng.integers(h - j, h + j + 1))) + w = max(1, int(rng.integers(w - j, w + j + 1))) + return rng.integers(0, 256, size=(h, w, 3), dtype=np.uint8) + + def __iter__(self) -> Iterator[Tuple[str, str, Optional[np.ndarray]]]: + cfg = self.cfg + rng = np.random.default_rng(self._epoch_seed) + for _ in range(cfg.n_examples): + prompt = self._random_text(rng, cfg.prompt_words) + response = self._random_text(rng, cfg.response_words) + if rng.random() < cfg.text_only_fraction: + yield prompt, response, None + else: + yield prompt, response, self._random_image(rng) + + def __len__(self) -> int: + return self.cfg.n_examples diff --git a/src/test/train/multimodal/__init__.py b/src/test/train/multimodal/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/test/train/multimodal/fsdp_test.py b/src/test/train/multimodal/fsdp_test.py new file mode 100644 index 000000000..adb384583 --- /dev/null +++ b/src/test/train/multimodal/fsdp_test.py @@ -0,0 +1,259 @@ +""" +FSDP tests for MultimodalTransformer + MultimodalTransformerTrainModule. + +Uses :func:`~olmo_core.testing.distributed.run_distributed_test` with the +``gloo`` backend to spawn 2 ranks on CPU, so this runs without GPUs and +covers the FSDP wrapping + materialize-then-init flow. +""" + +import tempfile + +import pytest +import torch + +from olmo_core.data.multimodal import ( + CropMode, + ImagePreprocessorConfig, + MultiCropPreprocessorConfig, + MultimodalDataLoaderConfig, + MultimodalPreprocessorConfig, + MultimodalTokenizerConfig, +) +from olmo_core.distributed.parallel import DataParallelType +from olmo_core.nn.transformer.config import TransformerConfig +from olmo_core.nn.vision import ( + MultimodalTransformer, + MultimodalTransformerConfig, + VisionBackboneConfig, + VisionBackboneType, + VisionConnectorConfig, +) +from olmo_core.optim import AdamWConfig +from olmo_core.testing import run_distributed_test +from olmo_core.train.train_module.multimodal import ( + MultimodalTransformerTrainModule, + MultimodalTransformerTrainModuleConfig, +) +from olmo_core.train.train_module.transformer.config import ( + TransformerDataParallelConfig, +) + +transformers = pytest.importorskip("transformers") + + +# --------------------------------------------------------------------------- +# Model and loader builders (same shape as PR 5's tests) +# --------------------------------------------------------------------------- + + +def _tiny_pixmo_cap_sample(n: int = 4): + """A tiny in-memory stand-in for a PixMo-Cap shard: deterministic + ``(prompt, caption, PIL.Image)`` triples, no network or disk required. + + Returns a list (re-iterable) so the data loader can scan it more than once. + """ + import numpy as np + from PIL import Image + + rng = np.random.default_rng(0) + items = [] + for i in range(n): + arr = rng.integers(0, 256, size=(56, 56, 3), dtype=np.uint8) + items.append(("Describe this image in detail.", f"A caption {i}.", Image.fromarray(arr))) + return items + + +def _build_meta_model() -> MultimodalTransformer: + """Build a tiny multimodal model on the meta device so FSDP can wrap it + before materialization.""" + mm_tok = MultimodalTokenizerConfig.dolma2() + lm_cfg = TransformerConfig.olmo2_1M(vocab_size=mm_tok.padded_vocab_size(128)) + vis_cfg = VisionBackboneConfig( + name=VisionBackboneType.openai, + image_default_input_size=(28, 28), + image_patch_size=14, + image_emb_dim=32, + image_num_heads=2, + image_num_key_value_heads=2, + image_num_layers=2, + image_head_dim=16, + image_mlp_dim=64, + image_num_pos=5, + image_norm_eps=1e-5, + ) + conn_cfg = VisionConnectorConfig.from_vision_backbone( + vis_cfg, output_dim=lm_cfg.d_model, mlp_hidden_size=32 + ) + cfg = MultimodalTransformerConfig( + lm=lm_cfg, + vision=vis_cfg, + connector=conn_cfg, + image_patch_token_id=mm_tok.image_patch_id, + ) + return MultimodalTransformer(cfg, init_device="meta") + + +# --------------------------------------------------------------------------- +# Single-rank tests (no torch.distributed): verify apply_fsdp / init_weights +# don't break the model when called as part of train module construction. +# --------------------------------------------------------------------------- + + +def test_init_weights_materializes_from_meta(): + """init_weights moves a meta-device model to a real device and runs + reset_parameters on every submodule.""" + model = _build_meta_model() + # Confirm meta first. + assert any(p.device.type == "meta" for p in model.parameters()) + + model.init_weights(device=torch.device("cpu")) + + # All params should now be on CPU and contain real values. + for p in model.parameters(): + assert p.device.type == "cpu" + assert torch.isfinite(p).all() + + +def test_train_module_without_dp_config_works_after_init_weights_change(): + """The single-device path (dp_config=None) still works exactly as PR 5.""" + mm_tok = MultimodalTokenizerConfig.dolma2() + cfg = MultimodalTransformerTrainModuleConfig( + rank_microbatch_size=512, + max_sequence_length=128, + optim=AdamWConfig(lr=1e-3), + ) + # Build the model on CPU (not meta) so the dp_config=None path works. + lm_cfg = TransformerConfig.olmo2_1M(vocab_size=mm_tok.padded_vocab_size(128)) + vis_cfg = VisionBackboneConfig( + name=VisionBackboneType.openai, + image_default_input_size=(28, 28), + image_patch_size=14, + image_emb_dim=32, + image_num_heads=2, + image_num_key_value_heads=2, + image_num_layers=2, + image_head_dim=16, + image_mlp_dim=64, + image_num_pos=5, + image_norm_eps=1e-5, + ) + conn_cfg = VisionConnectorConfig.from_vision_backbone( + vis_cfg, output_dim=lm_cfg.d_model, mlp_hidden_size=32 + ) + model = MultimodalTransformer( + MultimodalTransformerConfig( + lm=lm_cfg, + vision=vis_cfg, + connector=conn_cfg, + image_patch_token_id=mm_tok.image_patch_id, + ), + init_device="cpu", + ) + tm = cfg.build(model, device=torch.device("cpu")) + assert isinstance(tm, MultimodalTransformerTrainModule) + assert tm.world_mesh is None + + +# --------------------------------------------------------------------------- +# Distributed (2-rank gloo) FSDP test +# --------------------------------------------------------------------------- + + +def _fsdp_wrap_and_init_only(): + """Body for ``run_distributed_test``: just apply_fsdp + init_weights, + no train module, no data, no forward. Isolates the FSDP/materialization + flow from the rest of the train stack.""" + import torch.distributed as dist + + from olmo_core.distributed.parallel import build_world_mesh, get_dp_model_mesh + + model = _build_meta_model() + world_mesh = build_world_mesh( + dp=TransformerDataParallelConfig(name=DataParallelType.fsdp), device_type="cpu" + ) + dp_mesh = get_dp_model_mesh(world_mesh) + model.apply_fsdp(dp_mesh=dp_mesh) + model.init_weights( + max_seq_len=128, + device=torch.device("cpu"), + world_mesh=world_mesh, + ) + + # Each rank's local params should be finite. + for p in model.parameters(): + local = p.to_local() if hasattr(p, "to_local") else p + assert torch.isfinite(local).all(), f"non-finite param on rank {dist.get_rank()}" + + +def test_fsdp_2rank_wrap_and_init(): + """Minimal: FSDP wrap + init_weights under 2-rank gloo.""" + # Use spawn to avoid inheriting CUDA state from the test process. + run_distributed_test( + _fsdp_wrap_and_init_only, world_size=2, backend="gloo", start_method="spawn" + ) + + +def _fsdp_smoke(): + """Body for ``run_distributed_test``: build model on meta, apply FSDP via + the train module's dp_config path, then run a single forward + backward.""" + mm_tok = MultimodalTokenizerConfig.dolma2() + try: + hf_tok = mm_tok.load_hf_tokenizer() + except Exception as e: # noqa: BLE001 + pytest.skip(f"Could not load dolma2 HF tokenizer: {e}") + + model = _build_meta_model() + tm_cfg = MultimodalTransformerTrainModuleConfig( + rank_microbatch_size=512, + max_sequence_length=128, + optim=AdamWConfig(lr=1e-3), + dp_config=TransformerDataParallelConfig(name=DataParallelType.fsdp), + ) + tm = tm_cfg.build(model, device=torch.device("cpu")) + # No Trainer is attached; stub metric hooks. + tm.record_metric = lambda *a, **k: None + tm.record_ce_loss = lambda *a, **k: None + + # Build a tiny loader (rank-aware via dp_world_size=2). + loader_cfg = MultimodalDataLoaderConfig( + preprocessor=MultimodalPreprocessorConfig( + max_sequence_length=128, + multicrop=MultiCropPreprocessorConfig( + base_image_input_size=(28, 28), + crop_mode=CropMode.resize, + pool_h=2, + pool_w=2, + image_preprocessor=ImagePreprocessorConfig(patch_size=14), + ), + ), + global_batch_size=2, + work_dir=tempfile.mkdtemp(prefix="mm-fsdp-test-"), + ) + source = _tiny_pixmo_cap_sample(n=4) + import torch.distributed as dist + + loader = loader_cfg.build( + source, + hf_tok, + dp_world_size=dist.get_world_size(), + dp_rank=dist.get_rank(), + ) + loader.reshuffle(epoch=1) + batch = next(iter(loader)) + loader.reset() + + # Run a training step. Should not raise. + tm.zero_grads() + tm.train_batch(batch) + tm.optim_step() + + # Sanity: at least one connector param should be a DTensor (FSDP-sharded). + from torch.distributed.tensor import DTensor + + found_dtensor = any(isinstance(p, DTensor) for p in tm.model.connector.parameters()) + assert found_dtensor, "expected FSDP to leave connector params as DTensors" + + +def test_fsdp_2rank_smoke(): + """Wraps + materializes + runs one training step under 2-rank gloo FSDP.""" + run_distributed_test(_fsdp_smoke, world_size=2, backend="gloo", start_method="spawn") diff --git a/src/test/train/multimodal/train_module_test.py b/src/test/train/multimodal/train_module_test.py new file mode 100644 index 000000000..136c6b89c --- /dev/null +++ b/src/test/train/multimodal/train_module_test.py @@ -0,0 +1,297 @@ +"""Tests for the multimodal train module.""" + +import tempfile +from test.data.multimodal.synthetic_source import ( + SyntheticMultimodalDataset, + SyntheticMultimodalDatasetConfig, +) + +import pytest +import torch + +from olmo_core.data.multimodal import ( + CropMode, + ImagePreprocessorConfig, + MultiCropPreprocessorConfig, + MultimodalDataLoaderConfig, + MultimodalPreprocessorConfig, + MultimodalTokenizerConfig, +) +from olmo_core.nn.transformer.config import TransformerConfig +from olmo_core.nn.vision import ( + MultimodalTransformer, + MultimodalTransformerConfig, + VisionBackboneConfig, + VisionBackboneType, + VisionConnectorConfig, +) +from olmo_core.optim import AdamWConfig +from olmo_core.train.train_module.multimodal import ( + MultimodalTransformerTrainModule, + MultimodalTransformerTrainModuleConfig, +) + +transformers = pytest.importorskip("transformers") + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def hf_tokenizer(): + tok_cfg = MultimodalTokenizerConfig.dolma2() + try: + return tok_cfg.load_hf_tokenizer() + except Exception as e: # noqa: BLE001 + pytest.skip(f"Could not load dolma2 HF tokenizer: {e}") + + +def _build_model() -> MultimodalTransformer: + mm_tok = MultimodalTokenizerConfig.dolma2() + lm_cfg = TransformerConfig.olmo2_1M(vocab_size=mm_tok.padded_vocab_size(128)) + vis_cfg = VisionBackboneConfig( + name=VisionBackboneType.openai, + image_default_input_size=(28, 28), + image_patch_size=14, + image_emb_dim=32, + image_num_heads=2, + image_num_key_value_heads=2, + image_num_layers=2, + image_head_dim=16, + image_mlp_dim=64, + image_num_pos=5, + image_norm_eps=1e-5, + ) + conn_cfg = VisionConnectorConfig.from_vision_backbone( + vis_cfg, output_dim=lm_cfg.d_model, mlp_hidden_size=32 + ) + cfg = MultimodalTransformerConfig( + lm=lm_cfg, + vision=vis_cfg, + connector=conn_cfg, + image_patch_token_id=mm_tok.image_patch_id, + ) + return MultimodalTransformer(cfg, init_device="cpu") + + +def _build_train_module(max_grad_norm=None) -> MultimodalTransformerTrainModule: + cfg = MultimodalTransformerTrainModuleConfig( + rank_microbatch_size=512, + max_sequence_length=128, + optim=AdamWConfig(lr=1e-3), + max_grad_norm=max_grad_norm, + ) + return cfg.build(_build_model(), device=torch.device("cpu")) + + +def _build_loader(hf_tokenizer, n_examples=8, global_batch_size=2): + cfg = MultimodalDataLoaderConfig( + preprocessor=MultimodalPreprocessorConfig( + max_sequence_length=128, + multicrop=MultiCropPreprocessorConfig( + base_image_input_size=(28, 28), + crop_mode=CropMode.resize, + pool_h=2, + pool_w=2, + image_preprocessor=ImagePreprocessorConfig(patch_size=14), + ), + ), + global_batch_size=global_batch_size, + work_dir=tempfile.mkdtemp(prefix="mm-tm-test-"), + ) + source = SyntheticMultimodalDataset( + SyntheticMultimodalDatasetConfig(n_examples=n_examples, image_size=(56, 56), seed=0) + ) + return cfg.build(source, hf_tokenizer) + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +def test_config_build_returns_train_module(): + tm = _build_train_module() + assert isinstance(tm, MultimodalTransformerTrainModule) + assert tm.optim is not None + assert tm.rank_microbatch_size == 512 + assert tm.max_sequence_length == 128 + + +def test_model_lives_on_configured_device(): + tm = _build_train_module() + for p in tm.model.parameters(): + assert p.device.type == "cpu" + break + + +# --------------------------------------------------------------------------- +# train_batch — gradient flow + parameter update +# --------------------------------------------------------------------------- + + +def test_train_batch_updates_parameters(hf_tokenizer): + tm = _build_train_module() + # Stub out metric hooks (no Trainer attached in unit tests). + tm.record_metric = lambda *a, **k: None + tm.record_ce_loss = lambda *a, **k: None + loader = _build_loader(hf_tokenizer, n_examples=4, global_batch_size=2) + loader.reshuffle(epoch=1) + + # Snapshot connector + last LM block params before training. + before = { + "connector": tm.model.connector.projector.w1.weight.detach().clone(), + "lm_block_0": next(tm.model.lm.blocks["0"].parameters()).detach().clone(), + } + batch = next(iter(loader)) + loader.reset() + + tm.zero_grads() + tm.train_batch(batch) + tm.optim_step() + + after = { + "connector": tm.model.connector.projector.w1.weight.detach().clone(), + "lm_block_0": next(tm.model.lm.blocks["0"].parameters()).detach().clone(), + } + # Both submodules' params should change. + assert not torch.equal(before["connector"], after["connector"]) + assert not torch.equal(before["lm_block_0"], after["lm_block_0"]) + + +def test_train_batch_records_ce_loss(hf_tokenizer): + tm = _build_train_module() + loader = _build_loader(hf_tokenizer, n_examples=4, global_batch_size=2) + loader.reshuffle(epoch=1) + batch = next(iter(loader)) + loader.reset() + # We need a metric recorder; the TrainModule normally has self.trainer. + # For tests we monkeypatch a tiny recorder. + captured = {} + + def fake_record_metric(name, value, *args, **kwargs): + captured[name] = value.detach().item() if isinstance(value, torch.Tensor) else value + + def fake_record_ce_loss(value, *args, **kwargs): + captured["ce_loss"] = value.detach().item() + + tm.record_metric = fake_record_metric + tm.record_ce_loss = fake_record_ce_loss + + tm.zero_grads() + tm.train_batch(batch) + assert "ce_loss" in captured + assert captured["ce_loss"] > 0 # Loss should be finite and positive at init. + + +# --------------------------------------------------------------------------- +# loss_masks → label_mask conversion +# --------------------------------------------------------------------------- + + +def test_loss_masks_converted_to_label_mask(hf_tokenizer): + """The train module pops loss_masks (float) and creates label_mask (bool) + BEFORE delegating to the base class.""" + tm = _build_train_module() + batch = { + "input_ids": torch.zeros(2, 8, dtype=torch.long), + "loss_masks": torch.tensor( + [[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, 1, 1, 1]], dtype=torch.float32 + ), + "images": torch.zeros(2, 0, 4, 588, dtype=torch.float32), + "pooled_patches_idx": torch.zeros(2, 0, 4, dtype=torch.long), + } + tm._convert_loss_masks(batch) + assert "loss_masks" not in batch + assert "label_mask" in batch + assert batch["label_mask"].dtype == torch.bool + # Where loss_masks was 1.0, label_mask should be True. + assert batch["label_mask"][0, 4].item() is True + assert batch["label_mask"][0, 0].item() is False + + +def test_label_mask_passthrough_unchanged(hf_tokenizer): + """If the batch already has label_mask, _convert_loss_masks shouldn't clobber it.""" + tm = _build_train_module() + existing = torch.tensor([[True, False, True, False]], dtype=torch.bool) + batch = {"loss_masks": torch.ones(1, 4), "label_mask": existing} + tm._convert_loss_masks(batch) + # label_mask preserved; loss_masks left in place (we only pop when label_mask is missing). + assert torch.equal(batch["label_mask"], existing) + + +# --------------------------------------------------------------------------- +# Multi-step convergence (does loss decrease?) +# --------------------------------------------------------------------------- + + +def test_loss_decreases_over_steps(hf_tokenizer): + """Sanity check: a few training steps on a tiny synthetic batch should + reduce loss (or at least change it; can be noisy with bs=2).""" + # No max_grad_norm so optim_step doesn't try to read self.trainer. + tm = _build_train_module(max_grad_norm=None) + loader = _build_loader(hf_tokenizer, n_examples=8, global_batch_size=2) + + losses = [] + + def fake_record_ce_loss(value, *args, **kwargs): + losses.append(value.detach().item()) + + tm.record_metric = lambda *a, **k: None + tm.record_ce_loss = fake_record_ce_loss + + loader.reshuffle(epoch=1) + for batch in loader: + tm.zero_grads() + tm.train_batch(batch) + tm.optim_step() + loader.reset() + + # Loss should change across steps (gradient is being applied). + assert len(losses) >= 2 + assert losses[0] != losses[-1] + + +# --------------------------------------------------------------------------- +# state_dict round-trip +# --------------------------------------------------------------------------- + + +def test_state_dict_round_trip(): + tm1 = _build_train_module() + state = tm1.state_dict(optim=False) + assert "model" in state + + # Build a fresh module and load. + tm2 = _build_train_module() + sd_to_load = tm2.state_dict_to_load( + # state_dict_to_load needs a Metadata object — we test the simpler model-only path. + metadata=type("M", (), {"state_dict_metadata": {}})(), + optim=False, + ) + # The loaded state_dict_to_load should at least produce the same keys. + assert set(sd_to_load.keys()) == {"model"} + + +# --------------------------------------------------------------------------- +# End-to-end: data loader → train module +# --------------------------------------------------------------------------- + + +def test_end_to_end_train_step(hf_tokenizer): + """Drive the full pipeline: source → loader → train_module.train_batch.""" + tm = _build_train_module() + tm.record_metric = lambda *a, **k: None + tm.record_ce_loss = lambda *a, **k: None + + loader = _build_loader(hf_tokenizer, n_examples=4, global_batch_size=2) + loader.reshuffle(epoch=1) + for batch in loader: + tm.zero_grads() + tm.train_batch(batch) + tm.optim_step() + loader.reset() + # If we got here without exceptions, all the kwargs propagated correctly + # through the train module to the model's forward.