Skip to content

Commit bcb435f

Browse files
committed
Make the per-frame time embedding registry-based and config-selectable
1 parent 45bf42b commit bcb435f

13 files changed

Lines changed: 284 additions & 17 deletions

File tree

CHANGELOG.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
6060
- **Frame-aware model + training wiring.** `kempnerforge/model/vlm.py`: `_project_image_features``_project_visual_features` folds the frame axis through the encoder + pooler to `(B, F·P′, dim)` (a single image is the `F == 1` case). `VLMWrapper` gains `frames_per_clip`, threaded through `build_parallel_model` / `_build_vlm` / `build_vlm_wrapper` so the static visual-token count equals `F·P′` (drives the residual budget and MoT's positional split; static == runtime). `scripts/train.py` builds the video dataset/collator when `[video]` is set. Adds `configs/train/vlm_video_webvid.toml` (SigLIP2 + avgpool + WebVid).
6161
- Tests: `tests/unit/test_video_io.py`, `test_video_dataset.py`, `test_video_config.py`; video-forward cases (all four archs) + image-path regression in `test_vlm.py`; pooling-adapter cases in `test_adapter.py`. Docs: `docs/how-to/train-on-video.md`.
6262
- Deferred (follow-ups): grounding (`<points>`/`<tracks>` outputs with point-F1 / track-J&F eval), frame-mask-aware attention, bidirectional visual attention, VLM sequence packing, long-context (blocked on context-parallel being wired), and warm-start from a converted image-VLM checkpoint.
63-
- **Per-frame timestamps for video.** Each sampled frame carries its actual presentation time (seconds), embedded and added to that frame's visual tokens so the model can reason about *when* events occur, not just frame order. Zero-initialized so it is identity at step 0 (warm-start) and learned from there.
63+
- **Per-frame timestamps for video.** Each sampled frame carries its actual presentation time (seconds), embedded and added to that frame's visual tokens so the model can reason about *when* events occur, not just frame order. Registry-driven and config-selected (so new techniques drop in as small additions); zero-initialized so it is identity at step 0 (warm-start) and learned from there.
6464
- `kempnerforge/data/video_io.py`: `decode_video_frames` returns `(frames, times)` (the matched frames' presentation times); `kempnerforge/data/video_dataset.py` emits a `frame_times` `(F,)` tensor and `VideoCollator` stacks it to `(B, F)`.
65-
- `kempnerforge/model/frame_time.py`: `FrameTimeEmbedding` (sinusoidal features at log-spaced periods → zero-init projection), a `VLMWrapper` submodule applied per frame in `_project_visual_features` (video only; `None` for the image path) and built + FSDP-sharded + meta-materialized at both build sites (`build_vlm_wrapper`, `_build_vlm`). `scripts/train.py` threads `frame_times` into the VLM forward.
66-
- Tests: `tests/unit/test_frame_time.py` + frame-time forward cases in `test_vlm.py`.
65+
- `kempnerforge/model/frame_time.py`: a `TimeEmbedding` base (the additive `(B, F) seconds → (B, F, dim)` contract) + the `"sinusoidal"` implementation, registered via `@registry.register_time_embedding` and built through `build_time_embedding`. Applied per frame in `_project_visual_features` as a `VLMWrapper` submodule (video only; `None` for the image path) and built + FSDP-sharded + meta-materialized at both build sites (`build_vlm_wrapper`, `_build_vlm`).
66+
- `kempnerforge/config/time_embedding.py`: the `[time_embedding]` `TimeEmbeddingConfig` (`type` selects the registered builder; `type = "none"` disables it), wired into `JobConfig` and threaded through `build_parallel_model`; `scripts/train.py` passes `config.time_embedding` and threads `frame_times` into the forward.
67+
- Sequence-*modifying* time encodings (e.g. Molmo2-style interleaved text time-tokens) are a separate future hook at the sequence-assembly layer, gated on interleaved/variable-length sequence support — out of scope for this additive registry.
68+
- Tests: `tests/unit/test_frame_time.py`, `test_time_embedding_config.py`; frame-time forward + `type="none"` cases in `test_vlm.py`; the video build path in `test_distributed.py`.
6769
- `install-and-verify` plugin skill: runs `uv sync`, asserts Python ≥ 3.12, then runs the four CI gate checks (`ruff check`, `ruff format --check`, `pyright`, `pytest tests/unit/`). Canonical first command after cloning.
6870
- `.python-version` pinned to `>=3.12` so uv resolves the interpreter explicitly. Teammates on 3.13 use 3.13 (no download); 3.11-only users get 3.12 auto-fetched.
6971
- **Dynamic-checkpointing window** (`[checkpoint.dyn_ckpt_window]`). Opt-in dense save phase: inside `[start, stop]` a registered strategy decides which steps to save; outside the window the regular `interval` cadence applies. The default strategy, `"power2"`, saves at `start` and at every `start + 2^k` while `<= stop` — tight near the start of the window, doubling thereafter. Useful for analyzing early-training dynamics, where the loss moves fastest. The default `CheckpointConfig` is unchanged (no `dyn_ckpt_window`, interval-only saves).

docs/how-to/train-on-video.md

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,13 @@ A clip of `F` frames becomes `F × P′` visual tokens:
2626
`max_seq_len`).
2727

2828
Temporal order is carried by frame order (sequential positions). On top of that,
29-
each frame's **timestamp in seconds** is embedded (sinusoidal features → a
30-
zero-initialized projection) and added to that frame's visual tokens, so the
31-
model sees *when* each frame occurs, not just its order. Grounding outputs are a
32-
separate follow-up (see below).
29+
each frame's **timestamp in seconds** is embedded and added to that frame's
30+
visual tokens, so the model sees *when* each frame occurs, not just its order.
31+
The embedding is registry-driven: `[time_embedding].type` selects it
32+
(`sinusoidal` by default — sinusoidal features at log-spaced periods through a
33+
zero-initialized projection; `none` disables it), so new techniques (learned,
34+
Fourier, …) register as small additions and switch via config. Grounding
35+
outputs are a separate follow-up (see below).
3336

