diff --git a/CHANGELOG.md b/CHANGELOG.md index 49ab032..8500aba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,7 +65,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **Video data path (pluggable, registry-driven).** `kempnerforge/data/video_io.py`: timestamp-based frame sampling (target fps, first & last frame kept — Molmo2 §3.1/§A) registered as the `"uniform"` sampling policy and selectable via `[video].sampling_policy`; PyAV decode (lazy-imported). `kempnerforge/data/video_dataset.py`: a `VideoDataset` base + the WebVid-style `WebVidVideoDataset` (CSV manifest + `id[:2]/id[:4]/id[:6]/id.mp4` mapping) registered as `"webvid"` via `@registry.register_video_dataset`, plus a `build_video_dataset` dispatch — so other dataset styles are additive registrations selected by `[video].dataset_type`. The WebVid corpus directory is parameterized by `[video].dataset_name` (no longer hardcoded to `webvid-10M`). `VideoCollator` → `(B, F, 3, H, W)` + a frame-validity mask; an undecodable clip is masked out (no loss). `kempnerforge/config/registry.py`: `register_video_dataset` / `register_sampling_policy` registries. `kempnerforge/config/video.py`: the `[video]` `VideoConfig` section (`data_root`, `dataset_type`, `dataset_name`, `sampling_policy`, `split`, `fps`, `max_frames`, `min_frames`, `frame_size`, `max_samples`), wired into `JobConfig` (+ `is_video`). `av` is an optional `video` dependency group (`uv sync --group video`); CI installs it for the lint + unit-test jobs. - **Frame-aware model + training wiring.** `kempnerforge/model/vlm.py`: `_project_image_features` → `_project_visual_features` folds the frame axis through the encoder + pooler to `(B, F·P′, dim)` (a single image is the `F == 1` case). `VLMWrapper` gains `frames_per_clip`, threaded through `build_parallel_model` / `_build_vlm` / `build_vlm_wrapper` so the static visual-token count equals `F·P′` (drives the residual budget and MoT's positional split; static == runtime). `scripts/train.py` builds the video dataset/collator when `[video]` is set. Adds `configs/train/vlm_video_webvid.toml` (SigLIP2 + avgpool + WebVid). - Tests: `tests/unit/test_video_io.py`, `test_video_dataset.py`, `test_video_config.py`; video-forward cases (all four archs) + image-path regression in `test_vlm.py`; pooling-adapter cases in `test_adapter.py`. Docs: `docs/how-to/train-on-video.md`. - - Deferred (follow-ups; the registries make these additive): more video dataset styles (HuggingFace video sets, flat folders, alternate manifests) and frame-sampling policies; per-frame timestamp tokens + grounding (``/`` outputs with point-F1 / track-J&F eval), bidirectional visual attention, VLM sequence packing, long-context (blocked on context-parallel being wired), and warm-start from a converted image-VLM checkpoint. + - Deferred (follow-ups; the registries make these additive): more video dataset styles (HuggingFace video sets, flat folders, alternate manifests) and frame-sampling policies; interleaved text time-tokens (Molmo2-style, sequence-modifying), grounding (``/`` outputs with point-F1 / track-J&F eval), bidirectional visual attention, VLM sequence packing, long-context (blocked on context-parallel being wired), and warm-start from a converted image-VLM checkpoint. +- **Per-frame timestamps for video.** Each sampled frame carries its actual presentation time (seconds), embedded and added to that frame's visual tokens so the model can reason about *when* events occur, not just frame order. Registry-driven and config-selected (so new techniques drop in as small additions); zero-initialized so it is identity at step 0 (warm-start) and learned from there. + - `kempnerforge/data/video_io.py`: `decode_video_frames` returns `(frames, times)` (the matched frames' presentation times); `kempnerforge/data/video_dataset.py` emits a `frame_times` `(F,)` tensor and `VideoCollator` stacks it to `(B, F)`. + - `kempnerforge/model/frame_time.py`: a `TimeEmbedding` base (the additive `(B, F) seconds → (B, F, dim)` contract) + the `"sinusoidal"` implementation, registered via `@registry.register_time_embedding` and built through `build_time_embedding`. Applied per frame in `_project_visual_features` as a `VLMWrapper` submodule (video only; `None` for the image path) and built + FSDP-sharded + meta-materialized at both build sites (`build_vlm_wrapper`, `_build_vlm`). + - `kempnerforge/config/time_embedding.py`: the `[time_embedding]` `TimeEmbeddingConfig` (`type` selects the registered builder; `type = "none"` disables it), wired into `JobConfig` and threaded through `build_parallel_model`; `scripts/train.py` passes `config.time_embedding` and threads `frame_times` into the forward. A non-video config that still sets a non-default `[time_embedding]` gets a set-but-ineffective warning (`config/job.py`, mirroring the HF-encoder-override warning). + - `kempnerforge/config/vlm.py`: `frame_time_embed` added to `DEFAULT_MODULE_PATTERNS`, so the submodule is freeze-addressable (a freeze spec or schedule can stage it, like the adapter). + - Sequence-*modifying* time encodings (e.g. Molmo2-style interleaved text time-tokens) are a separate future hook at the sequence-assembly layer, gated on interleaved/variable-length sequence support — out of scope for this additive registry. + - Tests: `tests/unit/test_frame_time.py`, `test_time_embedding_config.py`; frame-time forward + `type="none"` + state_dict round-trip + freeze-addressability cases in `test_vlm.py`; the set-but-ineffective warning in `test_config.py`; the video build path in `test_distributed.py`. - **Padded frames masked from attention (all four archs).** Short/undecodable video clips pad to `max_frames` with blank frames; the `frame_mask` is now consumed so real tokens never attend to padded-frame visual tokens. One per-token validity mask, `ModalityContext.key_padding_mask` `(B, S)`, threads through the model: the shared `Attention` ANDs it with the causal (and doc) mask — covering Joint-Decoder and MoMa — `MoTAttention` builds an explicit causal-AND-valid mask, and Cross-Attention masks the padded image K/V via its existing `image_mask`. A NaN guard unmasks fully-masked query rows (an all-padded clip) so softmax stays finite. It is a **pure mask — no new state-dict keys**, so checkpoints stay compatible both ways; the image (`F=1`) and text paths are unchanged (no mask is built, so they keep the FlashAttention-2 path). Note: for the image-prefix arches (Joint-Decoder/MoT/MoMa) video self-attention always takes the explicit-mask SDPA path (FA2 disabled, a `(B,1,S,S)` mask materialized) even for fully-decoded clips — the result is identical to causal-only but not free; a deliberate `torch.compile`/DP-friendly trade-off (one graph, no host sync), with FA2-recovery / FlexAttention left as a follow-up. (Cross-Attention keeps FA2 on its text self-attention and only masks padded image K/V in the cross-attention blocks.) Foundation for variable-length / mixed image+video batches. - `kempnerforge/model/modality.py`: `ModalityContext.key_padding_mask` field (+ invariant). `kempnerforge/model/vlm.py`: `_visual_token_mask` expands `frame_mask (B,F)` → `(B, F·P′)`; the four strategies place it (image-prefix arches → `key_padding_mask`, Cross-Attention → `image_mask`). `attention.py` / `mot.py` / `cross_attention.py` consume the mask (+ NaN guard); `moma.py`'s `MoMaFFN` also excludes padded positions from expert-choice routing (so padded tokens never consume expert capacity). `scripts/train.py` threads `batch["frame_mask"]` into the forward. - Deferred: MoT configured with an *MoE* FFN still routes padded tokens through the shared `MoEMLP` — a follow-up ("generic token-validity in MoE") would mask that and padded text alike. MoT-dense (the default) and MoMa are fully masked. diff --git a/docs/how-to/train-on-video.md b/docs/how-to/train-on-video.md index 5fc98f3..47587d1 100644 --- a/docs/how-to/train-on-video.md +++ b/docs/how-to/train-on-video.md @@ -25,8 +25,14 @@ A clip of `F` frames becomes `F × P′` visual tokens: blocks; the residual stays text-only (so it fits more frames per `max_seq_len`). -Temporal order is carried by frame order (sequential positions). Per-frame -timestamp tokens and grounding outputs are a separate follow-up (see below). +Temporal order is carried by frame order (sequential positions). On top of that, +each frame's **timestamp in seconds** is embedded and added to that frame's +visual tokens, so the model sees *when* each frame occurs, not just its order. +The embedding is registry-driven: `[time_embedding].type` selects it +(`sinusoidal` by default — sinusoidal features at log-spaced periods through a +zero-initialized projection; `none` disables it), so new techniques (learned, +Fourier, …) register as small additions and switch via config. Grounding +outputs are a separate follow-up (see below). ## Token budget @@ -99,9 +105,22 @@ time, so it is set in the TOML, not via a `--vlm.arch=` CLI override.) ## Constraints and follow-ups -- **Causal attention; no per-frame timestamps yet** — temporal order is frame - order. Per-frame timestamp tokens + grounding (``/`` outputs - with point-F1 / track-J&F eval) are a follow-up. +- **Grounding outputs are a follow-up** — per-frame timestamps are encoded (see + above), but structured grounding (``/`` outputs with point-F1 + / track-J&F eval) is not yet implemented. +- **Sequence-modifying time encodings are a separate hook** — the + `[time_embedding]` registry is for *additive* per-frame embeddings (no change + to sequence length). Molmo2-style interleaved text time-tokens change the + token sequence and need interleaved/variable-length sequence support KF does + not have yet; they would hook the sequence-assembly layer, not this registry. +- **Inference must pass `frame_times`** — a video model silently drops the + learned temporal signal if `frame_times` is `None` (no error is raised). + Training threads it automatically; eval/generate paths must pass it for video + models. +- **Resuming a pre-timestamp video checkpoint** — a checkpoint trained before + per-frame timestamps lacks the `frame_time_embed` keys, so loading it into the + current (default-on) video model needs `[time_embedding].type = "none"` or a + warm-start key-fill. - **Padded frames are masked from attention** — short/undecodable clips pad to `max_frames` with blank frames, and the `frame_mask` is consumed so real tokens never attend to padded-frame visual tokens (MoMa also drops them from diff --git a/kempnerforge/config/job.py b/kempnerforge/config/job.py index fa39ed2..58f10c1 100644 --- a/kempnerforge/config/job.py +++ b/kempnerforge/config/job.py @@ -14,6 +14,7 @@ from kempnerforge.config.optimizer import OptimizerConfig from kempnerforge.config.profiling import ProfilingConfig from kempnerforge.config.scheduler import SchedulerConfig +from kempnerforge.config.time_embedding import TimeEmbeddingConfig from kempnerforge.config.training import TrainConfig from kempnerforge.config.video import VideoConfig from kempnerforge.config.vision import VisionEncoderConfig @@ -53,6 +54,7 @@ class JobConfig: adapter: AdapterConfig | None = None vlm: VLMConfig | None = None video: VideoConfig | None = None + time_embedding: TimeEmbeddingConfig | None = None def __post_init__(self) -> None: """Cross-section invariants that fire at construction time. @@ -124,6 +126,22 @@ def __post_init__(self) -> None: "the VLM wrapper, so a [vlm] section (and [vision_encoder]) is required." ) + # Set-but-ineffective [time_embedding] warning. The per-frame time + # embedding is built only for video (frames_per_clip > 1); an explicit, + # enabled [time_embedding] on a non-video config is silently ignored, so + # warn (mirrors the HF-encoder-override warning above). type="none" is + # an intentional disable and stays quiet. + if self.time_embedding is not None and self.time_embedding.enabled and self.video is None: + import logging + + logging.getLogger(__name__).warning( + "[time_embedding] is set (type=%r) but no [video] section is present; " + "the time embedding is built only for video (frames_per_clip > 1), so it " + 'will be ignored. Set [time_embedding].type = "none" or remove the section ' + "to silence this.", + self.time_embedding.type, + ) + @property def is_vlm(self) -> bool: """Whether this job builds a ``VLMWrapper`` around the text backbone.""" diff --git a/kempnerforge/config/registry.py b/kempnerforge/config/registry.py index a53b2c0..0cf9bec 100644 --- a/kempnerforge/config/registry.py +++ b/kempnerforge/config/registry.py @@ -227,6 +227,27 @@ def get_sampling_policy(self, name: str) -> Callable: def list_sampling_policies(self) -> list[str]: return self.list("sampling_policy") + def register_time_embedding(self, name: str) -> Callable: + """Decorator to register a time-embedding builder. + + Builders take ``(dim, **kwargs)`` and return an ``nn.Module`` mapping + per-frame timestamps ``(B, F)`` in seconds to an additive embedding + ``(B, F, dim)`` (and exposing ``reset_parameters()`` for meta-device + builds). Selected by ``[time_embedding].type`` on the VLM video path. + """ + + def decorator(fn: Callable) -> Callable: + self.register("time_embedding", name, fn) + return fn + + return decorator + + def get_time_embedding(self, name: str) -> Callable: + return self.get("time_embedding", name) + + def list_time_embeddings(self) -> list[str]: + return self.list("time_embedding") + def register_dyn_ckpt_strategy(self, name: str) -> Callable: """Decorator to register a dynamic-checkpointing-window strategy. diff --git a/kempnerforge/config/schema.py b/kempnerforge/config/schema.py index 09be75f..7c0180d 100644 --- a/kempnerforge/config/schema.py +++ b/kempnerforge/config/schema.py @@ -15,6 +15,7 @@ from kempnerforge.config.optimizer import OptimizerConfig # noqa: F401 from kempnerforge.config.profiling import ProfilingConfig # noqa: F401 from kempnerforge.config.scheduler import SchedulerConfig, SchedulerType # noqa: F401 +from kempnerforge.config.time_embedding import TimeEmbeddingConfig # noqa: F401 from kempnerforge.config.training import ActivationCheckpointing, TrainConfig # noqa: F401 from kempnerforge.config.vision import VisionEncoderConfig # noqa: F401 from kempnerforge.config.vlm import ( # noqa: F401 diff --git a/kempnerforge/config/time_embedding.py b/kempnerforge/config/time_embedding.py new file mode 100644 index 0000000..4b7b2ab --- /dev/null +++ b/kempnerforge/config/time_embedding.py @@ -0,0 +1,73 @@ +"""Time-embedding (per-frame timestamp) configuration. + +``TimeEmbeddingConfig`` selects which per-frame timestamp embedding the VLM +video path uses and parameterizes it. Dispatched via the ``time_embedding`` +registry at build time (see ``kempnerforge/model/frame_time.py``). + +In TOML, ``[time_embedding]`` is a top-level section parallel to ``[adapter]``. +It is only consumed for video (``frames_per_clip > 1``); the image and text +paths never build one. ``type = "none"`` disables the embedding even for video. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from kempnerforge.config.registry import registry + + +@dataclass +class TimeEmbeddingConfig: + """Selects the time-embedding type and parameterizes it. + + Register a new technique via ``@registry.register_time_embedding`` and select + it with ``type``; ``type = "none"`` disables the embedding entirely. + + Fields: + type: Registry key for the builder (``"sinusoidal"`` default, or ``"none"``). + num_bands: Number of sinusoidal frequency bands (``"sinusoidal"`` only). + min_period: Shortest period in seconds (finest temporal resolution). + max_period: Longest period in seconds (coarsest temporal scale). + """ + + type: str = "sinusoidal" + num_bands: int = 16 + min_period: float = 0.5 + max_period: float = 256.0 + + def __post_init__(self) -> None: + if self.type == "none": + return + # Late import: importing the module triggers the + # ``@registry.register_time_embedding`` decorators. Doing it at module + # scope would create a circular import via the config/model graph. + import kempnerforge.model.frame_time # noqa: F401, PLC0415 + + registered = tuple(registry.list_time_embeddings()) + if self.type not in registered: + raise ValueError( + f"Unknown time_embedding.type: {self.type!r}. " + f"Registered: {sorted(registered)} (or 'none' to disable)." + ) + if self.num_bands <= 0: + raise ValueError(f"time_embedding.num_bands must be positive (got {self.num_bands})") + if not 0.0 < self.min_period < self.max_period: + raise ValueError( + f"time_embedding requires 0 < min_period < max_period " + f"(got min_period={self.min_period}, max_period={self.max_period})" + ) + + @property + def enabled(self) -> bool: + """Whether a module should be built (``type != "none"``).""" + return self.type != "none" + + def extra_kwargs(self) -> dict[str, Any]: + """Builder kwargs beyond ``dim``. Type-specific builders take what they + need and swallow the rest via ``**_`` (mirrors ``AdapterConfig``).""" + return { + "num_bands": self.num_bands, + "min_period": self.min_period, + "max_period": self.max_period, + } diff --git a/kempnerforge/config/vlm.py b/kempnerforge/config/vlm.py index 777c27f..6d04bb9 100644 --- a/kempnerforge/config/vlm.py +++ b/kempnerforge/config/vlm.py @@ -47,6 +47,7 @@ "transformer": ["transformer", "transformer.*"], "vision_encoder": ["vision_encoder", "vision_encoder.*"], "adapter": ["adapter", "adapter.*"], + "frame_time_embed": ["frame_time_embed", "frame_time_embed.*"], } diff --git a/kempnerforge/data/video_dataset.py b/kempnerforge/data/video_dataset.py index ff459cc..ad6fe7f 100644 --- a/kempnerforge/data/video_dataset.py +++ b/kempnerforge/data/video_dataset.py @@ -191,7 +191,7 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: caption = self._caps[idx] path = self._video_path(videoid) try: - frames = decode_video_frames( + frames, frame_times_s = decode_video_frames( path, fps=self._fps, min_frames=self._min_frames, @@ -200,16 +200,22 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: ) except Exception as e: # noqa: BLE001 - any decode failure -> skip-with-mask logger.debug("video decode failed for %s: %s", path, e) - frames = [] + frames, frame_times_s = [], [] f = self._max_frames size = self._frame_size pixel_values = torch.zeros(f, 3, size, size, dtype=torch.float32) frame_mask = torch.zeros(f, dtype=torch.bool) + # Per-frame timestamp in seconds; 0.0 for pad frames. The time projection + # runs over every frame and does not consult frame_mask, so pad frames + # still receive a (time-0) embedding — harmless only while padded frames + # are themselves unmasked from attention, and inert once they are masked. + frame_times = torch.zeros(f, dtype=torch.float32) n_real = min(len(frames), f) for i in range(n_real): pixel_values[i] = _pil_to_tensor(frames[i], size, self._image_mean, self._image_std) frame_mask[i] = True + frame_times[i] = frame_times_s[i] prompt = self._prompt or None input_ids, labels = _tokenize_and_mask(self._tokenizer, caption, self._max_text_len, prompt) @@ -219,6 +225,7 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: return { "pixel_values": pixel_values, "frame_mask": frame_mask, + "frame_times": frame_times, "input_ids": input_ids, "labels": labels, } @@ -230,6 +237,7 @@ class VideoCollator: Output keys: - ``pixel_values``: ``(B, F, 3, H, W)`` float32. - ``frame_mask``: ``(B, F)`` bool (``True`` = real frame). + - ``frame_times``: ``(B, F)`` float32 (per-frame time in seconds). - ``input_ids``: ``(B, max_text_len)`` int64. - ``labels``: ``(B, max_text_len)`` int64 with ``-100`` on pad/prompt. @@ -249,6 +257,7 @@ def __call__(self, samples: list[dict[str, torch.Tensor]]) -> dict[str, torch.Te b = len(samples) pixel_values = torch.stack([s["pixel_values"] for s in samples], dim=0) frame_mask = torch.stack([s["frame_mask"] for s in samples], dim=0) + frame_times = torch.stack([s["frame_times"] for s in samples], dim=0) input_ids = torch.full((b, self.max_text_len), self.pad_id, dtype=torch.long) labels = torch.full((b, self.max_text_len), -100, dtype=torch.long) for i, s in enumerate(samples): @@ -260,6 +269,7 @@ def __call__(self, samples: list[dict[str, torch.Tensor]]) -> dict[str, torch.Te return { "pixel_values": pixel_values, "frame_mask": frame_mask, + "frame_times": frame_times, "input_ids": input_ids, "labels": labels, } diff --git a/kempnerforge/data/video_io.py b/kempnerforge/data/video_io.py index b6d91ab..418eb71 100644 --- a/kempnerforge/data/video_io.py +++ b/kempnerforge/data/video_io.py @@ -76,16 +76,18 @@ def _video_duration_seconds(stream: Any, container: Any) -> float: def decode_video_frames( path: str, *, fps: float, min_frames: int, max_frames: int, sampling_policy: str = "uniform" -) -> list[PILImage]: - """Decode a clip into a list of sampled ``PIL.Image`` frames (RGB). +) -> tuple[list[PILImage], list[float]]: + """Decode a clip into sampled ``PIL.Image`` frames (RGB) + their times. Frames are chosen by the registered ``sampling_policy`` (default ``"uniform"`` = ``sample_timestamps``) and read in a single decode pass: each target timestamp is mapped to the first decoded frame at or after it (timestamps past the last frame map to the last frame, so the final frame is - always returned). The returned list has length equal to the number of sampled - timestamps (``<= max_frames``), or is empty when the file has no decodable - video stream. + always returned). Returns ``(frames, times)`` where ``times`` are the matched + frames' actual presentation times in seconds (parallel to ``frames``, so the + caller can encode *when* each frame occurs). Both lists have length equal to + the number of sampled timestamps (``<= max_frames``), or are empty when the + file has no decodable video stream. Raises whatever ``av`` raises on a missing/corrupt file; callers that train over noisy data should catch and substitute an empty clip. @@ -100,9 +102,10 @@ def decode_video_frames( sample = registry.get_sampling_policy(sampling_policy) images: list[PILImage] = [] + times: list[float] = [] with av.open(path) as container: if not container.streams.video: - return images + return images, times stream = container.streams.video[0] stream.thread_type = "AUTO" duration_s = _video_duration_seconds(stream, container) @@ -111,17 +114,22 @@ def decode_video_frames( j = 0 eps = 1e-3 last_frame = None + last_t = 0.0 for frame in container.decode(stream): t = float(frame.time) if frame.time is not None else 0.0 while j < len(targets) and t + eps >= targets[j]: images.append(frame.to_image()) + times.append(t) j += 1 last_frame = frame + last_t = t if j >= len(targets): break # Trailing targets (e.g. the final ``duration_s`` timestamp, which sits # just past the last frame's PTS) map to the last decoded frame. if j < len(targets) and last_frame is not None: tail = last_frame.to_image() - images.extend(tail for _ in range(len(targets) - j)) - return images + for _ in range(len(targets) - j): + images.append(tail) + times.append(last_t) + return images, times diff --git a/kempnerforge/distributed/parallel.py b/kempnerforge/distributed/parallel.py index 404b91d..dfa1ae7 100644 --- a/kempnerforge/distributed/parallel.py +++ b/kempnerforge/distributed/parallel.py @@ -347,6 +347,13 @@ def _apply_fsdp_vlm( mp_policy=policy, reshard_after_forward=reshard_after_forward, ) + if wrapper.frame_time_embed is not None: + fully_shard( + wrapper.frame_time_embed, + mesh=dp_mesh, + mp_policy=policy, + reshard_after_forward=reshard_after_forward, + ) if not encoder_frozen: fully_shard( wrapper.vision_encoder, @@ -374,6 +381,7 @@ def _build_vlm( compile_model: bool, fp8: bool, frames_per_clip: int = 1, + time_embedding_config=None, ) -> torch.nn.Module: """Build a VLM wrapper with parallelism applied in the correct order. @@ -399,6 +407,7 @@ def _build_vlm( from kempnerforge.distributed.expert_parallel import apply_expert_parallel from kempnerforge.distributed.tensor_parallel import apply_tensor_parallel from kempnerforge.model.adapter import build_adapter + from kempnerforge.model.frame_time import build_time_embedding from kempnerforge.model.transformer import Transformer from kempnerforge.model.vlm import ( VLMWrapper, @@ -434,9 +443,24 @@ def _build_vlm( transformer = Transformer( model_config, vlm_config=vlm_config, num_image_tokens=visual_tokens ) + # Video gets a per-frame timestamp embedding (registry-selected via + # [time_embedding]); built alongside the adapter so it shares the + # meta/CPU build + materialize path below. + frame_time_embed = ( + build_time_embedding(time_embedding_config, model_config.dim) + if frames_per_clip > 1 + else None + ) strategy = build_modality_strategy(vlm_config) - wrapper = VLMWrapper(encoder, adapter, transformer, strategy, frames_per_clip=frames_per_clip) + wrapper = VLMWrapper( + encoder, + adapter, + transformer, + strategy, + frames_per_clip=frames_per_clip, + frame_time_embed=frame_time_embed, + ) # 3. Length cross-check now that num_tokens is resolved. required = wrapper.num_image_tokens + vlm_config.max_text_len @@ -470,12 +494,19 @@ def _build_vlm( # weights after to_empty. nn.Module itself does not declare the # method, so pyright sees an unknown attr; suppress the report. adapter.reset_parameters() # type: ignore[reportCallIssue,reportAttributeAccessIssue] + if frame_time_embed is not None: + frame_time_embed.to_empty(device=device) + frame_time_embed.reset_parameters() else: transformer.to(device=device) adapter.to(device=device) + if frame_time_embed is not None: + frame_time_embed.to(device=device) encoder.to(device) # Keep HF dtype per D16. transformer.to(dtype=param_dtype) adapter.to(dtype=param_dtype) + if frame_time_embed is not None: + frame_time_embed.to(dtype=param_dtype) # 7. Freeze specs + eval() for fully frozen encoder. apply_freeze_specs(wrapper, vlm_config.freeze, vlm_config.module_patterns) @@ -510,6 +541,7 @@ def build_parallel_model( compile_model: bool = False, fp8: bool = False, frames_per_clip: int = 1, + time_embedding_config=None, ) -> torch.nn.Module: """Build a Transformer (or a VLMWrapper) with parallelism applied. @@ -554,6 +586,7 @@ def build_parallel_model( compile_model=compile_model, fp8=fp8, frames_per_clip=frames_per_clip, + time_embedding_config=time_embedding_config, ) from kempnerforge.distributed.tensor_parallel import apply_tensor_parallel diff --git a/kempnerforge/model/frame_time.py b/kempnerforge/model/frame_time.py new file mode 100644 index 0000000..fcb0ecf --- /dev/null +++ b/kempnerforge/model/frame_time.py @@ -0,0 +1,161 @@ +"""Per-frame timestamp embedding for the VLM video path. + +A video clip enters the model as ``F`` frames in temporal order, but frame +*order* alone does not tell the model *when* each frame occurs: an 8-frame clip +spanning 2 seconds and one spanning 2 minutes both map to frame indices 0..7. +``FrameTimeEmbedding`` injects the actual per-frame timestamp (seconds) so the +model can reason about elapsed time (Molmo2-style temporal grounding). + +The timestamp is expanded into sinusoidal features at log-spaced periods (à la +Transformer positional encodings, but over continuous seconds rather than +integer positions), then projected to the model dimension and added to that +frame's visual tokens. The output projection is zero-initialized so the +temporal signal starts at zero and is learned from there — matching the +``CrossAttention`` warm-start convention: at step 0 the module is a no-op on +the model's outputs (the added embedding is zero, so logits are unchanged), +while gradients still flow into the projection, so it begins learning from the +first step. + +The frequencies are recomputed in ``forward`` on the input's device rather than +stored in a buffer, so the module is safe to construct under +``torch.device("meta")`` (the meta-device / FSDP build path): the only +parameters are the projection, which materializes like any other ``nn.Linear``. +""" + +from __future__ import annotations + +import math +from typing import Any + +import torch +import torch.nn as nn + +from kempnerforge.config.registry import registry + + +class TimeEmbedding(nn.Module): + """Base for per-frame timestamp embeddings (the *additive* family). + + Contract: ``forward(times: (B, F) seconds) -> (B, F, dim)`` — an additive + embedding added to each frame's visual tokens, with **no change to sequence + length** — plus ``reset_parameters()`` so meta-device builds can re-init + after ``to_empty``. Register a new technique with + ``@registry.register_time_embedding`` and select it via + ``[time_embedding].type``; ``build_time_embedding`` dispatches through the + registry. + + Out of scope (a separate, future integration point): sequence-*modifying* + time encodings — e.g. Molmo2-style textual time-tokens interleaved between + frame groups — change the token sequence (count / ``output_slice`` / + ``modality_ids`` / MoT split) and need tokenizer + interleaved-sequence + support KF does not have yet. Those would hook the sequence-assembly layer + (``ModalityStrategy.prepare``), not this additive registry; set + ``[time_embedding].type = "none"`` to run them instead of an additive one. + """ + + def forward(self, times: torch.Tensor) -> torch.Tensor: # pragma: no cover - interface + raise NotImplementedError + + def reset_parameters(self) -> None: # pragma: no cover - interface + raise NotImplementedError + + +class FrameTimeEmbedding(TimeEmbedding): + """Sinusoidal embedding of a per-frame timestamp (seconds) -> model dim. + + Args: + dim: Model dimension (the embedding is added to the visual tokens). + num_bands: Number of sinusoidal frequency bands; the raw feature width + is ``2 * num_bands`` (sin + cos). + min_period: Shortest period in seconds (highest frequency); sets the + finest temporal resolution. + max_period: Longest period in seconds (lowest frequency); sets the + coarsest temporal scale the embedding can represent. + """ + + def __init__( + self, + dim: int, + num_bands: int = 16, + min_period: float = 0.5, + max_period: float = 256.0, + ) -> None: + super().__init__() + if dim <= 0: + raise ValueError(f"FrameTimeEmbedding dim must be positive (got {dim})") + if num_bands <= 0: + raise ValueError(f"FrameTimeEmbedding num_bands must be positive (got {num_bands})") + if not 0.0 < min_period < max_period: + raise ValueError( + f"FrameTimeEmbedding requires 0 < min_period < max_period " + f"(got min_period={min_period}, max_period={max_period})" + ) + self.dim = dim + self.num_bands = num_bands + self.min_period = float(min_period) + self.max_period = float(max_period) + self.proj = nn.Linear(2 * num_bands, dim) + self.reset_parameters() + + def reset_parameters(self) -> None: + # Zero-init so the temporal signal starts at zero and is learned. Also + # re-init contract for the meta-device build (to_empty -> reset). + nn.init.zeros_(self.proj.weight) + nn.init.zeros_(self.proj.bias) + + def forward(self, times: torch.Tensor) -> torch.Tensor: + """Embed per-frame timestamps. + + Args: + times: Per-frame timestamps in seconds, shape ``(batch, frames)``. + + Returns: + ``(batch, frames, dim)`` temporal embeddings to add to each frame's + visual tokens. + """ + # Angular frequencies for log-spaced periods, on the input's device so + # this is safe regardless of where the module was constructed (meta/CPU). + periods = torch.logspace( + math.log10(self.min_period), + math.log10(self.max_period), + self.num_bands, + device=times.device, + dtype=torch.float32, + ) + ang = times.to(torch.float32).unsqueeze(-1) * (2.0 * math.pi / periods) # (B, F, bands) + feats = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1) # (B, F, 2*bands) + return self.proj(feats.to(self.proj.weight.dtype)) + + +@registry.register_time_embedding("sinusoidal") +def _build_sinusoidal( + dim: int, + *, + num_bands: int = 16, + min_period: float = 0.5, + max_period: float = 256.0, + **_: Any, +) -> FrameTimeEmbedding: + """Registry builder for the sinusoidal time embedding.""" + return FrameTimeEmbedding( + dim, num_bands=num_bands, min_period=min_period, max_period=max_period + ) + + +def build_time_embedding(time_embedding_config: Any, dim: int) -> TimeEmbedding | None: + """Build the per-frame time embedding from a ``TimeEmbeddingConfig``. + + Returns ``None`` when disabled (``type == "none"``). A ``None`` config falls + back to the default (sinusoidal) so video callers that pass nothing keep the + default behavior. The config is duck-typed (``.enabled`` / ``.type`` / + ``.extra_kwargs()``) to avoid a model->config import cycle, matching + ``build_adapter``. + """ + if time_embedding_config is None: + from kempnerforge.config.time_embedding import TimeEmbeddingConfig # noqa: PLC0415 + + time_embedding_config = TimeEmbeddingConfig() + if not time_embedding_config.enabled: + return None + builder = registry.get_time_embedding(time_embedding_config.type) + return builder(dim, **time_embedding_config.extra_kwargs()) diff --git a/kempnerforge/model/vlm.py b/kempnerforge/model/vlm.py index c6ca8f9..0d488c4 100644 --- a/kempnerforge/model/vlm.py +++ b/kempnerforge/model/vlm.py @@ -45,9 +45,11 @@ from kempnerforge.config.adapter import AdapterConfig from kempnerforge.config.registry import registry from kempnerforge.config.schema import ModelConfig +from kempnerforge.config.time_embedding import TimeEmbeddingConfig from kempnerforge.config.vision import VisionEncoderConfig from kempnerforge.config.vlm import FreezeSpec, VLMConfig from kempnerforge.model.adapter import VisionAdapter, build_adapter +from kempnerforge.model.frame_time import TimeEmbedding, build_time_embedding from kempnerforge.model.modality import ModalityContext from kempnerforge.model.transformer import Transformer from kempnerforge.model.vision import VisionEncoder @@ -68,13 +70,17 @@ def prepare( wrapper: VLMWrapper, pixel_values: torch.Tensor, input_ids: torch.Tensor, + *, + frame_times: torch.Tensor | None = None, frame_mask: torch.Tensor | None = None, ) -> ModalityContext: ... def num_image_tokens(self, wrapper: VLMWrapper) -> int: ... -def _project_visual_features(wrapper: VLMWrapper, pixel_values: torch.Tensor) -> torch.Tensor: +def _project_visual_features( + wrapper: VLMWrapper, pixel_values: torch.Tensor, frame_times: torch.Tensor | None = None +) -> torch.Tensor: """Encode + adapt visual features into LLM-dim tokens. Accepts a single-image batch ``(B, 3, H, W)`` or a video-clip batch @@ -114,7 +120,19 @@ def _project_visual_features(wrapper: VLMWrapper, pixel_values: torch.Tensor) -> embeds = wrapper.adapter(feats) if is_video: # (B*F, P', dim) -> (B, F*P', dim): frame-contiguous, temporal order kept. - embeds = embeds.reshape(b, f * embeds.shape[1], embeds.shape[2]) + pprime, dim = embeds.shape[1], embeds.shape[2] + embeds = embeds.reshape(b, f * pprime, dim) + # Per-frame timestamp embedding, broadcast across each frame's P' tokens, + # so the model can reason about *when* each frame occurs (not just order). + if wrapper.frame_time_embed is not None and frame_times is not None: + if frame_times.shape != (b, f): + raise ValueError( + f"frame_times shape {tuple(frame_times.shape)} does not match " + f"(batch, frames) = ({b}, {f})" + ) + t_emb = wrapper.frame_time_embed(frame_times) # (B, F, dim) + t_emb = t_emb.unsqueeze(2).expand(b, f, pprime, dim).reshape(b, f * pprime, dim) + embeds = embeds + t_emb.to(embeds.dtype) return embeds @@ -181,9 +199,11 @@ def prepare( wrapper: VLMWrapper, pixel_values: torch.Tensor, input_ids: torch.Tensor, + *, + frame_times: torch.Tensor | None = None, frame_mask: torch.Tensor | None = None, ) -> ModalityContext: - img_embeds = _project_visual_features(wrapper, pixel_values) + img_embeds = _project_visual_features(wrapper, pixel_values, frame_times) n = img_embeds.shape[1] # pooling-aware: the adapter's actual visual-token count return ModalityContext( prefix_embeds=img_embeds, @@ -215,9 +235,11 @@ def prepare( wrapper: VLMWrapper, pixel_values: torch.Tensor, input_ids: torch.Tensor, # noqa: ARG002 + *, + frame_times: torch.Tensor | None = None, frame_mask: torch.Tensor | None = None, ) -> ModalityContext: - img_embeds = _project_visual_features(wrapper, pixel_values) + img_embeds = _project_visual_features(wrapper, pixel_values, frame_times) return ModalityContext( image_features=img_embeds, image_mask=_visual_token_mask(frame_mask, img_embeds.shape[1]), @@ -253,9 +275,11 @@ def prepare( wrapper: VLMWrapper, pixel_values: torch.Tensor, input_ids: torch.Tensor, + *, + frame_times: torch.Tensor | None = None, frame_mask: torch.Tensor | None = None, ) -> ModalityContext: - img_embeds = _project_visual_features(wrapper, pixel_values) + img_embeds = _project_visual_features(wrapper, pixel_values, frame_times) n = img_embeds.shape[1] # pooling-aware: the adapter's actual visual-token count b, t_text = input_ids.shape modality_ids = torch.zeros(b, n + t_text, dtype=torch.long, device=input_ids.device) @@ -299,9 +323,11 @@ def prepare( wrapper: VLMWrapper, pixel_values: torch.Tensor, input_ids: torch.Tensor, + *, + frame_times: torch.Tensor | None = None, frame_mask: torch.Tensor | None = None, ) -> ModalityContext: - img_embeds = _project_visual_features(wrapper, pixel_values) + img_embeds = _project_visual_features(wrapper, pixel_values, frame_times) n = img_embeds.shape[1] # pooling-aware: the adapter's actual visual-token count b, t_text = input_ids.shape modality_ids = torch.zeros(b, n + t_text, dtype=torch.long, device=input_ids.device) @@ -346,11 +372,15 @@ def __init__( transformer: Transformer, strategy: ModalityStrategy, frames_per_clip: int = 1, + frame_time_embed: TimeEmbedding | None = None, ) -> None: super().__init__() self.vision_encoder = vision_encoder self.adapter = adapter self.transformer = transformer + # Per-frame timestamp embedding (video only; None for the image path). + # A registered submodule so FSDP2 shards it and DCP serializes it. + self.frame_time_embed = frame_time_embed # Frames per video clip (1 for a single image). The static visual-token # count is ``frames_per_clip * adapter.output_num_tokens(...)``; the # strategies use it for ``num_image_tokens`` (residual budget, MoT split). @@ -373,14 +403,26 @@ def forward( pixel_values: torch.Tensor, input_ids: torch.Tensor, labels: torch.Tensor | None = None, + frame_times: torch.Tensor | None = None, frame_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Run the VLM forward. + + ``frame_times`` is ``(B, F)`` per-frame timestamps in seconds. + **Inference contract:** a video model (one built with a time embedding) + silently drops the learned temporal signal when ``frame_times`` is + ``None`` — no error is raised — so eval/generate paths must thread it for + video. ``None`` is correct only for image/text models or an untrained + (zero-init) time embedding. + """ # Route the text embedding through Transformer.forward so FSDP2's # per-module hook intercepts the token_embedding call and # materializes the DTensor weight before F.embedding runs. Doing # the embedding externally (transformer.token_embedding(input_ids)) # bypasses FSDP and fails with "mixed torch.Tensor and DTensor". - modality = self.strategy.prepare(self, pixel_values, input_ids, frame_mask=frame_mask) + modality = self.strategy.prepare( + self, pixel_values, input_ids, frame_times=frame_times, frame_mask=frame_mask + ) logits = self.transformer(tokens=input_ids, modality=modality) return logits, labels @@ -424,6 +466,7 @@ def build_vlm_wrapper( adapter_config: AdapterConfig, vlm_config: VLMConfig, frames_per_clip: int = 1, + time_embedding_config: TimeEmbeddingConfig | None = None, ) -> VLMWrapper: """Build a ``VLMWrapper`` from the four top-level configs. @@ -467,4 +510,18 @@ def build_vlm_wrapper( ) transformer = Transformer(model_config, vlm_config=vlm_config, num_image_tokens=visual_tokens) strategy = build_modality_strategy(vlm_config) - return VLMWrapper(encoder, adapter, transformer, strategy, frames_per_clip=frames_per_clip) + # Video clips get a per-frame timestamp embedding (registry-selected via + # [time_embedding]); the image path (F=1) does not. type="none" disables it. + frame_time_embed = ( + build_time_embedding(time_embedding_config, model_config.dim) + if frames_per_clip > 1 + else None + ) + return VLMWrapper( + encoder, + adapter, + transformer, + strategy, + frames_per_clip=frames_per_clip, + frame_time_embed=frame_time_embed, + ) diff --git a/scripts/train.py b/scripts/train.py index b081310..291e1e3 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -175,6 +175,7 @@ def main() -> None: adapter_config=adapter_cfg, vlm_config=vlm_cfg, frames_per_clip=(config.video.max_frames if config.video is not None else 1), + time_embedding_config=config.time_embedding, ac_mode=tc.activation_checkpointing, mp_policy=mp_policy, param_dtype=tc.param_dtype, @@ -756,14 +757,19 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: pixel_values = batch["pixel_values"].to(device) input_ids = batch["input_ids"].to(device) labels = batch["labels"].to(device) - # Video batches carry a per-frame validity mask; image batches do not. + # Video batches carry per-frame timestamps + a validity mask; images do not. + frame_times = batch["frame_times"].to(device) if "frame_times" in batch else None frame_mask = batch["frame_mask"].to(device) if "frame_mask" in batch else None with maybe_no_sync(model, micro_step, tc.grad_accum_steps): if mc.is_moe: inner_transformer(model).set_moe_step(step, tc.max_steps) # type: ignore[attr-defined] logits, labels_out = model( - pixel_values, input_ids, labels, frame_mask=frame_mask + pixel_values, + input_ids, + labels, + frame_times=frame_times, + frame_mask=frame_mask, ) loss = loss_fn(logits, labels_out) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index f9274af..ed23a81 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -25,6 +25,7 @@ VisionEncoderConfig, VLMConfig, ) +from kempnerforge.config.time_embedding import TimeEmbeddingConfig from kempnerforge.config.vlm import MoMaConfig # --------------------------------------------------------------------------- @@ -602,6 +603,43 @@ def test_no_warning_for_random_with_overrides(self, caplog): assert not any("random" in r.getMessage() for r in caplog.records) +class TestTimeEmbeddingIneffectiveWarning: + """``JobConfig.__post_init__`` warns when an explicit, enabled + ``[time_embedding]`` is paired with a non-video config: the per-frame time + embedding is built only for video (``frames_per_clip > 1``), so it is + silently ignored otherwise (mirrors the HF-encoder-override warning).""" + + def setup_method(self): + import logging + + self._kf_logger = logging.getLogger("kempnerforge") + self._old_propagate = self._kf_logger.propagate + self._kf_logger.propagate = True + + def teardown_method(self): + self._kf_logger.propagate = self._old_propagate + + def _image_vlm(self, time_embedding): + return JobConfig( + model=ModelConfig(max_seq_len=2304), + vision_encoder=VisionEncoderConfig(type="random", feature_dim=384, num_tokens=64), + adapter=AdapterConfig(), + vlm=VLMConfig(max_text_len=2048), + time_embedding=time_embedding, + ) + + def test_warns_when_time_embedding_set_without_video(self, caplog): + with caplog.at_level("WARNING", logger="kempnerforge.config.job"): + self._image_vlm(TimeEmbeddingConfig()) + assert any("[time_embedding] is set" in r.getMessage() for r in caplog.records) + + def test_no_warning_when_disabled(self, caplog): + # type="none" is an intentional disable and stays quiet. + with caplog.at_level("WARNING", logger="kempnerforge.config.job"): + self._image_vlm(TimeEmbeddingConfig(type="none")) + assert not any("[time_embedding] is set" in r.getMessage() for r in caplog.records) + + class TestMoMaAcFullWarning: """``JobConfig.validate`` warns when MoMa is paired with ``activation_checkpointing="full"`` because ``apply_ac`` matches diff --git a/tests/unit/test_distributed.py b/tests/unit/test_distributed.py index 3c93877..ddb238f 100644 --- a/tests/unit/test_distributed.py +++ b/tests/unit/test_distributed.py @@ -377,6 +377,41 @@ def test_param_dtype_applied_to_transformer_and_adapter(self): assert model.transformer.token_embedding.embedding.weight.dtype == torch.bfloat16 assert model.adapter.proj1.weight.dtype == torch.bfloat16 + def test_video_build_attaches_frame_time_embed(self): + """frames_per_clip>1 routes through _build_vlm and attaches a + FrameTimeEmbedding, built + materialized + cast on the device_mesh=None + (no-FSDP) path.""" + from kempnerforge.distributed.parallel import build_parallel_model + + mc, vc, ac, lc = self._vlm_configs(max_seq_len=64, max_text_len=16, num_tokens=8) + model = build_parallel_model( + mc, + torch.device("cpu"), + device_mesh=None, + vision_config=vc, + adapter_config=ac, + vlm_config=lc, + param_dtype=torch.bfloat16, + frames_per_clip=4, + ) + assert model.frame_time_embed is not None + assert model.frame_time_embed.proj.weight.dtype == torch.bfloat16 # cast applied + + def test_image_build_has_no_frame_time_embed(self): + """frames_per_clip=1 (image / default) attaches no FrameTimeEmbedding.""" + from kempnerforge.distributed.parallel import build_parallel_model + + mc, vc, ac, lc = self._vlm_configs() + model = build_parallel_model( + mc, + torch.device("cpu"), + device_mesh=None, + vision_config=vc, + adapter_config=ac, + vlm_config=lc, + ) + assert model.frame_time_embed is None + def test_max_seq_len_too_short_raises(self): """Cross-check: ``num_image_tokens + max_text_len > max_seq_len`` raises.""" from kempnerforge.config.adapter import AdapterConfig diff --git a/tests/unit/test_frame_time.py b/tests/unit/test_frame_time.py new file mode 100644 index 0000000..eb10cc0 --- /dev/null +++ b/tests/unit/test_frame_time.py @@ -0,0 +1,96 @@ +"""Unit tests for FrameTimeEmbedding (per-frame timestamp encoding).""" + +from __future__ import annotations + +import pytest +import torch + +from kempnerforge.model.frame_time import FrameTimeEmbedding, build_time_embedding + + +class TestFrameTimeEmbedding: + def test_output_shape(self): + emb = FrameTimeEmbedding(dim=64, num_bands=8) + out = emb(torch.zeros(2, 4)) # (B, F) -> (B, F, dim) + assert out.shape == (2, 4, 64) + + def test_zero_init_is_zero(self): + # Zero-init proj => the temporal signal starts at exactly zero, so adding + # it is identity at step 0 (the CrossAttention warm-start convention). + emb = FrameTimeEmbedding(dim=32, num_bands=8) + out = emb(torch.tensor([[0.0, 1.0, 5.0, 10.0]])) + assert torch.count_nonzero(out) == 0 + + def test_grad_flows_from_zero_init(self): + # Features are nonzero (cos(0)=1, etc.) so the proj gets a real gradient + # even from zero-init and moves off zero during training. + emb = FrameTimeEmbedding(dim=16, num_bands=4) + emb(torch.tensor([[0.0, 1.0, 2.0, 3.0]])).sum().backward() + assert emb.proj.weight.grad is not None + assert torch.isfinite(emb.proj.weight.grad).all() + assert torch.count_nonzero(emb.proj.weight.grad) > 0 + + def test_distinguishes_timescales(self): + # Same frame INDICES, different absolute times must produce different + # embeddings — the whole point of encoding seconds rather than order. + emb = FrameTimeEmbedding(dim=16, num_bands=8) + with torch.no_grad(): + emb.proj.weight.normal_() + short = emb(torch.tensor([[0.0, 0.5, 1.0, 1.5]])) # 2s clip + long = emb(torch.tensor([[0.0, 20.0, 40.0, 60.0]])) # 60s clip + assert not torch.allclose(short, long) + + def test_dtype_follows_proj(self): + emb = FrameTimeEmbedding(dim=16, num_bands=4).to(torch.bfloat16) + out = emb(torch.zeros(1, 3)) # float32 input, bf16 module + assert out.dtype == torch.bfloat16 + + def test_reset_parameters_rezeros(self): + emb = FrameTimeEmbedding(dim=16, num_bands=4) + with torch.no_grad(): + emb.proj.weight.fill_(1.0) + emb.proj.bias.fill_(1.0) + emb.reset_parameters() + assert torch.count_nonzero(emb.proj.weight) == 0 + assert torch.count_nonzero(emb.proj.bias) == 0 + + @pytest.mark.parametrize( + ("kwargs", "match"), + [ + ({"dim": 0}, "dim must be positive"), + ({"dim": 16, "num_bands": 0}, "num_bands must be positive"), + ({"dim": 16, "min_period": 0.0}, "min_period < max_period"), + ({"dim": 16, "min_period": 10.0, "max_period": 5.0}, "min_period < max_period"), + ], + ) + def test_invalid_args_rejected(self, kwargs, match): + with pytest.raises(ValueError, match=match): + FrameTimeEmbedding(**kwargs) + + +class TestTimeEmbeddingRegistry: + """The registry + builder make the time embedding config-switchable.""" + + def test_sinusoidal_registered(self): + from kempnerforge.config.registry import registry + + assert "sinusoidal" in registry.list_time_embeddings() + + def test_build_none_config_defaults_to_sinusoidal(self): + # A None config preserves the default (sinusoidal) so video callers that + # pass nothing keep the current behavior. + m = build_time_embedding(None, dim=64) + assert isinstance(m, FrameTimeEmbedding) + assert m.proj.out_features == 64 + + def test_build_from_config(self): + from kempnerforge.config.time_embedding import TimeEmbeddingConfig + + m = build_time_embedding(TimeEmbeddingConfig(type="sinusoidal", num_bands=8), dim=32) + assert isinstance(m, FrameTimeEmbedding) + assert m.num_bands == 8 + + def test_build_none_type_returns_none(self): + from kempnerforge.config.time_embedding import TimeEmbeddingConfig + + assert build_time_embedding(TimeEmbeddingConfig(type="none"), dim=32) is None diff --git a/tests/unit/test_time_embedding_config.py b/tests/unit/test_time_embedding_config.py new file mode 100644 index 0000000..c8008f1 --- /dev/null +++ b/tests/unit/test_time_embedding_config.py @@ -0,0 +1,35 @@ +"""Unit tests for TimeEmbeddingConfig (the [time_embedding] section).""" + +from __future__ import annotations + +import pytest + +from kempnerforge.config.time_embedding import TimeEmbeddingConfig + + +class TestTimeEmbeddingConfig: + def test_defaults(self): + c = TimeEmbeddingConfig() + assert c.type == "sinusoidal" + assert c.num_bands == 16 + assert c.enabled is True + + def test_none_is_disabled(self): + c = TimeEmbeddingConfig(type="none") + assert c.enabled is False + + def test_unknown_type_rejected(self): + with pytest.raises(ValueError, match="Unknown time_embedding.type"): + TimeEmbeddingConfig(type="bogus") + + def test_non_positive_num_bands_rejected(self): + with pytest.raises(ValueError, match="num_bands must be positive"): + TimeEmbeddingConfig(num_bands=0) + + def test_bad_periods_rejected(self): + with pytest.raises(ValueError, match="min_period < max_period"): + TimeEmbeddingConfig(min_period=10.0, max_period=5.0) + + def test_extra_kwargs(self): + c = TimeEmbeddingConfig(num_bands=8, min_period=1.0, max_period=100.0) + assert c.extra_kwargs() == {"num_bands": 8, "min_period": 1.0, "max_period": 100.0} diff --git a/tests/unit/test_video_dataset.py b/tests/unit/test_video_dataset.py index 45f45ef..ff14dd0 100644 --- a/tests/unit/test_video_dataset.py +++ b/tests/unit/test_video_dataset.py @@ -76,6 +76,11 @@ def _frames(n: int, size: int = 16) -> list[Image.Image]: return [Image.new("RGB", (size, size), color=(i * 10 % 255, 0, 0)) for i in range(n)] +def _decoded(n: int, size: int = 16) -> tuple[list[Image.Image], list[float]]: + """Mimic decode_video_frames: (frames, per-frame presentation seconds).""" + return _frames(n, size), [float(i) for i in range(n)] + + # --------------------------------------------------------------------------- # Video path mapping (verified against the on-disk WebVid layout) # --------------------------------------------------------------------------- @@ -103,7 +108,7 @@ def test_validation_is_flat(self): class TestGetItem: def test_shapes_and_mask_full_clip(self, monkeypatch): - monkeypatch.setattr(vd, "decode_video_frames", lambda *a, **k: _frames(8)) + monkeypatch.setattr(vd, "decode_video_frames", lambda *a, **k: _decoded(8)) ds = _StubVideoDataset(["1"], ["a cat."], max_frames=8, frame_size=16) item = ds[0] assert item["pixel_values"].shape == (8, 3, 16, 16) @@ -111,19 +116,25 @@ def test_shapes_and_mask_full_clip(self, monkeypatch): assert item["frame_mask"].shape == (8,) assert item["frame_mask"].dtype == torch.bool assert item["frame_mask"].all() + assert item["frame_times"].shape == (8,) + assert item["frame_times"].dtype == torch.float32 + assert item["frame_times"].tolist() == [float(i) for i in range(8)] assert item["input_ids"].shape == (8,) assert item["labels"].shape == (8,) def test_pads_and_masks_short_clip(self, monkeypatch): - monkeypatch.setattr(vd, "decode_video_frames", lambda *a, **k: _frames(3)) + monkeypatch.setattr(vd, "decode_video_frames", lambda *a, **k: _decoded(3)) ds = _StubVideoDataset(["1"], ["a dog."], max_frames=8) item = ds[0] assert item["frame_mask"].tolist() == [True, True, True, False, False, False, False, False] # Padded frames are zeros. assert torch.count_nonzero(item["pixel_values"][3:]) == 0 + # Real frames carry their times; pad frames are 0.0. + assert item["frame_times"][:3].tolist() == [0.0, 1.0, 2.0] + assert item["frame_times"][3:].tolist() == [0.0, 0.0, 0.0, 0.0, 0.0] def test_caption_is_supervised_when_frames_present(self, monkeypatch): - monkeypatch.setattr(vd, "decode_video_frames", lambda *a, **k: _frames(4)) + monkeypatch.setattr(vd, "decode_video_frames", lambda *a, **k: _decoded(4)) ds = _StubVideoDataset(["1"], ["abc"], max_frames=8, max_text_len=8) item = ds[0] # "abc" -> ids 1,2,3 supervised; rest -100. @@ -139,17 +150,18 @@ def _boom(*a, **k): item = ds[0] assert torch.count_nonzero(item["pixel_values"]) == 0 assert not item["frame_mask"].any() + assert torch.count_nonzero(item["frame_times"]) == 0 assert (item["labels"] == -100).all() # no supervision for an unloadable clip def test_empty_decode_yields_zero_clip_no_loss(self, monkeypatch): - monkeypatch.setattr(vd, "decode_video_frames", lambda *a, **k: []) + monkeypatch.setattr(vd, "decode_video_frames", lambda *a, **k: ([], [])) ds = _StubVideoDataset(["1"], ["a cat."], max_frames=4) item = ds[0] assert not item["frame_mask"].any() assert (item["labels"] == -100).all() def test_prompt_is_masked(self, monkeypatch): - monkeypatch.setattr(vd, "decode_video_frames", lambda *a, **k: _frames(2)) + monkeypatch.setattr(vd, "decode_video_frames", lambda *a, **k: _decoded(2)) ds = _StubVideoDataset(["1"], ["xyz"], max_frames=4, max_text_len=8, prompt="ab") item = ds[0] # prompt "ab" (2 toks) masked; "xyz" (24,25,26) supervised. @@ -177,7 +189,15 @@ def _sample(self, n_frames_valid: int, max_frames: int = 4, max_text_len: int = ids[:3] = torch.tensor([1, 2, 3]) labels = torch.full((max_text_len,), -100, dtype=torch.long) labels[:3] = torch.tensor([1, 2, 3]) - return {"pixel_values": pv, "frame_mask": mask, "input_ids": ids, "labels": labels} + times = torch.zeros(max_frames, dtype=torch.float32) + times[:n_frames_valid] = torch.arange(n_frames_valid, dtype=torch.float32) + return { + "pixel_values": pv, + "frame_mask": mask, + "frame_times": times, + "input_ids": ids, + "labels": labels, + } def test_batch_shapes(self): collator = VideoCollator(pad_id=0, max_text_len=8) @@ -185,6 +205,8 @@ def test_batch_shapes(self): assert batch["pixel_values"].shape == (3, 4, 3, 16, 16) assert batch["frame_mask"].shape == (3, 4) assert batch["frame_mask"].dtype == torch.bool + assert batch["frame_times"].shape == (3, 4) + assert batch["frame_times"].dtype == torch.float32 assert batch["input_ids"].shape == (3, 8) assert batch["labels"].shape == (3, 8) @@ -259,6 +281,8 @@ def test_init_getitem_and_decode(self, tmp_path): item = ds[0] assert item["pixel_values"].shape == (8, 3, 32, 32) assert item["frame_mask"].any() # real frames decoded + assert item["frame_times"].shape == (8,) + assert (item["frame_times"][item["frame_mask"]] >= 0).all() # real-frame times set assert (item["labels"] != -100).any() # caption supervised def test_decode_failure_is_masked(self, tmp_path): diff --git a/tests/unit/test_video_io.py b/tests/unit/test_video_io.py index 4b08d33..3a5ac00 100644 --- a/tests/unit/test_video_io.py +++ b/tests/unit/test_video_io.py @@ -89,15 +89,17 @@ def test_decodes_pil_frames(self): from kempnerforge.data.video_io import decode_video_frames - frames = decode_video_frames(_WEBVID_CLIP, fps=2.0, min_frames=4, max_frames=8) + frames, times = decode_video_frames(_WEBVID_CLIP, fps=2.0, min_frames=4, max_frames=8) assert 1 <= len(frames) <= 8 + assert len(times) == len(frames) assert all(isinstance(f, Image.Image) and f.mode == "RGB" for f in frames) def test_respects_max_frames(self): from kempnerforge.data.video_io import decode_video_frames - frames = decode_video_frames(_WEBVID_CLIP, fps=8.0, min_frames=4, max_frames=4) + frames, times = decode_video_frames(_WEBVID_CLIP, fps=8.0, min_frames=4, max_frames=4) assert len(frames) == 4 + assert len(times) == 4 def test_missing_file_raises(self): from kempnerforge.data.video_io import decode_video_frames @@ -136,8 +138,10 @@ def test_decodes_rgb_frames(self, tmp_path): path = tmp_path / "clip.mp4" _write_mp4(path, n_frames=20, fps=10) # ~2s - frames = decode_video_frames(str(path), fps=2.0, min_frames=4, max_frames=8) + frames, times = decode_video_frames(str(path), fps=2.0, min_frames=4, max_frames=8) assert 1 <= len(frames) <= 8 + assert len(times) == len(frames) + assert times == sorted(times) # presentation times are non-decreasing assert all(isinstance(f, Image.Image) and f.mode == "RGB" for f in frames) def test_respects_max_frames(self, tmp_path): @@ -145,16 +149,18 @@ def test_respects_max_frames(self, tmp_path): path = tmp_path / "clip.mp4" _write_mp4(path, n_frames=40, fps=10) # ~4s - frames = decode_video_frames(str(path), fps=8.0, min_frames=4, max_frames=4) + frames, times = decode_video_frames(str(path), fps=8.0, min_frames=4, max_frames=4) assert len(frames) == 4 + assert len(times) == 4 def test_short_clip_returns_frames(self, tmp_path): from kempnerforge.data.video_io import decode_video_frames path = tmp_path / "short.mp4" _write_mp4(path, n_frames=3, fps=10) # shorter than min_frames request - frames = decode_video_frames(str(path), fps=2.0, min_frames=4, max_frames=8) + frames, times = decode_video_frames(str(path), fps=2.0, min_frames=4, max_frames=8) assert len(frames) >= 1 + assert len(times) == len(frames) class TestSamplingPolicyRegistry: diff --git a/tests/unit/test_vlm.py b/tests/unit/test_vlm.py index e9fd04c..6b9d7fa 100644 --- a/tests/unit/test_vlm.py +++ b/tests/unit/test_vlm.py @@ -650,6 +650,125 @@ def test_frame_count_mismatch_raises(self): wrapper(torch.randn(2, 3, 16, 16, device=DEVICE), input_ids) +class TestVideoFrameTimes: + """Per-frame timestamp embedding (frame-aware visual prefix).""" + + def test_frame_time_embed_built_for_video(self): + wrapper = _video_wrapper(JointDecoderConfig(max_text_len=8), frames=4) + assert wrapper.frame_time_embed is not None + + def test_image_wrapper_has_no_frame_time_embed(self): + wrapper = _video_wrapper(JointDecoderConfig(max_text_len=8), frames=1) + assert wrapper.frame_time_embed is None + + def test_forward_with_frame_times_keeps_shape(self): + wrapper = _video_wrapper(JointDecoderConfig(max_text_len=8), frames=4).to(DEVICE) + pixels = torch.randn(2, 4, 3, 16, 16, device=DEVICE) + input_ids = torch.randint(0, 256, (2, 6), device=DEVICE) + ft = torch.tensor([[0.0, 1.0, 2.0, 3.0], [0.0, 5.0, 10.0, 15.0]], device=DEVICE) + logits, _ = wrapper(pixels, input_ids, frame_times=ft) + assert logits.shape == (2, 6, 256) + + def test_zero_init_temporal_is_identity(self): + # Zero-init => passing frame_times adds nothing at step 0 (warm-start). + wrapper = _video_wrapper(JointDecoderConfig(max_text_len=8), frames=4).to(DEVICE).eval() + pixels = torch.randn(1, 4, 3, 16, 16, device=DEVICE) + input_ids = torch.randint(0, 256, (1, 6), device=DEVICE) + ft = torch.tensor([[0.0, 1.0, 2.0, 3.0]], device=DEVICE) + with torch.no_grad(): + l_none, _ = wrapper(pixels, input_ids) + l_zero, _ = wrapper(pixels, input_ids, frame_times=ft) + assert torch.allclose(l_none, l_zero) + + def test_frame_times_change_logits_once_learned(self): + # After the temporal proj is nonzero, different timestamps change output. + wrapper = _video_wrapper(JointDecoderConfig(max_text_len=8), frames=4).to(DEVICE).eval() + with torch.no_grad(): + wrapper.frame_time_embed.proj.weight.normal_(std=0.1) + pixels = torch.randn(1, 4, 3, 16, 16, device=DEVICE) + input_ids = torch.randint(0, 256, (1, 6), device=DEVICE) + t1 = torch.tensor([[0.0, 1.0, 2.0, 3.0]], device=DEVICE) + t2 = torch.tensor([[0.0, 20.0, 40.0, 60.0]], device=DEVICE) + with torch.no_grad(): + l1, _ = wrapper(pixels, input_ids, frame_times=t1) + l2, _ = wrapper(pixels, input_ids, frame_times=t2) + assert not torch.allclose(l1, l2) + + def test_frame_times_shape_mismatch_raises(self): + wrapper = _video_wrapper(JointDecoderConfig(max_text_len=8), frames=4).to(DEVICE) + pixels = torch.randn(2, 4, 3, 16, 16, device=DEVICE) + input_ids = torch.randint(0, 256, (2, 6), device=DEVICE) + bad = torch.zeros(2, 3, device=DEVICE) # F=3 != frames_per_clip=4 + with pytest.raises(ValueError, match="frame_times shape"): + wrapper(pixels, input_ids, frame_times=bad) + + def test_time_embedding_none_disables_for_video(self): + # Registry "none" => no temporal module even for video (frames_per_clip>1). + from kempnerforge.config.time_embedding import TimeEmbeddingConfig + + mc = ModelConfig(dim=64, n_layers=2, n_heads=4, vocab_size=256, max_seq_len=64) + vc = VisionEncoderConfig(type="random", feature_dim=96, num_tokens=16) + ac = AdapterConfig(type="avgpool", pool_window=2) + wrapper = build_vlm_wrapper( + mc, + vc, + ac, + JointDecoderConfig(max_text_len=8), + frames_per_clip=4, + time_embedding_config=TimeEmbeddingConfig(type="none"), + ) + assert wrapper.frame_time_embed is None + + def test_frame_time_embed_state_dict_round_trips(self): + # The default-on time embedding adds frame_time_embed.proj.* keys to + # every video checkpoint; pin that they serialize and reload with a + # bit-equal forward (the per-arch state_dict round-trip invariant, + # extended to the video model). Move the projection off zero-init first + # so the round-trip is not trivially comparing zeros. + wrapper = _video_wrapper(JointDecoderConfig(max_text_len=8), frames=4).to(DEVICE).eval() + with torch.no_grad(): + wrapper.frame_time_embed.proj.weight.normal_(std=0.1) + wrapper.frame_time_embed.proj.bias.normal_(std=0.1) + pixels = torch.randn(1, 4, 3, 16, 16, device=DEVICE) + input_ids = torch.randint(0, 256, (1, 6), device=DEVICE) + ft = torch.tensor([[0.0, 1.0, 2.0, 3.0]], device=DEVICE) + with torch.no_grad(): + logits_before, _ = wrapper(pixels, input_ids, frame_times=ft) + + state = wrapper.state_dict() + assert "frame_time_embed.proj.weight" in state + assert "frame_time_embed.proj.bias" in state + + restored = _video_wrapper(JointDecoderConfig(max_text_len=8), frames=4).to(DEVICE).eval() + restored.load_state_dict(state) + assert torch.equal( + restored.frame_time_embed.proj.weight, wrapper.frame_time_embed.proj.weight + ) + with torch.no_grad(): + logits_after, _ = restored(pixels, input_ids, frame_times=ft) + assert torch.equal(logits_before, logits_after) + + def test_frame_time_embed_is_freeze_addressable(self): + # frame_time_embed is an injected sibling submodule (like the adapter), + # so a freeze spec must be able to target it — warm-start staging recipes + # freeze it while the LLM warms up. The base module_patterns alias makes + # it clear the effective_freeze typo-guard and lets apply_freeze_specs + # match its params. + from kempnerforge.config.vlm import FreezeSpec + from kempnerforge.training.freeze import apply_freeze_specs, effective_freeze + + vlm_cfg = JointDecoderConfig(max_text_len=8) + wrapper = _video_wrapper(vlm_cfg, frames=4) + assert wrapper.frame_time_embed is not None + # (a) the alias clears the train.py typo-guard (pre-fix this raised). + valid = set(vlm_cfg.module_patterns.keys()) + specs = effective_freeze(0, [FreezeSpec("frame_time_embed", True)], [], valid) + # (b) applying it actually freezes the projection params. + apply_freeze_specs(wrapper, specs, vlm_cfg.module_patterns) + assert not wrapper.frame_time_embed.proj.weight.requires_grad + assert not wrapper.frame_time_embed.proj.bias.requires_grad + + class TestFramePaddingMask: """frame_mask hides padded video frames from attention (and, for MoMa, from expert-choice routing), so real-token outputs are invariant to padded-frame