Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **Video data path (pluggable, registry-driven).** `kempnerforge/data/video_io.py`: timestamp-based frame sampling (target fps, first & last frame kept — Molmo2 §3.1/§A) registered as the `"uniform"` sampling policy and selectable via `[video].sampling_policy`; PyAV decode (lazy-imported). `kempnerforge/data/video_dataset.py`: a `VideoDataset` base + the WebVid-style `WebVidVideoDataset` (CSV manifest + `id[:2]/id[:4]/id[:6]/id.mp4` mapping) registered as `"webvid"` via `@registry.register_video_dataset`, plus a `build_video_dataset` dispatch — so other dataset styles are additive registrations selected by `[video].dataset_type`. The WebVid corpus directory is parameterized by `[video].dataset_name` (no longer hardcoded to `webvid-10M`). `VideoCollator` → `(B, F, 3, H, W)` + a frame-validity mask; an undecodable clip is masked out (no loss). `kempnerforge/config/registry.py`: `register_video_dataset` / `register_sampling_policy` registries. `kempnerforge/config/video.py`: the `[video]` `VideoConfig` section (`data_root`, `dataset_type`, `dataset_name`, `sampling_policy`, `split`, `fps`, `max_frames`, `min_frames`, `frame_size`, `max_samples`), wired into `JobConfig` (+ `is_video`). `av` is an optional `video` dependency group (`uv sync --group video`); CI installs it for the lint + unit-test jobs.
- **Frame-aware model + training wiring.** `kempnerforge/model/vlm.py`: `_project_image_features` → `_project_visual_features` folds the frame axis through the encoder + pooler to `(B, F·P′, dim)` (a single image is the `F == 1` case). `VLMWrapper` gains `frames_per_clip`, threaded through `build_parallel_model` / `_build_vlm` / `build_vlm_wrapper` so the static visual-token count equals `F·P′` (drives the residual budget and MoT's positional split; static == runtime). `scripts/train.py` builds the video dataset/collator when `[video]` is set. Adds `configs/train/vlm_video_webvid.toml` (SigLIP2 + avgpool + WebVid).
- Tests: `tests/unit/test_video_io.py`, `test_video_dataset.py`, `test_video_config.py`; video-forward cases (all four archs) + image-path regression in `test_vlm.py`; pooling-adapter cases in `test_adapter.py`. Docs: `docs/how-to/train-on-video.md`.
- Deferred (follow-ups; the registries make these additive): more video dataset styles (HuggingFace video sets, flat folders, alternate manifests) and frame-sampling policies; per-frame timestamp tokens + grounding (`<points>`/`<tracks>` outputs with point-F1 / track-J&F eval), frame-mask-aware attention, bidirectional visual attention, VLM sequence packing, long-context (blocked on context-parallel being wired), and warm-start from a converted image-VLM checkpoint.
- Deferred (follow-ups; the registries make these additive): more video dataset styles (HuggingFace video sets, flat folders, alternate manifests) and frame-sampling policies; per-frame timestamp tokens + grounding (`<points>`/`<tracks>` outputs with point-F1 / track-J&F eval), bidirectional visual attention, VLM sequence packing, long-context (blocked on context-parallel being wired), and warm-start from a converted image-VLM checkpoint.
- **Padded frames masked from attention (all four archs).** Short/undecodable video clips pad to `max_frames` with blank frames; the `frame_mask` is now consumed so real tokens never attend to padded-frame visual tokens. One per-token validity mask, `ModalityContext.key_padding_mask` `(B, S)`, threads through the model: the shared `Attention` ANDs it with the causal (and doc) mask — covering Joint-Decoder and MoMa — `MoTAttention` builds an explicit causal-AND-valid mask, and Cross-Attention masks the padded image K/V via its existing `image_mask`. A NaN guard unmasks fully-masked query rows (an all-padded clip) so softmax stays finite. It is a **pure mask — no new state-dict keys**, so checkpoints stay compatible both ways; the image (`F=1`) and text paths are unchanged (no mask is built, so they keep the FlashAttention-2 path). Note: for the image-prefix arches (Joint-Decoder/MoT/MoMa) video self-attention always takes the explicit-mask SDPA path (FA2 disabled, a `(B,1,S,S)` mask materialized) even for fully-decoded clips — the result is identical to causal-only but not free; a deliberate `torch.compile`/DP-friendly trade-off (one graph, no host sync), with FA2-recovery / FlexAttention left as a follow-up. (Cross-Attention keeps FA2 on its text self-attention and only masks padded image K/V in the cross-attention blocks.) Foundation for variable-length / mixed image+video batches.
- `kempnerforge/model/modality.py`: `ModalityContext.key_padding_mask` field (+ invariant). `kempnerforge/model/vlm.py`: `_visual_token_mask` expands `frame_mask (B,F)` → `(B, F·P′)`; the four strategies place it (image-prefix arches → `key_padding_mask`, Cross-Attention → `image_mask`). `attention.py` / `mot.py` / `cross_attention.py` consume the mask (+ NaN guard); `moma.py`'s `MoMaFFN` also excludes padded positions from expert-choice routing (so padded tokens never consume expert capacity). `scripts/train.py` threads `batch["frame_mask"]` into the forward.
- Deferred: MoT configured with an *MoE* FFN still routes padded tokens through the shared `MoEMLP` — a follow-up ("generic token-validity in MoE") would mask that and padded text alike. MoT-dense (the default) and MoMa are fully masked.
- Tests: `tests/unit/test_vlm.py` (per-arch masking invariance, image no-op, undecodable-clip NaN guard, mask expansion); `test_moma.py` (FFN routing exclusion); `test_modality_context.py` (the new invariant).
- `install-and-verify` plugin skill: runs `uv sync`, asserts Python ≥ 3.12, then runs the four CI gate checks (`ruff check`, `ruff format --check`, `pyright`, `pytest tests/unit/`). Canonical first command after cloning.
- `.python-version` pinned to `>=3.12` so uv resolves the interpreter explicitly. Teammates on 3.13 use 3.13 (no download); 3.11-only users get 3.12 auto-fetched.
- **Dynamic-checkpointing window** (`[checkpoint.dyn_ckpt_window]`). Opt-in dense save phase: inside `[start, stop]` a registered strategy decides which steps to save; outside the window the regular `interval` cadence applies. The default strategy, `"power2"`, saves at `start` and at every `start + 2^k` while `<= stop` — tight near the start of the window, doubling thereafter. Useful for analyzing early-training dynamics, where the loss moves fastest. The default `CheckpointConfig` is unchanged (no `dyn_ckpt_window`, interval-only saves).
Expand Down
16 changes: 13 additions & 3 deletions docs/how-to/train-on-video.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,19 @@ time, so it is set in the TOML, not via a `--vlm.arch=` CLI override.)
- **Causal attention; no per-frame timestamps yet** — temporal order is frame
order. Per-frame timestamp tokens + grounding (`<points>`/`<tracks>` outputs
with point-F1 / track-J&F eval) are a follow-up.
- **Padded frames are not yet masked from attention** — short clips pad to
`max_frames` with blank frames; a `frame_mask` is produced but not yet
consumed by the attention mask.
- **Padded frames are masked from attention** — short/undecodable clips pad to
`max_frames` with blank frames, and the `frame_mask` is consumed so real
tokens never attend to padded-frame visual tokens (MoMa also drops them from
expert-choice routing); a NaN guard keeps an all-padded clip finite. It is a
pure mask (no new checkpoint keys); image/text keep the FlashAttention-2 path.
For the image-prefix arches (Joint-Decoder/MoT/MoMa), video self-attention
always takes the explicit-mask SDPA path (FA2 disabled, a `(B,1,S,S)` mask
built) even for fully-decoded clips — a deliberate compile/DP-friendly
trade-off; recovering FA2 / FlexAttention is a follow-up. (Cross-Attention
keeps FA2 on its text self-attention; it masks padded image K/V in the
cross-attention blocks instead.) *Remaining:* MoT
configured with an MoE FFN still routes padded tokens through the shared MoE
(a "generic token-validity in MoE" follow-up).
- **Fixed `F` per batch** keeps tensor shapes static (for `torch.compile` and
DP-rank consistency); variable-length clips arrive with VLM sequence packing.
- **Long-context** (many frames) is blocked on context-parallel being wired.
61 changes: 52 additions & 9 deletions kempnerforge/model/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def forward(
*,
kv_cache: KVCache | None = None,
doc_ids: torch.Tensor | None = None,
key_padding_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass.

Expand All @@ -132,6 +133,11 @@ def forward(
doc_ids: Optional per-token document IDs for packed sequences,
shape (batch, seq_len). When provided, constructs a block-diagonal
causal mask so tokens only attend within their document.
key_padding_mask: Optional per-key validity mask, shape
(batch, seq_len); ``True`` = attend, ``False`` = drop (e.g. the
visual tokens of padded video frames). Combined with the causal
(and doc) mask; fully-masked query rows are unmasked to keep
softmax finite.

Returns:
Output tensor of shape (batch, seq_len, dim).
Expand Down Expand Up @@ -177,14 +183,46 @@ def forward(
# restrict attention to only the first key position).
if self.capture_attention_weights:
# Manual attention for weight extraction (analysis only, not for training)
out, attn_weights = self._attention_with_weights(q, k, v, seq_len, doc_ids, kv_cache)
out, attn_weights = self._attention_with_weights(
q, k, v, seq_len, doc_ids, kv_cache, key_padding_mask
)
self.last_attention_weights = attn_weights.detach().cpu()
elif doc_ids is not None:
elif doc_ids is not None or key_padding_mask is not None:
# An explicit attn_mask is not a FlashAttention-2 shape, so SDPA falls
# back to the mem-efficient/math kernel here. The image-prefix video
# arches (Joint-Decoder/MoMa here, MoT in mot.py) always pass a
# key_padding_mask (all-True when unpadded), so their self-attention
# always takes this branch -- losing FA2 and materializing a (B, 1, S, S)
# mask even for fully-decoded clips. (Cross-Attention sets no
# key_padding_mask on this text self-attention -- it masks padded image
# K/V in the cross-attention blocks instead -- so it keeps FA2 here.)
# Deliberate: always-masking is torch.compile / DP-friendly (one graph,
# no host sync). Recovering FA2 for unpadded batches (or moving to
# FlexAttention) is a follow-up.
#
# Asserts no kv_cache: neither doc_ids (packed training) nor
# key_padding_mask (VLM video) co-occurs with decode today, and this
# branch's full-sequence causal mask would mis-handle a cached
# (seq_len=1) decode rather than attend to all cached positions.
assert kv_cache is None, (
"doc_ids / key_padding_mask are not supported with kv_cache decode "
"(would build an incorrect causal mask)."
)
seq_len_kv = k.shape[2]
# Block-diagonal mask: same-document AND causal
doc_mask = doc_ids.unsqueeze(2) == doc_ids.unsqueeze(1) # (B, S, S)
# Explicit bool mask: causal, AND same-document (doc_ids), AND valid
# keys (key_padding_mask, e.g. dropping padded video frames' tokens).
causal = torch.ones(seq_len, seq_len_kv, dtype=torch.bool, device=q.device).tril()
attn_mask = (doc_mask & causal).unsqueeze(1) # (B, 1, S, S)
attn_mask = causal.unsqueeze(0).unsqueeze(0) # (1, 1, S, S_kv)
if doc_ids is not None:
doc_mask = (doc_ids.unsqueeze(2) == doc_ids.unsqueeze(1)).unsqueeze(1) # (B,1,S,S)
attn_mask = attn_mask & doc_mask
if key_padding_mask is not None:
attn_mask = attn_mask & key_padding_mask.view(batch, 1, 1, seq_len_kv)
# NaN guard: a query row with no reachable valid key (e.g. the leading
# positions of an all-padded / undecodable clip) would softmax over all
# -inf -> NaN. Unmask such rows; their outputs are discarded (trimmed by
# output_slice, or the clip's labels are all -100).
attn_mask = attn_mask | ~attn_mask.any(dim=-1, keepdim=True)
with self._sdpa_context():
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
else:
Expand All @@ -205,6 +243,7 @@ def _attention_with_weights(
seq_len: int,
doc_ids: torch.Tensor | None,
kv_cache: KVCache | None,
key_padding_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute attention output and weights manually (for analysis).

Expand All @@ -221,11 +260,15 @@ def _attention_with_weights(
attn = torch.matmul(q, k.transpose(-2, -1)) * scale

seq_len_kv = k.shape[2]
if doc_ids is not None:
doc_mask = doc_ids.unsqueeze(2) == doc_ids.unsqueeze(1)
if doc_ids is not None or key_padding_mask is not None:
causal = torch.ones(seq_len, seq_len_kv, dtype=torch.bool, device=q.device).tril()
mask = ~(doc_mask & causal).unsqueeze(1)
attn = attn.masked_fill(mask, float("-inf"))
valid = causal.unsqueeze(0).unsqueeze(0) # (1, 1, S, S_kv)
if doc_ids is not None:
valid = valid & (doc_ids.unsqueeze(2) == doc_ids.unsqueeze(1)).unsqueeze(1)
if key_padding_mask is not None:
valid = valid & key_padding_mask.view(q.shape[0], 1, 1, seq_len_kv)
valid = valid | ~valid.any(dim=-1, keepdim=True) # NaN guard (see forward)
attn = attn.masked_fill(~valid, float("-inf"))
elif kv_cache is None or seq_len > 1:
causal = torch.ones(seq_len, seq_len_kv, dtype=torch.bool, device=q.device).triu(
diagonal=1
Expand Down
7 changes: 6 additions & 1 deletion kempnerforge/model/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ def forward(
# the right thing across heads and text Q positions.
attn_mask = None
if image_mask is not None:
attn_mask = image_mask.view(batch, 1, 1, num_image_tokens)
# NaN guard: a sample with no valid image tokens (e.g. an undecodable
# clip — all frames padded) would softmax over all -inf -> NaN. Unmask
# such rows so softmax stays finite; their text outputs are discarded
# (the clip's labels are all -100).
safe = image_mask | ~image_mask.any(dim=1, keepdim=True) # (B, N)
attn_mask = safe.view(batch, 1, 1, num_image_tokens)

# Cross-attention: no causal mask on the image axis.
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False)
Expand Down
13 changes: 13 additions & 0 deletions kempnerforge/model/modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class ModalityContext:
error).
- ``modality_ids`` requires ``prefix_embeds`` or ``inputs_embeds``
to be set (routing without a residual extension is meaningless).
- ``key_padding_mask`` requires ``prefix_embeds`` or ``inputs_embeds``
to be set (it is a key-validity mask over the residual sequence; the
Cross-Attention arch masks image K/V via ``image_mask`` instead).

``output_slice`` composes with the ``tokens`` path AND with the
``inputs_embeds`` path; it is not constrained intra-context. The
Expand All @@ -59,6 +62,7 @@ class ModalityContext:
image_features: torch.Tensor | None = None
image_mask: torch.Tensor | None = None
modality_ids: torch.Tensor | None = None
key_padding_mask: torch.Tensor | None = None

def __post_init__(self) -> None:
residual_routes = sum(
Expand All @@ -80,3 +84,12 @@ def __post_init__(self) -> None:
"ModalityContext: modality_ids requires prefix_embeds OR "
"inputs_embeds to be set (routing without a residual extension is meaningless)"
)
if (
self.key_padding_mask is not None
and self.prefix_embeds is None
and self.inputs_embeds is None
):
raise ValueError(
"ModalityContext: key_padding_mask requires prefix_embeds OR inputs_embeds "
"to be set (it masks the residual sequence; Cross-Attention uses image_mask)"
)
Loading
Loading