Skip to content

Commit fd80a71

Browse files
Copilotanxiangsir
andcommitted
docs: Add detailed documentation for ViT unified input processing (image, video chunk sampling, codec-style)
Co-authored-by: anxiangsir <31175974+anxiangsir@users.noreply.github.com>
1 parent 26e8bfb commit fd80a71

1 file changed

Lines changed: 177 additions & 0 deletions

File tree

docs/model_card.md

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,183 @@
3131
- **Native Resolution Support**: Supports native resolution input without tiling or cropping.
3232
- **Flash Attention 2**: Efficient attention implementation for improved performance and memory efficiency.
3333

34+
## Unified Input Processing
35+
36+
OneVision Encoder uses a unified architecture to process three types of visual inputs—images, video chunks (uniform frame sampling), and codec-style sparse patches—through the same Vision Transformer backbone. The key insight is that all inputs are converted to a sequence of patch tokens with 3D position encodings, enabling a single model to handle diverse visual modalities.
37+
38+
### Image Processing
39+
40+
For single image input, the ViT processes data in the standard 4D tensor format `[B, C, H, W]`:
41+
42+
```
43+
Input: [B, C, H, W] → e.g., [1, 3, 448, 448]
44+
45+
Patch Embedding (Conv2d with kernel=16, stride=16)
46+
47+
Flatten: [B, num_patches, hidden_size]
48+
e.g., [1, 784, 1024] for 448×448 image
49+
50+
3D RoPE Position Encoding (T=1, H=28, W=28)
51+
52+
Transformer Encoder (24 layers)
53+
54+
Output: [B, num_patches, hidden_size]
55+
```
56+
57+
**Key points:**
58+
- Images are internally treated as single-frame videos with `T=1`
59+
- Position encoding uses the same 3D RoPE with temporal dimension fixed at 1
60+
- All patches are processed (no masking), resulting in `(H/patch_size) × (W/patch_size)` tokens
61+
62+
### Video Chunk Sampling
63+
64+
For video input with uniform frame sampling, the ViT processes 5D tensor format `[B, C, T, H, W]`:
65+
66+
```
67+
Input: [B, C, T, H, W] → e.g., [1, 3, 16, 224, 224]
68+
69+
Patch Embedding (per-frame Conv2d)
70+
71+
Flatten: [B, T × H_patches × W_patches, hidden_size]
72+
e.g., [1, 16 × 14 × 14, 1024] = [1, 3136, 1024]
73+
74+
Build visible_indices for temporal mapping
75+
76+
3D RoPE Position Encoding with frame positions
77+
78+
Transformer Encoder (24 layers)
79+
80+
Output: [B, num_visible_patches, hidden_size]
81+
```
82+
83+
**The `visible_indices` mechanism:**
84+
85+
The `visible_indices` tensor maps actual frame positions to a virtual temporal grid (e.g., 64 virtual frames), enabling proper temporal position encoding even with sparse frame sampling:
86+
87+
```python
88+
# Example: 16 frames sampled from a video, mapped to 64 virtual frame positions
89+
num_frames = 16 # Actual number of sampled frames
90+
frame_tokens = 256 # Patches per frame (16×16 for 256×256 with patch_size=16)
91+
target_frames = 64 # Virtual temporal grid size (model's RoPE temporal dimension)
92+
93+
# Map 16 actual frames to positions in the 64-frame virtual grid
94+
frame_pos = torch.linspace(0, target_frames - 1, num_frames).long()
95+
# frame_pos = [0, 4, 8, 12, 17, 21, 25, 29, 34, 38, 42, 46, 51, 55, 59, 63]
96+
97+
# Build visible_indices: each frame's patches get position encoding based on frame_pos
98+
visible_indices = (frame_pos.unsqueeze(-1) * frame_tokens +
99+
torch.arange(frame_tokens)).reshape(1, -1)
100+
# Shape: [1, 4096] (16 frames × 256 patches)
101+
```
102+
103+
This enables the model to understand temporal relationships even when frames are not densely sampled.
104+
105+
### Codec-Style Input
106+
107+
Codec-style input is the most sophisticated processing mode, inspired by HEVC video compression. Instead of processing all patches from all frames, it selectively processes only temporally-salient patches identified through motion and residual analysis.
108+
109+
```
110+
Input Video: 64 frames
111+
112+
┌───────────────────────────────────────────────┐
113+
│ HEVC Feature Extraction │
114+
│ ├── Motion Vectors (MV): quarter-pel motion │
115+
│ └── Residuals: prediction error signals │
116+
└───────────────────────────────────────────────┘
117+
118+
┌───────────────────────────────────────────────┐
119+
│ Temporal Saliency Detection │
120+
│ ├── MV Energy: camera-compensated motion mag │
121+
│ ├── Residual Energy: prediction error mag │
122+
│ └── Fused Energy: weighted combination │
123+
└───────────────────────────────────────────────┘
124+
125+
┌───────────────────────────────────────────────┐
126+
│ Top-K Patch Selection │
127+
│ ├── Score each patch by fused energy │
128+
│ ├── Select K most salient patches │
129+
│ └── Build sparse visible_indices │
130+
└───────────────────────────────────────────────┘
131+
132+
┌───────────────────────────────────────────────┐
133+
│ ViT Processing with Sparse visible_indices │
134+
│ ├── Input: [B, C, T, H, W] full video │
135+
│ ├── visible_indices: [B, K] selected patches │
136+
│ └── Output: [B, K, hidden_size] │
137+
└───────────────────────────────────────────────┘
138+
```
139+
140+
**Detailed Codec Processing Pipeline:**
141+
142+
1. **Motion Vector Analysis**: Extract motion vectors from HEVC codec at quarter-pixel precision. Apply camera motion compensation (median, similarity, or affine model) to isolate object motion from camera movement.
143+
144+
2. **Residual Analysis**: Extract prediction residuals that capture texture changes and fine-grained motion not captured by block-based motion compensation.
145+
146+
3. **Energy Fusion**: Combine MV energy and residual energy with configurable weights to produce a unified saliency map.
147+
148+
4. **Top-K Selection**: Rank all patches (across all frames) by their saliency scores and select the top K patches. This achieves 75-98% compression while retaining critical temporal dynamics.
149+
150+
5. **Sparse Processing**: The selected patches are processed by the ViT with proper 3D position encodings, enabling the model to understand the spatiotemporal context of each selected patch.
151+
152+
**Example codec-style inference:**
153+
154+
```python
155+
# Codec-style: select 2048 most salient patches from 64 frames
156+
# (equivalent to 8 full frames worth of tokens)
157+
K_keep = 2048 # 256 patches/frame × 8 frames equivalent
158+
159+
# visible_indices are computed by the codec saliency detection
160+
# Each index points to a specific (frame, h, w) position in the patch grid
161+
visible_indices = compute_codec_visible_indices(
162+
video_path,
163+
K=K_keep,
164+
mv_compensate="similarity", # Camera motion compensation
165+
patch_size=16
166+
)
167+
168+
# Process with the model
169+
outputs = model(video, visible_indices=visible_indices)
170+
# Output: [B, 2048, 1024] - features for 2048 selected patches
171+
```
172+
173+
### Comparison of Input Modes
174+
175+
| Mode | Input Shape | visible_indices | Output Shape | Use Case |
176+
|------|-------------|-----------------|--------------|----------|
177+
| **Image** | `[B, 3, H, W]` | All patches | `[B, (H/16)×(W/16), 1024]` | Single image understanding |
178+
| **Video Chunk** | `[B, 3, T, H, W]` | Frame-mapped | `[B, T×(H/16)×(W/16), 1024]` | Uniform temporal sampling |
179+
| **Codec-Style** | `[B, 3, T, H, W]` | Top-K salient | `[B, K, 1024]` | Efficient dense temporal |
180+
181+
### 3D RoPE Position Encoding
182+
183+
All three input modes share the same 3D Rotary Position Embedding (RoPE) with a 4:6:6 split:
184+
185+
- **Temporal (T)**: 4/16 of head dimension → captures frame ordering
186+
- **Height (H)**: 6/16 of head dimension → captures vertical position
187+
- **Width (W)**: 6/16 of head dimension → captures horizontal position
188+
189+
```python
190+
# 3D position encoding construction
191+
head_dim = hidden_size // num_heads # 1024 // 16 = 64
192+
half = head_dim // 2 # 32
193+
194+
# Split dimensions with 4:6:6 ratio (4+6+6 = 16 units total)
195+
unit = half // 16 # 32 // 16 = 2
196+
t_size = 4 * unit # 4 * 2 = 8 dims for temporal
197+
h_size = 6 * unit # 6 * 2 = 12 dims for height
198+
w_size = 6 * unit # 6 * 2 = 12 dims for width
199+
# Total: 8 + 12 + 12 = 32 = half of head_dim
200+
201+
# Compute frequencies for each dimension
202+
freqs = concat([
203+
freq_temporal[t_ids], # Based on frame index
204+
freq_height[h_ids], # Based on patch row
205+
freq_width[w_ids] # Based on patch column
206+
])
207+
```
208+
209+
This unified position encoding allows the model to maintain consistent spatial and temporal understanding across all input modalities.
210+
34211
## Intended Use
35212

36213
### Primary Use Cases

0 commit comments

Comments
 (0)