Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `scripts/train.py`: training-loop hook that runs `mot_warm_start_from_text_stack` once at step 0 when `mot_warm_start_from_text=True`, between `ckpt_mgr.load(...)` and `apply_freeze_specs(...)`.
- Tests: `tests/unit/test_mot.py` (Algorithm-1 reference parity, warm-start helper round-trips); `tests/unit/test_model.py::TestMoT` + `TestModalityIdsCrossArgs`; MoT cases in `tests/unit/test_vlm.py`, `test_vlm_config.py`, `test_modality_context.py`; `tests/integration/test_vlm_mot.py`; `tests/distributed/test_vlm_mot_fsdp.py` (gated on multi-GPU).
- Configs: `configs/train/vlm_debug_mot.toml` (1-GPU smoke) and `configs/train/vlm_7b_mot.toml` (4-GPU 7B).
- **Video understanding (all four archs).** Extend the VLM path to ingest video — a clip is an ordered set of frames sampled at a target fps — through the same registry-driven design as the image path; trains end-to-end on WebVid-10M with Joint-Decoder, Cross-Attention, MoT, and MoMa. The text-only and single-image paths are unchanged (bit-exact).
- **Pooling connector.** `kempnerforge/model/adapter.py`: a `VisionAdapter` base with an `output_num_tokens()` contract, plus `avgpool` and `attentional_pool` (Molmo2-style mean-query MHA) connectors registered via `@registry.register_adapter`; `kempnerforge/config/adapter.py` gains `pool_window` / `pool_heads`. The adapter-derived visual-token count is threaded through `build_vlm_wrapper` / `_build_vlm`, the four modality strategies, and the three seq-len checks (`config/job.py`, `distributed/parallel.py`, `model/vlm.py`). Projection adapters keep the count (identity) so the image path is bit-exact.
- **Video data path.** `kempnerforge/data/video_io.py`: timestamp-based frame sampling (target fps, uniform, first & last frame kept — Molmo2 §3.1/§A) + PyAV decode (lazily imported; the wheel bundles FFmpeg, so no system FFmpeg is required). `kempnerforge/data/video_dataset.py`: `WebVidVideoDataset` (CSV manifest + `id[:2]/id[:4]/id[:6]/id.mp4` path mapping, reuses the image preprocessing) and `VideoCollator` → `(B, F, 3, H, W)` + a frame-validity mask; an undecodable clip is masked out (contributes no loss). `kempnerforge/config/video.py`: the `[video]` `VideoConfig` section (`data_root`, `split`, `fps`, `max_frames`, `min_frames`, `frame_size`, `max_samples`), wired into `JobConfig` (+ `is_video`). Adds `av` to dependencies.
- **Frame-aware model + training wiring.** `kempnerforge/model/vlm.py`: `_project_image_features` → `_project_visual_features` folds the frame axis through the encoder + pooler to `(B, F·P′, dim)` (a single image is the `F == 1` case). `VLMWrapper` gains `frames_per_clip`, threaded through `build_parallel_model` / `_build_vlm` / `build_vlm_wrapper` so the static visual-token count equals `F·P′` (drives the residual budget and MoT's positional split; static == runtime). `scripts/train.py` builds the video dataset/collator when `[video]` is set. Adds `configs/train/vlm_video_webvid.toml` (SigLIP2 + avgpool + WebVid).
- Tests: `tests/unit/test_video_io.py`, `test_video_dataset.py`, `test_video_config.py`; video-forward cases (all four archs) + image-path regression in `test_vlm.py`; pooling-adapter cases in `test_adapter.py`. Docs: `docs/how-to/train-on-video.md`.
- Deferred (follow-ups): per-frame timestamp tokens + grounding (`<points>`/`<tracks>` outputs with point-F1 / track-J&F eval), frame-mask-aware attention, bidirectional visual attention, VLM sequence packing, long-context (blocked on context-parallel being wired), and warm-start from a converted image-VLM checkpoint.
- `install-and-verify` plugin skill: runs `uv sync`, asserts Python ≥ 3.12, then runs the four CI gate checks (`ruff check`, `ruff format --check`, `pyright`, `pytest tests/unit/`). Canonical first command after cloning.
- `.python-version` pinned to `>=3.12` so uv resolves the interpreter explicitly. Teammates on 3.13 use 3.13 (no download); 3.11-only users get 3.12 auto-fetched.
- **Dynamic-checkpointing window** (`[checkpoint.dyn_ckpt_window]`). Opt-in dense save phase: inside `[start, stop]` a registered strategy decides which steps to save; outside the window the regular `interval` cadence applies. The default strategy, `"power2"`, saves at `start` and at every `start + 2^k` while `<= stop` — tight near the start of the window, doubling thereafter. Useful for analyzing early-training dynamics, where the loss moves fastest. The default `CheckpointConfig` is unchanged (no `dyn_ckpt_window`, interval-only saves).
Expand Down
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,15 @@ Further reading:

### Vision-language models

KempnerForge supports VLM training via a thin wrapper around the existing `Transformer`. Image tokens come from a frozen HF vision encoder (SigLIP2, CLIP, or a tiny `random` test stub), pass through a 2-layer adapter, and feed the backbone via an arch-specific path:
KempnerForge supports VLM training — images **or video** (a clip is an ordered set of frames) — via a thin wrapper around the existing `Transformer`. Visual tokens come from a frozen HF vision encoder (SigLIP2, CLIP, or a tiny `random` test stub), pass through a connector (a 2-layer adapter, or an `avgpool` / `attentional_pool` connector that pools patches to reduce tokens per frame), and feed the backbone via an arch-specific path:

- **Joint-Decoder** (`arch = "joint_decoder"`): image embeds are prepended to the text sequence; the transformer runs over the concatenated `(image, text)` sequence and the LM head is applied to text positions only.
- **Cross-Attention** (`arch = "cross_attention"`, Llama-3-V style): the residual stream carries text only. Separate `CrossAttentionBlock`s inserted at a configurable cadence let text queries attend to image K/V. CA blocks are zero-initialized so adding the arch on top of a text-only checkpoint is identity at step 0 and learns from there.
- **Mixture-of-Transformers** (`arch = "mot"`, Liang et al. 2024 Algorithm 1): every layer carries per-modality Q/K/V/O projections plus a per-modality FFN; a single global self-attention mixes all modality streams. Image tokens prepend the text sequence (same residual layout as Joint-Decoder); per-modality residual projections are zero-initialized so a fresh MoT block is identity at construction. A warm-start helper (`mot_warm_start_from_text_stack`) translates a JD or text-only checkpoint into per-modality copies — toggle via `[model.vlm].mot_warm_start_from_text` + `mot_warm_start_path`.
- **Mixture of Modality-Aware Experts** (`arch = "moma"`, Lin et al. 2024 arXiv:2407.21770): one shared set of Q/K/V/O projections feeding a global self-attention, plus per-modality MoE FFN groups (paper's optimal default 4 image + 4 text experts per layer). Tokens route deterministically to their modality group (level-1, reusing the same `modality_ids` mechanism MoT uses) and then through a learned expert-choice + Sigmoid router within the group (level-2, with Gumbel-Sigmoid noise during training, paper Eq. 5). Image tokens prepend the text sequence (same residual layout as JD/MoT). v1 supports training only — expert-choice routing is non-causal, so autoregressive generation requires auxiliary routers (paper §2.4) which are deferred to a follow-up.

**Video** works across all four archs with no arch-specific changes: a clip is decoded into frames (sampled by timestamp at a target fps — uniform, with the first and last frame always kept), each frame is encoded and pooled by the connector, and the `F × tokens_per_frame` visual tokens enter the backbone exactly like image tokens. Configure the `[video]` section (`data_root`, `fps`, `max_frames`, `frame_size`); see `configs/train/vlm_video_webvid.toml`.

```bash
# 1-GPU smoke (random encoder, Joint-Decoder)
uv run python scripts/train.py configs/train/vlm_debug.toml \
Expand All @@ -164,9 +166,12 @@ uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_7b_mot.tom

# 4-GPU 7B Mixture of Modality-Aware Experts (4 text + 4 image experts per layer)
uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_7b_moma.toml

# 4-GPU video training on WebVid (Joint-Decoder; flip [vlm].arch for cross_attention / mot / moma)
uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_video_webvid.toml
```

Configs set `[model.vlm]` with `arch`, the encoder registry key, the number of image tokens, and a freeze list (`FreezeSpec`). For Cross-Attention, set `cross_attention_every_n_layers` and optionally `cross_attention_n_kv_heads` (0 → MHA; positive → GQA on the cross path). For MoT, set `mot_modalities` (must include both `"image"` and `"text"`); `mot_image_n_heads` / `mot_image_n_kv_heads` are forward-looking per-modality head fields (v1 enforces equality with the text-side counts since the operator runs a single global SDPA). For MoMa, set `moma_experts_per_modality = {image = N, text = M}` as a nested TOML table (the paper's optimal balanced default is `4t4i`; unbalanced allocations like `{image = 1, text = 7}` match the paper's `moe_7t1i` ablation), and optionally `moma_capacity_factor` (defaults to `1/|E^M|` per modality — the paper's perfect-balance setting) and `moma_gumbel_noise` (`true` by default for paper-faithful EC routing). `model.num_experts` must be `0` when `arch = "moma"`; the per-modality counts supersede it, and JobConfig.validate rejects the combination. The vision encoder stays in its HF-loaded dtype; the transformer, adapter, and CA / MoT / MoMa blocks are cast to `param_dtype`. Pipeline Parallel + VLM is not supported on this branch (raises at startup); MoMa + Expert Parallelism is also rejected in v1. Multi-image and video are reserved slots for follow-up work.
Configs set `[model.vlm]` with `arch`, the encoder registry key, the number of image tokens, and a freeze list (`FreezeSpec`). For Cross-Attention, set `cross_attention_every_n_layers` and optionally `cross_attention_n_kv_heads` (0 → MHA; positive → GQA on the cross path). For MoT, set `mot_modalities` (must include both `"image"` and `"text"`); `mot_image_n_heads` / `mot_image_n_kv_heads` are forward-looking per-modality head fields (v1 enforces equality with the text-side counts since the operator runs a single global SDPA). For MoMa, set `moma_experts_per_modality = {image = N, text = M}` as a nested TOML table (the paper's optimal balanced default is `4t4i`; unbalanced allocations like `{image = 1, text = 7}` match the paper's `moe_7t1i` ablation), and optionally `moma_capacity_factor` (defaults to `1/|E^M|` per modality — the paper's perfect-balance setting) and `moma_gumbel_noise` (`true` by default for paper-faithful EC routing). `model.num_experts` must be `0` when `arch = "moma"`; the per-modality counts supersede it, and JobConfig.validate rejects the combination. The vision encoder stays in its HF-loaded dtype; the transformer, adapter, and CA / MoT / MoMa blocks are cast to `param_dtype`. Pipeline Parallel + VLM is not supported on this branch (raises at startup); MoMa + Expert Parallelism is also rejected in v1. Video is supported across all four archs via the `[video]` section (a clip is decoded into frames, pooled by the connector, and fed like image tokens); multi-image inputs and video *grounding* (point/track outputs with per-frame timestamps) are reserved for follow-up work.

**Adding a new VLM arch.** The discriminated-union dispatch is registry-driven, so a new arch is four small additions, no edits to existing call sites:

Expand Down
93 changes: 93 additions & 0 deletions configs/train/vlm_video_webvid.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Video VLM training on WebVid-10M (Joint-Decoder).
#
# A clip is decoded into `max_frames` frames (sampled by timestamp at `fps`,
# Molmo2-style: uniform, first + last frame kept), each encoded by the vision
# tower, pooled by the connector, and prepended to the caption tokens. The LLM
# is trained from scratch here; warm-starting from a pretrained backbone is a
# later phase.
#
# Visual-token budget (Joint-Decoder): max_frames * tokens_per_frame + max_text_len
# must be <= model.max_seq_len. With SigLIP2 @224/patch16 (14x14 = 196 patches)
# and the avgpool connector at pool_window=2 -> ceil(14/2)^2 = 49 tokens/frame.
# 8 frames -> 8*49 = 392 visual + 64 text = 456 <= 576.
#
# Launch (single node, 4 GPUs):
# uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_video_webvid.toml
#
# Quick smoke (no SigLIP download, a few clips; pair with a small step count):
# ... --vision_encoder.type=random --vision_encoder.num_tokens=196 \
# --vision_encoder.feature_dim=768 --video.max_samples=256 --train.max_steps=20

[model]
dim = 1024
n_layers = 12
n_heads = 16
n_kv_heads = 16
vocab_size = 50257 # gpt2 tokenizer
max_seq_len = 576 # 8 frames * 49 pooled tokens + 64 text + headroom
norm_type = "rmsnorm"
activation = "silu"
rope_theta = 10000.0

[vision_encoder]
type = "siglip2"
path = "google/siglip2-base-patch16-224"
# feature_dim / num_tokens left at 0 -> probed from the HF model at build time.

[adapter]
type = "avgpool" # token-reducing connector; "attentional_pool" is the faithful upgrade
pool_window = 2 # 14x14 patch grid -> 7x7 = 49 tokens/frame

[vlm]
arch = "joint_decoder" # also: cross_attention / mot / moma
max_text_len = 64
freeze = [{module = "vision_encoder", frozen = true}]

[video]
data_root = "/n/holylfs06/LABS/kempner_shared/Everyone/testbed/video/webvid-10m"
split = "train"
fps = 2.0
max_frames = 8
min_frames = 4
frame_size = 224
max_samples = 0 # 0 = full manifest; set small (e.g. 256) for a smoke

[data]
tokenizer_path = "gpt2"
num_workers = 4
pin_memory = true

[train]
batch_size = 4
seq_len = 576
max_steps = 5000
grad_accum_steps = 4
grad_clip_norm = 1.0
seed = 42
compile_model = false
activation_checkpointing = "none"

[optimizer]
name = "adamw"
lr = 2e-4
weight_decay = 0.1
betas = [0.9, 0.95]
fused = false

[scheduler]
name = "cosine"
warmup_steps = 200
min_lr_ratio = 0.1

[distributed]
dp_shard = -1

[checkpoint]
dir = "checkpoints/vlm_video_webvid"
interval = 1000
keep_last_n = 2

[metrics]
log_interval = 10
enable_wandb = false
enable_tensorboard = false
7 changes: 7 additions & 0 deletions docs/how-to/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ with runnable code (or a link to a notebook, config, or script that runs it)
— `ActivationStore`, `extract_representations()`, save to `.npz`,
feed to probing / CKA / SVCCA.

## Multimodal

- [Train on video](train-on-video.md) — ingest a clip as frames through the
VLM path: timestamp frame sampling, the pooling connector, the `[video]`
config + token budget, and how all four archs consume a clip.

```{toctree}
:maxdepth: 1
:hidden:
Expand All @@ -69,4 +75,5 @@ data-mixing-annealing
fp8-training
moe-experiments
mechanistic-interpretability
train-on-video
```
99 changes: 99 additions & 0 deletions docs/how-to/train-on-video.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Train on video

The VLM path ingests **video** through the same wrapper, connectors, and fusion
archs as images — a clip is just an ordered set of frames. This guide covers the
data layout, the `[video]` config, the frame-sampling policy, and how all four
archs consume a clip.

## The model's view of a clip

A clip of `F` frames becomes `F × P′` visual tokens:

1. **Sample** `F` frames from the video by timestamp (target `fps`, uniform,
first and last frame always kept).
2. **Encode** each frame with the frozen vision tower (e.g. SigLIP2), fold the
frame axis into the batch so `B×F` frames run through the encoder once.
3. **Pool + project** each frame with the connector — an `avgpool` or
`attentional_pool` adapter reduces a `grid×grid` patch map to
`P′ = ceil(grid/window)²` tokens per frame (e.g. SigLIP2 @224/patch16 →
14×14 → 49 tokens at `pool_window=2`).
4. **Fuse** the resulting `(B, F·P′, dim)` visual tokens into the backbone the
same way images are fused — so **all four archs work unchanged**:
- `joint_decoder` / `mot` / `moma`: the `F·P′` tokens prepend the text in the
residual stream and are trimmed before the LM head.
- `cross_attention`: the `F·P′` tokens flow as K/V into the cross-attention
blocks; the residual stays text-only (so it fits more frames per
`max_seq_len`).

Temporal order is carried by frame order (sequential positions). Per-frame
timestamp tokens and grounding outputs are a separate follow-up (see below).

## Token budget

For the residual-stream archs (JD / MoT / MoMa):

```
max_frames × tokens_per_frame + max_text_len ≤ model.max_seq_len
```

e.g. 8 frames × 49 + 64 text = 456 ≤ 576. Cross-attention only needs
`max_text_len ≤ max_seq_len` (visual tokens are K/V, not in the residual). The
build- and config-time checks enforce this and fail before any GPU work.

## Configure it

A video run adds a `[video]` section (sibling of `[vision_encoder]` /
`[adapter]` / `[vlm]`) and a token-reducing connector. See
`configs/train/vlm_video_webvid.toml` for a complete example; the key parts:

```toml
[adapter]
type = "avgpool" # or "attentional_pool"; pools patches per frame
pool_window = 2 # 14×14 grid -> 7×7 = 49 tokens/frame

[vlm]
arch = "joint_decoder" # also: cross_attention | mot | moma

[video]
data_root = "/path/to/webvid-10m"
split = "train" # "train" | "validation"
fps = 2.0 # target sampling rate
max_frames = 8 # per-clip frame budget
min_frames = 4
frame_size = 224
max_samples = 0 # 0 = full manifest; set small for a smoke
```

The `[video]` section is decoded by `WebVidVideoDataset` (a WebVid-style layout:
CSV manifests under `raw/webvid-10M/data/<split>/partitions/` and `.mp4` files
under `raw/videos/<split>/`). Decoding uses PyAV (its wheel bundles FFmpeg, so no
system FFmpeg is required); it is imported lazily, so the package imports without
`av` and only actual decoding needs it.

## Launch

```bash
# 4-GPU video training (Joint-Decoder)
uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_video_webvid.toml

# Quick smoke: no SigLIP download, a few clips, few steps
uv run torchrun --nproc_per_node=2 scripts/train.py configs/train/vlm_video_webvid.toml \
--vision_encoder.type=random --vision_encoder.num_tokens=196 \
--vision_encoder.feature_dim=768 --video.max_samples=256 --train.max_steps=20
```

To switch arch, change `[vlm].arch` in the config — everything else (frame
sampling, connector, dataset) is identical. (`arch` is resolved at config-load
time, so it is set in the TOML, not via a `--vlm.arch=` CLI override.)

## Constraints and follow-ups

- **Causal attention; no per-frame timestamps yet** — temporal order is frame
order. Per-frame timestamp tokens + grounding (`<points>`/`<tracks>` outputs
with point-F1 / track-J&F eval) are a follow-up.
- **Padded frames are not yet masked from attention** — short clips pad to
`max_frames` with blank frames; a `frame_mask` is produced but not yet
consumed by the attention mask.
- **Fixed `F` per batch** keeps tensor shapes static (for `torch.compile` and
DP-rank consistency); variable-length clips arrive with VLM sequence packing.
- **Long-context** (many frames) is blocked on context-parallel being wired.
Loading
Loading