diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f7c8a7..9d13fa9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `scripts/train.py`: training-loop hook that runs `mot_warm_start_from_text_stack` once at step 0 when `mot_warm_start_from_text=True`, between `ckpt_mgr.load(...)` and `apply_freeze_specs(...)`. - Tests: `tests/unit/test_mot.py` (Algorithm-1 reference parity, warm-start helper round-trips); `tests/unit/test_model.py::TestMoT` + `TestModalityIdsCrossArgs`; MoT cases in `tests/unit/test_vlm.py`, `test_vlm_config.py`, `test_modality_context.py`; `tests/integration/test_vlm_mot.py`; `tests/distributed/test_vlm_mot_fsdp.py` (gated on multi-GPU). - Configs: `configs/train/vlm_debug_mot.toml` (1-GPU smoke) and `configs/train/vlm_7b_mot.toml` (4-GPU 7B). +- **Video understanding (all four archs).** Extend the VLM path to ingest video — a clip is an ordered set of frames sampled at a target fps — through the same registry-driven design as the image path; trains end-to-end on WebVid-10M with Joint-Decoder, Cross-Attention, MoT, and MoMa. The text-only and single-image paths are unchanged (bit-exact). + - **Pooling connector.** `kempnerforge/model/adapter.py`: a `VisionAdapter` base with an `output_num_tokens()` contract, plus `avgpool` and `attentional_pool` (Molmo2-style mean-query MHA) connectors registered via `@registry.register_adapter`; `kempnerforge/config/adapter.py` gains `pool_window` / `pool_heads`. The adapter-derived visual-token count is threaded through `build_vlm_wrapper` / `_build_vlm`, the four modality strategies, and the three seq-len checks (`config/job.py`, `distributed/parallel.py`, `model/vlm.py`). Projection adapters keep the count (identity) so the image path is bit-exact. + - **Video data path.** `kempnerforge/data/video_io.py`: timestamp-based frame sampling (target fps, uniform, first & last frame kept — Molmo2 §3.1/§A) + PyAV decode (lazily imported; the wheel bundles FFmpeg, so no system FFmpeg is required). `kempnerforge/data/video_dataset.py`: `WebVidVideoDataset` (CSV manifest + `id[:2]/id[:4]/id[:6]/id.mp4` path mapping, reuses the image preprocessing) and `VideoCollator` → `(B, F, 3, H, W)` + a frame-validity mask; an undecodable clip is masked out (contributes no loss). `kempnerforge/config/video.py`: the `[video]` `VideoConfig` section (`data_root`, `split`, `fps`, `max_frames`, `min_frames`, `frame_size`, `max_samples`), wired into `JobConfig` (+ `is_video`). Adds `av` to dependencies. + - **Frame-aware model + training wiring.** `kempnerforge/model/vlm.py`: `_project_image_features` → `_project_visual_features` folds the frame axis through the encoder + pooler to `(B, F·P′, dim)` (a single image is the `F == 1` case). `VLMWrapper` gains `frames_per_clip`, threaded through `build_parallel_model` / `_build_vlm` / `build_vlm_wrapper` so the static visual-token count equals `F·P′` (drives the residual budget and MoT's positional split; static == runtime). `scripts/train.py` builds the video dataset/collator when `[video]` is set. Adds `configs/train/vlm_video_webvid.toml` (SigLIP2 + avgpool + WebVid). + - Tests: `tests/unit/test_video_io.py`, `test_video_dataset.py`, `test_video_config.py`; video-forward cases (all four archs) + image-path regression in `test_vlm.py`; pooling-adapter cases in `test_adapter.py`. Docs: `docs/how-to/train-on-video.md`. + - Deferred (follow-ups): per-frame timestamp tokens + grounding (``/`` outputs with point-F1 / track-J&F eval), frame-mask-aware attention, bidirectional visual attention, VLM sequence packing, long-context (blocked on context-parallel being wired), and warm-start from a converted image-VLM checkpoint. - `install-and-verify` plugin skill: runs `uv sync`, asserts Python ≥ 3.12, then runs the four CI gate checks (`ruff check`, `ruff format --check`, `pyright`, `pytest tests/unit/`). Canonical first command after cloning. - `.python-version` pinned to `>=3.12` so uv resolves the interpreter explicitly. Teammates on 3.13 use 3.13 (no download); 3.11-only users get 3.12 auto-fetched. - **Dynamic-checkpointing window** (`[checkpoint.dyn_ckpt_window]`). Opt-in dense save phase: inside `[start, stop]` a registered strategy decides which steps to save; outside the window the regular `interval` cadence applies. The default strategy, `"power2"`, saves at `start` and at every `start + 2^k` while `<= stop` — tight near the start of the window, doubling thereafter. Useful for analyzing early-training dynamics, where the loss moves fastest. The default `CheckpointConfig` is unchanged (no `dyn_ckpt_window`, interval-only saves). diff --git a/README.md b/README.md index 7e404fb..69873c2 100644 --- a/README.md +++ b/README.md @@ -141,13 +141,15 @@ Further reading: ### Vision-language models -KempnerForge supports VLM training via a thin wrapper around the existing `Transformer`. Image tokens come from a frozen HF vision encoder (SigLIP2, CLIP, or a tiny `random` test stub), pass through a 2-layer adapter, and feed the backbone via an arch-specific path: +KempnerForge supports VLM training — images **or video** (a clip is an ordered set of frames) — via a thin wrapper around the existing `Transformer`. Visual tokens come from a frozen HF vision encoder (SigLIP2, CLIP, or a tiny `random` test stub), pass through a connector (a 2-layer adapter, or an `avgpool` / `attentional_pool` connector that pools patches to reduce tokens per frame), and feed the backbone via an arch-specific path: - **Joint-Decoder** (`arch = "joint_decoder"`): image embeds are prepended to the text sequence; the transformer runs over the concatenated `(image, text)` sequence and the LM head is applied to text positions only. - **Cross-Attention** (`arch = "cross_attention"`, Llama-3-V style): the residual stream carries text only. Separate `CrossAttentionBlock`s inserted at a configurable cadence let text queries attend to image K/V. CA blocks are zero-initialized so adding the arch on top of a text-only checkpoint is identity at step 0 and learns from there. - **Mixture-of-Transformers** (`arch = "mot"`, Liang et al. 2024 Algorithm 1): every layer carries per-modality Q/K/V/O projections plus a per-modality FFN; a single global self-attention mixes all modality streams. Image tokens prepend the text sequence (same residual layout as Joint-Decoder); per-modality residual projections are zero-initialized so a fresh MoT block is identity at construction. A warm-start helper (`mot_warm_start_from_text_stack`) translates a JD or text-only checkpoint into per-modality copies — toggle via `[model.vlm].mot_warm_start_from_text` + `mot_warm_start_path`. - **Mixture of Modality-Aware Experts** (`arch = "moma"`, Lin et al. 2024 arXiv:2407.21770): one shared set of Q/K/V/O projections feeding a global self-attention, plus per-modality MoE FFN groups (paper's optimal default 4 image + 4 text experts per layer). Tokens route deterministically to their modality group (level-1, reusing the same `modality_ids` mechanism MoT uses) and then through a learned expert-choice + Sigmoid router within the group (level-2, with Gumbel-Sigmoid noise during training, paper Eq. 5). Image tokens prepend the text sequence (same residual layout as JD/MoT). v1 supports training only — expert-choice routing is non-causal, so autoregressive generation requires auxiliary routers (paper §2.4) which are deferred to a follow-up. +**Video** works across all four archs with no arch-specific changes: a clip is decoded into frames (sampled by timestamp at a target fps — uniform, with the first and last frame always kept), each frame is encoded and pooled by the connector, and the `F × tokens_per_frame` visual tokens enter the backbone exactly like image tokens. Configure the `[video]` section (`data_root`, `fps`, `max_frames`, `frame_size`); see `configs/train/vlm_video_webvid.toml`. + ```bash # 1-GPU smoke (random encoder, Joint-Decoder) uv run python scripts/train.py configs/train/vlm_debug.toml \ @@ -164,9 +166,12 @@ uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_7b_mot.tom # 4-GPU 7B Mixture of Modality-Aware Experts (4 text + 4 image experts per layer) uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_7b_moma.toml + +# 4-GPU video training on WebVid (Joint-Decoder; flip [vlm].arch for cross_attention / mot / moma) +uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_video_webvid.toml ``` -Configs set `[model.vlm]` with `arch`, the encoder registry key, the number of image tokens, and a freeze list (`FreezeSpec`). For Cross-Attention, set `cross_attention_every_n_layers` and optionally `cross_attention_n_kv_heads` (0 → MHA; positive → GQA on the cross path). For MoT, set `mot_modalities` (must include both `"image"` and `"text"`); `mot_image_n_heads` / `mot_image_n_kv_heads` are forward-looking per-modality head fields (v1 enforces equality with the text-side counts since the operator runs a single global SDPA). For MoMa, set `moma_experts_per_modality = {image = N, text = M}` as a nested TOML table (the paper's optimal balanced default is `4t4i`; unbalanced allocations like `{image = 1, text = 7}` match the paper's `moe_7t1i` ablation), and optionally `moma_capacity_factor` (defaults to `1/|E^M|` per modality — the paper's perfect-balance setting) and `moma_gumbel_noise` (`true` by default for paper-faithful EC routing). `model.num_experts` must be `0` when `arch = "moma"`; the per-modality counts supersede it, and JobConfig.validate rejects the combination. The vision encoder stays in its HF-loaded dtype; the transformer, adapter, and CA / MoT / MoMa blocks are cast to `param_dtype`. Pipeline Parallel + VLM is not supported on this branch (raises at startup); MoMa + Expert Parallelism is also rejected in v1. Multi-image and video are reserved slots for follow-up work. +Configs set `[model.vlm]` with `arch`, the encoder registry key, the number of image tokens, and a freeze list (`FreezeSpec`). For Cross-Attention, set `cross_attention_every_n_layers` and optionally `cross_attention_n_kv_heads` (0 → MHA; positive → GQA on the cross path). For MoT, set `mot_modalities` (must include both `"image"` and `"text"`); `mot_image_n_heads` / `mot_image_n_kv_heads` are forward-looking per-modality head fields (v1 enforces equality with the text-side counts since the operator runs a single global SDPA). For MoMa, set `moma_experts_per_modality = {image = N, text = M}` as a nested TOML table (the paper's optimal balanced default is `4t4i`; unbalanced allocations like `{image = 1, text = 7}` match the paper's `moe_7t1i` ablation), and optionally `moma_capacity_factor` (defaults to `1/|E^M|` per modality — the paper's perfect-balance setting) and `moma_gumbel_noise` (`true` by default for paper-faithful EC routing). `model.num_experts` must be `0` when `arch = "moma"`; the per-modality counts supersede it, and JobConfig.validate rejects the combination. The vision encoder stays in its HF-loaded dtype; the transformer, adapter, and CA / MoT / MoMa blocks are cast to `param_dtype`. Pipeline Parallel + VLM is not supported on this branch (raises at startup); MoMa + Expert Parallelism is also rejected in v1. Video is supported across all four archs via the `[video]` section (a clip is decoded into frames, pooled by the connector, and fed like image tokens); multi-image inputs and video *grounding* (point/track outputs with per-frame timestamps) are reserved for follow-up work. **Adding a new VLM arch.** The discriminated-union dispatch is registry-driven, so a new arch is four small additions, no edits to existing call sites: diff --git a/configs/train/vlm_video_webvid.toml b/configs/train/vlm_video_webvid.toml new file mode 100644 index 0000000..785b48b --- /dev/null +++ b/configs/train/vlm_video_webvid.toml @@ -0,0 +1,93 @@ +# Video VLM training on WebVid-10M (Joint-Decoder). +# +# A clip is decoded into `max_frames` frames (sampled by timestamp at `fps`, +# Molmo2-style: uniform, first + last frame kept), each encoded by the vision +# tower, pooled by the connector, and prepended to the caption tokens. The LLM +# is trained from scratch here; warm-starting from a pretrained backbone is a +# later phase. +# +# Visual-token budget (Joint-Decoder): max_frames * tokens_per_frame + max_text_len +# must be <= model.max_seq_len. With SigLIP2 @224/patch16 (14x14 = 196 patches) +# and the avgpool connector at pool_window=2 -> ceil(14/2)^2 = 49 tokens/frame. +# 8 frames -> 8*49 = 392 visual + 64 text = 456 <= 576. +# +# Launch (single node, 4 GPUs): +# uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_video_webvid.toml +# +# Quick smoke (no SigLIP download, a few clips; pair with a small step count): +# ... --vision_encoder.type=random --vision_encoder.num_tokens=196 \ +# --vision_encoder.feature_dim=768 --video.max_samples=256 --train.max_steps=20 + +[model] +dim = 1024 +n_layers = 12 +n_heads = 16 +n_kv_heads = 16 +vocab_size = 50257 # gpt2 tokenizer +max_seq_len = 576 # 8 frames * 49 pooled tokens + 64 text + headroom +norm_type = "rmsnorm" +activation = "silu" +rope_theta = 10000.0 + +[vision_encoder] +type = "siglip2" +path = "google/siglip2-base-patch16-224" +# feature_dim / num_tokens left at 0 -> probed from the HF model at build time. + +[adapter] +type = "avgpool" # token-reducing connector; "attentional_pool" is the faithful upgrade +pool_window = 2 # 14x14 patch grid -> 7x7 = 49 tokens/frame + +[vlm] +arch = "joint_decoder" # also: cross_attention / mot / moma +max_text_len = 64 +freeze = [{module = "vision_encoder", frozen = true}] + +[video] +data_root = "/n/holylfs06/LABS/kempner_shared/Everyone/testbed/video/webvid-10m" +split = "train" +fps = 2.0 +max_frames = 8 +min_frames = 4 +frame_size = 224 +max_samples = 0 # 0 = full manifest; set small (e.g. 256) for a smoke + +[data] +tokenizer_path = "gpt2" +num_workers = 4 +pin_memory = true + +[train] +batch_size = 4 +seq_len = 576 +max_steps = 5000 +grad_accum_steps = 4 +grad_clip_norm = 1.0 +seed = 42 +compile_model = false +activation_checkpointing = "none" + +[optimizer] +name = "adamw" +lr = 2e-4 +weight_decay = 0.1 +betas = [0.9, 0.95] +fused = false + +[scheduler] +name = "cosine" +warmup_steps = 200 +min_lr_ratio = 0.1 + +[distributed] +dp_shard = -1 + +[checkpoint] +dir = "checkpoints/vlm_video_webvid" +interval = 1000 +keep_last_n = 2 + +[metrics] +log_interval = 10 +enable_wandb = false +enable_tensorboard = false diff --git a/docs/how-to/index.md b/docs/how-to/index.md index 2f2a82d..d44ff02 100644 --- a/docs/how-to/index.md +++ b/docs/how-to/index.md @@ -52,6 +52,12 @@ with runnable code (or a link to a notebook, config, or script that runs it) — `ActivationStore`, `extract_representations()`, save to `.npz`, feed to probing / CKA / SVCCA. +## Multimodal + +- [Train on video](train-on-video.md) — ingest a clip as frames through the + VLM path: timestamp frame sampling, the pooling connector, the `[video]` + config + token budget, and how all four archs consume a clip. + ```{toctree} :maxdepth: 1 :hidden: @@ -69,4 +75,5 @@ data-mixing-annealing fp8-training moe-experiments mechanistic-interpretability +train-on-video ``` diff --git a/docs/how-to/train-on-video.md b/docs/how-to/train-on-video.md new file mode 100644 index 0000000..ce146e8 --- /dev/null +++ b/docs/how-to/train-on-video.md @@ -0,0 +1,99 @@ +# Train on video + +The VLM path ingests **video** through the same wrapper, connectors, and fusion +archs as images — a clip is just an ordered set of frames. This guide covers the +data layout, the `[video]` config, the frame-sampling policy, and how all four +archs consume a clip. + +## The model's view of a clip + +A clip of `F` frames becomes `F × P′` visual tokens: + +1. **Sample** `F` frames from the video by timestamp (target `fps`, uniform, + first and last frame always kept). +2. **Encode** each frame with the frozen vision tower (e.g. SigLIP2), fold the + frame axis into the batch so `B×F` frames run through the encoder once. +3. **Pool + project** each frame with the connector — an `avgpool` or + `attentional_pool` adapter reduces a `grid×grid` patch map to + `P′ = ceil(grid/window)²` tokens per frame (e.g. SigLIP2 @224/patch16 → + 14×14 → 49 tokens at `pool_window=2`). +4. **Fuse** the resulting `(B, F·P′, dim)` visual tokens into the backbone the + same way images are fused — so **all four archs work unchanged**: + - `joint_decoder` / `mot` / `moma`: the `F·P′` tokens prepend the text in the + residual stream and are trimmed before the LM head. + - `cross_attention`: the `F·P′` tokens flow as K/V into the cross-attention + blocks; the residual stays text-only (so it fits more frames per + `max_seq_len`). + +Temporal order is carried by frame order (sequential positions). Per-frame +timestamp tokens and grounding outputs are a separate follow-up (see below). + +## Token budget + +For the residual-stream archs (JD / MoT / MoMa): + +``` +max_frames × tokens_per_frame + max_text_len ≤ model.max_seq_len +``` + +e.g. 8 frames × 49 + 64 text = 456 ≤ 576. Cross-attention only needs +`max_text_len ≤ max_seq_len` (visual tokens are K/V, not in the residual). The +build- and config-time checks enforce this and fail before any GPU work. + +## Configure it + +A video run adds a `[video]` section (sibling of `[vision_encoder]` / +`[adapter]` / `[vlm]`) and a token-reducing connector. See +`configs/train/vlm_video_webvid.toml` for a complete example; the key parts: + +```toml +[adapter] +type = "avgpool" # or "attentional_pool"; pools patches per frame +pool_window = 2 # 14×14 grid -> 7×7 = 49 tokens/frame + +[vlm] +arch = "joint_decoder" # also: cross_attention | mot | moma + +[video] +data_root = "/path/to/webvid-10m" +split = "train" # "train" | "validation" +fps = 2.0 # target sampling rate +max_frames = 8 # per-clip frame budget +min_frames = 4 +frame_size = 224 +max_samples = 0 # 0 = full manifest; set small for a smoke +``` + +The `[video]` section is decoded by `WebVidVideoDataset` (a WebVid-style layout: +CSV manifests under `raw/webvid-10M/data//partitions/` and `.mp4` files +under `raw/videos//`). Decoding uses PyAV (its wheel bundles FFmpeg, so no +system FFmpeg is required); it is imported lazily, so the package imports without +`av` and only actual decoding needs it. + +## Launch + +```bash +# 4-GPU video training (Joint-Decoder) +uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_video_webvid.toml + +# Quick smoke: no SigLIP download, a few clips, few steps +uv run torchrun --nproc_per_node=2 scripts/train.py configs/train/vlm_video_webvid.toml \ + --vision_encoder.type=random --vision_encoder.num_tokens=196 \ + --vision_encoder.feature_dim=768 --video.max_samples=256 --train.max_steps=20 +``` + +To switch arch, change `[vlm].arch` in the config — everything else (frame +sampling, connector, dataset) is identical. (`arch` is resolved at config-load +time, so it is set in the TOML, not via a `--vlm.arch=` CLI override.) + +## Constraints and follow-ups + +- **Causal attention; no per-frame timestamps yet** — temporal order is frame + order. Per-frame timestamp tokens + grounding (``/`` outputs + with point-F1 / track-J&F eval) are a follow-up. +- **Padded frames are not yet masked from attention** — short clips pad to + `max_frames` with blank frames; a `frame_mask` is produced but not yet + consumed by the attention mask. +- **Fixed `F` per batch** keeps tensor shapes static (for `torch.compile` and + DP-rank consistency); variable-length clips arrive with VLM sequence packing. +- **Long-context** (many frames) is blocked on context-parallel being wired. diff --git a/kempnerforge/config/adapter.py b/kempnerforge/config/adapter.py index 94470b2..4fa88bf 100644 --- a/kempnerforge/config/adapter.py +++ b/kempnerforge/config/adapter.py @@ -1,14 +1,12 @@ -"""Adapter configuration. +"""Adapter (connector) configuration. ``AdapterConfig`` selects which adapter the VLM wrapper instantiates and parameterizes the chosen adapter. Dispatched via the ``adapter`` registry at build time (see ``kempnerforge/model/adapter.py``). -This module is registered-component-shaped (parallel to ``VisionEncoderConfig``, -``VLMConfig``). A follow-up PR will flatten the VLM TOML schema to expose -``[adapter]`` as a top-level section; until then ``build_vlm_wrapper`` -constructs an ``AdapterConfig`` internally from the existing ``VLMConfig`` -fields (``adapter_hidden_dim``, ``adapter_activation``). +In TOML, ``[adapter]`` is a top-level section parallel to ``[model]``, +``[vision_encoder]``, and ``[vlm]``. When ``[vlm]`` is set without an +``[adapter]`` section, ``JobConfig`` materializes the default ``AdapterConfig``. """ from __future__ import annotations @@ -24,18 +22,24 @@ class AdapterConfig: """Selects the adapter type and parameterizes it. Fields: - type: Registry key for the adapter builder. ``"mlp_2layer"`` (default) - or ``"linear"``. Custom adapters register additional names. + type: Registry key for the adapter builder. Projection adapters + ``"mlp_2layer"`` (default) / ``"linear"`` keep the token count; + pooling adapters ``"avgpool"`` / ``"attentional_pool"`` reduce it. hidden_dim: Hidden width for ``mlp_2layer``. ``0`` means "match - ``out_dim``"; ignored by ``linear``. + ``out_dim``"; ignored by the other types. activation: Activation between the two MLP projections. One of - ``"gelu"`` (default), ``"silu"``, ``"relu"``. Ignored by - ``linear``. + ``"gelu"`` (default), ``"silu"``, ``"relu"``. ``mlp_2layer`` only. + pool_window: Pooling kernel side for the pooling adapters (e.g. ``2`` + for image 2×2, ``3`` for video 3×3); ignored by projection adapters. + pool_heads: Number of attention heads for ``attentional_pool``; must + divide the vision feature dim. Ignored by the other types. """ type: str = "mlp_2layer" hidden_dim: int = 0 activation: str = "gelu" + pool_window: int = 2 + pool_heads: int = 16 def __post_init__(self) -> None: # Late import: importing the adapter module triggers the @@ -56,14 +60,49 @@ def __post_init__(self) -> None: raise ValueError( f"Unknown adapter.activation: {self.activation!r}. Options: 'gelu', 'silu', 'relu'." ) + if self.pool_window <= 0: + raise ValueError(f"adapter.pool_window must be positive (got {self.pool_window})") + if self.pool_heads <= 0: + raise ValueError(f"adapter.pool_heads must be positive (got {self.pool_heads})") def extra_kwargs(self) -> dict[str, Any]: """Builder kwargs beyond ``in_dim`` / ``out_dim``. ``hidden_dim=0`` is mapped to ``None`` so the adapter falls back to - its own default (e.g., ``out_dim`` for ``MLP2LayerAdapter``). + its own default (e.g., ``out_dim`` for ``MLP2LayerAdapter``). Pooling + kwargs are always passed; projection builders swallow them via ``**_``. """ return { "hidden_dim": self.hidden_dim or None, "activation": self.activation, + "pool_window": self.pool_window, + "pool_heads": self.pool_heads, } + + def output_num_tokens(self, num_input_tokens: int) -> int: + """Predict the post-adapter token count for ``num_input_tokens`` in. + + Mirrors the built module's ``output_num_tokens`` so config-time + sequence-length checks match the build-time/runtime token budget. + Projection adapters are the identity; pooling adapters apply the + shared ``pooled_token_count`` math. Non-positive inputs (e.g. the + ``num_tokens=0`` "infer at build time" sentinel) pass through. + """ + if num_input_tokens <= 0 or self.type not in self._pooling_types(): + return num_input_tokens + from kempnerforge.model.adapter import ( # noqa: PLC0415 + DIVISIBLE_ONLY_POOL_TYPES, + pooled_token_count, + ) + + return pooled_token_count( + num_input_tokens, + self.pool_window, + require_divisible=self.type in DIVISIBLE_ONLY_POOL_TYPES, + ) + + @staticmethod + def _pooling_types() -> tuple[str, ...]: + from kempnerforge.model.adapter import POOLING_ADAPTER_TYPES # noqa: PLC0415 + + return POOLING_ADAPTER_TYPES diff --git a/kempnerforge/config/job.py b/kempnerforge/config/job.py index 4b3e1dd..fa39ed2 100644 --- a/kempnerforge/config/job.py +++ b/kempnerforge/config/job.py @@ -15,6 +15,7 @@ from kempnerforge.config.profiling import ProfilingConfig from kempnerforge.config.scheduler import SchedulerConfig from kempnerforge.config.training import TrainConfig +from kempnerforge.config.video import VideoConfig from kempnerforge.config.vision import VisionEncoderConfig from kempnerforge.config.vlm import MoMaConfig, VLMConfig @@ -51,6 +52,7 @@ class JobConfig: vision_encoder: VisionEncoderConfig | None = None adapter: AdapterConfig | None = None vlm: VLMConfig | None = None + video: VideoConfig | None = None def __post_init__(self) -> None: """Cross-section invariants that fire at construction time. @@ -103,9 +105,11 @@ def __post_init__(self) -> None: # build time" sentinel) defers the check to ``build_vlm_wrapper`` # / ``_build_vlm`` using the encoder's resolved value. if self.vision_encoder.num_tokens > 0: - residual_image_tokens = self.vlm.residual_stream_image_tokens( - self.vision_encoder.num_tokens - ) + assert self.adapter is not None # materialized in __post_init__ when [vlm] set + per_frame = self.adapter.output_num_tokens(self.vision_encoder.num_tokens) + frames = self.video.max_frames if self.video is not None else 1 + visual_tokens = frames * per_frame + residual_image_tokens = self.vlm.residual_stream_image_tokens(visual_tokens) required = residual_image_tokens + self.vlm.max_text_len if self.model.max_seq_len < required: raise ValueError( @@ -114,11 +118,22 @@ def __post_init__(self) -> None: f"vlm.max_text_len ({self.vlm.max_text_len}) = {required}" ) + if self.video is not None and self.vlm is None: + raise ValueError( + "[video] is set but [vlm] is missing; video training runs through " + "the VLM wrapper, so a [vlm] section (and [vision_encoder]) is required." + ) + @property def is_vlm(self) -> bool: """Whether this job builds a ``VLMWrapper`` around the text backbone.""" return self.vlm is not None + @property + def is_video(self) -> bool: + """Whether this job trains on a video dataset (a VLM sub-mode).""" + return self.video is not None + def validate(self, world_size: int = 1) -> None: """Run cross-config validations that depend on the world size.""" self.distributed.validate_world_size(world_size) @@ -197,9 +212,11 @@ def validate(self, world_size: int = 1) -> None: # text-only). # num_tokens=0 is deferred to build_vlm_wrapper. if self.vision_encoder.num_tokens > 0: - residual_image_tokens = self.vlm.residual_stream_image_tokens( - self.vision_encoder.num_tokens - ) + assert self.adapter is not None # materialized in __post_init__ when [vlm] set + per_frame = self.adapter.output_num_tokens(self.vision_encoder.num_tokens) + frames = self.video.max_frames if self.video is not None else 1 + visual_tokens = frames * per_frame + residual_image_tokens = self.vlm.residual_stream_image_tokens(visual_tokens) required = residual_image_tokens + self.vlm.max_text_len if self.train.seq_len < required: raise ValueError( diff --git a/kempnerforge/config/video.py b/kempnerforge/config/video.py new file mode 100644 index 0000000..8099c6c --- /dev/null +++ b/kempnerforge/config/video.py @@ -0,0 +1,64 @@ +"""Video input configuration. + +``VideoConfig`` is the ``[video]`` top-level section. When present, the job +trains on a video dataset through the VLM wrapper: a clip is decoded into an +ordered set of frames, each preprocessed like an image and fed to the vision +encoder. The section is a sibling of ``[vision_encoder]`` / ``[adapter]`` / +``[vlm]`` and requires ``[vlm]`` to be set. + +Frame-sampling defaults follow the Molmo2 paper (sample at ``fps`` per second, +include the first and last frame, cap at ``max_frames``). ``max_frames`` is the +per-clip frame budget; the number of visual tokens it implies +(``max_frames * tokens_per_frame``) feeds the residual-stream / sequence-length +math once the model consumes video. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +_VIDEO_SPLITS = ("train", "validation") + + +@dataclass +class VideoConfig: + """Video dataset location and frame-sampling knobs. + + Fields: + data_root: Root directory of the on-disk video dataset. + split: Which split to read (``"train"`` or ``"validation"``). + max_samples: Cap the manifest to this many examples (``0`` = all). + max_frames: Maximum frames sampled per clip (the per-clip budget). + min_frames: Minimum frames sampled per clip; short clips pad up to this. + fps: Target sampling rate in frames per second (Molmo2 uses 2). + frame_size: Square pixel size each frame is resized to. + prompt: Optional instruction prepended to the target text, masked from loss. + """ + + data_root: str = "" + split: str = "train" + max_samples: int = 0 + max_frames: int = 16 + min_frames: int = 4 + fps: float = 2.0 + frame_size: int = 224 + prompt: str = "" + + def __post_init__(self) -> None: + if self.split not in _VIDEO_SPLITS: + raise ValueError(f"video.split must be one of {_VIDEO_SPLITS} (got {self.split!r})") + if self.max_samples < 0: + raise ValueError(f"video.max_samples must be non-negative (got {self.max_samples})") + if self.min_frames < 1: + raise ValueError(f"video.min_frames must be >= 1 (got {self.min_frames})") + if self.max_frames < 1: + raise ValueError(f"video.max_frames must be >= 1 (got {self.max_frames})") + if self.min_frames > self.max_frames: + raise ValueError( + f"video.min_frames ({self.min_frames}) must be <= video.max_frames " + f"({self.max_frames})" + ) + if self.fps <= 0: + raise ValueError(f"video.fps must be positive (got {self.fps})") + if self.frame_size <= 0: + raise ValueError(f"video.frame_size must be positive (got {self.frame_size})") diff --git a/kempnerforge/data/video_dataset.py b/kempnerforge/data/video_dataset.py new file mode 100644 index 0000000..6dc3abc --- /dev/null +++ b/kempnerforge/data/video_dataset.py @@ -0,0 +1,234 @@ +"""Video dataset and collator for the VLM video path (WebVid-style layout). + +``WebVidVideoDataset`` reads a WebVid-style on-disk corpus — per-partition CSV +manifests (``videoid``, ``name`` = caption) plus ``.mp4`` files laid out under +``raw/videos//`` — and produces the video analogue of the single-image +``VLMSample``: + +- ``pixel_values``: ``(F, 3, H, W)`` float tensor — ``F = max_frames`` frames, + each resized/normalized exactly like the image path. Clips that yield fewer + than ``F`` real frames are zero-padded. +- ``frame_mask``: ``(F,)`` bool — ``True`` for real frames, ``False`` for padding. +- ``input_ids`` / ``labels``: ``(T,)`` int64, right-padded to ``max_text_len``, + with ``-100`` on pad/prompt positions. A clip that fails to decode contributes + no loss (all labels ``-100``) so noisy data never crashes training. + +``VideoCollator`` stacks samples into a fixed-shape batch +(``pixel_values: (B, F, 3, H, W)``, ``frame_mask: (B, F)``) so every DP rank +sees identical shapes under FSDP2. + +Frame decoding lives in ``video_io.decode_video_frames`` and is imported at +module scope so tests can substitute a stub; ``av`` itself is imported lazily +inside the decoder. +""" + +from __future__ import annotations + +import logging +import os +from typing import Any + +import torch +from torch.utils.data import Dataset + +from kempnerforge.data.video_io import decode_video_frames +from kempnerforge.data.vlm_dataset import ( + DEFAULT_IMAGE_MEAN, + DEFAULT_IMAGE_STD, + _pil_to_tensor, + _tokenize_and_mask, +) + +logger = logging.getLogger(__name__) + +# WebVid layout: the metadata split directory ("val") differs from the video +# directory name ("validation"); "train" matches both. +_CSV_SUBDIR = {"train": "train", "validation": "val"} +_VIDEO_SUBDIR = {"train": "train", "validation": "validation"} + + +def _resolve_pad_id(tokenizer: Any) -> int: + pad_id = tokenizer.pad_token_id + if pad_id is None: + pad_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0 + return int(pad_id) + + +class WebVidVideoDataset(Dataset): + """Map-style WebVid-style video-caption dataset for VLM training. + + Args: + data_root: Dataset root (contains ``raw/webvid-10M/data`` and + ``raw/videos``). + split: ``"train"`` or ``"validation"``. + tokenizer_path: HF tokenizer id or local path. + max_text_len: Fixed-length text pad target. + max_frames / min_frames / fps: Frame-sampling knobs (see ``video_io``). + frame_size: Square pixel size per frame. + max_samples: Cap the manifest (``0`` = all). + prompt: Optional instruction prepended and masked from the loss. + image_mean / image_std: Per-channel normalization (SigLIP defaults). + """ + + def __init__( + self, + data_root: str, + split: str, + tokenizer_path: str, + max_text_len: int, + *, + max_frames: int, + min_frames: int, + fps: float, + frame_size: int = 224, + max_samples: int = 0, + prompt: str = "", + image_mean: tuple[float, float, float] = DEFAULT_IMAGE_MEAN, + image_std: tuple[float, float, float] = DEFAULT_IMAGE_STD, + ) -> None: + from transformers import AutoTokenizer + + if split not in _VIDEO_SUBDIR: + raise ValueError(f"split must be one of {tuple(_VIDEO_SUBDIR)} (got {split!r})") + self._split = split + self._video_dir = os.path.join(data_root, "raw", "videos", _VIDEO_SUBDIR[split]) + csv_dir = os.path.join( + data_root, "raw", "webvid-10M", "data", _CSV_SUBDIR[split], "partitions" + ) + self._ids, self._caps = self._load_manifest(csv_dir, max_samples) + self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + self._pad_id = _resolve_pad_id(self._tokenizer) + self._max_text_len = max_text_len + self._max_frames = max_frames + self._min_frames = min_frames + self._fps = fps + self._frame_size = frame_size + self._prompt = prompt + self._image_mean = image_mean + self._image_std = image_std + logger.info( + "WebVidVideoDataset: %s [%s], %d clips, max_frames=%d, fps=%s, frame_size=%d", + data_root, + split, + len(self._ids), + max_frames, + fps, + frame_size, + ) + + @staticmethod + def _load_manifest(csv_dir: str, max_samples: int) -> tuple[list[str], list[str]]: + """Read partition CSVs into (videoid, caption) lists. + + Reads partitions in sorted order, stopping early once ``max_samples`` + rows are collected so a quick run does not scan the entire 10M-row + corpus. ``videoid`` is kept as a string to preserve the digits used by + the on-disk path mapping. + """ + import glob + + import pandas as pd + + files = sorted(glob.glob(os.path.join(csv_dir, "*.csv"))) + if not files: + raise FileNotFoundError(f"No partition CSVs found under {csv_dir!r}") + ids: list[str] = [] + caps: list[str] = [] + for path in files: + df = pd.read_csv(path, usecols=["videoid", "name"], dtype={"videoid": str}) + ids.extend(df["videoid"].tolist()) + caps.extend(df["name"].astype(str).tolist()) + if max_samples and len(ids) >= max_samples: + break + if max_samples: + ids = ids[:max_samples] + caps = caps[:max_samples] + return ids, caps + + def _video_path(self, videoid: str) -> str: + """Map a videoid to its ``.mp4`` path. + + Train videos are nested by id prefixes (``id[:2]/id[:4]/id[:6]/id.mp4``); + validation videos are flat (``id.mp4``). + """ + s = str(videoid) + if self._split == "train": + return os.path.join(self._video_dir, s[:2], s[:4], s[:6], f"{s}.mp4") + return os.path.join(self._video_dir, f"{s}.mp4") + + def __len__(self) -> int: + return len(self._ids) + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + videoid = self._ids[idx] + caption = self._caps[idx] + path = self._video_path(videoid) + try: + frames = decode_video_frames( + path, fps=self._fps, min_frames=self._min_frames, max_frames=self._max_frames + ) + except Exception as e: # noqa: BLE001 - any decode failure -> skip-with-mask + logger.debug("video decode failed for %s: %s", path, e) + frames = [] + + f = self._max_frames + size = self._frame_size + pixel_values = torch.zeros(f, 3, size, size, dtype=torch.float32) + frame_mask = torch.zeros(f, dtype=torch.bool) + n_real = min(len(frames), f) + for i in range(n_real): + pixel_values[i] = _pil_to_tensor(frames[i], size, self._image_mean, self._image_std) + frame_mask[i] = True + + prompt = self._prompt or None + input_ids, labels = _tokenize_and_mask(self._tokenizer, caption, self._max_text_len, prompt) + if n_real == 0: + # Undecodable clip: keep static shapes but contribute no loss. + labels = torch.full_like(labels, -100) + return { + "pixel_values": pixel_values, + "frame_mask": frame_mask, + "input_ids": input_ids, + "labels": labels, + } + + +class VideoCollator: + """Stack video samples into a fixed-shape batch. + + Output keys: + - ``pixel_values``: ``(B, F, 3, H, W)`` float32. + - ``frame_mask``: ``(B, F)`` bool (``True`` = real frame). + - ``input_ids``: ``(B, max_text_len)`` int64. + - ``labels``: ``(B, max_text_len)`` int64 with ``-100`` on pad/prompt. + + Text is always padded to ``max_text_len`` (never batch-max) so DP ranks + see identical shapes under FSDP2, matching ``VLMCollator``. + """ + + def __init__(self, pad_id: int, max_text_len: int) -> None: + if max_text_len <= 0: + raise ValueError("max_text_len must be positive") + self.pad_id = pad_id + self.max_text_len = max_text_len + + def __call__(self, samples: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + if not samples: + raise ValueError("VideoCollator received an empty batch") + b = len(samples) + pixel_values = torch.stack([s["pixel_values"] for s in samples], dim=0) + frame_mask = torch.stack([s["frame_mask"] for s in samples], dim=0) + input_ids = torch.full((b, self.max_text_len), self.pad_id, dtype=torch.long) + labels = torch.full((b, self.max_text_len), -100, dtype=torch.long) + for i, s in enumerate(samples): + ids = s["input_ids"] + lbl = s["labels"] + n = min(ids.shape[0], self.max_text_len) + input_ids[i, :n] = ids[:n] + labels[i, :n] = lbl[:n] + return { + "pixel_values": pixel_values, + "frame_mask": frame_mask, + "input_ids": input_ids, + "labels": labels, + } diff --git a/kempnerforge/data/video_io.py b/kempnerforge/data/video_io.py new file mode 100644 index 0000000..9981c67 --- /dev/null +++ b/kempnerforge/data/video_io.py @@ -0,0 +1,116 @@ +"""Video frame sampling and decoding for the VLM video path. + +A clip is reduced to an ordered set of still frames that the VLM pipeline +treats like a sequence of images. Two concerns live here: + +1. ``sample_timestamps`` — *which* timestamps to sample. This is the policy + from the Molmo2 paper (§3.1, §A): sample at a target frame-rate ``fps``, + cap the total at ``max_frames`` (uniformly subsampling longer clips), and + always include the first and last frame. Sampling is expressed in + *seconds* rather than frame indices so it is robust to variable-fps video. + This function is pure (no decoder dependency) and unit-tested directly. + +2. ``decode_video_frames`` — *how* to read those frames. Decoding uses PyAV + (``av``), whose manylinux wheel bundles FFmpeg, so no system FFmpeg or + matching CUDA libraries are required (torchcodec needs both). ``av`` is + imported lazily so this module imports cleanly without it; only actual + decoding requires the package. + +Returned frames are ``PIL.Image`` objects so the caller can reuse the exact +image preprocessing (``_pil_to_tensor``) used on the single-image path. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: # pragma: no cover - typing only + from PIL.Image import Image as PILImage + +# AV_TIME_BASE: container.duration is expressed in microseconds. +_AV_TIME_BASE = 1_000_000.0 + + +def sample_timestamps( + duration_s: float, fps: float, min_frames: int, max_frames: int +) -> list[float]: + """Timestamps (seconds) to sample from a clip of length ``duration_s``. + + Policy (Molmo2 §3.1/§A): aim for ``fps`` frames per second, clamp the + count to ``[min_frames, max_frames]``, and lay the samples out uniformly + over ``[0, duration_s]`` so the first frame (``0.0``) and last frame + (``duration_s``) are always included. A non-positive duration (unknown or + instantaneous) yields a single timestamp at the start. + + Returns a strictly increasing list of length in ``[1, max_frames]``. + """ + if fps <= 0: + raise ValueError(f"fps must be positive (got {fps})") + if min_frames < 1 or max_frames < 1: + raise ValueError(f"min_frames and max_frames must be >= 1 (got {min_frames}, {max_frames})") + if min_frames > max_frames: + raise ValueError(f"min_frames ({min_frames}) must be <= max_frames ({max_frames})") + if duration_s <= 0.0: + return [0.0] + desired = round(duration_s * fps) + desired = max(min_frames, min(max_frames, desired)) + if desired <= 1: + return [0.0] + step = duration_s / (desired - 1) + return [step * i for i in range(desired)] + + +def _video_duration_seconds(stream: Any, container: Any) -> float: + """Best-effort clip duration in seconds from PyAV stream/container metadata.""" + if stream.duration is not None and stream.time_base is not None: + return float(stream.duration * stream.time_base) + if container.duration is not None: + return float(container.duration) / _AV_TIME_BASE + if stream.frames and stream.average_rate: + return float(stream.frames) / float(stream.average_rate) + return 0.0 + + +def decode_video_frames( + path: str, *, fps: float, min_frames: int, max_frames: int +) -> list[PILImage]: + """Decode a clip into a list of sampled ``PIL.Image`` frames (RGB). + + Frames are chosen by ``sample_timestamps`` and read in a single decode + pass: each target timestamp is mapped to the first decoded frame at or + after it (timestamps past the last frame map to the last frame, so the + final frame is always returned). The returned list has length equal to the + number of sampled timestamps (``<= max_frames``), or is empty when the file + has no decodable video stream. + + Raises whatever ``av`` raises on a missing/corrupt file; callers that train + over noisy data should catch and substitute an empty clip. + """ + import av # lazy: bundled-FFmpeg decoder, optional at import time + + images: list[PILImage] = [] + with av.open(path) as container: + if not container.streams.video: + return images + stream = container.streams.video[0] + stream.thread_type = "AUTO" + duration_s = _video_duration_seconds(stream, container) + targets = sample_timestamps(duration_s, fps, min_frames, max_frames) + + j = 0 + eps = 1e-3 + last_frame = None + for frame in container.decode(stream): + t = float(frame.time) if frame.time is not None else 0.0 + while j < len(targets) and t + eps >= targets[j]: + images.append(frame.to_image()) + j += 1 + last_frame = frame + if j >= len(targets): + break + # Trailing targets (e.g. the final ``duration_s`` timestamp, which sits + # just past the last frame's PTS) map to the last decoded frame. + if j < len(targets) and last_frame is not None: + tail = last_frame.to_image() + images.extend(tail for _ in range(len(targets) - j)) + return images diff --git a/kempnerforge/distributed/parallel.py b/kempnerforge/distributed/parallel.py index d0c12e6..404b91d 100644 --- a/kempnerforge/distributed/parallel.py +++ b/kempnerforge/distributed/parallel.py @@ -373,6 +373,7 @@ def _build_vlm( param_dtype: torch.dtype, compile_model: bool, fp8: bool, + frames_per_clip: int = 1, ) -> torch.nn.Module: """Build a VLM wrapper with parallelism applied in the correct order. @@ -426,13 +427,16 @@ def _build_vlm( # extra vlm_config / num_image_tokens kwargs the VLM path needs). ctx = torch.device("meta") if tp_enabled else contextlib.nullcontext() with ctx: + adapter = build_adapter(adapter_config, in_dim=in_dim, out_dim=model_config.dim) + # Pooling adapters reduce the token count; size the transformer's + # image-prefix split on the adapter's output, not the raw patch count. + visual_tokens = frames_per_clip * adapter.output_num_tokens(encoder.num_tokens) transformer = Transformer( - model_config, vlm_config=vlm_config, num_image_tokens=encoder.num_tokens + model_config, vlm_config=vlm_config, num_image_tokens=visual_tokens ) - adapter = build_adapter(adapter_config, in_dim=in_dim, out_dim=model_config.dim) strategy = build_modality_strategy(vlm_config) - wrapper = VLMWrapper(encoder, adapter, transformer, strategy) + wrapper = VLMWrapper(encoder, adapter, transformer, strategy, frames_per_clip=frames_per_clip) # 3. Length cross-check now that num_tokens is resolved. required = wrapper.num_image_tokens + vlm_config.max_text_len @@ -505,6 +509,7 @@ def build_parallel_model( param_dtype: torch.dtype = torch.bfloat16, compile_model: bool = False, fp8: bool = False, + frames_per_clip: int = 1, ) -> torch.nn.Module: """Build a Transformer (or a VLMWrapper) with parallelism applied. @@ -548,6 +553,7 @@ def build_parallel_model( param_dtype=param_dtype, compile_model=compile_model, fp8=fp8, + frames_per_clip=frames_per_clip, ) from kempnerforge.distributed.tensor_parallel import apply_tensor_parallel diff --git a/kempnerforge/model/adapter.py b/kempnerforge/model/adapter.py index 2134e4f..acea41a 100644 --- a/kempnerforge/model/adapter.py +++ b/kempnerforge/model/adapter.py @@ -1,21 +1,35 @@ -"""Vision-to-LLM adapter modules. +"""Vision-to-LLM adapter modules (the "connector"). -The adapter projects image features (shape ``(B, num_tokens, feature_dim)``) -into the LLM embedding space (shape ``(B, num_tokens, model.dim)``). It sits +The adapter projects vision features (shape ``(B, num_tokens, feature_dim)``) +into the LLM embedding space (shape ``(B, out_tokens, model.dim)``). It sits between the vision encoder and the transformer in ``VLMWrapper``. -Adapters register themselves under the ``adapter`` registry category. The -default is ``mlp_2layer`` (a 2-layer MLP, the canonical adapter shape across -LLaVA-family papers). ``linear`` is a single ``nn.Linear`` with no -activation, useful for ablations. +Two families: + +- **Projection adapters** keep the token count (``out_tokens == num_tokens``): + ``mlp_2layer`` (default, the canonical LLaVA-family 2-layer MLP) and + ``linear`` (single ``nn.Linear``, an ablation baseline). +- **Pooling adapters** reduce the token count by pooling the square patch grid + before projecting: ``avgpool`` (window-average, the cheapest reducer) and + ``attentional_pool`` (Molmo2-style per-window multi-head attention with the + window mean as query). Pooling is what makes many-frame video fit the + sequence budget: a 27×27 SigLIP grid (729 tokens) pools to 81 tokens at a + 3×3 window. + +Every adapter is a ``VisionAdapter`` exposing ``output_num_tokens(n_in)`` so the +build path can size the residual stream (and MoT's positional split) without a +dry-run forward. Adapters register themselves under the ``adapter`` registry +category. """ from __future__ import annotations +import math from typing import Any import torch import torch.nn as nn +import torch.nn.functional as F from kempnerforge.config.registry import registry @@ -25,12 +39,93 @@ "relu": nn.ReLU, } +# Registry keys of adapters that pool the patch grid (reduce token count). The +# config layer (``AdapterConfig.output_num_tokens``) consults this to predict +# the post-adapter token count without building the module. Keep in sync with +# the registered pooling builders below. +POOLING_ADAPTER_TYPES: tuple[str, ...] = ("avgpool", "attentional_pool") + +# Pooling adapters whose ``forward`` requires the patch grid be divisible by the +# window (no ragged edge windows). Their token count must enforce the same so a +# ragged config is rejected at config/build time, not at the first training step. +DIVISIBLE_ONLY_POOL_TYPES: tuple[str, ...] = ("attentional_pool",) + + +def pooled_token_count( + num_input_tokens: int, window: int, *, require_divisible: bool = False +) -> int: + """Token count out of a ``window×window`` pool over a square patch grid. + + A vision encoder emits ``num_input_tokens`` patch tokens laid out on a + square ``grid × grid`` map (``grid = sqrt(num_input_tokens)``). Pooling with + a ``window × window`` kernel and ceil edges yields ``ceil(grid/window) ** 2`` + tokens; edge windows that do not fill the kernel pool only the patches they + cover (Molmo2 §A: "the bottom and far-right image patches are pooled with a + reduced number of patches"). + + Connectors that cannot pool ragged edges (``require_divisible=True``, e.g. + ``attentional_pool``) raise when ``grid`` is not divisible by ``window``, so a + ragged config is rejected at config/build time rather than deterministically + failing in ``forward`` at the first step. + + This is the single source of truth for the post-pool count: it must equal + the pooling adapters' actual ``forward`` output length, because the build + path uses it to size MoT's positional split. + """ + if window <= 0: + raise ValueError(f"pool window must be positive (got {window})") + if num_input_tokens <= 0: + raise ValueError(f"num_input_tokens must be positive (got {num_input_tokens})") + grid = _grid_side(num_input_tokens) + if require_divisible and grid % window != 0: + raise ValueError( + f"this pooling connector requires the patch grid ({grid}x{grid}) be " + f"divisible by the pool window ({window}); got a ragged grid " + f"(num_tokens={num_input_tokens}). Use avgpool for ragged grids, or pick " + "a divisible window." + ) + per_side = math.ceil(grid / window) + return per_side * per_side + + +def _grid_side(num_tokens: int) -> int: + """Side length of the square patch grid, or raise if not a perfect square.""" + grid = math.isqrt(num_tokens) + if grid * grid != num_tokens: + raise ValueError( + f"pooling requires a square patch grid, but num_tokens={num_tokens} is " + "not a perfect square. Use a vision encoder that strips any CLS token so " + "the patch tokens form a square grid." + ) + return grid + + +class VisionAdapter(nn.Module): + """Base class for vision→LLM adapters (the connector). + + Contract: ``forward`` maps ``(B, N, in_dim) -> (B, M, out_dim)`` where + ``M == output_num_tokens(N)``. Projection adapters keep ``M == N``; pooling + adapters reduce it. ``output_num_tokens`` lets the build path size the + residual stream and MoT's positional split without a dry-run forward, and + must agree exactly with the forward output length. + """ + + def output_num_tokens(self, num_input_tokens: int) -> int: + """Tokens emitted per image given ``num_input_tokens`` patch tokens in. + + Identity by default (projection adapters); pooling adapters override. + """ + return num_input_tokens + + def forward(self, x: torch.Tensor) -> torch.Tensor: # pragma: no cover + raise NotImplementedError + -class MLP2LayerAdapter(nn.Module): +class MLP2LayerAdapter(VisionAdapter): """2-layer MLP from image-feature dim to LLM embedding dim. Architecture: ``Linear(in_dim, hidden) -> activation -> Linear(hidden, out_dim)``. - ``hidden_dim=None`` defaults to ``out_dim``. + ``hidden_dim=None`` defaults to ``out_dim``. Keeps the token count. ``reset_parameters`` is provided so callers that materialize adapters from meta can re-initialize weights with the standard Linear defaults. @@ -67,11 +162,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.proj2(self.act(self.proj1(x))) -class LinearAdapter(nn.Module): +class LinearAdapter(VisionAdapter): """Single ``nn.Linear`` from image-feature dim to LLM embedding dim. - No activation, no hidden layer. Useful as an ablation baseline against - ``MLP2LayerAdapter``. + No activation, no hidden layer. Keeps the token count. Useful as an + ablation baseline against ``MLP2LayerAdapter``. """ def __init__(self, in_dim: int, out_dim: int) -> None: @@ -87,6 +182,142 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.proj(x) +class AvgPoolAdapter(VisionAdapter): + """Average-pool a square patch grid by a window, then project. + + ``(B, N, in_dim)`` patch tokens (``N == grid**2``) are averaged over + ``window × window`` spatial windows (ceil edges; partial edge windows + average only the real patches they cover), giving ``(B, M, in_dim)`` with + ``M == ceil(grid/window)**2``, then a ``Linear`` maps ``in_dim -> out_dim``. + + The cheapest token-count reducer (LLaVA-NeXT / sibling-repo style). ``window`` + is overridable per ``forward`` call so one connector can pool images (e.g. + 2×2) and video frames (3×3) with the same projection weights. + """ + + def __init__(self, in_dim: int, out_dim: int, pool_window: int = 2) -> None: + super().__init__() + if in_dim <= 0 or out_dim <= 0: + raise ValueError("AvgPoolAdapter in_dim and out_dim must be positive") + if pool_window <= 0: + raise ValueError(f"AvgPoolAdapter pool_window must be positive (got {pool_window})") + self.in_dim = in_dim + self.out_dim = out_dim + self.pool_window = pool_window + self.proj = nn.Linear(in_dim, out_dim, bias=True) + + def reset_parameters(self) -> None: + self.proj.reset_parameters() + + def output_num_tokens(self, num_input_tokens: int) -> int: + return pooled_token_count(num_input_tokens, self.pool_window) + + def forward(self, x: torch.Tensor, pool_window: int | None = None) -> torch.Tensor: + w = pool_window if pool_window is not None else self.pool_window + if w <= 0: + raise ValueError(f"pool_window must be positive (got {w})") + b, n, c = x.shape + grid = _grid_side(n) + per = math.ceil(grid / w) + padded = per * w + x = x.view(b, grid, grid, c) + if padded != grid: + pad = padded - grid + # F.pad pads from the last dim backward: (C:0,0)(W:0,pad)(H:0,pad). + x = F.pad(x, (0, 0, 0, pad, 0, pad)) + mask = torch.ones(b, grid, grid, 1, dtype=x.dtype, device=x.device) + mask = F.pad(mask, (0, 0, 0, pad, 0, pad)) + else: + mask = torch.ones(b, padded, padded, 1, dtype=x.dtype, device=x.device) + # Group into windows and average over real (unpadded) cells only. + sums = x.view(b, per, w, per, w, c).sum(dim=(2, 4)) # (B, per, per, C) + counts = mask.view(b, per, w, per, w, 1).sum(dim=(2, 4)).clamp_(min=1) # (B, per, per, 1) + pooled = (sums / counts).reshape(b, per * per, c) + return self.proj(pooled) + + +class AttentionalPoolAdapter(VisionAdapter): + """Attentional pooling connector (Molmo2 §3.1). + + For each ``window × window`` patch window, a multi-head attention layer + pools the window's patches into one vector, using the **mean of the window's + patches as the query** and the patches themselves as keys/values; the result + is projected ``in_dim -> out_dim``. Output length is ``ceil(grid/window)**2``. + + ``window`` is overridable per ``forward`` call (shared params across image + 2×2 and video 3×3 pooling, per the paper). v1 requires the grid be divisible + by the window (no ragged edge windows); ragged attentional pooling is a + follow-up. + """ + + def __init__( + self, in_dim: int, out_dim: int, pool_window: int = 2, pool_heads: int = 16 + ) -> None: + super().__init__() + if in_dim <= 0 or out_dim <= 0: + raise ValueError("AttentionalPoolAdapter in_dim and out_dim must be positive") + if pool_window <= 0: + raise ValueError( + f"AttentionalPoolAdapter pool_window must be positive (got {pool_window})" + ) + if pool_heads <= 0: + raise ValueError( + f"AttentionalPoolAdapter pool_heads must be positive (got {pool_heads})" + ) + if in_dim % pool_heads != 0: + raise ValueError( + f"AttentionalPoolAdapter in_dim ({in_dim}) must be divisible by " + f"pool_heads ({pool_heads})" + ) + self.in_dim = in_dim + self.out_dim = out_dim + self.pool_window = pool_window + self.pool_heads = pool_heads + self.head_dim = in_dim // pool_heads + self.q_proj = nn.Linear(in_dim, in_dim, bias=True) + self.k_proj = nn.Linear(in_dim, in_dim, bias=True) + self.v_proj = nn.Linear(in_dim, in_dim, bias=True) + self.o_proj = nn.Linear(in_dim, in_dim, bias=True) + self.out_proj = nn.Linear(in_dim, out_dim, bias=True) + + def reset_parameters(self) -> None: + for layer in (self.q_proj, self.k_proj, self.v_proj, self.o_proj, self.out_proj): + layer.reset_parameters() + + def output_num_tokens(self, num_input_tokens: int) -> int: + # require_divisible mirrors forward()'s ragged-grid rejection so a bad + # config fails at build / seq-len-check time, not at the first step. + return pooled_token_count(num_input_tokens, self.pool_window, require_divisible=True) + + def forward(self, x: torch.Tensor, pool_window: int | None = None) -> torch.Tensor: + w = pool_window if pool_window is not None else self.pool_window + if w <= 0: + raise ValueError(f"pool_window must be positive (got {w})") + b, n, c = x.shape + grid = _grid_side(n) + if grid % w != 0: + raise ValueError( + f"attentional_pool v1 requires the patch grid ({grid}x{grid}) be divisible " + f"by the pool window ({w}); ragged edge windows are not yet supported. " + "Use avgpool for ragged grids, or pick a divisible window." + ) + per = grid // w + k_win = w * w + # (B, grid, grid, C) -> windows (B*per*per, w*w, C): each window's patches contiguous. + windows = ( + x.view(b, per, w, per, w, c).permute(0, 1, 3, 2, 4, 5).reshape(b * per * per, k_win, c) + ) + m = windows.shape[0] + query = windows.mean(dim=1, keepdim=True) # (M, 1, C) — window mean as query + q = self.q_proj(query).view(m, 1, self.pool_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(windows).view(m, k_win, self.pool_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(windows).view(m, k_win, self.pool_heads, self.head_dim).transpose(1, 2) + attn = F.scaled_dot_product_attention(q, k, v) # (M, H, 1, head_dim) + attn = attn.transpose(1, 2).reshape(m, c) # (M, C) + pooled = self.o_proj(attn).view(b, per * per, c) + return self.out_proj(pooled) + + @registry.register_adapter("mlp_2layer") def _build_mlp_2layer( in_dim: int, @@ -94,7 +325,7 @@ def _build_mlp_2layer( hidden_dim: int | None = None, activation: str = "gelu", **_: Any, -) -> MLP2LayerAdapter: +) -> VisionAdapter: return MLP2LayerAdapter( in_dim=in_dim, out_dim=out_dim, @@ -108,11 +339,34 @@ def _build_linear( in_dim: int, out_dim: int, **_: Any, -) -> LinearAdapter: +) -> VisionAdapter: return LinearAdapter(in_dim=in_dim, out_dim=out_dim) -def build_adapter(adapter_config, in_dim: int, out_dim: int) -> nn.Module: +@registry.register_adapter("avgpool") +def _build_avgpool( + in_dim: int, + out_dim: int, + pool_window: int = 2, + **_: Any, +) -> VisionAdapter: + return AvgPoolAdapter(in_dim=in_dim, out_dim=out_dim, pool_window=pool_window) + + +@registry.register_adapter("attentional_pool") +def _build_attentional_pool( + in_dim: int, + out_dim: int, + pool_window: int = 2, + pool_heads: int = 16, + **_: Any, +) -> VisionAdapter: + return AttentionalPoolAdapter( + in_dim=in_dim, out_dim=out_dim, pool_window=pool_window, pool_heads=pool_heads + ) + + +def build_adapter(adapter_config, in_dim: int, out_dim: int) -> VisionAdapter: """Dispatch to the registered adapter builder. Args: @@ -122,7 +376,8 @@ def build_adapter(adapter_config, in_dim: int, out_dim: int) -> nn.Module: out_dim: Target embedding dim (the transformer's ``dim``). Returns: - An ``nn.Module`` with signature ``(B, N, in_dim) -> (B, N, out_dim)``. + A ``VisionAdapter`` with signature ``(B, N, in_dim) -> (B, M, out_dim)``, + where ``M == adapter.output_num_tokens(N)``. """ builder = registry.get_adapter(adapter_config.type) return builder(in_dim=in_dim, out_dim=out_dim, **adapter_config.extra_kwargs()) diff --git a/kempnerforge/model/vlm.py b/kempnerforge/model/vlm.py index 51ac879..65daf9c 100644 --- a/kempnerforge/model/vlm.py +++ b/kempnerforge/model/vlm.py @@ -47,7 +47,7 @@ from kempnerforge.config.schema import ModelConfig from kempnerforge.config.vision import VisionEncoderConfig from kempnerforge.config.vlm import FreezeSpec, VLMConfig -from kempnerforge.model.adapter import build_adapter +from kempnerforge.model.adapter import VisionAdapter, build_adapter from kempnerforge.model.modality import ModalityContext from kempnerforge.model.transformer import Transformer from kempnerforge.model.vision import VisionEncoder @@ -73,20 +73,48 @@ def prepare( def num_image_tokens(self, wrapper: VLMWrapper) -> int: ... -def _project_image_features(wrapper: VLMWrapper, pixel_values: torch.Tensor) -> torch.Tensor: - """Encode + adapt image features. Cast at the encoder/adapter - boundary so the encoder can stay in its HF dtype (often fp32) while - the adapter and transformer run in bf16 without an inner dtype - clash. +def _project_visual_features(wrapper: VLMWrapper, pixel_values: torch.Tensor) -> torch.Tensor: + """Encode + adapt visual features into LLM-dim tokens. + + Accepts a single-image batch ``(B, 3, H, W)`` or a video-clip batch + ``(B, F, 3, H, W)``. For video the frame axis is folded into the batch so + the per-frame vision encoder + adapter run once over ``B*F`` frames, then + the per-frame tokens are concatenated back per clip in frame order to + ``(B, F * tokens_per_frame, dim)``. A single image is just the ``F == 1`` + case with the frame axis absent. + + Casts at the encoder/adapter boundary so the encoder can stay in its HF + dtype (often fp32) while the adapter and transformer run in bf16. """ - feats = wrapper.vision_encoder(pixel_values) - # Adapter-agnostic dtype lookup: use the first adapter parameter's - # dtype so any registered adapter (mlp_2layer, linear, ...) works + is_video = pixel_values.dim() == 5 + # The static visual-token count (residual budget, MoT's positional split) is + # sized for ``frames_per_clip``, so each clip must carry exactly that many + # frames. Validate here to turn a downstream shape/split error into a clear one. + effective_frames = pixel_values.shape[1] if is_video else 1 + if effective_frames != wrapper.frames_per_clip: + raise ValueError( + f"frames-per-clip mismatch: received {effective_frames} frame(s) " + f"(pixel_values.dim()={pixel_values.dim()}) but the wrapper was built with " + f"frames_per_clip={wrapper.frames_per_clip}. Pass a clip with exactly " + "frames_per_clip frames (or rebuild the wrapper for this frame count)." + ) + if is_video: + b, f = pixel_values.shape[0], pixel_values.shape[1] + encoder_input = pixel_values.reshape(b * f, *pixel_values.shape[2:]) + else: + encoder_input = pixel_values + feats = wrapper.vision_encoder(encoder_input) + # Adapter-agnostic dtype lookup: use the first adapter parameter's dtype + # so any registered adapter (mlp_2layer, linear, avgpool, ...) works # without coupling to a specific submodule attribute. adapter_dtype = next(wrapper.adapter.parameters()).dtype if feats.dtype != adapter_dtype: feats = feats.to(adapter_dtype) - return wrapper.adapter(feats) + embeds = wrapper.adapter(feats) + if is_video: + # (B*F, P', dim) -> (B, F*P', dim): frame-contiguous, temporal order kept. + embeds = embeds.reshape(b, f * embeds.shape[1], embeds.shape[2]) + return embeds @registry.register_modality_strategy("joint_decoder") @@ -106,12 +134,14 @@ def prepare( pixel_values: torch.Tensor, input_ids: torch.Tensor, # noqa: ARG002 ) -> ModalityContext: - img_embeds = _project_image_features(wrapper, pixel_values) - n = wrapper.vision_encoder.num_tokens + img_embeds = _project_visual_features(wrapper, pixel_values) + n = img_embeds.shape[1] # pooling-aware: the adapter's actual visual-token count return ModalityContext(prefix_embeds=img_embeds, output_slice=slice(n, None)) def num_image_tokens(self, wrapper: VLMWrapper) -> int: - return wrapper.vision_encoder.num_tokens + return wrapper.frames_per_clip * wrapper.adapter.output_num_tokens( + wrapper.vision_encoder.num_tokens + ) @registry.register_modality_strategy("cross_attention") @@ -132,7 +162,7 @@ def prepare( pixel_values: torch.Tensor, input_ids: torch.Tensor, # noqa: ARG002 ) -> ModalityContext: - img_embeds = _project_image_features(wrapper, pixel_values) + img_embeds = _project_visual_features(wrapper, pixel_values) return ModalityContext(image_features=img_embeds, image_mask=None) def num_image_tokens(self, wrapper: VLMWrapper) -> int: # noqa: ARG002 @@ -166,8 +196,8 @@ def prepare( pixel_values: torch.Tensor, input_ids: torch.Tensor, ) -> ModalityContext: - img_embeds = _project_image_features(wrapper, pixel_values) - n = wrapper.vision_encoder.num_tokens + img_embeds = _project_visual_features(wrapper, pixel_values) + n = img_embeds.shape[1] # pooling-aware: the adapter's actual visual-token count b, t_text = input_ids.shape modality_ids = torch.zeros(b, n + t_text, dtype=torch.long, device=input_ids.device) modality_ids[:, n:] = 1 @@ -178,7 +208,9 @@ def prepare( ) def num_image_tokens(self, wrapper: VLMWrapper) -> int: - return wrapper.vision_encoder.num_tokens + return wrapper.frames_per_clip * wrapper.adapter.output_num_tokens( + wrapper.vision_encoder.num_tokens + ) @registry.register_modality_strategy("moma") @@ -208,8 +240,8 @@ def prepare( pixel_values: torch.Tensor, input_ids: torch.Tensor, ) -> ModalityContext: - img_embeds = _project_image_features(wrapper, pixel_values) - n = wrapper.vision_encoder.num_tokens + img_embeds = _project_visual_features(wrapper, pixel_values) + n = img_embeds.shape[1] # pooling-aware: the adapter's actual visual-token count b, t_text = input_ids.shape modality_ids = torch.zeros(b, n + t_text, dtype=torch.long, device=input_ids.device) modality_ids[:, n:] = 1 @@ -220,7 +252,9 @@ def prepare( ) def num_image_tokens(self, wrapper: VLMWrapper) -> int: - return wrapper.vision_encoder.num_tokens + return wrapper.frames_per_clip * wrapper.adapter.output_num_tokens( + wrapper.vision_encoder.num_tokens + ) def build_modality_strategy(vlm: VLMConfig) -> ModalityStrategy: @@ -246,14 +280,19 @@ class VLMWrapper(nn.Module): def __init__( self, vision_encoder: VisionEncoder, - adapter: nn.Module, + adapter: VisionAdapter, transformer: Transformer, strategy: ModalityStrategy, + frames_per_clip: int = 1, ) -> None: super().__init__() self.vision_encoder = vision_encoder self.adapter = adapter self.transformer = transformer + # Frames per video clip (1 for a single image). The static visual-token + # count is ``frames_per_clip * adapter.output_num_tokens(...)``; the + # strategies use it for ``num_image_tokens`` (residual budget, MoT split). + self.frames_per_clip = frames_per_clip # Strategy is a plain Python object (not nn.Module). nn.Module's # __setattr__ only routes Module/Parameter/Tensor attributes into # _modules/_parameters/_buffers, so plain objects are stored as @@ -321,6 +360,7 @@ def build_vlm_wrapper( vision_config: VisionEncoderConfig, adapter_config: AdapterConfig, vlm_config: VLMConfig, + frames_per_clip: int = 1, ) -> VLMWrapper: """Build a ``VLMWrapper`` from the four top-level configs. @@ -347,19 +387,21 @@ def build_vlm_wrapper( # num_tokens=0 (the "infer from encoder at build time" sentinel) the # config-time check is skipped and the residual-stream allocation goes # unchecked until the model actually runs. This guard fills that gap. - residual_image_tokens = vlm_config.residual_stream_image_tokens(encoder.num_tokens) + in_dim = vision_config.feature_dim or encoder.feature_dim + adapter = build_adapter(adapter_config, in_dim=in_dim, out_dim=model_config.dim) + # Visual tokens entering the LLM = the adapter's output count (pooling + # adapters reduce it; projection adapters are the identity), not the raw + # encoder patch count. This drives the residual budget and MoT's split. + visual_tokens = frames_per_clip * adapter.output_num_tokens(encoder.num_tokens) + residual_image_tokens = vlm_config.residual_stream_image_tokens(visual_tokens) required = residual_image_tokens + vlm_config.max_text_len if model_config.max_seq_len < required: raise ValueError( f"max_seq_len ({model_config.max_seq_len}) insufficient for VLM at build time: " - f"encoder.num_tokens ({encoder.num_tokens}) -> " - f"residual_image_tokens ({residual_image_tokens}) + " + f"encoder.num_tokens ({encoder.num_tokens}) -> adapter visual_tokens " + f"({visual_tokens}) -> residual_image_tokens ({residual_image_tokens}) + " f"vlm.max_text_len ({vlm_config.max_text_len}) = {required}" ) - in_dim = vision_config.feature_dim or encoder.feature_dim - adapter = build_adapter(adapter_config, in_dim=in_dim, out_dim=model_config.dim) - transformer = Transformer( - model_config, vlm_config=vlm_config, num_image_tokens=encoder.num_tokens - ) + transformer = Transformer(model_config, vlm_config=vlm_config, num_image_tokens=visual_tokens) strategy = build_modality_strategy(vlm_config) - return VLMWrapper(encoder, adapter, transformer, strategy) + return VLMWrapper(encoder, adapter, transformer, strategy, frames_per_clip=frames_per_clip) diff --git a/pyproject.toml b/pyproject.toml index c3ef2da..8bf0752 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "datasets", "tokenizers", "torchao>=0.17.0", + "av>=17.1.0", ] [dependency-groups] diff --git a/scripts/train.py b/scripts/train.py index d9b8d29..4bfbf8c 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -174,6 +174,7 @@ def main() -> None: vision_config=vision_cfg, adapter_config=adapter_cfg, vlm_config=vlm_cfg, + frames_per_clip=(config.video.max_frames if config.video is not None else 1), ac_mode=tc.activation_checkpointing, mp_policy=mp_policy, param_dtype=tc.param_dtype, @@ -287,37 +288,69 @@ def main() -> None: eos_token_id = _AT.from_pretrained(config.data.tokenizer_path).eos_token_id if is_vlm: - # --- VLM (Joint-Decoder) data path --- - # Mixing VLM + text-only datasets in one run is out of scope on this - # branch. DatasetSource doesn't describe image sources yet; follow-up. - if not config.data.hf_dataset_name or not config.data.tokenizer_path: - raise ValueError("VLM training requires data.hf_dataset_name and data.tokenizer_path") - from transformers import AutoTokenizer + assert vlm_cfg is not None # narrowed by is_vlm + if config.is_video: + # --- Video data path (a clip = ordered frames; same VLM wrapper) --- + assert config.video is not None # narrowed by is_video + if not config.data.tokenizer_path: + raise ValueError("Video training requires data.tokenizer_path") + from transformers import AutoTokenizer - from kempnerforge.data.vlm_dataset import HuggingFaceVLMDataset, VLMCollator + from kempnerforge.data.video_dataset import VideoCollator, WebVidVideoDataset - assert vlm_cfg is not None # narrowed by is_vlm - dataset = HuggingFaceVLMDataset( - dataset_name=config.data.hf_dataset_name, - split=config.data.hf_dataset_split, - image_field=config.data.hf_dataset_image_field, - text_field=config.data.hf_dataset_text_field, - tokenizer_path=config.data.tokenizer_path, - max_text_len=vlm_cfg.max_text_len, - prompt_field=config.data.hf_dataset_prompt_field or None, - image_size=config.data.hf_image_size, - dataset_config=config.data.hf_dataset_config, - ) - # Resolve pad_id from the tokenizer for VLMCollator. Fall back to - # EOS when pad_token_id is unset (gpt2, some Llama families), then - # to 0 as a last resort. Collator also enforces fixed-length - # padding so all DP ranks see identical tensor shapes and emits - # the image_positions slot (D18) for downstream multi-image work. - _tok = AutoTokenizer.from_pretrained(config.data.tokenizer_path) - _pad_id = _tok.pad_token_id - if _pad_id is None: - _pad_id = _tok.eos_token_id if _tok.eos_token_id is not None else 0 - collator = VLMCollator(pad_id=int(_pad_id), max_text_len=vlm_cfg.max_text_len) + vcfg = config.video + dataset = WebVidVideoDataset( + data_root=vcfg.data_root, + split=vcfg.split, + tokenizer_path=config.data.tokenizer_path, + max_text_len=vlm_cfg.max_text_len, + max_frames=vcfg.max_frames, + min_frames=vcfg.min_frames, + fps=vcfg.fps, + frame_size=vcfg.frame_size, + max_samples=vcfg.max_samples, + prompt=vcfg.prompt, + ) + _tok = AutoTokenizer.from_pretrained(config.data.tokenizer_path) + _pad_id = _tok.pad_token_id + if _pad_id is None: + _pad_id = _tok.eos_token_id if _tok.eos_token_id is not None else 0 + collator = VideoCollator(pad_id=int(_pad_id), max_text_len=vlm_cfg.max_text_len) + logger.info(f"Video dataset: {len(dataset):,} clips from {vcfg.data_root}") + else: + # --- Image VLM (Joint-Decoder) data path --- + # Mixing VLM + text-only datasets in one run is out of scope on this + # branch. DatasetSource doesn't describe image sources yet; follow-up. + if not config.data.hf_dataset_name or not config.data.tokenizer_path: + raise ValueError( + "VLM training requires data.hf_dataset_name and data.tokenizer_path" + ) + from transformers import AutoTokenizer + + from kempnerforge.data.vlm_dataset import HuggingFaceVLMDataset, VLMCollator + + dataset = HuggingFaceVLMDataset( + dataset_name=config.data.hf_dataset_name, + split=config.data.hf_dataset_split, + image_field=config.data.hf_dataset_image_field, + text_field=config.data.hf_dataset_text_field, + tokenizer_path=config.data.tokenizer_path, + max_text_len=vlm_cfg.max_text_len, + prompt_field=config.data.hf_dataset_prompt_field or None, + image_size=config.data.hf_image_size, + dataset_config=config.data.hf_dataset_config, + ) + # Resolve pad_id from the tokenizer for VLMCollator. Fall back to + # EOS when pad_token_id is unset (gpt2, some Llama families), then + # to 0 as a last resort. Collator also enforces fixed-length + # padding so all DP ranks see identical tensor shapes and emits + # the image_positions slot (D18) for downstream multi-image work. + _tok = AutoTokenizer.from_pretrained(config.data.tokenizer_path) + _pad_id = _tok.pad_token_id + if _pad_id is None: + _pad_id = _tok.eos_token_id if _tok.eos_token_id is not None else 0 + collator = VLMCollator(pad_id=int(_pad_id), max_text_len=vlm_cfg.max_text_len) + logger.info(f"VLM dataset: {len(dataset):,} samples from {config.data.hf_dataset_name}") sampler = DistributedSampler( dataset, num_replicas=dp_size, rank=dp_rank, shuffle=True, seed=tc.effective_data_seed ) @@ -328,7 +361,6 @@ def main() -> None: config=config.data, collate_fn=collator, ) - logger.info(f"VLM dataset: {len(dataset):,} samples from {config.data.hf_dataset_name}") elif config.data.datasets: # --- Multi-dataset mixing --- diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index ffbc283..5cb8ee3 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -10,9 +10,14 @@ from kempnerforge.config.adapter import AdapterConfig from kempnerforge.config.registry import registry from kempnerforge.model.adapter import ( + POOLING_ADAPTER_TYPES, + AttentionalPoolAdapter, + AvgPoolAdapter, LinearAdapter, MLP2LayerAdapter, + VisionAdapter, build_adapter, + pooled_token_count, ) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -233,3 +238,280 @@ def test_activation_passes_through_to_mlp(self): import torch.nn as nn assert isinstance(adapter.act, nn.SiLU) + + +# --------------------------------------------------------------------------- +# pooled_token_count (pure helper, single source of truth for the post-pool +# token count) +# --------------------------------------------------------------------------- + + +class TestPooledTokenCount: + @pytest.mark.parametrize( + ("n_in", "window", "expected"), + [ + (196, 2, 49), # 14x14 grid, divisible -> 7x7 + (256, 2, 64), # 16x16 -> 8x8 + (729, 3, 81), # 27x27 (Molmo2 SigLIP 378/14) -> 9x9 + (16, 2, 4), # 4x4 -> 2x2 + (16, 3, 4), # 4x4 ragged: ceil(4/3)=2 -> 2x2 + (100, 3, 16), # 10x10 ragged: ceil(10/3)=4 -> 4x4 + (16, 1, 16), # window 1 == identity + (16, 4, 1), # whole grid into one token + ], + ) + def test_counts(self, n_in, window, expected): + assert pooled_token_count(n_in, window) == expected + + def test_non_square_grid_raises(self): + with pytest.raises(ValueError, match="square patch grid"): + pooled_token_count(8, 2) # 8 is not a perfect square + + def test_non_positive_window_raises(self): + with pytest.raises(ValueError, match="window must be positive"): + pooled_token_count(16, 0) + + def test_non_positive_tokens_raises(self): + with pytest.raises(ValueError, match="must be positive"): + pooled_token_count(0, 2) + + def test_require_divisible_raises_on_ragged(self): + # attentional_pool path: a ragged grid is rejected up front, not at forward. + with pytest.raises(ValueError, match="ragged grid"): + pooled_token_count(196, 3, require_divisible=True) # 14x14 not divisible by 3 + + def test_require_divisible_ok_when_divisible(self): + assert pooled_token_count(196, 2, require_divisible=True) == 49 # 14x14 -> 7x7 + + +# --------------------------------------------------------------------------- +# AvgPoolAdapter +# --------------------------------------------------------------------------- + + +class TestAvgPoolAdapter: + def test_forward_shape_divisible(self): + adapter = AvgPoolAdapter(in_dim=96, out_dim=64, pool_window=2).to(DEVICE) + x = torch.randn(2, 16, 96, device=DEVICE) # 4x4 grid + out = adapter(x) + assert out.shape == (2, 4, 64) # 2x2 windows + + def test_forward_shape_ragged(self): + adapter = AvgPoolAdapter(in_dim=96, out_dim=64, pool_window=3).to(DEVICE) + x = torch.randn(2, 16, 96, device=DEVICE) # 4x4 grid, ceil(4/3)=2 + out = adapter(x) + assert out.shape == (2, 4, 64) + + def test_is_vision_adapter(self): + assert isinstance(AvgPoolAdapter(in_dim=8, out_dim=4), VisionAdapter) + + @pytest.mark.parametrize(("n_in", "window"), [(16, 2), (16, 3), (256, 2), (729, 3), (100, 3)]) + def test_output_num_tokens_matches_forward(self, n_in, window): + """The static count MUST equal the actual forward length — MoT's + positional split relies on this agreement.""" + adapter = AvgPoolAdapter(in_dim=32, out_dim=16, pool_window=window) + x = torch.randn(1, n_in, 32) + assert adapter(x).shape[1] == adapter.output_num_tokens(n_in) + + def test_pool_window_override_in_forward(self): + adapter = AvgPoolAdapter(in_dim=32, out_dim=16, pool_window=2) + x = torch.randn(1, 16, 32) # 4x4 grid + # Override to window 4 -> ceil(4/4)=1 -> single token. + assert adapter(x, pool_window=4).shape == (1, 1, 16) + # Default still 2x2 -> 4 tokens. + assert adapter(x).shape == (1, 4, 16) + + def test_non_square_input_raises(self): + adapter = AvgPoolAdapter(in_dim=8, out_dim=4, pool_window=2) + with pytest.raises(ValueError, match="square patch grid"): + adapter(torch.randn(1, 8, 8)) # 8 not a perfect square + + def test_rejects_zero_dim(self): + with pytest.raises(ValueError, match="must be positive"): + AvgPoolAdapter(in_dim=0, out_dim=8) + + def test_rejects_zero_window(self): + with pytest.raises(ValueError, match="pool_window must be positive"): + AvgPoolAdapter(in_dim=8, out_dim=8, pool_window=0) + + def test_backward_grads_flow(self): + adapter = AvgPoolAdapter(in_dim=32, out_dim=16, pool_window=2).to(DEVICE) + x = torch.randn(1, 16, 32, device=DEVICE, requires_grad=True) + adapter(x).sum().backward() + for p in adapter.parameters(): + assert p.grad is not None + assert torch.isfinite(p.grad).all() + + def test_reset_parameters_reinitializes(self): + adapter = AvgPoolAdapter(in_dim=8, out_dim=4) + adapter.reset_parameters() + assert torch.isfinite(adapter.proj.weight).all() + + def test_divisible_pool_is_plain_mean(self): + """With a divisible grid, the pooled value is the exact window mean.""" + adapter = AvgPoolAdapter(in_dim=4, out_dim=4, pool_window=2) + # Make proj an identity so we can read the pooled values directly. + with torch.no_grad(): + adapter.proj.weight.copy_(torch.eye(4)) + adapter.proj.bias.zero_() + x = torch.arange(16.0).view(1, 16, 1).expand(1, 16, 4).contiguous() # 4x4 grid + out = adapter(x) + # Top-left window = mean of grid cells (0,0),(0,1),(1,0),(1,1) = + # tokens 0,1,4,5 -> mean 2.5. + assert torch.allclose(out[0, 0], torch.full((4,), 2.5)) + + def test_forward_rejects_nonpositive_window_override(self): + adapter = AvgPoolAdapter(in_dim=8, out_dim=4, pool_window=2) + with pytest.raises(ValueError, match="pool_window must be positive"): + adapter(torch.randn(1, 16, 8), pool_window=0) + + +# --------------------------------------------------------------------------- +# AttentionalPoolAdapter +# --------------------------------------------------------------------------- + + +class TestAttentionalPoolAdapter: + def test_forward_shape(self): + adapter = AttentionalPoolAdapter(in_dim=96, out_dim=64, pool_window=2, pool_heads=16).to( + DEVICE + ) + x = torch.randn(2, 16, 96, device=DEVICE) # 4x4 grid + out = adapter(x) + assert out.shape == (2, 4, 64) + + def test_is_vision_adapter(self): + assert isinstance(AttentionalPoolAdapter(in_dim=16, out_dim=8, pool_heads=4), VisionAdapter) + + @pytest.mark.parametrize(("n_in", "window"), [(16, 2), (256, 2), (729, 3)]) + def test_output_num_tokens_matches_forward(self, n_in, window): + adapter = AttentionalPoolAdapter(in_dim=32, out_dim=16, pool_window=window, pool_heads=4) + x = torch.randn(1, n_in, 32) + assert adapter(x).shape[1] == adapter.output_num_tokens(n_in) + + def test_ragged_grid_raises(self): + adapter = AttentionalPoolAdapter(in_dim=96, out_dim=64, pool_window=3, pool_heads=16) + with pytest.raises(ValueError, match="divisible"): + adapter(torch.randn(1, 16, 96)) # 4x4 grid, not divisible by 3 + + def test_output_num_tokens_rejects_ragged(self): + # The static count must mirror forward()'s ragged rejection so an invalid + # config fails at build / seq-len-check time, not at the first step. + adapter = AttentionalPoolAdapter(in_dim=96, out_dim=64, pool_window=3, pool_heads=16) + with pytest.raises(ValueError, match="ragged grid"): + adapter.output_num_tokens(16) # 4x4 grid, not divisible by 3 + + def test_heads_must_divide_dim(self): + with pytest.raises(ValueError, match="divisible by"): + AttentionalPoolAdapter(in_dim=96, out_dim=64, pool_heads=7) + + def test_rejects_zero_dim(self): + with pytest.raises(ValueError, match="must be positive"): + AttentionalPoolAdapter(in_dim=0, out_dim=8) + + def test_backward_grads_flow(self): + adapter = AttentionalPoolAdapter(in_dim=32, out_dim=16, pool_window=2, pool_heads=4).to( + DEVICE + ) + x = torch.randn(1, 16, 32, device=DEVICE, requires_grad=True) + adapter(x).sum().backward() + for p in adapter.parameters(): + assert p.grad is not None + assert torch.isfinite(p.grad).all() + + def test_reset_parameters_reinitializes(self): + adapter = AttentionalPoolAdapter(in_dim=16, out_dim=8, pool_heads=4) + adapter.reset_parameters() + assert torch.isfinite(adapter.out_proj.weight).all() + + def test_forward_rejects_nonpositive_window_override(self): + adapter = AttentionalPoolAdapter(in_dim=16, out_dim=8, pool_window=2, pool_heads=4) + with pytest.raises(ValueError, match="pool_window must be positive"): + adapter(torch.randn(1, 16, 16), pool_window=0) + + +# --------------------------------------------------------------------------- +# Pooling adapters: registry + config wiring +# --------------------------------------------------------------------------- + + +class TestPoolingAdapterRegistry: + def test_avgpool_registered(self): + adapter = registry.get_adapter("avgpool")(in_dim=8, out_dim=4, pool_window=2) + assert isinstance(adapter, AvgPoolAdapter) + + def test_attentional_pool_registered(self): + adapter = registry.get_adapter("attentional_pool")(in_dim=8, out_dim=4, pool_heads=2) + assert isinstance(adapter, AttentionalPoolAdapter) + + def test_pooling_types_constant_matches_registry(self): + for name in POOLING_ADAPTER_TYPES: + assert name in registry.list_adapters() + + +class TestAdapterConfigPooling: + def test_pool_defaults(self): + cfg = AdapterConfig() + assert cfg.pool_window == 2 + assert cfg.pool_heads == 16 + + def test_rejects_zero_pool_window(self): + with pytest.raises(ValueError, match="pool_window must be positive"): + AdapterConfig(pool_window=0) + + def test_rejects_zero_pool_heads(self): + with pytest.raises(ValueError, match="pool_heads must be positive"): + AdapterConfig(pool_heads=0) + + def test_extra_kwargs_includes_pool_fields(self): + kwargs = AdapterConfig(type="avgpool", pool_window=3, pool_heads=8).extra_kwargs() + assert kwargs["pool_window"] == 3 + assert kwargs["pool_heads"] == 8 + + def test_output_num_tokens_identity_for_projection(self): + assert AdapterConfig(type="mlp_2layer").output_num_tokens(196) == 196 + assert AdapterConfig(type="linear").output_num_tokens(196) == 196 + + def test_output_num_tokens_pools_for_avgpool(self): + assert AdapterConfig(type="avgpool", pool_window=2).output_num_tokens(196) == 49 + + def test_output_num_tokens_pools_for_attentional(self): + assert AdapterConfig(type="attentional_pool", pool_window=3).output_num_tokens(729) == 81 + + def test_attentional_output_num_tokens_rejects_ragged(self): + # Config-time check rejects a ragged attentional_pool grid (mirrors forward), + # so the misconfig fails at config load, not at the first training step. + with pytest.raises(ValueError, match="ragged grid"): + AdapterConfig(type="attentional_pool", pool_window=3).output_num_tokens(196) + + def test_avgpool_output_num_tokens_allows_ragged(self): + # avgpool pools ragged edges, so the same ragged grid is fine (ceil math). + assert AdapterConfig(type="avgpool", pool_window=3).output_num_tokens(196) == 25 + + def test_output_num_tokens_passthrough_on_sentinel(self): + # num_tokens=0 ("infer at build time") must not trigger the square check. + assert AdapterConfig(type="avgpool").output_num_tokens(0) == 0 + + @pytest.mark.parametrize("adapter_type", ["avgpool", "attentional_pool"]) + def test_config_count_matches_module_count(self, adapter_type): + """Config-time estimate must equal the built module's count so the + config-time seq-len check matches the build-time budget.""" + cfg = AdapterConfig(type=adapter_type, pool_window=2, pool_heads=4) + module = build_adapter(cfg, in_dim=32, out_dim=16) + assert cfg.output_num_tokens(256) == module.output_num_tokens(256) + + +class TestBuildAdapterPooling: + def test_dispatches_to_avgpool_with_window(self): + cfg = AdapterConfig(type="avgpool", pool_window=2) + adapter = build_adapter(cfg, in_dim=32, out_dim=16) + assert isinstance(adapter, AvgPoolAdapter) + assert adapter.pool_window == 2 + out = adapter(torch.randn(1, 16, 32)) # 4x4 grid -> 2x2 + assert out.shape == (1, 4, 16) + + def test_dispatches_to_attentional_with_heads(self): + cfg = AdapterConfig(type="attentional_pool", pool_window=2, pool_heads=8) + adapter = build_adapter(cfg, in_dim=32, out_dim=16) + assert isinstance(adapter, AttentionalPoolAdapter) + assert adapter.pool_heads == 8 diff --git a/tests/unit/test_moma.py b/tests/unit/test_moma.py index d5458c3..4bf05cc 100644 --- a/tests/unit/test_moma.py +++ b/tests/unit/test_moma.py @@ -192,6 +192,9 @@ def __init__(self, in_dim: int, out_dim: int) -> None: super().__init__() self.proj = nn.Linear(in_dim, out_dim) + def output_num_tokens(self, num_input_tokens: int) -> int: + return num_input_tokens # projection stub: token count unchanged + def forward(self, feats: torch.Tensor) -> torch.Tensor: return self.proj(feats) @@ -201,6 +204,7 @@ def __init__(self, num_tokens: int, feature_dim: int, dim: int) -> None: super().__init__() self.vision_encoder = _StubVisionEncoder(num_tokens, feature_dim) self.adapter = _StubAdapter(feature_dim, dim) + self.frames_per_clip = 1 class TestMoMaStrategy: diff --git a/tests/unit/test_video_config.py b/tests/unit/test_video_config.py new file mode 100644 index 0000000..11895a4 --- /dev/null +++ b/tests/unit/test_video_config.py @@ -0,0 +1,78 @@ +"""Unit tests for VideoConfig and its JobConfig wiring.""" + +from __future__ import annotations + +import pytest + +from kempnerforge.config.job import JobConfig +from kempnerforge.config.model import ModelConfig +from kempnerforge.config.video import VideoConfig +from kempnerforge.config.vision import VisionEncoderConfig +from kempnerforge.config.vlm import JointDecoderConfig + + +class TestVideoConfig: + def test_defaults(self): + cfg = VideoConfig() + assert cfg.split == "train" + assert cfg.max_frames == 16 + assert cfg.min_frames == 4 + assert cfg.fps == 2.0 + assert cfg.frame_size == 224 + assert cfg.max_samples == 0 + + def test_bad_split_rejected(self): + with pytest.raises(ValueError, match="video.split"): + VideoConfig(split="test") + + def test_min_greater_than_max_rejected(self): + with pytest.raises(ValueError, match="must be <="): + VideoConfig(min_frames=8, max_frames=4) + + def test_non_positive_fps_rejected(self): + with pytest.raises(ValueError, match="video.fps must be positive"): + VideoConfig(fps=0.0) + + def test_non_positive_frame_size_rejected(self): + with pytest.raises(ValueError, match="video.frame_size must be positive"): + VideoConfig(frame_size=0) + + def test_negative_max_samples_rejected(self): + with pytest.raises(ValueError, match="video.max_samples"): + VideoConfig(max_samples=-1) + + def test_min_frames_below_one_rejected(self): + with pytest.raises(ValueError, match="video.min_frames must be >= 1"): + VideoConfig(min_frames=0) + + def test_max_frames_below_one_rejected(self): + with pytest.raises(ValueError, match="video.max_frames must be >= 1"): + VideoConfig(max_frames=0) + + +class TestJobConfigVideoWiring: + def _vlm_kwargs(self) -> dict: + return { + "model": ModelConfig(dim=64, n_layers=2, n_heads=4, vocab_size=256, max_seq_len=64), + "vision_encoder": VisionEncoderConfig(type="random"), + "vlm": JointDecoderConfig(max_text_len=32), + } + + def test_video_requires_vlm(self): + with pytest.raises(ValueError, match=r"\[video\] is set but \[vlm\] is missing"): + JobConfig(video=VideoConfig(data_root="/some/root")) + + def test_valid_video_job(self): + cfg = JobConfig(video=VideoConfig(data_root="/some/root"), **self._vlm_kwargs()) + assert cfg.is_video is True + assert cfg.is_vlm is True + + def test_is_video_false_without_section(self): + cfg = JobConfig(**self._vlm_kwargs()) + assert cfg.is_video is False + assert cfg.is_vlm is True + + def test_text_only_job_is_not_video(self): + cfg = JobConfig() + assert cfg.is_video is False + assert cfg.is_vlm is False diff --git a/tests/unit/test_video_dataset.py b/tests/unit/test_video_dataset.py new file mode 100644 index 0000000..cc3f964 --- /dev/null +++ b/tests/unit/test_video_dataset.py @@ -0,0 +1,292 @@ +"""Unit tests for WebVidVideoDataset and VideoCollator. + +The dataset is exercised with a stubbed decoder (no real video / no ``av``) +and a char-level mock tokenizer (no HF download), mirroring the approach in +``test_vlm_dataset.py``. +""" + +from __future__ import annotations + +import importlib.util + +import pytest +import torch +from PIL import Image + +from kempnerforge.data import video_dataset as vd +from kempnerforge.data.video_dataset import VideoCollator, WebVidVideoDataset +from kempnerforge.data.vlm_dataset import DEFAULT_IMAGE_MEAN, DEFAULT_IMAGE_STD + + +class _MockTokenizer: + """Char-level tokenizer (a->1..z->26, space->27, '.'->28), pad id 0.""" + + pad_token_id = 0 + eos_token_id = 28 + + def __call__(self, text: str, add_special_tokens: bool = False): + del add_special_tokens + ids = [] + for ch in text.lower(): + if ch == " ": + ids.append(27) + elif ch == ".": + ids.append(28) + elif "a" <= ch <= "z": + ids.append(1 + ord(ch) - ord("a")) + else: + ids.append(0) + return {"input_ids": ids, "attention_mask": [1] * len(ids)} + + +class _StubVideoDataset(WebVidVideoDataset): + """Bypass __init__ (no CSV/tokenizer loading); set attributes directly.""" + + def __init__( + self, + ids: list[str], + caps: list[str], + split: str = "train", + *, + max_frames: int = 8, + min_frames: int = 2, + fps: float = 2.0, + frame_size: int = 16, + max_text_len: int = 8, + prompt: str = "", + ) -> None: + self._ids = ids + self._caps = caps + self._split = split + self._video_dir = f"/fake/videos/{'train' if split == 'train' else 'validation'}" + self._tokenizer = _MockTokenizer() + self._pad_id = 0 + self._max_text_len = max_text_len + self._max_frames = max_frames + self._min_frames = min_frames + self._fps = fps + self._frame_size = frame_size + self._prompt = prompt + self._image_mean = DEFAULT_IMAGE_MEAN + self._image_std = DEFAULT_IMAGE_STD + + +def _frames(n: int, size: int = 16) -> list[Image.Image]: + return [Image.new("RGB", (size, size), color=(i * 10 % 255, 0, 0)) for i in range(n)] + + +# --------------------------------------------------------------------------- +# Video path mapping (verified against the on-disk WebVid layout) +# --------------------------------------------------------------------------- + + +class TestVideoPath: + def test_train_prefix_nesting(self): + ds = _StubVideoDataset(["8469580"], ["x"], split="train") + assert ds._video_path("8469580") == "/fake/videos/train/84/8469/846958/8469580.mp4" + + def test_short_id_prefix(self): + # id shorter than 6 chars: id[:6] is the whole id. + ds = _StubVideoDataset(["84490"], ["x"], split="train") + assert ds._video_path("84490") == "/fake/videos/train/84/8449/84490/84490.mp4" + + def test_validation_is_flat(self): + ds = _StubVideoDataset(["10006310"], ["x"], split="validation") + assert ds._video_path("10006310") == "/fake/videos/validation/10006310.mp4" + + +# --------------------------------------------------------------------------- +# __getitem__ (stubbed decoder) +# --------------------------------------------------------------------------- + + +class TestGetItem: + def test_shapes_and_mask_full_clip(self, monkeypatch): + monkeypatch.setattr(vd, "decode_video_frames", lambda *a, **k: _frames(8)) + ds = _StubVideoDataset(["1"], ["a cat."], max_frames=8, frame_size=16) + item = ds[0] + assert item["pixel_values"].shape == (8, 3, 16, 16) + assert item["pixel_values"].dtype == torch.float32 + assert item["frame_mask"].shape == (8,) + assert item["frame_mask"].dtype == torch.bool + assert item["frame_mask"].all() + assert item["input_ids"].shape == (8,) + assert item["labels"].shape == (8,) + + def test_pads_and_masks_short_clip(self, monkeypatch): + monkeypatch.setattr(vd, "decode_video_frames", lambda *a, **k: _frames(3)) + ds = _StubVideoDataset(["1"], ["a dog."], max_frames=8) + item = ds[0] + assert item["frame_mask"].tolist() == [True, True, True, False, False, False, False, False] + # Padded frames are zeros. + assert torch.count_nonzero(item["pixel_values"][3:]) == 0 + + def test_caption_is_supervised_when_frames_present(self, monkeypatch): + monkeypatch.setattr(vd, "decode_video_frames", lambda *a, **k: _frames(4)) + ds = _StubVideoDataset(["1"], ["abc"], max_frames=8, max_text_len=8) + item = ds[0] + # "abc" -> ids 1,2,3 supervised; rest -100. + assert item["labels"][:3].tolist() == [1, 2, 3] + assert (item["labels"][3:] == -100).all() + + def test_decode_failure_yields_zero_clip_no_loss(self, monkeypatch): + def _boom(*a, **k): + raise RuntimeError("corrupt video") + + monkeypatch.setattr(vd, "decode_video_frames", _boom) + ds = _StubVideoDataset(["1"], ["a cat."], max_frames=8) + item = ds[0] + assert torch.count_nonzero(item["pixel_values"]) == 0 + assert not item["frame_mask"].any() + assert (item["labels"] == -100).all() # no supervision for an unloadable clip + + def test_empty_decode_yields_zero_clip_no_loss(self, monkeypatch): + monkeypatch.setattr(vd, "decode_video_frames", lambda *a, **k: []) + ds = _StubVideoDataset(["1"], ["a cat."], max_frames=4) + item = ds[0] + assert not item["frame_mask"].any() + assert (item["labels"] == -100).all() + + def test_prompt_is_masked(self, monkeypatch): + monkeypatch.setattr(vd, "decode_video_frames", lambda *a, **k: _frames(2)) + ds = _StubVideoDataset(["1"], ["xyz"], max_frames=4, max_text_len=8, prompt="ab") + item = ds[0] + # prompt "ab" (2 toks) masked; "xyz" (24,25,26) supervised. + assert item["input_ids"][:5].tolist() == [1, 2, 24, 25, 26] + assert item["labels"][:2].tolist() == [-100, -100] + assert item["labels"][2:5].tolist() == [24, 25, 26] + + def test_len(self): + ds = _StubVideoDataset(["1", "2", "3"], ["a", "b", "c"]) + assert len(ds) == 3 + + +# --------------------------------------------------------------------------- +# VideoCollator +# --------------------------------------------------------------------------- + + +class TestVideoCollator: + def _sample(self, n_frames_valid: int, max_frames: int = 4, max_text_len: int = 8): + pv = torch.zeros(max_frames, 3, 16, 16) + pv[:n_frames_valid] = torch.randn(n_frames_valid, 3, 16, 16) + mask = torch.zeros(max_frames, dtype=torch.bool) + mask[:n_frames_valid] = True + ids = torch.zeros(max_text_len, dtype=torch.long) + ids[:3] = torch.tensor([1, 2, 3]) + labels = torch.full((max_text_len,), -100, dtype=torch.long) + labels[:3] = torch.tensor([1, 2, 3]) + return {"pixel_values": pv, "frame_mask": mask, "input_ids": ids, "labels": labels} + + def test_batch_shapes(self): + collator = VideoCollator(pad_id=0, max_text_len=8) + batch = collator([self._sample(4), self._sample(2), self._sample(3)]) + assert batch["pixel_values"].shape == (3, 4, 3, 16, 16) + assert batch["frame_mask"].shape == (3, 4) + assert batch["frame_mask"].dtype == torch.bool + assert batch["input_ids"].shape == (3, 8) + assert batch["labels"].shape == (3, 8) + + def test_frame_mask_preserved(self): + collator = VideoCollator(pad_id=0, max_text_len=8) + batch = collator([self._sample(2, max_frames=4)]) + assert batch["frame_mask"][0].tolist() == [True, True, False, False] + + def test_empty_batch_raises(self): + with pytest.raises(ValueError, match="empty batch"): + VideoCollator(pad_id=0, max_text_len=8)([]) + + def test_max_text_len_must_be_positive(self): + with pytest.raises(ValueError, match="max_text_len must be positive"): + VideoCollator(pad_id=0, max_text_len=0) + + +# --------------------------------------------------------------------------- +# Real dataset integration: build a synthetic WebVid layout (CSV manifest + +# a tiny encoded .mp4 at the prefix path) and exercise the real __init__, +# manifest load, path mapping, __getitem__ decode, and the decode-failure +# path. Uses av (a hard dependency) so it runs in CI; gpt2 tokenizer matches +# the existing VLM dataset tests. +# --------------------------------------------------------------------------- + +_AV_AVAILABLE = importlib.util.find_spec("av") is not None + + +def _write_mp4(path, n_frames: int, size: int = 32, fps: int = 8) -> None: + import av + import numpy as np + + with av.open(str(path), mode="w") as container: + stream = container.add_stream("mpeg4", rate=fps) + stream.width = size + stream.height = size + stream.pix_fmt = "yuv420p" + for i in range(n_frames): + arr = np.full((size, size, 3), (i * 17) % 256, dtype=np.uint8) + frame = av.VideoFrame.from_ndarray(arr, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + for packet in stream.encode(): + container.mux(packet) + + +@pytest.mark.skipif(not _AV_AVAILABLE, reason="requires the 'av' package") +class TestRealDatasetIntegration: + def _manifest_dir(self, root): + d = root / "raw" / "webvid-10M" / "data" / "train" / "partitions" + d.mkdir(parents=True) + return d + + def test_init_getitem_and_decode(self, tmp_path): + vid, cap = "123456", "a test clip" + (self._manifest_dir(tmp_path) / "0000.csv").write_text(f"videoid,name\n{vid},{cap}\n") + vdir = tmp_path / "raw" / "videos" / "train" / vid[:2] / vid[:4] / vid[:6] + vdir.mkdir(parents=True) + _write_mp4(vdir / f"{vid}.mp4", n_frames=16, size=32, fps=8) + + ds = WebVidVideoDataset( + data_root=str(tmp_path), + split="train", + tokenizer_path="gpt2", + max_text_len=16, + max_frames=8, + min_frames=4, + fps=2.0, + frame_size=32, + ) + assert len(ds) == 1 + item = ds[0] + assert item["pixel_values"].shape == (8, 3, 32, 32) + assert item["frame_mask"].any() # real frames decoded + assert (item["labels"] != -100).any() # caption supervised + + def test_decode_failure_is_masked(self, tmp_path): + # Manifest points at a videoid with no .mp4 on disk -> decode raises, + # __getitem__ catches it and yields a zero clip with no loss. + (self._manifest_dir(tmp_path) / "0000.csv").write_text("videoid,name\n999999,missing\n") + ds = WebVidVideoDataset( + data_root=str(tmp_path), + split="train", + tokenizer_path="gpt2", + max_text_len=8, + max_frames=4, + min_frames=2, + fps=2.0, + frame_size=16, + ) + item = ds[0] + assert not item["frame_mask"].any() + assert (item["labels"] == -100).all() + + def test_empty_manifest_raises(self, tmp_path): + self._manifest_dir(tmp_path) # dir exists but no CSVs + with pytest.raises(FileNotFoundError, match="No partition CSVs"): + WebVidVideoDataset( + data_root=str(tmp_path), + split="train", + tokenizer_path="gpt2", + max_text_len=8, + max_frames=4, + min_frames=2, + fps=2.0, + ) diff --git a/tests/unit/test_video_io.py b/tests/unit/test_video_io.py new file mode 100644 index 0000000..2224764 --- /dev/null +++ b/tests/unit/test_video_io.py @@ -0,0 +1,157 @@ +"""Unit tests for video frame sampling and decoding.""" + +from __future__ import annotations + +import importlib.util +import os + +import pytest + +from kempnerforge.data.video_io import sample_timestamps + +# A known-good WebVid clip on the Kempner testbed; the decode integration test +# is skipped when ``av`` or the data are unavailable (CI without either). +_WEBVID_CLIP = ( + "/n/holylfs06/LABS/kempner_shared/Everyone/testbed/video/webvid-10m/" + "raw/videos/train/21/2117/211794/21179416.mp4" +) +_AV_AVAILABLE = importlib.util.find_spec("av") is not None + + +# --------------------------------------------------------------------------- +# sample_timestamps (pure policy, no decoder) +# --------------------------------------------------------------------------- + + +class TestSampleTimestamps: + def test_zero_duration_returns_single_start(self): + assert sample_timestamps(0.0, fps=2.0, min_frames=4, max_frames=16) == [0.0] + + def test_negative_duration_returns_single_start(self): + assert sample_timestamps(-3.0, fps=2.0, min_frames=4, max_frames=16) == [0.0] + + def test_includes_first_and_last_frame(self): + ts = sample_timestamps(10.0, fps=2.0, min_frames=4, max_frames=16) + assert ts[0] == 0.0 + assert ts[-1] == pytest.approx(10.0) + + def test_strictly_increasing(self): + ts = sample_timestamps(7.5, fps=2.0, min_frames=4, max_frames=16) + assert all(b > a for a, b in zip(ts, ts[1:], strict=False)) + + def test_caps_at_max_frames(self): + # 100s * 2fps = 200 desired, capped to 16, uniformly over [0, 100]. + ts = sample_timestamps(100.0, fps=2.0, min_frames=4, max_frames=16) + assert len(ts) == 16 + assert ts[-1] == pytest.approx(100.0) + + def test_target_rate_when_under_cap(self): + # 2s * 2fps = 4 frames, within [4, 16]. + ts = sample_timestamps(2.0, fps=2.0, min_frames=4, max_frames=16) + assert len(ts) == 4 + assert ts == pytest.approx([0.0, 2 / 3, 4 / 3, 2.0]) + + def test_floors_at_min_frames(self): + # 1s * 2fps = 2 desired, raised to min_frames=4. + ts = sample_timestamps(1.0, fps=2.0, min_frames=4, max_frames=16) + assert len(ts) == 4 + + def test_single_frame_when_max_is_one(self): + ts = sample_timestamps(5.0, fps=2.0, min_frames=1, max_frames=1) + assert ts == [0.0] + + @pytest.mark.parametrize("fps", [0.0, -1.0]) + def test_bad_fps_raises(self, fps): + with pytest.raises(ValueError, match="fps must be positive"): + sample_timestamps(10.0, fps=fps, min_frames=4, max_frames=16) + + def test_min_greater_than_max_raises(self): + with pytest.raises(ValueError, match="must be <="): + sample_timestamps(10.0, fps=2.0, min_frames=8, max_frames=4) + + def test_min_below_one_raises(self): + with pytest.raises(ValueError, match=">= 1"): + sample_timestamps(10.0, fps=2.0, min_frames=0, max_frames=4) + + +# --------------------------------------------------------------------------- +# decode_video_frames (integration; needs av + the testbed data) +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif( + not _AV_AVAILABLE or not os.path.exists(_WEBVID_CLIP), + reason="requires the 'av' package and the WebVid testbed clip", +) +class TestDecodeVideoFramesIntegration: + def test_decodes_pil_frames(self): + from PIL import Image + + from kempnerforge.data.video_io import decode_video_frames + + frames = decode_video_frames(_WEBVID_CLIP, fps=2.0, min_frames=4, max_frames=8) + assert 1 <= len(frames) <= 8 + assert all(isinstance(f, Image.Image) and f.mode == "RGB" for f in frames) + + def test_respects_max_frames(self): + from kempnerforge.data.video_io import decode_video_frames + + frames = decode_video_frames(_WEBVID_CLIP, fps=8.0, min_frames=4, max_frames=4) + assert len(frames) == 4 + + def test_missing_file_raises(self): + from kempnerforge.data.video_io import decode_video_frames + + with pytest.raises(Exception): # noqa: B017,PT011 - any av/OS error is acceptable + decode_video_frames("/no/such/video.mp4", fps=2.0, min_frames=4, max_frames=8) + + +def _write_mp4(path, n_frames: int, size: int = 32, fps: int = 10) -> None: + """Encode a tiny solid-color clip with PyAV (av is a hard dependency).""" + import av + import numpy as np + + with av.open(str(path), mode="w") as container: + stream = container.add_stream("mpeg4", rate=fps) + stream.width = size + stream.height = size + stream.pix_fmt = "yuv420p" + for i in range(n_frames): + arr = np.full((size, size, 3), (i * 17) % 256, dtype=np.uint8) + frame = av.VideoFrame.from_ndarray(arr, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + for packet in stream.encode(): # flush + container.mux(packet) + + +@pytest.mark.skipif(not _AV_AVAILABLE, reason="requires the 'av' package") +class TestDecodeSynthetic: + """Decode a synthetic clip (no external data) — runs in CI since av is a dep.""" + + def test_decodes_rgb_frames(self, tmp_path): + from PIL import Image + + from kempnerforge.data.video_io import decode_video_frames + + path = tmp_path / "clip.mp4" + _write_mp4(path, n_frames=20, fps=10) # ~2s + frames = decode_video_frames(str(path), fps=2.0, min_frames=4, max_frames=8) + assert 1 <= len(frames) <= 8 + assert all(isinstance(f, Image.Image) and f.mode == "RGB" for f in frames) + + def test_respects_max_frames(self, tmp_path): + from kempnerforge.data.video_io import decode_video_frames + + path = tmp_path / "clip.mp4" + _write_mp4(path, n_frames=40, fps=10) # ~4s + frames = decode_video_frames(str(path), fps=8.0, min_frames=4, max_frames=4) + assert len(frames) == 4 + + def test_short_clip_returns_frames(self, tmp_path): + from kempnerforge.data.video_io import decode_video_frames + + path = tmp_path / "short.mp4" + _write_mp4(path, n_frames=3, fps=10) # shorter than min_frames request + frames = decode_video_frames(str(path), fps=2.0, min_frames=4, max_frames=8) + assert len(frames) >= 1 diff --git a/tests/unit/test_vlm.py b/tests/unit/test_vlm.py index b05275d..19a1760 100644 --- a/tests/unit/test_vlm.py +++ b/tests/unit/test_vlm.py @@ -16,6 +16,7 @@ CrossAttentionConfig, FreezeSpec, JointDecoderConfig, + MoMaConfig, MoTConfig, VLMConfig, ) @@ -449,3 +450,201 @@ def test_ca_forward_logits_text_only_shape(self): with torch.no_grad(): logits, _ = wrapper(pixel_values, input_ids) assert logits.shape == (1, 16, 256) + + +# --------------------------------------------------------------------------- +# Pooling connector (Phase 1): a pooling adapter reduces the visual-token +# count between the encoder and the LLM. The whole VLM path must use the +# adapter's output count (not encoder.num_tokens) for the prefix length, +# output_slice, modality_ids, and MoT's positional split. +# --------------------------------------------------------------------------- + + +def _build_pooled_jd_wrapper(adapter_type: str = "avgpool", pool_window: int = 2) -> VLMWrapper: + # Encoder emits a 4x4 grid (16 tokens); a 2x2 pool -> 4 visual tokens. + mc = ModelConfig(dim=64, n_layers=2, n_heads=4, vocab_size=256, max_seq_len=64) + vc = VisionEncoderConfig(type="random", feature_dim=96, num_tokens=16) + ac = AdapterConfig(type=adapter_type, pool_window=pool_window) + lc = JointDecoderConfig(max_text_len=32) + return build_vlm_wrapper(mc, vc, ac, lc) + + +def _build_pooled_mot_wrapper(pool_window: int = 2) -> VLMWrapper: + mc = ModelConfig( + dim=64, n_layers=2, n_heads=4, vocab_size=256, max_seq_len=64, ffn_hidden_dim=128 + ) + vc = VisionEncoderConfig(type="random", feature_dim=96, num_tokens=16) + ac = AdapterConfig(type="avgpool", pool_window=pool_window) + lc = MoTConfig(max_text_len=32) + return build_vlm_wrapper(mc, vc, ac, lc) + + +class TestPoolingConnector: + @pytest.mark.parametrize("adapter_type", ["avgpool", "attentional_pool"]) + def test_num_image_tokens_is_pooled_count(self, adapter_type): + # 16 patch tokens, 2x2 pool -> 4 visual tokens. + wrapper = _build_pooled_jd_wrapper(adapter_type, pool_window=2) + assert wrapper.num_image_tokens == 4 + + @pytest.mark.parametrize("adapter_type", ["avgpool", "attentional_pool"]) + def test_jd_prefix_length_is_pooled(self, adapter_type): + wrapper = _build_pooled_jd_wrapper(adapter_type, pool_window=2) + strategy = JointDecoderStrategy() + pixels = torch.randn(2, 3, 16, 16) + input_ids = torch.randint(0, 256, (2, 12)) + ctx = strategy.prepare(wrapper, pixels, input_ids) + assert ctx.prefix_embeds is not None + assert ctx.prefix_embeds.shape == (2, 4, 64) # pooled prefix, model dim + assert ctx.output_slice == slice(4, None) + + @pytest.mark.parametrize("adapter_type", ["avgpool", "attentional_pool"]) + def test_jd_forward_trims_pooled_prefix(self, adapter_type): + wrapper = _build_pooled_jd_wrapper(adapter_type, pool_window=2).to(DEVICE).eval() + pixels = torch.randn(2, 3, 16, 16, device=DEVICE) + input_ids = torch.randint(0, 256, (2, 20), device=DEVICE) + with torch.no_grad(): + logits, _ = wrapper(pixels, input_ids) + # output_slice trims the 4 pooled positions -> logits cover text only. + assert logits.shape == (2, 20, 256) + + def test_mot_split_uses_pooled_count(self): + """MoT modality_ids length and the Transformer's positional split + both key off the pooled count; a mismatch would crash the forward. + """ + wrapper = _build_pooled_mot_wrapper(pool_window=2).to(DEVICE) + strategy = MoTStrategy() + pixels = torch.randn(2, 3, 16, 16, device=DEVICE) + input_ids = torch.randint(0, 256, (2, 16), device=DEVICE) + ctx = strategy.prepare(wrapper, pixels, input_ids) + assert ctx.modality_ids is not None + assert ctx.modality_ids.shape == (2, 4 + 16) # pooled prefix + text + assert (ctx.modality_ids[:, :4] == 0).all() + assert (ctx.modality_ids[:, 4:] == 1).all() + # End-to-end forward exercises the build-time _mot_n_image split, + # which must equal the runtime pooled prefix length (4). + labels = torch.full((2, 16), -100, dtype=torch.long, device=DEVICE) + logits, _ = wrapper(pixels, input_ids, labels) + assert logits.shape == (2, 16, 256) + + def test_projection_adapter_keeps_token_count(self): + """Regression: a non-pooling adapter must leave num_image_tokens == + encoder.num_tokens (the image path stays bit-for-bit unchanged).""" + wrapper = _build_pooled_jd_wrapper("mlp_2layer", pool_window=2) + assert wrapper.num_image_tokens == 16 # no pooling -> unchanged + + +# --------------------------------------------------------------------------- +# Video forward (Phase 3): the wrapper consumes a (B, F, 3, H, W) clip batch. +# _project_visual_features folds the frame axis through the encoder+pooler to +# (B, F*P', dim); the static visual-token count (frames_per_clip * per-frame) +# drives the residual budget + MoT's split and must equal the runtime prefix. +# --------------------------------------------------------------------------- + + +def _video_wrapper(vlm_cfg, frames: int = 4, *, ffn_hidden_dim: int | None = None) -> VLMWrapper: + # Encoder: 4x4 patch grid (16 tokens); avgpool 2x2 -> 4 tokens/frame. + # frames=4 -> 4*4 = 16 visual tokens in the residual prefix. + mc_kwargs: dict[str, int] = { + "dim": 64, + "n_layers": 2, + "n_heads": 4, + "vocab_size": 256, + "max_seq_len": 64, + } + if ffn_hidden_dim is not None: + mc_kwargs["ffn_hidden_dim"] = ffn_hidden_dim + mc = ModelConfig(**mc_kwargs) + vc = VisionEncoderConfig(type="random", feature_dim=96, num_tokens=16) + ac = AdapterConfig(type="avgpool", pool_window=2) + return build_vlm_wrapper(mc, vc, ac, vlm_cfg, frames_per_clip=frames) + + +class TestVideoForward: + def test_num_image_tokens_is_frames_times_per_frame(self): + # 4 frames * (16 patches -> 2x2 pool -> 4) = 16 visual tokens. + wrapper = _video_wrapper(JointDecoderConfig(max_text_len=8), frames=4) + assert wrapper.num_image_tokens == 16 + + def test_projector_folds_frame_axis(self): + wrapper = _video_wrapper(JointDecoderConfig(max_text_len=8), frames=4) + ctx = JointDecoderStrategy().prepare( + wrapper, torch.randn(2, 4, 3, 16, 16), torch.randint(0, 256, (2, 6)) + ) + assert ctx.prefix_embeds is not None + assert ctx.prefix_embeds.shape == (2, 16, 64) # (B, F*P', dim) + assert ctx.output_slice == slice(16, None) + + def test_static_count_matches_runtime_prefix(self): + """MoT's positional split uses the build-time count; it must equal the + runtime prefix length (frames * per-frame).""" + wrapper = _video_wrapper(JointDecoderConfig(max_text_len=8), frames=4) + ctx = JointDecoderStrategy().prepare( + wrapper, torch.randn(1, 4, 3, 16, 16), torch.randint(0, 256, (1, 6)) + ) + assert ctx.prefix_embeds is not None + assert ctx.prefix_embeds.shape[1] == wrapper.num_image_tokens == 16 + + @pytest.mark.parametrize("arch", ["joint_decoder", "cross_attention", "mot", "moma"]) + def test_video_forward_all_archs(self, arch): + ffn = 128 if arch in ("mot", "moma") else None + if arch == "joint_decoder": + vlm_cfg: VLMConfig = JointDecoderConfig(max_text_len=8) + elif arch == "cross_attention": + vlm_cfg = CrossAttentionConfig(max_text_len=8, cross_attention_every_n_layers=2) + elif arch == "mot": + vlm_cfg = MoTConfig(max_text_len=8) + else: + vlm_cfg = MoMaConfig(max_text_len=8) + wrapper = _video_wrapper(vlm_cfg, frames=4, ffn_hidden_dim=ffn).to(DEVICE) + pixels = torch.randn(2, 4, 3, 16, 16, device=DEVICE) + input_ids = torch.randint(0, 256, (2, 6), device=DEVICE) + labels = torch.full((2, 6), -100, dtype=torch.long, device=DEVICE) + logits, _ = wrapper(pixels, input_ids, labels) + # output_slice trims the F*P' visual prefix -> logits cover text only. + assert logits.shape == (2, 6, 256) + + def test_video_forward_backward_grads(self): + wrapper = _video_wrapper(JointDecoderConfig(max_text_len=8), frames=4).to(DEVICE) + pixels = torch.randn(1, 4, 3, 16, 16, device=DEVICE) + input_ids = torch.randint(0, 256, (1, 6), device=DEVICE) + logits, _ = wrapper(pixels, input_ids) + logits.sum().backward() + for p in wrapper.adapter.parameters(): + assert p.grad is not None + assert torch.isfinite(p.grad).all() + + def test_image_path_unchanged_with_4d(self): + """frames_per_clip=1 + a 4D image batch still works (image path intact).""" + wrapper = _video_wrapper(JointDecoderConfig(max_text_len=8), frames=1).to(DEVICE).eval() + pixels = torch.randn(2, 3, 16, 16, device=DEVICE) # 4D single-image batch + input_ids = torch.randint(0, 256, (2, 6), device=DEVICE) + with torch.no_grad(): + logits, _ = wrapper(pixels, input_ids) + assert logits.shape == (2, 6, 256) + assert wrapper.num_image_tokens == 4 # per-frame pooled count, frames=1 + + def test_video_forward_dtype_mismatch_cast(self): + """Encoder fp32 + bf16 adapter/transformer: the cast inside the visual + projector lets the 5D video path run (covers the dtype-cast branch).""" + wrapper = _video_wrapper(JointDecoderConfig(max_text_len=8), frames=4).to(DEVICE) + wrapper.adapter.to(torch.bfloat16) + wrapper.transformer.to(torch.bfloat16) + # vision_encoder (RandomVisionEncoder) stays fp32 + pixels = torch.randn(2, 4, 3, 16, 16, device=DEVICE) + input_ids = torch.randint(0, 256, (2, 6), device=DEVICE) + logits, _ = wrapper(pixels, input_ids) + assert logits.dtype == torch.bfloat16 + assert logits.shape == (2, 6, 256) + + def test_frame_count_mismatch_raises(self): + """A clip whose frame count != frames_per_clip is rejected at the + projection boundary: the static MoT split and seq-len budget assume + frames_per_clip, so a mismatch is a clear error, not a confusing one.""" + wrapper = _video_wrapper(JointDecoderConfig(max_text_len=8), frames=4).to(DEVICE) + input_ids = torch.randint(0, 256, (2, 6), device=DEVICE) + # 2-frame clip into a 4-frame wrapper. + with pytest.raises(ValueError, match="frames-per-clip mismatch"): + wrapper(torch.randn(2, 2, 3, 16, 16, device=DEVICE), input_ids) + # 4D single-image batch into a video (frames_per_clip>1) wrapper. + with pytest.raises(ValueError, match="frames-per-clip mismatch"): + wrapper(torch.randn(2, 3, 16, 16, device=DEVICE), input_ids) diff --git a/uv.lock b/uv.lock index 400a8b2..42f3530 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.12" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", @@ -208,6 +208,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/64/b4/17d4b0b2a2dc85a6df63d1157e028ed19f90d4cd97c36717afef2bc2f395/attrs-26.1.0-py3-none-any.whl", hash = "sha256:c647aa4a12dfbad9333ca4e71fe62ddc36f4e63b2d260a37a8b83d2f043ac309", size = 67548, upload-time = "2026-03-19T14:22:23.645Z" }, ] +[[package]] +name = "av" +version = "17.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5e/e3/477fa20578c284abeda08d91b63ee9abaebc93445d8feeb989d3d444bae1/av-17.1.0.tar.gz", hash = "sha256:7f1e71ff621b66253333926f948e00faae11d855b2442133c65128bca64cdeb3", size = 4288546, upload-time = "2026-06-07T05:52:55.999Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/87/8036b5c781bc3639ea04ef42d4e26da253bd4bd4311d8705b6a1c8824047/av-17.1.0-cp311-abi3-macosx_11_0_x86_64.whl", hash = "sha256:ad7b4aa011093324b7118245f50ac6db244cfe9900d4072508a5245a2b0d3f41", size = 22460847, upload-time = "2026-06-07T05:52:04.261Z" }, + { url = "https://files.pythonhosted.org/packages/6d/af/dfdf6fc7b17814b50d0aa9e7a7e37b87be91be3890f44b0d525433cd1fd1/av-17.1.0-cp311-abi3-macosx_14_0_arm64.whl", hash = "sha256:43ebbe977f19a7f2d2bd1a4e119675a0b15e05852cf7309846b6ab922ba7ffe9", size = 18159115, upload-time = "2026-06-07T05:52:06.64Z" }, + { url = "https://files.pythonhosted.org/packages/ad/13/64f6c466471cea225b8b2f4cdc51a571f8a286984b55a08d169b932fda5d/av-17.1.0-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:6a20658ec7d96a70e14b1196eff00b7cdd8831ac3b99868e16b8ba8b24090847", size = 33224427, upload-time = "2026-06-07T05:52:09.165Z" }, + { url = "https://files.pythonhosted.org/packages/77/43/96b35170bf2e64e00a41748c6400ff73232dc0fc62ded283679fb07c7fe0/av-17.1.0-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f9a65d1f48b818323fb411e80358f89d77dec340b01d27c6b2dfbb9cbf4b779f", size = 35370183, upload-time = "2026-06-07T05:52:11.959Z" }, + { url = "https://files.pythonhosted.org/packages/2e/b3/8e8b4b6498731bfbd88e8399a756543f8088f1bd33d08eab678b5aebe728/av-17.1.0-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:58f7593726437cda5bd19793027e027768450b5c4a594777bf487798a33db702", size = 24459265, upload-time = "2026-06-07T05:52:14.66Z" }, + { url = "https://files.pythonhosted.org/packages/14/ac/ceb84b7553db21f1143d817245c560d9267168e1e58b1a8eeae2b62c4d04/av-17.1.0-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:bbab058bd965309f39962e53caac8126987c68c0be094fc4f9427e5615b0218f", size = 34283709, upload-time = "2026-06-07T05:52:17.389Z" }, + { url = "https://files.pythonhosted.org/packages/59/f9/4115fd84148c9a1cf365096694be6ac882fd3cd3cdb7a2f35e71fecf1631/av-17.1.0-cp311-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:9514cfda85180554c430695282faf4be3ffdf95775d8519733821244eecb58e0", size = 25397573, upload-time = "2026-06-07T05:52:20.012Z" }, + { url = "https://files.pythonhosted.org/packages/e2/ac/92e52d5ed0e0b84d9d93e52b4338c2713d8a44082b8696e6516fdae7c4e4/av-17.1.0-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e1c90f85cd7431ede95b11e8e711571a896ebea433f298849c2c0f1594c8d86e", size = 36451495, upload-time = "2026-06-07T05:52:22.581Z" }, + { url = "https://files.pythonhosted.org/packages/6b/f2/53a7cd34adb6a971d7e6d99663e74db286966c9db8afdca17472fdf0f98e/av-17.1.0-cp311-abi3-win_amd64.whl", hash = "sha256:5df5c1172ef1cf65a1529d612f7da7798ce2cf82c1ff7212466b538a6cc7214c", size = 28036393, upload-time = "2026-06-07T05:52:25.657Z" }, + { url = "https://files.pythonhosted.org/packages/66/47/cd9ae0edf2206351c1251bb94b5ec58728e42c5f6ee16c03c412f3a1bb3e/av-17.1.0-cp311-abi3-win_arm64.whl", hash = "sha256:ee98534242a74da847af78624779ac5a3177dc7c69f956a4da9e6f0fdb37d7f6", size = 21174601, upload-time = "2026-06-07T05:52:28.077Z" }, + { url = "https://files.pythonhosted.org/packages/36/90/b5668cddb3c401fcf22553bc495d5b0c6d8a01d118624b26f0db1d0b8653/av-17.1.0-cp314-cp314t-macosx_11_0_x86_64.whl", hash = "sha256:5327807c1219293803ef0c5d1578ff3ae1cf638c09e5998962026e1a554ec240", size = 22699499, upload-time = "2026-06-07T05:52:30.335Z" }, + { url = "https://files.pythonhosted.org/packages/e0/7e/7be6bfddb823d045ff9fd5d4deb922ee3847605e162c3882e6c45b4c35ff/av-17.1.0-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:6c9b71fe5c0c5a8d303b1588d4d8ce9397d6b023f467cfef95000ba1f75507fa", size = 18366696, upload-time = "2026-06-07T05:52:32.645Z" }, + { url = "https://files.pythonhosted.org/packages/a2/23/391dcfa75c1ae1977efca44b753a11b929399b558826670c16a8808dd0e3/av-17.1.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:f997e3351bdf51127c07a74e21741a2996e9230cbeb2d81c14acde761b116c9c", size = 36582649, upload-time = "2026-06-07T05:52:35.218Z" }, + { url = "https://files.pythonhosted.org/packages/fb/32/7312854868b318b9d1b1dcbd1bddb460aaaeac7d57f816e11efec3bef5b1/av-17.1.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:efe9b1397300b67b644ad220c89df4892a76f2debe70f16bae1749fa20526e63", size = 38479390, upload-time = "2026-06-07T05:52:37.968Z" }, + { url = "https://files.pythonhosted.org/packages/2a/72/af47f59b4458e81ca7d89f477698dbfb3d5a0cd8ae6c1e4441d01074af8a/av-17.1.0-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:fa64e1f1500d01c4a98e7a41dc1a9a35fb4dfe71f5de0389264ec1192200c76a", size = 27127432, upload-time = "2026-06-07T05:52:40.371Z" }, + { url = "https://files.pythonhosted.org/packages/88/85/c2e6861baf0f8c7d21c4ce811d4d424fedac915e3910d3570ce4377717dc/av-17.1.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ffbd78d73d2c9bf31e9a007c992faec3991428b2941a3b085b84fb82e8c32d19", size = 37406592, upload-time = "2026-06-07T05:52:43.215Z" }, + { url = "https://files.pythonhosted.org/packages/ba/40/3cc13125aea976101c0858af99ac47257c0654411aa199b5d8e81eea7002/av-17.1.0-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:bff8896454b38fcb785a70e5ae0485d7021cb776303a5849393128a30b8f850b", size = 28336228, upload-time = "2026-06-07T05:52:46.134Z" }, + { url = "https://files.pythonhosted.org/packages/a2/38/c7d9c3e746209a1a695c13e3aa7d817229e84a85d0a84271f313d1befdd3/av-17.1.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:1284addf3c0dd939887a9722dc30df2241a97471ad52c3c507e31583ae22ff02", size = 39490680, upload-time = "2026-06-07T05:52:48.887Z" }, + { url = "https://files.pythonhosted.org/packages/a1/25/9d42da561b7b8f7dabdfaebba07b52977bee58c5c7e4285ac991abcfaa72/av-17.1.0-cp314-cp314t-win_amd64.whl", hash = "sha256:ec630be6321b04e317862f6082e84812bbd801e55a3c2298312e3fc8a0a4af4f", size = 28355673, upload-time = "2026-06-07T05:52:51.614Z" }, + { url = "https://files.pythonhosted.org/packages/a8/41/562a61d5a61fba3ffb273a115e249f1d8471b9515c59fcc38b4b9deda238/av-17.1.0-cp314-cp314t-win_arm64.whl", hash = "sha256:b41647e42884bf543b8e8d0a1dabd4d1b006c99183eb1a2d7afc5b01f73eeff4", size = 21324700, upload-time = "2026-06-07T05:52:53.972Z" }, +] + [[package]] name = "babel" version = "2.18.0" @@ -1246,6 +1274,7 @@ name = "kempnerforge" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "av" }, { name = "datasets" }, { name = "tensorboard" }, { name = "tokenizers" }, @@ -1280,6 +1309,7 @@ docs = [ [package.metadata] requires-dist = [ + { name = "av", specifier = ">=17.1.0" }, { name = "datasets" }, { name = "tensorboard" }, { name = "tokenizers" },