diff --git a/CHANGELOG.md b/CHANGELOG.md index 49ab032..9d37adb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- **Ragged grids in `attentional_pool`.** The `attentional_pool` connector now pools ragged patch grids (grid not divisible by the window) instead of rejecting them: each partial edge window pools only its real patches via a masked attention (matching `avgpool` and Molmo2 §A). This enables faithful **Molmo2 3×3 video pooling** on a 14×14 SigLIP grid (`14 % 3 != 0`), which previously had to fall back to `avgpool` or a divisible window. `output_num_tokens` is `ceil(grid/window)²`; divisible grids are bit-exact with before (no mask is built). + - `kempnerforge/model/adapter.py`: `AttentionalPoolAdapter.forward` pads the bottom/right edges and masks padded patches out of each edge window's K/V (with a masked-mean query); `DIVISIBLE_ONLY_POOL_TYPES` is now empty. + - Tests: `tests/unit/test_adapter.py` — ragged token count, masked edge-window correctness (a 1-real-patch window equals attention over that patch), config accepts ragged. - **MoE router z-loss** (`moe_router_z_loss_weight`, ST-MoE style). An optional penalty on the router's pre-softmax logits — per MoE layer `mean_token(logsumexp(router_logits))²` — summed across layers and added to the training loss as `moe_router_z_loss_weight × z_loss`. It keeps router logits from growing without bound, targeting *logit-growth stability* (not load balance — that's the aux loss). Default `0.0` is off: the term is never added, so training, outputs, and gradients are unchanged. `z_loss` is a plain attribute like `aux_loss` (not a buffer/parameter), so it never enters `state_dict` — checkpoint-safe. - `kempnerforge/config/model.py`: `moe_router_z_loss_weight: float = 0.0` (with a non-negativity check). - `kempnerforge/model/router.py`: both `SoftmaxTopKRouter` and `SigmoidTopKRouter` set `self.z_loss = (logsumexp(logits, dim=-1) ** 2).mean()`. diff --git a/docs/how-to/train-on-video.md b/docs/how-to/train-on-video.md index 5fc98f3..e2b1a6b 100644 --- a/docs/how-to/train-on-video.md +++ b/docs/how-to/train-on-video.md @@ -16,7 +16,9 @@ A clip of `F` frames becomes `F × P′` visual tokens: 3. **Pool + project** each frame with the connector — an `avgpool` or `attentional_pool` adapter reduces a `grid×grid` patch map to `P′ = ceil(grid/window)²` tokens per frame (e.g. SigLIP2 @224/patch16 → - 14×14 → 49 tokens at `pool_window=2`). + 14×14 → 49 tokens at `pool_window=2`). Ragged windows work too — both + connectors pool partial edge windows over their real patches — so Molmo2-style + 3×3 on the 14×14 grid gives 5×5 = 25 tokens. 4. **Fuse** the resulting `(B, F·P′, dim)` visual tokens into the backbone the same way images are fused — so **all four archs work unchanged**: - `joint_decoder` / `mot` / `moma`: the `F·P′` tokens prepend the text in the diff --git a/kempnerforge/model/adapter.py b/kempnerforge/model/adapter.py index acea41a..35cc2f7 100644 --- a/kempnerforge/model/adapter.py +++ b/kempnerforge/model/adapter.py @@ -45,10 +45,11 @@ # the registered pooling builders below. POOLING_ADAPTER_TYPES: tuple[str, ...] = ("avgpool", "attentional_pool") -# Pooling adapters whose ``forward`` requires the patch grid be divisible by the -# window (no ragged edge windows). Their token count must enforce the same so a -# ragged config is rejected at config/build time, not at the first training step. -DIVISIBLE_ONLY_POOL_TYPES: tuple[str, ...] = ("attentional_pool",) +# Pooling adapters whose ``forward`` cannot pool ragged edge windows (so their +# token count must reject a non-divisible grid at config/build time). Both +# ``avgpool`` and ``attentional_pool`` now mask partial edge windows, so this is +# empty -- kept as a seam for a future connector that genuinely needs divisibility. +DIVISIBLE_ONLY_POOL_TYPES: tuple[str, ...] = () def pooled_token_count( @@ -63,10 +64,11 @@ def pooled_token_count( cover (Molmo2 §A: "the bottom and far-right image patches are pooled with a reduced number of patches"). - Connectors that cannot pool ragged edges (``require_divisible=True``, e.g. - ``attentional_pool``) raise when ``grid`` is not divisible by ``window``, so a - ragged config is rejected at config/build time rather than deterministically - failing in ``forward`` at the first step. + Connectors that genuinely cannot pool ragged edges may pass + ``require_divisible=True`` to raise when ``grid`` is not divisible by + ``window``, rejecting a ragged config at config/build time rather than + deterministically failing in ``forward`` at the first step. (Today both + pooling connectors handle ragged edges, so none set it.) This is the single source of truth for the post-pool count: it must equal the pooling adapters' actual ``forward`` output length, because the build @@ -81,8 +83,8 @@ def pooled_token_count( raise ValueError( f"this pooling connector requires the patch grid ({grid}x{grid}) be " f"divisible by the pool window ({window}); got a ragged grid " - f"(num_tokens={num_input_tokens}). Use avgpool for ragged grids, or pick " - "a divisible window." + f"(num_tokens={num_input_tokens}). Use a ragged-capable connector " + "(avgpool or attentional_pool), or pick a divisible window." ) per_side = math.ceil(grid / window) return per_side * per_side @@ -100,6 +102,34 @@ def _grid_side(num_tokens: int) -> int: return grid +def _pad_grid_to_windows( + x: torch.Tensor, window: int +) -> tuple[torch.Tensor, int, torch.Tensor | None]: + """Reshape patch tokens to a square grid, padded to tile into windows. + + ``x`` is ``(B, grid**2, C)``. Returns ``(x, per, valid)``: ``x`` reshaped to + ``(B, padded, padded, C)`` with ``padded = ceil(grid/window) * window`` + (bottom/right zero-padded so it tiles into ``per × per`` windows of + ``window × window``), and ``valid`` a ``(B, padded, padded, 1)`` bool mask of + the real (non-padded) patches -- or ``None`` when the grid is already + divisible (no padding, so callers skip masking). Shared by the pooling + adapters so their ragged edge handling stays in lock-step. + """ + b, n, c = x.shape + grid = _grid_side(n) + per = math.ceil(grid / window) + padded = per * window + x = x.view(b, grid, grid, c) + if padded == grid: + return x, per, None + pad = padded - grid + valid = torch.ones(b, grid, grid, 1, dtype=torch.bool, device=x.device) + # F.pad pads the last dim backward: (C:0,0)(W:0,pad)(H:0,pad). + x = F.pad(x, (0, 0, 0, pad, 0, pad)) + valid = F.pad(valid, (0, 0, 0, pad, 0, pad)) # (B, padded, padded, 1) bool + return x, per, valid + + class VisionAdapter(nn.Module): """Base class for vision→LLM adapters (the connector). @@ -216,23 +246,15 @@ def forward(self, x: torch.Tensor, pool_window: int | None = None) -> torch.Tens w = pool_window if pool_window is not None else self.pool_window if w <= 0: raise ValueError(f"pool_window must be positive (got {w})") - b, n, c = x.shape - grid = _grid_side(n) - per = math.ceil(grid / w) - padded = per * w - x = x.view(b, grid, grid, c) - if padded != grid: - pad = padded - grid - # F.pad pads from the last dim backward: (C:0,0)(W:0,pad)(H:0,pad). - x = F.pad(x, (0, 0, 0, pad, 0, pad)) - mask = torch.ones(b, grid, grid, 1, dtype=x.dtype, device=x.device) - mask = F.pad(mask, (0, 0, 0, pad, 0, pad)) - else: - mask = torch.ones(b, padded, padded, 1, dtype=x.dtype, device=x.device) + x, per, valid = _pad_grid_to_windows(x, w) + b, _, _, c = x.shape # Group into windows and average over real (unpadded) cells only. sums = x.view(b, per, w, per, w, c).sum(dim=(2, 4)) # (B, per, per, C) - counts = mask.view(b, per, w, per, w, 1).sum(dim=(2, 4)).clamp_(min=1) # (B, per, per, 1) - pooled = (sums / counts).reshape(b, per * per, c) + if valid is None: + pooled = (sums / (w * w)).reshape(b, per * per, c) + else: + counts = valid.view(b, per, w, per, w, 1).to(sums.dtype).sum(dim=(2, 4)).clamp_(min=1) + pooled = (sums / counts).reshape(b, per * per, c) return self.proj(pooled) @@ -245,9 +267,9 @@ class AttentionalPoolAdapter(VisionAdapter): is projected ``in_dim -> out_dim``. Output length is ``ceil(grid/window)**2``. ``window`` is overridable per ``forward`` call (shared params across image - 2×2 and video 3×3 pooling, per the paper). v1 requires the grid be divisible - by the window (no ragged edge windows); ragged attentional pooling is a - follow-up. + 2×2 and video 3×3 pooling, per the paper). Ragged grids are supported: a + partial edge window pools only its real patches (the padded patches are + masked out of the window's K/V), matching ``avgpool`` and Molmo2 §A. """ def __init__( @@ -285,34 +307,50 @@ def reset_parameters(self) -> None: layer.reset_parameters() def output_num_tokens(self, num_input_tokens: int) -> int: - # require_divisible mirrors forward()'s ragged-grid rejection so a bad - # config fails at build / seq-len-check time, not at the first step. - return pooled_token_count(num_input_tokens, self.pool_window, require_divisible=True) + return pooled_token_count(num_input_tokens, self.pool_window) def forward(self, x: torch.Tensor, pool_window: int | None = None) -> torch.Tensor: w = pool_window if pool_window is not None else self.pool_window if w <= 0: raise ValueError(f"pool_window must be positive (got {w})") - b, n, c = x.shape - grid = _grid_side(n) - if grid % w != 0: - raise ValueError( - f"attentional_pool v1 requires the patch grid ({grid}x{grid}) be divisible " - f"by the pool window ({w}); ragged edge windows are not yet supported. " - "Use avgpool for ragged grids, or pick a divisible window." - ) - per = grid // w + x, per, valid = _pad_grid_to_windows(x, w) + b, _, _, c = x.shape k_win = w * w - # (B, grid, grid, C) -> windows (B*per*per, w*w, C): each window's patches contiguous. + # Ragged grid: each padded edge window pools only its real patches via a + # masked attention -- `valid` masks the padded patches out of the window's + # K/V (and the masked-mean query). Divisible grids return valid=None and + # stay bit-identical to the unmasked pooling. + win_mask: torch.Tensor | None + if valid is None: + win_mask = None + else: + # Window order must match `windows` below. + win_mask = ( + valid.view(b, per, w, per, w, 1) + .permute(0, 1, 3, 2, 4, 5) + .reshape(b * per * per, k_win) + ) + # (B, padded, padded, C) -> windows (B*per*per, w*w, C): each window's patches contiguous. windows = ( x.view(b, per, w, per, w, c).permute(0, 1, 3, 2, 4, 5).reshape(b * per * per, k_win, c) ) m = windows.shape[0] - query = windows.mean(dim=1, keepdim=True) # (M, 1, C) — window mean as query + if win_mask is None: + query = windows.mean(dim=1, keepdim=True) # (M, 1, C) — window mean as query + else: + # Query = mean over real patches only (every window has >=1 real patch, + # so the count is never zero). + wf = win_mask.unsqueeze(-1).to(windows.dtype) # (M, k_win, 1) + query = (windows * wf).sum(dim=1, keepdim=True) / wf.sum(dim=1, keepdim=True).clamp_( + min=1 + ) q = self.q_proj(query).view(m, 1, self.pool_heads, self.head_dim).transpose(1, 2) k = self.k_proj(windows).view(m, k_win, self.pool_heads, self.head_dim).transpose(1, 2) v = self.v_proj(windows).view(m, k_win, self.pool_heads, self.head_dim).transpose(1, 2) - attn = F.scaled_dot_product_attention(q, k, v) # (M, H, 1, head_dim) + # Mask padded patches out of the K/V for edge windows; None -> plain SDPA + # (the divisible path, bit-identical to before). + attn_mask = None if win_mask is None else win_mask.view(m, 1, 1, k_win) + attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) # (M, H, 1, head_dim) attn = attn.transpose(1, 2).reshape(m, c) # (M, C) pooled = self.o_proj(attn).view(b, per * per, c) return self.out_proj(pooled) diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 5cb8ee3..6371df4 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -276,7 +276,8 @@ def test_non_positive_tokens_raises(self): pooled_token_count(0, 2) def test_require_divisible_raises_on_ragged(self): - # attentional_pool path: a ragged grid is rejected up front, not at forward. + # The generic require_divisible flag still rejects a ragged grid up front + # (the seam for a future divisible-only connector); no current connector sets it. with pytest.raises(ValueError, match="ragged grid"): pooled_token_count(196, 3, require_divisible=True) # 14x14 not divisible by 3 @@ -383,23 +384,41 @@ def test_forward_shape(self): def test_is_vision_adapter(self): assert isinstance(AttentionalPoolAdapter(in_dim=16, out_dim=8, pool_heads=4), VisionAdapter) - @pytest.mark.parametrize(("n_in", "window"), [(16, 2), (256, 2), (729, 3)]) + # (196, 3) is ragged: 14x14 grid, 14 % 3 != 0 -> ceil(14/3)=5 -> 25. + @pytest.mark.parametrize(("n_in", "window"), [(16, 2), (256, 2), (729, 3), (196, 3)]) def test_output_num_tokens_matches_forward(self, n_in, window): adapter = AttentionalPoolAdapter(in_dim=32, out_dim=16, pool_window=window, pool_heads=4) x = torch.randn(1, n_in, 32) assert adapter(x).shape[1] == adapter.output_num_tokens(n_in) - def test_ragged_grid_raises(self): + def test_ragged_grid_supported(self): + # 4x4 grid, window 3 -> ceil(4/3)=2 -> 2x2=4 tokens; edge windows pool only + # their real patches (padded patches masked out). No longer raises. adapter = AttentionalPoolAdapter(in_dim=96, out_dim=64, pool_window=3, pool_heads=16) - with pytest.raises(ValueError, match="divisible"): - adapter(torch.randn(1, 16, 96)) # 4x4 grid, not divisible by 3 + out = adapter(torch.randn(2, 16, 96)) # 4x4 grid, ragged for window 3 + assert out.shape == (2, 4, 64) + assert torch.isfinite(out).all() - def test_output_num_tokens_rejects_ragged(self): - # The static count must mirror forward()'s ragged rejection so an invalid - # config fails at build / seq-len-check time, not at the first step. + def test_output_num_tokens_ragged(self): + # Ragged grids are supported: the count is ceil(grid/window)**2, matching + # avgpool and the masked forward. adapter = AttentionalPoolAdapter(in_dim=96, out_dim=64, pool_window=3, pool_heads=16) - with pytest.raises(ValueError, match="ragged grid"): - adapter.output_num_tokens(16) # 4x4 grid, not divisible by 3 + assert adapter.output_num_tokens(16) == 4 # 4x4 -> ceil(4/3)=2 -> 2x2 + assert adapter.output_num_tokens(196) == 25 # 14x14 -> ceil(14/3)=5 -> 5x5 + + def test_ragged_edge_window_pools_only_real_patches(self): + # 4x4 grid, window 3 -> 2x2 windows; the bottom-right window (output token + # 3) has exactly one real patch (index 15). With the padded patches masked, + # attention over a single key is the identity on its value, so that + # window's output must equal out_proj(o_proj(v_proj(patch_15))). + torch.manual_seed(0) + adapter = AttentionalPoolAdapter(in_dim=32, out_dim=16, pool_window=3, pool_heads=4).eval() + x = torch.randn(1, 16, 32) # 4x4 grid + with torch.no_grad(): + out = adapter(x) # (1, 4, 16) + real = x[:, 15:16] # (1, 1, 32) — the window's only real patch + ref = adapter.out_proj(adapter.o_proj(adapter.v_proj(real))) # (1, 1, 16) + assert torch.allclose(out[:, 3:4], ref, atol=1e-5) def test_heads_must_divide_dim(self): with pytest.raises(ValueError, match="divisible by"): @@ -478,11 +497,10 @@ def test_output_num_tokens_pools_for_avgpool(self): def test_output_num_tokens_pools_for_attentional(self): assert AdapterConfig(type="attentional_pool", pool_window=3).output_num_tokens(729) == 81 - def test_attentional_output_num_tokens_rejects_ragged(self): - # Config-time check rejects a ragged attentional_pool grid (mirrors forward), - # so the misconfig fails at config load, not at the first training step. - with pytest.raises(ValueError, match="ragged grid"): - AdapterConfig(type="attentional_pool", pool_window=3).output_num_tokens(196) + def test_attentional_output_num_tokens_allows_ragged(self): + # attentional_pool now pools ragged edges (masked partial windows), so the + # config-time count uses the same ceil math as avgpool — no rejection. + assert AdapterConfig(type="attentional_pool", pool_window=3).output_num_tokens(196) == 25 def test_avgpool_output_num_tokens_allows_ragged(self): # avgpool pools ragged edges, so the same ragged grid is fine (ceil math).