Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

### Added
- **Ragged grids in `attentional_pool`.** The `attentional_pool` connector now pools ragged patch grids (grid not divisible by the window) instead of rejecting them: each partial edge window pools only its real patches via a masked attention (matching `avgpool` and Molmo2 §A). This enables faithful **Molmo2 3×3 video pooling** on a 14×14 SigLIP grid (`14 % 3 != 0`), which previously had to fall back to `avgpool` or a divisible window. `output_num_tokens` is `ceil(grid/window)²`; divisible grids are bit-exact with before (no mask is built).
- `kempnerforge/model/adapter.py`: `AttentionalPoolAdapter.forward` pads the bottom/right edges and masks padded patches out of each edge window's K/V (with a masked-mean query); `DIVISIBLE_ONLY_POOL_TYPES` is now empty.
- Tests: `tests/unit/test_adapter.py` — ragged token count, masked edge-window correctness (a 1-real-patch window equals attention over that patch), config accepts ragged.
- **MoE router z-loss** (`moe_router_z_loss_weight`, ST-MoE style). An optional penalty on the router's pre-softmax logits — per MoE layer `mean_token(logsumexp(router_logits))²` — summed across layers and added to the training loss as `moe_router_z_loss_weight × z_loss`. It keeps router logits from growing without bound, targeting *logit-growth stability* (not load balance — that's the aux loss). Default `0.0` is off: the term is never added, so training, outputs, and gradients are unchanged. `z_loss` is a plain attribute like `aux_loss` (not a buffer/parameter), so it never enters `state_dict` — checkpoint-safe.
- `kempnerforge/config/model.py`: `moe_router_z_loss_weight: float = 0.0` (with a non-negativity check).
- `kempnerforge/model/router.py`: both `SoftmaxTopKRouter` and `SigmoidTopKRouter` set `self.z_loss = (logsumexp(logits, dim=-1) ** 2).mean()`.
Expand Down
4 changes: 3 additions & 1 deletion docs/how-to/train-on-video.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ A clip of `F` frames becomes `F × P′` visual tokens:
3. **Pool + project** each frame with the connector — an `avgpool` or
`attentional_pool` adapter reduces a `grid×grid` patch map to
`P′ = ceil(grid/window)²` tokens per frame (e.g. SigLIP2 @224/patch16 →
14×14 → 49 tokens at `pool_window=2`).
14×14 → 49 tokens at `pool_window=2`). Ragged windows work too — both
connectors pool partial edge windows over their real patches — so Molmo2-style
3×3 on the 14×14 grid gives 5×5 = 25 tokens.
4. **Fuse** the resulting `(B, F·P′, dim)` visual tokens into the backbone the
same way images are fused — so **all four archs work unchanged**:
- `joint_decoder` / `mot` / `moma`: the `F·P′` tokens prepend the text in the
Expand Down
124 changes: 81 additions & 43 deletions kempnerforge/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@
# the registered pooling builders below.
POOLING_ADAPTER_TYPES: tuple[str, ...] = ("avgpool", "attentional_pool")

# Pooling adapters whose ``forward`` requires the patch grid be divisible by the
# window (no ragged edge windows). Their token count must enforce the same so a
# ragged config is rejected at config/build time, not at the first training step.
DIVISIBLE_ONLY_POOL_TYPES: tuple[str, ...] = ("attentional_pool",)
# Pooling adapters whose ``forward`` cannot pool ragged edge windows (so their
# token count must reject a non-divisible grid at config/build time). Both
# ``avgpool`` and ``attentional_pool`` now mask partial edge windows, so this is
# empty -- kept as a seam for a future connector that genuinely needs divisibility.
DIVISIBLE_ONLY_POOL_TYPES: tuple[str, ...] = ()