3437
## Token budget
3538

@@ -94,6 +97,11 @@ time, so it is set in the TOML, not via a `--vlm.arch=` CLI override.)
9497
- **Grounding outputs are a follow-up** — per-frame timestamps are encoded (see
9598
above), but structured grounding (`<points>`/`<tracks>` outputs with point-F1
9699
/ track-J&F eval) is not yet implemented.
100+
- **Sequence-modifying time encodings are a separate hook** — the
101+
`[time_embedding]` registry is for *additive* per-frame embeddings (no change
102+
to sequence length). Molmo2-style interleaved text time-tokens change the
103+
token sequence and need interleaved/variable-length sequence support KF does
104+
not have yet; they would hook the sequence-assembly layer, not this registry.
97105
- **Padded frames are not yet masked from attention** — short clips pad to
98106
`max_frames` with blank frames; a `frame_mask` is produced but not yet
99107
consumed by the attention mask.

kempnerforge/config/job.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from kempnerforge.config.optimizer import OptimizerConfig
1515
from kempnerforge.config.profiling import ProfilingConfig
1616
from kempnerforge.config.scheduler import SchedulerConfig
17+
from kempnerforge.config.time_embedding import TimeEmbeddingConfig
1718
from kempnerforge.config.training import TrainConfig
1819
from kempnerforge.config.video import VideoConfig
1920
from kempnerforge.config.vision import VisionEncoderConfig
@@ -53,6 +54,7 @@ class JobConfig:
5354
adapter: AdapterConfig | None = None
5455
vlm: VLMConfig | None = None
5556
video: VideoConfig | None = None
57+
time_embedding: TimeEmbeddingConfig | None = None
5658

5759
def __post_init__(self) -> None:
5860
"""Cross-section invariants that fire at construction time.

kempnerforge/config/registry.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,27 @@ def get_adapter(self, name: str) -> Callable:
186186
def list_adapters(self) -> list[str]:
187187
return self.list("adapter")
188188

189+
def register_time_embedding(self, name: str) -> Callable:
190+
"""Decorator to register a time-embedding builder.
191+
192+
Builders take ``(dim, **kwargs)`` and return an ``nn.Module`` mapping
193+
per-frame timestamps ``(B, F)`` in seconds to an additive embedding
194+
``(B, F, dim)`` (and exposing ``reset_parameters()`` for meta-device
195+
builds). Selected by ``[time_embedding].type`` on the VLM video path.
196+
"""
197+
198+
def decorator(fn: Callable) -> Callable:
199+
self.register("time_embedding", name, fn)
200+
return fn
201+
202+
return decorator
203+
204+
def get_time_embedding(self, name: str) -> Callable:
205+
return self.get("time_embedding", name)
206+
207+
def list_time_embeddings(self) -> list[str]:
208+
return self.list("time_embedding")
209+
189210
def register_dyn_ckpt_strategy(self, name: str) -> Callable:
190211
"""Decorator to register a dynamic-checkpointing-window strategy.
191212

