diff --git a/CHANGELOG.md b/CHANGELOG.md index 28c7284..49ab032 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,7 +65,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **Video data path (pluggable, registry-driven).** `kempnerforge/data/video_io.py`: timestamp-based frame sampling (target fps, first & last frame kept — Molmo2 §3.1/§A) registered as the `"uniform"` sampling policy and selectable via `[video].sampling_policy`; PyAV decode (lazy-imported). `kempnerforge/data/video_dataset.py`: a `VideoDataset` base + the WebVid-style `WebVidVideoDataset` (CSV manifest + `id[:2]/id[:4]/id[:6]/id.mp4` mapping) registered as `"webvid"` via `@registry.register_video_dataset`, plus a `build_video_dataset` dispatch — so other dataset styles are additive registrations selected by `[video].dataset_type`. The WebVid corpus directory is parameterized by `[video].dataset_name` (no longer hardcoded to `webvid-10M`). `VideoCollator` → `(B, F, 3, H, W)` + a frame-validity mask; an undecodable clip is masked out (no loss). `kempnerforge/config/registry.py`: `register_video_dataset` / `register_sampling_policy` registries. `kempnerforge/config/video.py`: the `[video]` `VideoConfig` section (`data_root`, `dataset_type`, `dataset_name`, `sampling_policy`, `split`, `fps`, `max_frames`, `min_frames`, `frame_size`, `max_samples`), wired into `JobConfig` (+ `is_video`). `av` is an optional `video` dependency group (`uv sync --group video`); CI installs it for the lint + unit-test jobs. - **Frame-aware model + training wiring.** `kempnerforge/model/vlm.py`: `_project_image_features` → `_project_visual_features` folds the frame axis through the encoder + pooler to `(B, F·P′, dim)` (a single image is the `F == 1` case). `VLMWrapper` gains `frames_per_clip`, threaded through `build_parallel_model` / `_build_vlm` / `build_vlm_wrapper` so the static visual-token count equals `F·P′` (drives the residual budget and MoT's positional split; static == runtime). `scripts/train.py` builds the video dataset/collator when `[video]` is set. Adds `configs/train/vlm_video_webvid.toml` (SigLIP2 + avgpool + WebVid). - Tests: `tests/unit/test_video_io.py`, `test_video_dataset.py`, `test_video_config.py`; video-forward cases (all four archs) + image-path regression in `test_vlm.py`; pooling-adapter cases in `test_adapter.py`. Docs: `docs/how-to/train-on-video.md`. - - Deferred (follow-ups; the registries make these additive): more video dataset styles (HuggingFace video sets, flat folders, alternate manifests) and frame-sampling policies; per-frame timestamp tokens + grounding (``/`` outputs with point-F1 / track-J&F eval), frame-mask-aware attention, bidirectional visual attention, VLM sequence packing, long-context (blocked on context-parallel being wired), and warm-start from a converted image-VLM checkpoint. + - Deferred (follow-ups; the registries make these additive): more video dataset styles (HuggingFace video sets, flat folders, alternate manifests) and frame-sampling policies; per-frame timestamp tokens + grounding (``/`` outputs with point-F1 / track-J&F eval), bidirectional visual attention, VLM sequence packing, long-context (blocked on context-parallel being wired), and warm-start from a converted image-VLM checkpoint. +- **Padded frames masked from attention (all four archs).** Short/undecodable video clips pad to `max_frames` with blank frames; the `frame_mask` is now consumed so real tokens never attend to padded-frame visual tokens. One per-token validity mask, `ModalityContext.key_padding_mask` `(B, S)`, threads through the model: the shared `Attention` ANDs it with the causal (and doc) mask — covering Joint-Decoder and MoMa — `MoTAttention` builds an explicit causal-AND-valid mask, and Cross-Attention masks the padded image K/V via its existing `image_mask`. A NaN guard unmasks fully-masked query rows (an all-padded clip) so softmax stays finite. It is a **pure mask — no new state-dict keys**, so checkpoints stay compatible both ways; the image (`F=1`) and text paths are unchanged (no mask is built, so they keep the FlashAttention-2 path). Note: for the image-prefix arches (Joint-Decoder/MoT/MoMa) video self-attention always takes the explicit-mask SDPA path (FA2 disabled, a `(B,1,S,S)` mask materialized) even for fully-decoded clips — the result is identical to causal-only but not free; a deliberate `torch.compile`/DP-friendly trade-off (one graph, no host sync), with FA2-recovery / FlexAttention left as a follow-up. (Cross-Attention keeps FA2 on its text self-attention and only masks padded image K/V in the cross-attention blocks.) Foundation for variable-length / mixed image+video batches. + - `kempnerforge/model/modality.py`: `ModalityContext.key_padding_mask` field (+ invariant). `kempnerforge/model/vlm.py`: `_visual_token_mask` expands `frame_mask (B,F)` → `(B, F·P′)`; the four strategies place it (image-prefix arches → `key_padding_mask`, Cross-Attention → `image_mask`). `attention.py` / `mot.py` / `cross_attention.py` consume the mask (+ NaN guard); `moma.py`'s `MoMaFFN` also excludes padded positions from expert-choice routing (so padded tokens never consume expert capacity). `scripts/train.py` threads `batch["frame_mask"]` into the forward. + - Deferred: MoT configured with an *MoE* FFN still routes padded tokens through the shared `MoEMLP` — a follow-up ("generic token-validity in MoE") would mask that and padded text alike. MoT-dense (the default) and MoMa are fully masked. + - Tests: `tests/unit/test_vlm.py` (per-arch masking invariance, image no-op, undecodable-clip NaN guard, mask expansion); `test_moma.py` (FFN routing exclusion); `test_modality_context.py` (the new invariant). - `install-and-verify` plugin skill: runs `uv sync`, asserts Python ≥ 3.12, then runs the four CI gate checks (`ruff check`, `ruff format --check`, `pyright`, `pytest tests/unit/`). Canonical first command after cloning. - `.python-version` pinned to `>=3.12` so uv resolves the interpreter explicitly. Teammates on 3.13 use 3.13 (no download); 3.11-only users get 3.12 auto-fetched. - **Dynamic-checkpointing window** (`[checkpoint.dyn_ckpt_window]`). Opt-in dense save phase: inside `[start, stop]` a registered strategy decides which steps to save; outside the window the regular `interval` cadence applies. The default strategy, `"power2"`, saves at `start` and at every `start + 2^k` while `<= stop` — tight near the start of the window, doubling thereafter. Useful for analyzing early-training dynamics, where the loss moves fastest. The default `CheckpointConfig` is unchanged (no `dyn_ckpt_window`, interval-only saves). diff --git a/docs/how-to/train-on-video.md b/docs/how-to/train-on-video.md index d2e304a..5fc98f3 100644 --- a/docs/how-to/train-on-video.md +++ b/docs/how-to/train-on-video.md @@ -102,9 +102,19 @@ time, so it is set in the TOML, not via a `--vlm.arch=` CLI override.) - **Causal attention; no per-frame timestamps yet** — temporal order is frame order. Per-frame timestamp tokens + grounding (``/`` outputs with point-F1 / track-J&F eval) are a follow-up. -- **Padded frames are not yet masked from attention** — short clips pad to - `max_frames` with blank frames; a `frame_mask` is produced but not yet - consumed by the attention mask. +- **Padded frames are masked from attention** — short/undecodable clips pad to + `max_frames` with blank frames, and the `frame_mask` is consumed so real + tokens never attend to padded-frame visual tokens (MoMa also drops them from + expert-choice routing); a NaN guard keeps an all-padded clip finite. It is a + pure mask (no new checkpoint keys); image/text keep the FlashAttention-2 path. + For the image-prefix arches (Joint-Decoder/MoT/MoMa), video self-attention + always takes the explicit-mask SDPA path (FA2 disabled, a `(B,1,S,S)` mask + built) even for fully-decoded clips — a deliberate compile/DP-friendly + trade-off; recovering FA2 / FlexAttention is a follow-up. (Cross-Attention + keeps FA2 on its text self-attention; it masks padded image K/V in the + cross-attention blocks instead.) *Remaining:* MoT + configured with an MoE FFN still routes padded tokens through the shared MoE + (a "generic token-validity in MoE" follow-up). - **Fixed `F` per batch** keeps tensor shapes static (for `torch.compile` and DP-rank consistency); variable-length clips arrive with VLM sequence packing. - **Long-context** (many frames) is blocked on context-parallel being wired. diff --git a/kempnerforge/model/attention.py b/kempnerforge/model/attention.py index 0acab2c..b4b226c 100644 --- a/kempnerforge/model/attention.py +++ b/kempnerforge/model/attention.py @@ -121,6 +121,7 @@ def forward( *, kv_cache: KVCache | None = None, doc_ids: torch.Tensor | None = None, + key_padding_mask: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass. @@ -132,6 +133,11 @@ def forward( doc_ids: Optional per-token document IDs for packed sequences, shape (batch, seq_len). When provided, constructs a block-diagonal causal mask so tokens only attend within their document. + key_padding_mask: Optional per-key validity mask, shape + (batch, seq_len); ``True`` = attend, ``False`` = drop (e.g. the + visual tokens of padded video frames). Combined with the causal + (and doc) mask; fully-masked query rows are unmasked to keep + softmax finite. Returns: Output tensor of shape (batch, seq_len, dim). @@ -177,14 +183,46 @@ def forward( # restrict attention to only the first key position). if self.capture_attention_weights: # Manual attention for weight extraction (analysis only, not for training) - out, attn_weights = self._attention_with_weights(q, k, v, seq_len, doc_ids, kv_cache) + out, attn_weights = self._attention_with_weights( + q, k, v, seq_len, doc_ids, kv_cache, key_padding_mask + ) self.last_attention_weights = attn_weights.detach().cpu() - elif doc_ids is not None: + elif doc_ids is not None or key_padding_mask is not None: + # An explicit attn_mask is not a FlashAttention-2 shape, so SDPA falls + # back to the mem-efficient/math kernel here. The image-prefix video + # arches (Joint-Decoder/MoMa here, MoT in mot.py) always pass a + # key_padding_mask (all-True when unpadded), so their self-attention + # always takes this branch -- losing FA2 and materializing a (B, 1, S, S) + # mask even for fully-decoded clips. (Cross-Attention sets no + # key_padding_mask on this text self-attention -- it masks padded image + # K/V in the cross-attention blocks instead -- so it keeps FA2 here.) + # Deliberate: always-masking is torch.compile / DP-friendly (one graph, + # no host sync). Recovering FA2 for unpadded batches (or moving to + # FlexAttention) is a follow-up. + # + # Asserts no kv_cache: neither doc_ids (packed training) nor + # key_padding_mask (VLM video) co-occurs with decode today, and this + # branch's full-sequence causal mask would mis-handle a cached + # (seq_len=1) decode rather than attend to all cached positions. + assert kv_cache is None, ( + "doc_ids / key_padding_mask are not supported with kv_cache decode " + "(would build an incorrect causal mask)." + ) seq_len_kv = k.shape[2] - # Block-diagonal mask: same-document AND causal - doc_mask = doc_ids.unsqueeze(2) == doc_ids.unsqueeze(1) # (B, S, S) + # Explicit bool mask: causal, AND same-document (doc_ids), AND valid + # keys (key_padding_mask, e.g. dropping padded video frames' tokens). causal = torch.ones(seq_len, seq_len_kv, dtype=torch.bool, device=q.device).tril() - attn_mask = (doc_mask & causal).unsqueeze(1) # (B, 1, S, S) + attn_mask = causal.unsqueeze(0).unsqueeze(0) # (1, 1, S, S_kv) + if doc_ids is not None: + doc_mask = (doc_ids.unsqueeze(2) == doc_ids.unsqueeze(1)).unsqueeze(1) # (B,1,S,S) + attn_mask = attn_mask & doc_mask + if key_padding_mask is not None: + attn_mask = attn_mask & key_padding_mask.view(batch, 1, 1, seq_len_kv) + # NaN guard: a query row with no reachable valid key (e.g. the leading + # positions of an all-padded / undecodable clip) would softmax over all + # -inf -> NaN. Unmask such rows; their outputs are discarded (trimmed by + # output_slice, or the clip's labels are all -100). + attn_mask = attn_mask | ~attn_mask.any(dim=-1, keepdim=True) with self._sdpa_context(): out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) else: @@ -205,6 +243,7 @@ def _attention_with_weights( seq_len: int, doc_ids: torch.Tensor | None, kv_cache: KVCache | None, + key_padding_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute attention output and weights manually (for analysis). @@ -221,11 +260,15 @@ def _attention_with_weights( attn = torch.matmul(q, k.transpose(-2, -1)) * scale seq_len_kv = k.shape[2] - if doc_ids is not None: - doc_mask = doc_ids.unsqueeze(2) == doc_ids.unsqueeze(1) + if doc_ids is not None or key_padding_mask is not None: causal = torch.ones(seq_len, seq_len_kv, dtype=torch.bool, device=q.device).tril() - mask = ~(doc_mask & causal).unsqueeze(1) - attn = attn.masked_fill(mask, float("-inf")) + valid = causal.unsqueeze(0).unsqueeze(0) # (1, 1, S, S_kv) + if doc_ids is not None: + valid = valid & (doc_ids.unsqueeze(2) == doc_ids.unsqueeze(1)).unsqueeze(1) + if key_padding_mask is not None: + valid = valid & key_padding_mask.view(q.shape[0], 1, 1, seq_len_kv) + valid = valid | ~valid.any(dim=-1, keepdim=True) # NaN guard (see forward) + attn = attn.masked_fill(~valid, float("-inf")) elif kv_cache is None or seq_len > 1: causal = torch.ones(seq_len, seq_len_kv, dtype=torch.bool, device=q.device).triu( diagonal=1 diff --git a/kempnerforge/model/cross_attention.py b/kempnerforge/model/cross_attention.py index 1cce4f4..4083b75 100644 --- a/kempnerforge/model/cross_attention.py +++ b/kempnerforge/model/cross_attention.py @@ -101,7 +101,12 @@ def forward( # the right thing across heads and text Q positions. attn_mask = None if image_mask is not None: - attn_mask = image_mask.view(batch, 1, 1, num_image_tokens) + # NaN guard: a sample with no valid image tokens (e.g. an undecodable + # clip — all frames padded) would softmax over all -inf -> NaN. Unmask + # such rows so softmax stays finite; their text outputs are discarded + # (the clip's labels are all -100). + safe = image_mask | ~image_mask.any(dim=1, keepdim=True) # (B, N) + attn_mask = safe.view(batch, 1, 1, num_image_tokens) # Cross-attention: no causal mask on the image axis. out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False) diff --git a/kempnerforge/model/modality.py b/kempnerforge/model/modality.py index e281723..168bd59 100644 --- a/kempnerforge/model/modality.py +++ b/kempnerforge/model/modality.py @@ -46,6 +46,9 @@ class ModalityContext: error). - ``modality_ids`` requires ``prefix_embeds`` or ``inputs_embeds`` to be set (routing without a residual extension is meaningless). + - ``key_padding_mask`` requires ``prefix_embeds`` or ``inputs_embeds`` + to be set (it is a key-validity mask over the residual sequence; the + Cross-Attention arch masks image K/V via ``image_mask`` instead). ``output_slice`` composes with the ``tokens`` path AND with the ``inputs_embeds`` path; it is not constrained intra-context. The @@ -59,6 +62,7 @@ class ModalityContext: image_features: torch.Tensor | None = None image_mask: torch.Tensor | None = None modality_ids: torch.Tensor | None = None + key_padding_mask: torch.Tensor | None = None def __post_init__(self) -> None: residual_routes = sum( @@ -80,3 +84,12 @@ def __post_init__(self) -> None: "ModalityContext: modality_ids requires prefix_embeds OR " "inputs_embeds to be set (routing without a residual extension is meaningless)" ) + if ( + self.key_padding_mask is not None + and self.prefix_embeds is None + and self.inputs_embeds is None + ): + raise ValueError( + "ModalityContext: key_padding_mask requires prefix_embeds OR inputs_embeds " + "to be set (it masks the residual sequence; Cross-Attention uses image_mask)" + ) diff --git a/kempnerforge/model/moma.py b/kempnerforge/model/moma.py index b65fee9..36e5a40 100644 --- a/kempnerforge/model/moma.py +++ b/kempnerforge/model/moma.py @@ -299,7 +299,12 @@ def __init__( } ) - def forward(self, x: torch.Tensor, modality_ids: torch.Tensor) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + modality_ids: torch.Tensor, + key_padding_mask: torch.Tensor | None = None, + ) -> torch.Tensor: """Dispatch tokens by modality and run per-modality EC-MoE. Args: @@ -307,6 +312,10 @@ def forward(self, x: torch.Tensor, modality_ids: torch.Tensor) -> torch.Tensor: modality_ids: ``(B, S)`` long tensor. ``modality_ids == i`` routes that token to ``self.modalities[i]``'s expert group. + key_padding_mask: Optional ``(B, S)`` bool mask; ``False`` positions + (e.g. padded video frames) are excluded from the expert-choice + routing so they neither consume expert capacity nor perturb which + real tokens the experts select. ``None`` = all positions routed. Returns: ``(B, S, D)`` tensor with each modality's positions filled @@ -329,31 +338,36 @@ def forward(self, x: torch.Tensor, modality_ids: torch.Tensor) -> torch.Tensor: b, s, d = x.shape x_flat = x.reshape(b * s, d) mod_flat = modality_ids.reshape(b * s) + valid_flat = key_padding_mask.reshape(b * s) if key_padding_mask is not None else None out = torch.zeros_like(x_flat) - # Tracks how many positions actually got routed to *some* modality - # group. With well-formed modality_ids (values in [0, len(modalities))) - # this equals b*s at the end. We accumulate Python ints from - # ``idx.numel()`` (tensor metadata, no host sync) and compare after - # the loop — much cheaper than an upfront ``.all()`` reduction which - # would force a device->host sync every step. The error fires - # post-FFN, but the in-range work is the same either way and the - # failure mode without this check is silent zero-output on the - # affected positions (residual still carries them through, so the - # bug would only surface as quietly wrong training). + # ``total_routed`` counts positions matching *some* modality group (the + # modality_ids range check below); with well-formed modality_ids it + # equals b*s. We accumulate Python ints from ``idx.numel()`` (tensor + # metadata, no host sync) rather than a ``.all()`` reduction that would + # force a device->host sync every step. The per-modality routing set + # additionally drops padded positions so they never compete for capacity. total_routed = 0 for i, m in enumerate(self.modalities): - # nonzero() avoids the boolean-mask copy and gives us a 1-D index + # nonzero() avoids the boolean-mask copy and gives a 1-D index # tensor we can feed to index_select + scatter. - idx = (mod_flat == i).nonzero(as_tuple=False).squeeze(-1) # (N_m,) + mod_idx = (mod_flat == i).nonzero(as_tuple=False).squeeze(-1) # (N_m,) + total_routed += mod_idx.numel() + idx = mod_idx + if valid_flat is not None: + # Drop padded (e.g. blank-frame) positions from the expert-choice + # competition: padded tokens must not consume expert capacity or + # change which real tokens the experts pick. They get zero FFN + # output; the outer residual skip carries them through unchanged. + keep = valid_flat.index_select(0, mod_idx).nonzero(as_tuple=False).squeeze(-1) + idx = mod_idx.index_select(0, keep) if idx.numel() == 0: continue - total_routed += idx.numel() - x_m = x_flat.index_select(0, idx) # (N_m, D) - y_m = self.experts[m](x_m) # (N_m, D) - # The modality groups partition the position space, so indices - # are guaranteed unique across iterations. index_copy on - # disjoint indices is safe and autograd-friendly. + x_m = x_flat.index_select(0, idx) # (N_routed, D) + y_m = self.experts[m](x_m) # (N_routed, D) + # The modality groups partition the position space, so indices are + # unique across iterations; index_copy on disjoint indices is safe + # and autograd-friendly. out = out.index_copy(0, idx, y_m) if total_routed != b * s: @@ -435,10 +449,20 @@ def forward( modality_ids: torch.Tensor, *, doc_ids: torch.Tensor | None = None, + key_padding_mask: torch.Tensor | None = None, ) -> torch.Tensor: # Pre-norm attention with residual (shared QKVO, single SDPA). # kv_cache is intentionally omitted: EC routing is non-causal in v1. - x = x + self.attention(self.attention_norm(x), rope_cos, rope_sin, doc_ids=doc_ids) - # Pre-norm MoMa FFN with residual (per-modality EC-MoE groups). - x = x + self.mlp(self.mlp_norm(x), modality_ids=modality_ids) + x = x + self.attention( + self.attention_norm(x), + rope_cos, + rope_sin, + doc_ids=doc_ids, + key_padding_mask=key_padding_mask, + ) + # Pre-norm MoMa FFN with residual (per-modality EC-MoE groups). The + # key_padding_mask drops padded positions from expert-choice routing. + x = x + self.mlp( + self.mlp_norm(x), modality_ids=modality_ids, key_padding_mask=key_padding_mask + ) return x diff --git a/kempnerforge/model/mot.py b/kempnerforge/model/mot.py index 069b86f..d16ef6d 100644 --- a/kempnerforge/model/mot.py +++ b/kempnerforge/model/mot.py @@ -111,6 +111,7 @@ def forward( streams: dict[str, torch.Tensor], rope: dict[str, tuple[torch.Tensor, torch.Tensor]], is_causal: bool = True, + key_padding_mask: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: """Run per-modality projections, global SDPA, per-modality output. @@ -123,7 +124,12 @@ def forward( ``(seq_m, head_dim // 2)`` — counts from position 0 within that modality's axis. is_causal: passed through to - ``F.scaled_dot_product_attention``. + ``F.scaled_dot_product_attention`` (used only when + ``key_padding_mask`` is None). + key_padding_mask: Optional ``(B, S)`` bool over the concatenated + ``[image | text]`` sequence; ``True`` = attend. When given, + builds an explicit causal-AND-valid mask (dropping padded video + frames' visual tokens) instead of the ``is_causal`` fast path. Returns: per-modality output of shape ``(batch, seq_m, dim)``. @@ -175,7 +181,19 @@ def forward( k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) - out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal) + if key_padding_mask is not None: + # Explicit mask: causal AND valid keys (drop padded video frames' + # visual tokens). Replaces the is_causal fast path. + total_seq = q.shape[2] + causal = torch.ones(total_seq, total_seq, dtype=torch.bool, device=q.device).tril() + kv = key_padding_mask.view(batch, 1, 1, total_seq) + attn_mask = causal.unsqueeze(0).unsqueeze(0) & kv + # NaN guard: unmask fully-masked query rows (e.g. an all-padded clip) + # so softmax stays finite; those outputs are discarded downstream. + attn_mask = attn_mask | ~attn_mask.any(dim=-1, keepdim=True) + out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + else: + out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal) # (B, n_heads, total_seq, head_dim) -> (B, total_seq, n_heads, head_dim) out = out.transpose(1, 2).contiguous() @@ -293,6 +311,7 @@ def forward( self, streams: dict[str, torch.Tensor], rope: dict[str, tuple[torch.Tensor, torch.Tensor]], + key_padding_mask: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: if set(streams.keys()) != set(self.modalities): raise ValueError( @@ -300,7 +319,7 @@ def forward( f"do not match construction-time modalities {sorted(self.modalities)}" ) normed_attn = {m: self.attn_norm[m](streams[m]) for m in self.modalities} - attn_out = self.attn(normed_attn, rope, is_causal=True) + attn_out = self.attn(normed_attn, rope, is_causal=True, key_padding_mask=key_padding_mask) post_attn = {m: streams[m] + attn_out[m] for m in self.modalities} normed_mlp = {m: self.mlp_norm[m](post_attn[m]) for m in self.modalities} mlp_out = {m: self.mlp[m](normed_mlp[m]) for m in self.modalities} diff --git a/kempnerforge/model/transformer.py b/kempnerforge/model/transformer.py index 89ba45c..20d26b8 100644 --- a/kempnerforge/model/transformer.py +++ b/kempnerforge/model/transformer.py @@ -87,10 +87,16 @@ def forward( *, kv_cache: KVCache | None = None, doc_ids: torch.Tensor | None = None, + key_padding_mask: torch.Tensor | None = None, ) -> torch.Tensor: # Pre-norm attention with residual x = x + self.attention( - self.attention_norm(x), rope_cos, rope_sin, kv_cache=kv_cache, doc_ids=doc_ids + self.attention_norm(x), + rope_cos, + rope_sin, + kv_cache=kv_cache, + doc_ids=doc_ids, + key_padding_mask=key_padding_mask, ) # Pre-norm MLP with residual x = x + self.mlp(self.mlp_norm(x)) @@ -355,6 +361,7 @@ def forward( image_features = modality.image_features if modality is not None else None image_mask = modality.image_mask if modality is not None else None modality_ids = modality.modality_ids if modality is not None else None + key_padding_mask = modality.key_padding_mask if modality is not None else None if (tokens is None) == (inputs_embeds is None): raise ValueError( @@ -436,7 +443,9 @@ def forward( f"match residual shape {tuple(h.shape[:2])}" ) for layer in self.layers.values(): - h = layer(h, cos, sin, modality_ids, doc_ids=doc_ids) + h = layer( + h, cos, sin, modality_ids, doc_ids=doc_ids, key_padding_mask=key_padding_mask + ) h = self.norm(h) # MoT path: position-based image-then-text split, per-modality # streams through the MoTBlock stack, single global SDPA per @@ -471,7 +480,7 @@ def forward( "text": (cos[:t_text], sin[:t_text]), } for layer in self.layers.values(): - streams = layer(streams, rope) + streams = layer(streams, rope, key_padding_mask=key_padding_mask) streams = {m: self.mot_norms[m](streams[m]) for m in self._mot_modalities} # Re-concat in image-then-text order to match the residual # layout the rest of forward expects (output_slice + head). @@ -485,7 +494,9 @@ def forward( ca_iter = iter(self.cross_attention_layers.values()) if self._ca_cadence else None for i, layer in enumerate(self.layers.values()): cache = kv_caches[i] if kv_caches is not None else None - h = layer(h, cos, sin, kv_cache=cache, doc_ids=doc_ids) + h = layer( + h, cos, sin, kv_cache=cache, doc_ids=doc_ids, key_padding_mask=key_padding_mask + ) if ca_iter is not None and (i + 1) % self._ca_cadence == 0: ca = next(ca_iter, None) if ca is not None: diff --git a/kempnerforge/model/vlm.py b/kempnerforge/model/vlm.py index 65daf9c..c6ca8f9 100644 --- a/kempnerforge/model/vlm.py +++ b/kempnerforge/model/vlm.py @@ -68,6 +68,7 @@ def prepare( wrapper: VLMWrapper, pixel_values: torch.Tensor, input_ids: torch.Tensor, + frame_mask: torch.Tensor | None = None, ) -> ModalityContext: ... def num_image_tokens(self, wrapper: VLMWrapper) -> int: ... @@ -117,6 +118,53 @@ def _project_visual_features(wrapper: VLMWrapper, pixel_values: torch.Tensor) -> return embeds +def _visual_token_mask( + frame_mask: torch.Tensor | None, num_visual_tokens: int +) -> torch.Tensor | None: + """Expand a per-frame validity mask to per-visual-token. + + ``frame_mask`` is ``(B, F)`` bool (``True`` = real frame). Each frame maps to + ``num_visual_tokens // F`` visual tokens (frame-contiguous, see + ``_project_visual_features``), so each frame's bit is repeated over its + tokens -> ``(B, num_visual_tokens)``. Returns ``None`` when no mask is given + (the image path, or a caller that passes nothing), read downstream as "all + tokens valid". + """ + if frame_mask is None: + return None + num_frames = frame_mask.shape[1] + if num_visual_tokens % num_frames != 0: + # Visual tokens are frame-contiguous (F * tokens_per_frame), so the count + # must be divisible by the frame count. A future adapter that adds a + # non-per-frame token (e.g. a global/CLS token) would break this and + # silently misalign the mask -- fail loudly here instead. + raise ValueError( + f"_visual_token_mask: num_visual_tokens ({num_visual_tokens}) is not a " + f"multiple of num_frames ({num_frames}); the per-frame expansion assumes " + "frame-contiguous visual tokens." + ) + tokens_per_frame = num_visual_tokens // num_frames + return frame_mask.repeat_interleave(tokens_per_frame, dim=1) + + +def _prefix_key_padding_mask( + frame_mask: torch.Tensor | None, num_visual_tokens: int, input_ids: torch.Tensor +) -> torch.Tensor | None: + """Residual key-validity mask ``(B, S)`` for the image-prefix arches. + + ``S = num_visual_tokens + T_text``. Visual positions follow the expanded + per-frame mask; text positions are always valid (trailing text padding is + causal-safe and is not masked here). Returns ``None`` when no frame_mask is + given. + """ + vmask = _visual_token_mask(frame_mask, num_visual_tokens) + if vmask is None: + return None + b, t_text = input_ids.shape + text_valid = torch.ones(b, t_text, dtype=torch.bool, device=vmask.device) + return torch.cat([vmask, text_valid], dim=1) + + @registry.register_modality_strategy("joint_decoder") class JointDecoderStrategy: """Joint-Decoder: image embeds prepended to the text sequence. @@ -132,11 +180,16 @@ def prepare( self, wrapper: VLMWrapper, pixel_values: torch.Tensor, - input_ids: torch.Tensor, # noqa: ARG002 + input_ids: torch.Tensor, + frame_mask: torch.Tensor | None = None, ) -> ModalityContext: img_embeds = _project_visual_features(wrapper, pixel_values) n = img_embeds.shape[1] # pooling-aware: the adapter's actual visual-token count - return ModalityContext(prefix_embeds=img_embeds, output_slice=slice(n, None)) + return ModalityContext( + prefix_embeds=img_embeds, + output_slice=slice(n, None), + key_padding_mask=_prefix_key_padding_mask(frame_mask, n, input_ids), + ) def num_image_tokens(self, wrapper: VLMWrapper) -> int: return wrapper.frames_per_clip * wrapper.adapter.output_num_tokens( @@ -152,8 +205,9 @@ class CrossAttentionStrategy: Forward path: ``feats = vision_encoder(pixel_values)``; ``img_embeds = adapter(feats)``; ``ModalityContext(image_features, - image_mask=None)``. ``image_mask=None`` means "all image tokens - valid"; multi-image variants will fill it in later. + image_mask)``. ``image_mask`` carries per-visual-token validity (padded + video frames are masked out of the image K/V); ``None`` means all image + tokens are valid (e.g. a single image or a full clip). """ def prepare( @@ -161,9 +215,13 @@ def prepare( wrapper: VLMWrapper, pixel_values: torch.Tensor, input_ids: torch.Tensor, # noqa: ARG002 + frame_mask: torch.Tensor | None = None, ) -> ModalityContext: img_embeds = _project_visual_features(wrapper, pixel_values) - return ModalityContext(image_features=img_embeds, image_mask=None) + return ModalityContext( + image_features=img_embeds, + image_mask=_visual_token_mask(frame_mask, img_embeds.shape[1]), + ) def num_image_tokens(self, wrapper: VLMWrapper) -> int: # noqa: ARG002 # Cross-Attention does not extend the residual stream. @@ -195,6 +253,7 @@ def prepare( wrapper: VLMWrapper, pixel_values: torch.Tensor, input_ids: torch.Tensor, + frame_mask: torch.Tensor | None = None, ) -> ModalityContext: img_embeds = _project_visual_features(wrapper, pixel_values) n = img_embeds.shape[1] # pooling-aware: the adapter's actual visual-token count @@ -205,6 +264,7 @@ def prepare( prefix_embeds=img_embeds, output_slice=slice(n, None), modality_ids=modality_ids, + key_padding_mask=_prefix_key_padding_mask(frame_mask, n, input_ids), ) def num_image_tokens(self, wrapper: VLMWrapper) -> int: @@ -239,6 +299,7 @@ def prepare( wrapper: VLMWrapper, pixel_values: torch.Tensor, input_ids: torch.Tensor, + frame_mask: torch.Tensor | None = None, ) -> ModalityContext: img_embeds = _project_visual_features(wrapper, pixel_values) n = img_embeds.shape[1] # pooling-aware: the adapter's actual visual-token count @@ -249,6 +310,7 @@ def prepare( prefix_embeds=img_embeds, output_slice=slice(n, None), modality_ids=modality_ids, + key_padding_mask=_prefix_key_padding_mask(frame_mask, n, input_ids), ) def num_image_tokens(self, wrapper: VLMWrapper) -> int: @@ -311,13 +373,14 @@ def forward( pixel_values: torch.Tensor, input_ids: torch.Tensor, labels: torch.Tensor | None = None, + frame_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: # Route the text embedding through Transformer.forward so FSDP2's # per-module hook intercepts the token_embedding call and # materializes the DTensor weight before F.embedding runs. Doing # the embedding externally (transformer.token_embedding(input_ids)) # bypasses FSDP and fails with "mixed torch.Tensor and DTensor". - modality = self.strategy.prepare(self, pixel_values, input_ids) + modality = self.strategy.prepare(self, pixel_values, input_ids, frame_mask=frame_mask) logits = self.transformer(tokens=input_ids, modality=modality) return logits, labels diff --git a/scripts/train.py b/scripts/train.py index 54e6f3c..b081310 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -756,11 +756,15 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: pixel_values = batch["pixel_values"].to(device) input_ids = batch["input_ids"].to(device) labels = batch["labels"].to(device) + # Video batches carry a per-frame validity mask; image batches do not. + frame_mask = batch["frame_mask"].to(device) if "frame_mask" in batch else None with maybe_no_sync(model, micro_step, tc.grad_accum_steps): if mc.is_moe: inner_transformer(model).set_moe_step(step, tc.max_steps) # type: ignore[attr-defined] - logits, labels_out = model(pixel_values, input_ids, labels) + logits, labels_out = model( + pixel_values, input_ids, labels, frame_mask=frame_mask + ) loss = loss_fn(logits, labels_out) total_text_tokens += int((labels_out != -100).sum().item()) diff --git a/tests/unit/test_modality_context.py b/tests/unit/test_modality_context.py index 46f4a62..e3a13ab 100644 --- a/tests/unit/test_modality_context.py +++ b/tests/unit/test_modality_context.py @@ -73,6 +73,17 @@ def test_image_mask_without_image_features_raises(self): with pytest.raises(ValueError, match="image_mask requires image_features"): ModalityContext(image_mask=torch.ones(1, 4, dtype=torch.bool)) + def test_key_padding_mask_with_prefix_embeds_is_valid(self): + ctx = ModalityContext( + prefix_embeds=torch.zeros(1, 4, 8), + key_padding_mask=torch.ones(1, 10, dtype=torch.bool), + ) + assert ctx.key_padding_mask is not None + + def test_key_padding_mask_without_residual_raises(self): + with pytest.raises(ValueError, match="key_padding_mask requires"): + ModalityContext(key_padding_mask=torch.ones(1, 10, dtype=torch.bool)) + def test_three_residual_routes_raises(self): """Setting all three residual-route fields surfaces the same error.""" with pytest.raises(ValueError, match="mutually exclusive"): diff --git a/tests/unit/test_moma.py b/tests/unit/test_moma.py index 4bf05cc..46b300e 100644 --- a/tests/unit/test_moma.py +++ b/tests/unit/test_moma.py @@ -81,6 +81,35 @@ def test_moma_finegrained_experts(): assert expert.gate_proj.weight.shape[0] == cfg.computed_ffn_hidden_dim // 2 +def test_moma_ffn_excludes_padded_from_routing(): + """key_padding_mask drops padded positions from expert-choice routing: they + get zero FFN output, and real tokens' outputs don't depend on padded tokens' + content (no capacity competition).""" + torch.manual_seed(0) + cfg = ModelConfig(dim=32, n_layers=2, n_heads=4, vocab_size=128, max_seq_len=64) + ffn = MoMaFFN( + cfg, + modalities=("image", "text"), + experts_per_modality={"image": 2, "text": 2}, + capacity_factor_per_modality={"image": 1.0, "text": 1.0}, + gumbel_noise=False, # deterministic routing for the comparison + ) + # positions 0..3 image, 4..7 text; mark image positions 2,3 as padded. + modality_ids = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1]]) + kpm = torch.tensor([[True, True, False, False, True, True, True, True]]) + x = torch.randn(1, 8, cfg.dim) + out = ffn(x, modality_ids, key_padding_mask=kpm) + # Padded image positions are excluded from routing -> zero FFN output. + assert torch.count_nonzero(out[0, 2]) == 0 + assert torch.count_nonzero(out[0, 3]) == 0 + # Real tokens' outputs are invariant to the padded tokens' content. + x2 = x.clone() + x2[0, 2:4] = torch.randn(2, cfg.dim) + out2 = ffn(x2, modality_ids, key_padding_mask=kpm) + real = [0, 1, 4, 5, 6, 7] + assert torch.equal(out[0, real], out2[0, real]) + + # --------------------------------------------------------------------------- # MoMaConfig # --------------------------------------------------------------------------- diff --git a/tests/unit/test_vlm.py b/tests/unit/test_vlm.py index 19a1760..e9fd04c 100644 --- a/tests/unit/test_vlm.py +++ b/tests/unit/test_vlm.py @@ -648,3 +648,81 @@ def test_frame_count_mismatch_raises(self): # 4D single-image batch into a video (frames_per_clip>1) wrapper. with pytest.raises(ValueError, match="frames-per-clip mismatch"): wrapper(torch.randn(2, 3, 16, 16, device=DEVICE), input_ids) + + +class TestFramePaddingMask: + """frame_mask hides padded video frames from attention (and, for MoMa, from + expert-choice routing), so real-token outputs are invariant to padded-frame + content. The image (F=1) path is a no-op, and an all-padded (undecodable) + clip stays finite via the NaN guard.""" + + @staticmethod + def _arch_wrapper(arch): + ffn = 128 if arch in ("mot", "moma") else None + cfgs = { + "joint_decoder": JointDecoderConfig(max_text_len=8), + "cross_attention": CrossAttentionConfig( + max_text_len=8, cross_attention_every_n_layers=2 + ), + "mot": MoTConfig(max_text_len=8), + "moma": MoMaConfig(max_text_len=8), + } + w = _video_wrapper(cfgs[arch], frames=4, ffn_hidden_dim=ffn).to(DEVICE).eval() + # Move off zero-init warm-start gating (e.g. MoT/CA zero-init o_proj) so + # the image path actually contributes — otherwise the no-mask control is + # vacuous (image gated off at init). + with torch.no_grad(): + for p in w.parameters(): + p.add_(0.02 * torch.randn_like(p)) + return w + + def test_visual_token_mask_expands_per_frame(self): + from kempnerforge.model.vlm import _visual_token_mask + + fm = torch.tensor([[True, False, True]]) # 3 frames + out = _visual_token_mask(fm, num_visual_tokens=6) # 2 tokens/frame + assert out.tolist() == [[True, True, False, False, True, True]] + assert _visual_token_mask(None, 6) is None + # Non-divisible count (e.g. a future non-per-frame token) must fail loudly. + with pytest.raises(ValueError, match="multiple of num_frames"): + _visual_token_mask(fm, num_visual_tokens=7) + + @pytest.mark.parametrize("arch", ["joint_decoder", "cross_attention", "mot", "moma"]) + def test_masked_frames_do_not_affect_real_tokens(self, arch): + torch.manual_seed(0) + w = self._arch_wrapper(arch) + ids = torch.randint(0, 256, (1, 6), device=DEVICE) + pix = torch.randn(1, 4, 3, 16, 16, device=DEVICE) + fm = torch.tensor([[True, True, False, False]], device=DEVICE) # frames 2,3 padded + pix2 = pix.clone() + pix2[:, 2:] = torch.randn(1, 2, 3, 16, 16, device=DEVICE) # corrupt the padded frames + with torch.no_grad(): + masked_a, _ = w(pix, ids, frame_mask=fm) + masked_b, _ = w(pix2, ids, frame_mask=fm) + nomask_a, _ = w(pix, ids) + nomask_b, _ = w(pix2, ids) + assert torch.equal(masked_a, masked_b), f"{arch}: masked output depends on padded frames" + assert not torch.equal(nomask_a, nomask_b), f"{arch}: control — pads should leak unmasked" + + def test_image_f1_mask_is_noop(self): + torch.manual_seed(0) + w = _video_wrapper(JointDecoderConfig(max_text_len=8), frames=1).to(DEVICE).eval() + img = torch.randn(1, 3, 16, 16, device=DEVICE) + ids = torch.randint(0, 256, (1, 6), device=DEVICE) + with torch.no_grad(): + no_mask, _ = w(img, ids) + all_true, _ = w(img, ids, frame_mask=torch.tensor([[True]], device=DEVICE)) + assert torch.equal(no_mask, all_true) + + @pytest.mark.parametrize("arch", ["joint_decoder", "cross_attention", "mot", "moma"]) + def test_undecodable_clip_stays_finite(self, arch): + # An undecodable clip has frame_mask all-False; the NaN guard must keep + # softmax finite (no NaN poisoning the batch loss). + torch.manual_seed(0) + w = self._arch_wrapper(arch) + ids = torch.randint(0, 256, (2, 6), device=DEVICE) + pix = torch.randn(2, 4, 3, 16, 16, device=DEVICE) + fm = torch.tensor([[False, False, False, False], [True, True, True, True]], device=DEVICE) + with torch.no_grad(): + logits, _ = w(pix, ids, frame_mask=fm) + assert torch.isfinite(logits).all(), f"{arch}: NaN/inf with an all-padded clip"