def pooled_token_count(
Expand All @@ -63,10 +64,11 @@ def pooled_token_count(
cover (Molmo2 §A: "the bottom and far-right image patches are pooled with a
reduced number of patches").

Connectors that cannot pool ragged edges (``require_divisible=True``, e.g.
``attentional_pool``) raise when ``grid`` is not divisible by ``window``, so a
ragged config is rejected at config/build time rather than deterministically
failing in ``forward`` at the first step.
Connectors that genuinely cannot pool ragged edges may pass
``require_divisible=True`` to raise when ``grid`` is not divisible by
``window``, rejecting a ragged config at config/build time rather than
deterministically failing in ``forward`` at the first step. (Today both
pooling connectors handle ragged edges, so none set it.)

This is the single source of truth for the post-pool count: it must equal
the pooling adapters' actual ``forward`` output length, because the build
Expand All @@ -81,8 +83,8 @@ def pooled_token_count(
raise ValueError(
f"this pooling connector requires the patch grid ({grid}x{grid}) be "
f"divisible by the pool window ({window}); got a ragged grid "
f"(num_tokens={num_input_tokens}). Use avgpool for ragged grids, or pick "
"a divisible window."
f"(num_tokens={num_input_tokens}). Use a ragged-capable connector "
"(avgpool or attentional_pool), or pick a divisible window."
)
per_side = math.ceil(grid / window)
return per_side * per_side
Expand All @@ -100,6 +102,34 @@ def _grid_side(num_tokens: int) -> int:
return grid


def _pad_grid_to_windows(
x: torch.Tensor, window: int
) -> tuple[torch.Tensor, int, torch.Tensor | None]:
"""Reshape patch tokens to a square grid, padded to tile into windows.

``x`` is ``(B, grid**2, C)``. Returns ``(x, per, valid)``: ``x`` reshaped to
``(B, padded, padded, C)`` with ``padded = ceil(grid/window) * window``
(bottom/right zero-padded so it tiles into ``per × per`` windows of
``window × window``), and ``valid`` a ``(B, padded, padded, 1)`` bool mask of
the real (non-padded) patches -- or ``None`` when the grid is already
divisible (no padding, so callers skip masking). Shared by the pooling
adapters so their ragged edge handling stays in lock-step.
"""
b, n, c = x.shape
grid = _grid_side(n)
per = math.ceil(grid / window)
padded = per * window
x = x.view(b, grid, grid, c)
if padded == grid:
return x, per, None
pad = padded - grid
valid = torch.ones(b, grid, grid, 1, dtype=torch.bool, device=x.device)
# F.pad pads the last dim backward: (C:0,0)(W:0,pad)(H:0,pad).
x = F.pad(x, (0, 0, 0, pad, 0, pad))
valid = F.pad(valid, (0, 0, 0, pad, 0, pad)) # (B, padded, padded, 1) bool
return x, per, valid


class VisionAdapter(nn.Module):
"""Base class for vision→LLM adapters (the connector).

Expand Down Expand Up @@ -216,23 +246,15 @@ def forward(self, x: torch.Tensor, pool_window: int | None = None) -> torch.Tens
w = pool_window if pool_window is not None else self.pool_window
if w <= 0:
raise ValueError(f"pool_window must be positive (got {w})")
b, n, c = x.shape
grid = _grid_side(n)
per = math.ceil(grid / w)
padded = per * w
x = x.view(b, grid, grid, c)
if padded != grid:
pad = padded - grid
# F.pad pads from the last dim backward: (C:0,0)(W:0,pad)(H:0,pad).
x = F.pad(x, (0, 0, 0, pad, 0, pad))
mask = torch.ones(b, grid, grid, 1, dtype=x.dtype, device=x.device)
mask = F.pad(mask, (0, 0, 0, pad, 0, pad))
else:
mask = torch.ones(b, padded, padded, 1, dtype=x.dtype, device=x.device)
x, per, valid = _pad_grid_to_windows(x, w)
b, _, _, c = x.shape
# Group into windows and average over real (unpadded) cells only.
sums = x.view(b, per, w, per, w, c).sum(dim=(2, 4)) # (B, per, per, C)
counts = mask.view(b, per, w, per, w, 1).sum(dim=(2, 4)).clamp_(min=1) # (B, per, per, 1)
pooled = (sums / counts).reshape(b, per * per, c)
if valid is None:
pooled = (sums / (w * w)).reshape(b, per * per, c)
else:
counts = valid.view(b, per, w, per, w, 1).to(sums.dtype).sum(dim=(2, 4)).clamp_(min=1)
pooled = (sums / counts).reshape(b, per * per, c)
return self.proj(pooled)


Expand All @@ -245,9 +267,9 @@ class AttentionalPoolAdapter(VisionAdapter):
is projected ``in_dim -> out_dim``. Output length is ``ceil(grid/window)**2``.

``window`` is overridable per ``forward`` call (shared params across image
2×2 and video 3×3 pooling, per the paper). v1 requires the grid be divisible
by the window (no ragged edge windows); ragged attentional pooling is a
follow-up.
2×2 and video 3×3 pooling, per the paper). Ragged grids are supported: a
partial edge window pools only its real patches (the padded patches are
masked out of the window's K/V), matching ``avgpool`` and Molmo2 §A.
"""

def __init__(
Expand Down Expand Up @@ -285,34 +307,50 @@ def reset_parameters(self) -> None:
layer.reset_parameters()

def output_num_tokens(self, num_input_tokens: int) -> int:
# require_divisible mirrors forward()'s ragged-grid rejection so a bad
# config fails at build / seq-len-check time, not at the first step.
return pooled_token_count(num_input_tokens, self.pool_window, require_divisible=True)
return pooled_token_count(num_input_tokens, self.pool_window)

def forward(self, x: torch.Tensor, pool_window: int | None = None) -> torch.Tensor:
w = pool_window if pool_window is not None else self.pool_window
if w <= 0:
raise ValueError(f"pool_window must be positive (got {w})")
b, n, c = x.shape
grid = _grid_side(n)
if grid % w != 0:
raise ValueError(
f"attentional_pool v1 requires the patch grid ({grid}x{grid}) be divisible "
f"by the pool window ({w}); ragged edge windows are not yet supported. "
"Use avgpool for ragged grids, or pick a divisible window."
)
per = grid // w
x, per, valid = _pad_grid_to_windows(x, w)
b, _, _, c = x.shape
k_win = w * w
# (B, grid, grid, C) -> windows (B*per*per, w*w, C): each window's patches contiguous.
# Ragged grid: each padded edge window pools only its real patches via a
# masked attention -- `valid` masks the padded patches out of the window's
# K/V (and the masked-mean query). Divisible grids return valid=None and
# stay bit-identical to the unmasked pooling.
win_mask: torch.Tensor | None
if valid is None:
win_mask = None
else:
# Window order must match `windows` below.
win_mask = (
valid.view(b, per, w, per, w, 1)
.permute(0, 1, 3, 2, 4, 5)
.reshape(b * per * per, k_win)
)
# (B, padded, padded, C) -> windows (B*per*per, w*w, C): each window's patches contiguous.
windows = (
x.view(b, per, w, per, w, c).permute(0, 1, 3, 2, 4, 5).reshape(b * per * per, k_win, c)
)
m = windows.shape[0]
query = windows.mean(dim=1, keepdim=True) # (M, 1, C) — window mean as query
if win_mask is None:
query = windows.mean(dim=1, keepdim=True) # (M, 1, C) — window mean as query
else:
# Query = mean over real patches only (every window has >=1 real patch,
# so the count is never zero).
wf = win_mask.unsqueeze(-1).to(windows.dtype) # (M, k_win, 1)
query = (windows * wf).sum(dim=1, keepdim=True) / wf.sum(dim=1, keepdim=True).clamp_(
min=1
)
q = self.q_proj(query).view(m, 1, self.pool_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(windows).view(m, k_win, self.pool_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(windows).view(m, k_win, self.pool_heads, self.head_dim).transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v) # (M, H, 1, head_dim)
# Mask padded patches out of the K/V for edge windows; None -> plain SDPA
# (the divisible path, bit-identical to before).
attn_mask = None if win_mask is None else win_mask.view(m, 1, 1, k_win)
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) # (M, H, 1, head_dim)
attn = attn.transpose(1, 2).reshape(m, c) # (M, C)
pooled = self.o_proj(attn).view(b, per * per, c)
return self.out_proj(pooled)
Expand Down
48 changes: 33 additions & 15 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ def test_non_positive_tokens_raises(self):
pooled_token_count(0, 2)

def test_require_divisible_raises_on_ragged(self):
# attentional_pool path: a ragged grid is rejected up front, not at forward.
# The generic require_divisible flag still rejects a ragged grid up front
# (the seam for a future divisible-only connector); no current connector sets it.
with pytest.raises(ValueError, match="ragged grid"):
pooled_token_count(196, 3, require_divisible=True) # 14x14 not divisible by 3

Expand Down Expand Up @@ -383,23 +384,41 @@ def test_forward_shape(self):
def test_is_vision_adapter(self):
assert isinstance(AttentionalPoolAdapter(in_dim=16, out_dim=8, pool_heads=4), VisionAdapter)

@pytest.mark.parametrize(("n_in", "window"), [(16, 2), (256, 2), (729, 3)])
# (196, 3) is ragged: 14x14 grid, 14 % 3 != 0 -> ceil(14/3)=5 -> 25.
@pytest.mark.parametrize(("n_in", "window"), [(16, 2), (256, 2), (729, 3), (196, 3)])
def test_output_num_tokens_matches_forward(self, n_in, window):
adapter = AttentionalPoolAdapter(in_dim=32, out_dim=16, pool_window=window, pool_heads=4)
x = torch.randn(1, n_in, 32)
assert adapter(x).shape[1] == adapter.output_num_tokens(n_in)

def test_ragged_grid_raises(self):
def test_ragged_grid_supported(self):
# 4x4 grid, window 3 -> ceil(4/3)=2 -> 2x2=4 tokens; edge windows pool only
# their real patches (padded patches masked out). No longer raises.
adapter = AttentionalPoolAdapter(in_dim=96, out_dim=64, pool_window=3, pool_heads=16)
with pytest.raises(ValueError, match="divisible"):
adapter(torch.randn(1, 16, 96)) # 4x4 grid, not divisible by 3
out = adapter(torch.randn(2, 16, 96)) # 4x4 grid, ragged for window 3
assert out.shape == (2, 4, 64)
assert torch.isfinite(out).all()

def test_output_num_tokens_rejects_ragged(self):
# The static count must mirror forward()'s ragged rejection so an invalid
# config fails at build / seq-len-check time, not at the first step.
def test_output_num_tokens_ragged(self):
# Ragged grids are supported: the count is ceil(grid/window)**2, matching
# avgpool and the masked forward.
adapter = AttentionalPoolAdapter(in_dim=96, out_dim=64, pool_window=3, pool_heads=16)
with pytest.raises(ValueError, match="ragged grid"):
adapter.output_num_tokens(16) # 4x4 grid, not divisible by 3
assert adapter.output_num_tokens(16) == 4 # 4x4 -> ceil(4/3)=2 -> 2x2
assert adapter.output_num_tokens(196) == 25 # 14x14 -> ceil(14/3)=5 -> 5x5

def test_ragged_edge_window_pools_only_real_patches(self):
# 4x4 grid, window 3 -> 2x2 windows; the bottom-right window (output token
# 3) has exactly one real patch (index 15). With the padded patches masked,
# attention over a single key is the identity on its value, so that
# window's output must equal out_proj(o_proj(v_proj(patch_15))).
torch.manual_seed(0)
adapter = AttentionalPoolAdapter(in_dim=32, out_dim=16, pool_window=3, pool_heads=4).eval()
x = torch.randn(1, 16, 32) # 4x4 grid
with torch.no_grad():
out = adapter(x) # (1, 4, 16)
real = x[:, 15:16] # (1, 1, 32) — the window's only real patch
ref = adapter.out_proj(adapter.o_proj(adapter.v_proj(real))) # (1, 1, 16)
assert torch.allclose(out[:, 3:4], ref, atol=1e-5)

def test_heads_must_divide_dim(self):
with pytest.raises(ValueError, match="divisible by"):
Expand Down Expand Up @@ -478,11 +497,10 @@ def test_output_num_tokens_pools_for_avgpool(self):
def test_output_num_tokens_pools_for_attentional(self):
assert AdapterConfig(type="attentional_pool", pool_window=3).output_num_tokens(729) == 81

def test_attentional_output_num_tokens_rejects_ragged(self):
# Config-time check rejects a ragged attentional_pool grid (mirrors forward),
# so the misconfig fails at config load, not at the first training step.
with pytest.raises(ValueError, match="ragged grid"):
AdapterConfig(type="attentional_pool", pool_window=3).output_num_tokens(196)
def test_attentional_output_num_tokens_allows_ragged(self):
# attentional_pool now pools ragged edges (masked partial windows), so the
# config-time count uses the same ceil math as avgpool — no rejection.
assert AdapterConfig(type="attentional_pool", pool_window=3).output_num_tokens(196) == 25

def test_avgpool_output_num_tokens_allows_ragged(self):
# avgpool pools ragged edges, so the same ragged grid is fine (ceil math).
Expand Down
Loading