Skip to content

Commit c7b2ebb

Browse files
committed
Add vision transformer, connector, and MultimodalTransformer
VLM/multimodal vision-language architecture stack: - VisionBackbone (OpenAI CLIP / SigLIP / SigLIP2): OpenAI-style ViT encoder with configurable image size, patch size, embedding dim, and attention heads. Supports CLIP (openai), SigLIP (siglip), and SigLIP2 (siglip2) initialisation. - VisionConnector: attention-pooling (2×2) + SwiGLU MLP projector that maps vision embeddings to the language-model hidden dimension. - MultimodalTransformer: composite model that fuses image patch tokens into the LM token stream at image-patch positions, then runs the full LM forward pass. - Removed DINOv2 backbone variants (not used in Molmo2). - HF parity tests for CLIP, SigLIP, and SigLIP2 vision encoders.
1 parent 754d58d commit c7b2ebb

14 files changed

Lines changed: 2428 additions & 5 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111

12+
- Added vision transformer encoder (`VisionTransformer`, `SiglipVisionTransformer`), vision-to-LM connector (`VisionConnector`), and `MultimodalTransformer` — a composite vision-language model that fuses image patch tokens into the LM token stream. Supports OpenAI CLIP, SigLIP, and SigLIP2 backbone variants with factory configs for all standard Molmo2 checkpoints.
1213
- Added `HFConverterCallback`, which can be used to convert models to huggingface format at the end of the training run.
1314
- Trainer now records checkpoint save and load durations as `train/checkpoint_save_duration_s` and `train/checkpoint_load_duration_s` metrics.
1415
- Added `PowerLR`, a power-law learning rate scheduler with linear warmup, power-decay phase (`lr = initial_lr * (current / warmup) ** b` for negative `b`, making the LR independent of the training horizon), and an optional linear decay tail. Registered as `"power_lr"`.

src/olmo_core/nn/__init__.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,29 @@
11
"""
22
Common :class:`torch.nn.Module` implementations.
33
"""
4+
5+
from .vision import (
6+
ImagePoolingType,
7+
ImageProjectorType,
8+
MultimodalTransformer,
9+
MultimodalTransformerConfig,
10+
SiglipVisionTransformer,
11+
VisionBackboneConfig,
12+
VisionBackboneType,
13+
VisionConnector,
14+
VisionConnectorConfig,
15+
VisionTransformer,
16+
)
17+
18+
__all__ = [
19+
"VisionBackboneType",
20+
"VisionBackboneConfig",
21+
"VisionTransformer",
22+
"SiglipVisionTransformer",
23+
"ImagePoolingType",
24+
"ImageProjectorType",
25+
"VisionConnectorConfig",
26+
"VisionConnector",
27+
"MultimodalTransformerConfig",
28+
"MultimodalTransformer",
29+
]

src/olmo_core/nn/transformer/model.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ def forward(
506506
self,
507507
input_ids: torch.Tensor,
508508
*,
509+
input_embeddings: Optional[torch.Tensor] = None,
509510
labels: Optional[torch.Tensor] = None,
510511
ignore_index: int = -100,
511512
loss_reduction: Literal["mean", "sum", "none"] = "mean",
@@ -519,6 +520,12 @@ def forward(
519520
Run the transformer on the token input IDs.
520521
521522
:param input_ids: The token input IDs, shape ``(batch_size, seq_len)``.
523+
:param input_embeddings: Pre-computed embeddings to use instead of looking up
524+
``input_ids`` in the embedding table, shape
525+
``(batch_size, seq_len, d_model)``. When provided the embedding lookup,
526+
scale, and norm steps are all skipped. Intended for multimodal use-cases
527+
where image features have already been spliced into the embedding sequence.
528+
Not supported with context parallelism.
522529
:param labels: The token labels, shape ``(batch_size, seq_len)``.
523530
:param ignore_index: The index to ignore in the loss computation. Default is -100.
524531
:param loss_reduction: The reduction method for the loss. Can be "mean", "sum", or "none".
@@ -550,11 +557,14 @@ def forward(
550557

551558
# Get embeddings but pass-through for non-existent layers to allow easy
552559
# pipeline parallel configuration.
553-
h = self.embeddings(input_ids) if self.embeddings is not None else input_ids
554-
if self.embeddings is not None and self.embed_scale is not None:
555-
h = h * self.embed_scale
556-
if self.embedding_norm is not None:
557-
h = self.embedding_norm(h)
560+
if input_embeddings is not None:
561+
h = move_to_device(input_embeddings, self.device)
562+
else:
563+
h = self.embeddings(input_ids) if self.embeddings is not None else input_ids
564+
if self.embeddings is not None and self.embed_scale is not None:
565+
h = h * self.embed_scale
566+
if self.embedding_norm is not None:
567+
h = self.embedding_norm(h)
558568

559569
# Run each block.
560570
for block_key, block in self.blocks.items():
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""
2+
Vision encoder modules for multimodal (VLM) training.
3+
"""
4+
5+
from .config import VisionBackboneConfig, VisionBackboneType
6+
from .connector import (
7+
ImagePoolingType,
8+
ImageProjectorType,
9+
VisionConnector,
10+
VisionConnectorConfig,
11+
)
12+
from .image_vit import SiglipVisionTransformer, VisionTransformer
13+
from .multimodal import MultimodalTransformer, MultimodalTransformerConfig
14+
15+
__all__ = [
16+
"VisionBackboneType",
17+
"VisionBackboneConfig",
18+
"VisionTransformer",
19+
"SiglipVisionTransformer",
20+
"ImagePoolingType",
21+
"ImageProjectorType",
22+
"VisionConnectorConfig",
23+
"VisionConnector",
24+
"MultimodalTransformerConfig",
25+
"MultimodalTransformer",
26+
]

0 commit comments

Comments
 (0)