kempnerforge/config/schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from kempnerforge.config.optimizer import OptimizerConfig # noqa: F401
1616
from kempnerforge.config.profiling import ProfilingConfig # noqa: F401
1717
from kempnerforge.config.scheduler import SchedulerConfig, SchedulerType # noqa: F401
18+
from kempnerforge.config.time_embedding import TimeEmbeddingConfig # noqa: F401
1819
from kempnerforge.config.training import ActivationCheckpointing, TrainConfig # noqa: F401
1920
from kempnerforge.config.vision import VisionEncoderConfig # noqa: F401
2021
from kempnerforge.config.vlm import ( # noqa: F401
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Time-embedding (per-frame timestamp) configuration.
2+
3+
``TimeEmbeddingConfig`` selects which per-frame timestamp embedding the VLM
4+
video path uses and parameterizes it. Dispatched via the ``time_embedding``
5+
registry at build time (see ``kempnerforge/model/frame_time.py``).
6+
7+
In TOML, ``[time_embedding]`` is a top-level section parallel to ``[adapter]``.
8+
It is only consumed for video (``frames_per_clip > 1``); the image and text
9+
paths never build one. ``type = "none"`` disables the embedding even for video.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
from dataclasses import dataclass
15+
from typing import Any
16+
17+
from kempnerforge.config.registry import registry
18+
19+
20+
@dataclass
21+
class TimeEmbeddingConfig:
22+
"""Selects the time-embedding type and parameterizes it.
23+
24+
Register a new technique via ``@registry.register_time_embedding`` and select
25+
it with ``type``; ``type = "none"`` disables the embedding entirely.
26+
27+
Fields:
28+
type: Registry key for the builder (``"sinusoidal"`` default, or ``"none"``).
29+
num_bands: Number of sinusoidal frequency bands (``"sinusoidal"`` only).
30+
min_period: Shortest period in seconds (finest temporal resolution).
31+
max_period: Longest period in seconds (coarsest temporal scale).
32+
"""
33+
34+
type: str = "sinusoidal"
35+
num_bands: int = 16
36+
min_period: float = 0.5
37+
max_period: float = 256.0
38+
39+
def __post_init__(self) -> None:
40+
if self.type == "none":
41+
return
42+
# Late import: importing the module triggers the
43+
# ``@registry.register_time_embedding`` decorators. Doing it at module
44+
# scope would create a circular import via the config/model graph.
45+
import kempnerforge.model.frame_time # noqa: F401, PLC0415
46+
47+
registered = tuple(registry.list_time_embeddings())
48+
if self.type not in registered:
49+
raise ValueError(
50+
f"Unknown time_embedding.type: {self.type!r}. "
51+
f"Registered: {sorted(registered)} (or 'none' to disable)."
52+
)
53+
if self.num_bands <= 0:
54+
raise ValueError(f"time_embedding.num_bands must be positive (got {self.num_bands})")
55+
if not 0.0 < self.min_period < self.max_period:
56+
raise ValueError(
57+
f"time_embedding requires 0 < min_period < max_period "
58+
f"(got min_period={self.min_period}, max_period={self.max_period})"
59+
)
60+
61+
@property
62+
def enabled(self) -> bool:
63+
"""Whether a module should be built (``type != "none"``)."""
64+
return self.type != "none"
65+
66+
def extra_kwargs(self) -> dict[str, Any]:
67+
"""Builder kwargs beyond ``dim``. Type-specific builders take what they
68+
need and swallow the rest via ``**_`` (mirrors ``AdapterConfig``)."""
69+
return {
70+
"num_bands": self.num_bands,
71+
"min_period": self.min_period,
72+
"max_period": self.max_period,
73+
}

kempnerforge/distributed/parallel.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ def _build_vlm(
381381
compile_model: bool,
382382
fp8: bool,
383383
frames_per_clip: int = 1,
384+
time_embedding_config=None,
384385
) -> torch.nn.Module:
385386
"""Build a VLM wrapper with parallelism applied in the correct order.
386387
@@ -406,7 +407,7 @@ def _build_vlm(
406407
from kempnerforge.distributed.expert_parallel import apply_expert_parallel
407408
from kempnerforge.distributed.tensor_parallel import apply_tensor_parallel
408409
from kempnerforge.model.adapter import build_adapter
409-
from kempnerforge.model.frame_time import FrameTimeEmbedding
410+
from kempnerforge.model.frame_time import build_time_embedding
410411
from kempnerforge.model.transformer import Transformer
411412
from kempnerforge.model.vlm import (
412413
VLMWrapper,
@@ -442,9 +443,14 @@ def _build_vlm(
442443
transformer = Transformer(
443444
model_config, vlm_config=vlm_config, num_image_tokens=visual_tokens
444445
)
445-
# Video gets a per-frame timestamp embedding; built alongside the adapter
446-
# so it shares the meta/CPU build + materialize path below.
447-
frame_time_embed = FrameTimeEmbedding(model_config.dim) if frames_per_clip > 1 else None
446+
# Video gets a per-frame timestamp embedding (registry-selected via
447+
# [time_embedding]); built alongside the adapter so it shares the
448+
# meta/CPU build + materialize path below.
449+
frame_time_embed = (
450+
build_time_embedding(time_embedding_config, model_config.dim)
451+
if frames_per_clip > 1
452+
else None
453+
)
448454

449455
strategy = build_modality_strategy(vlm_config)
450456
wrapper = VLMWrapper(
@@ -535,6 +541,7 @@ def build_parallel_model(
535541
compile_model: bool = False,
536542
fp8: bool = False,
537543
frames_per_clip: int = 1,
544+
time_embedding_config=None,
538545
) -> torch.nn.Module:
539546
"""Build a Transformer (or a VLMWrapper) with parallelism applied.
540547
@@ -579,6 +586,7 @@ def build_parallel_model(
579586
compile_model=compile_model,
580587
fp8=fp8,
581588
frames_per_clip=frames_per_clip,
589+
time_embedding_config=time_embedding_config,
582590
)
583591

584592
from kempnerforge.distributed.tensor_parallel import apply_tensor_parallel

kempnerforge/model/frame_time.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,42 @@
2323
from __future__ import annotations
2424

2525
import math
26+
from typing import Any
2627

2728
import torch
2829
import torch.nn as nn
2930

31+
from kempnerforge.config.registry import registry
3032

31-
class FrameTimeEmbedding(nn.Module):
33+
34+
class TimeEmbedding(nn.Module):
35+
"""Base for per-frame timestamp embeddings (the *additive* family).
36+
37+
Contract: ``forward(times: (B, F) seconds) -> (B, F, dim)`` — an additive
38+
embedding added to each frame's visual tokens, with **no change to sequence
39+
length** — plus ``reset_parameters()`` so meta-device builds can re-init
40+
after ``to_empty``. Register a new technique with
41+
``@registry.register_time_embedding`` and select it via
42+
``[time_embedding].type``; ``build_time_embedding`` dispatches through the
43+
registry.
44+
45+
Out of scope (a separate, future integration point): sequence-*modifying*
46+
time encodings — e.g. Molmo2-style textual time-tokens interleaved between
47+
frame groups — change the token sequence (count / ``output_slice`` /
48+
``modality_ids`` / MoT split) and need tokenizer + interleaved-sequence
49+
support KF does not have yet. Those would hook the sequence-assembly layer
50+
(``ModalityStrategy.prepare``), not this additive registry; set
51+
``[time_embedding].type = "none"`` to run them instead of an additive one.
52+
"""
53+
54+
def forward(self, times: torch.Tensor) -> torch.Tensor: # pragma: no cover - interface
55+
raise NotImplementedError
56+
57+
def reset_parameters(self) -> None: # pragma: no cover - interface
58+
raise NotImplementedError
59+
60+
61+
class FrameTimeEmbedding(TimeEmbedding):
3262
"""Sinusoidal embedding of a per-frame timestamp (seconds) -> model dim.
3363
3464
Args:
@@ -93,3 +123,37 @@ def forward(self, times: torch.Tensor) -> torch.Tensor:
93123
ang = times.to(torch.float32).unsqueeze(-1) * (2.0 * math.pi / periods) # (B, F, bands)
94124
feats = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1) # (B, F, 2*bands)
95125
return self.proj(feats.to(self.proj.weight.dtype))
126+
127+
128+
@registry.register_time_embedding("sinusoidal")
129+
def _build_sinusoidal(
130+
dim: int,
131+
*,
132+
num_bands: int = 16,
133+
min_period: float = 0.5,
134+
max_period: float = 256.0,
135+
**_: Any,
136+
) -> FrameTimeEmbedding:
137+
"""Registry builder for the sinusoidal time embedding."""
138+
return FrameTimeEmbedding(
139+
dim, num_bands=num_bands, min_period=min_period, max_period=max_period
140+
)
141+
142+
143+
def build_time_embedding(time_embedding_config: Any, dim: int) -> TimeEmbedding | None:
144+
"""Build the per-frame time embedding from a ``TimeEmbeddingConfig``.
145+
146+
Returns ``None`` when disabled (``type == "none"``). A ``None`` config falls
147+
back to the default (sinusoidal) so video callers that pass nothing keep the
148+
default behavior. The config is duck-typed (``.enabled`` / ``.type`` /
149+
``.extra_kwargs()``) to avoid a model->config import cycle, matching
150+
``build_adapter``.
151+
"""
152+
if time_embedding_config is None:
153+
from kempnerforge.config.time_embedding import TimeEmbeddingConfig # noqa: PLC0415
154+
155+
time_embedding_config = TimeEmbeddingConfig()
156+
if not time_embedding_config.enabled:
157+
return None
158+
builder = registry.get_time_embedding(time_embedding_config.type)
159+
return builder(dim, **time_embedding_config.extra_kwargs())

kempnerforge/model/vlm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,11 @@
4545
from kempnerforge.config.adapter import AdapterConfig
4646
from kempnerforge.config.registry import registry
4747
from kempnerforge.config.schema import ModelConfig
48+
from kempnerforge.config.time_embedding import TimeEmbeddingConfig
4849
from kempnerforge.config.vision import VisionEncoderConfig
4950
from kempnerforge.config.vlm import FreezeSpec, VLMConfig
5051
from kempnerforge.model.adapter import VisionAdapter, build_adapter
51-
from kempnerforge.model.frame_time import FrameTimeEmbedding
52+
from kempnerforge.model.frame_time import TimeEmbedding, build_time_embedding
5253
from kempnerforge.model.modality import ModalityContext
5354
from kempnerforge.model.transformer import Transformer
5455
from kempnerforge.model.vision import VisionEncoder
@@ -309,7 +310,7 @@ def __init__(
309310
transformer: Transformer,
310311
strategy: ModalityStrategy,
311312
frames_per_clip: int = 1,
312-
frame_time_embed: FrameTimeEmbedding | None = None,
313+
frame_time_embed: TimeEmbedding | None = None,
313314
) -> None:
314315
super().__init__()
315316
self.vision_encoder = vision_encoder
@@ -391,6 +392,7 @@ def build_vlm_wrapper(
391392
adapter_config: AdapterConfig,
392393
vlm_config: VLMConfig,
393394
frames_per_clip: int = 1,
395+
time_embedding_config: TimeEmbeddingConfig | None = None,
394396
) -> VLMWrapper:
395397
"""Build a ``VLMWrapper`` from the four top-level configs.
396398
@@ -434,8 +436,13 @@ def build_vlm_wrapper(
434436
)
435437
transformer = Transformer(model_config, vlm_config=vlm_config, num_image_tokens=visual_tokens)
436438
strategy = build_modality_strategy(vlm_config)
437-
# Video clips get a per-frame timestamp embedding; the image path (F=1) does not.
438-
frame_time_embed = FrameTimeEmbedding(model_config.dim) if frames_per_clip > 1 else None
439+
# Video clips get a per-frame timestamp embedding (registry-selected via
440+
# [time_embedding]); the image path (F=1) does not. type="none" disables it.
441+
frame_time_embed = (
442+
build_time_embedding(time_embedding_config, model_config.dim)
443+
if frames_per_clip > 1
444+
else None
445+
)
439446
return VLMWrapper(
440447
encoder,
441448
adapter,

scripts/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def main() -> None:
175175
adapter_config=adapter_cfg,
176176
vlm_config=vlm_cfg,
177177
frames_per_clip=(config.video.max_frames if config.video is not None else 1),
178+
time_embedding_config=config.time_embedding,
178179
ac_mode=tc.activation_checkpointing,
179180
mp_policy=mp_policy,
180181
param_dtype=tc.param_dtype,

0 commit comments

Comments
 (0)