|
31 | 31 | - **Native Resolution Support**: Supports native resolution input without tiling or cropping. |
32 | 32 | - **Flash Attention 2**: Efficient attention implementation for improved performance and memory efficiency. |
33 | 33 |
|
| 34 | +## Unified Input Processing |
| 35 | + |
| 36 | +OneVision Encoder uses a unified architecture to process three types of visual inputs—images, video chunks (uniform frame sampling), and codec-style sparse patches—through the same Vision Transformer backbone. The key insight is that all inputs are converted to a sequence of patch tokens with 3D position encodings, enabling a single model to handle diverse visual modalities. |
| 37 | + |
| 38 | +### Image Processing |
| 39 | + |
| 40 | +For single image input, the ViT processes data in the standard 4D tensor format `[B, C, H, W]`: |
| 41 | + |
| 42 | +``` |
| 43 | +Input: [B, C, H, W] → e.g., [1, 3, 448, 448] |
| 44 | + ↓ |
| 45 | + Patch Embedding (Conv2d with kernel=16, stride=16) |
| 46 | + ↓ |
| 47 | + Flatten: [B, num_patches, hidden_size] |
| 48 | + e.g., [1, 784, 1024] for 448×448 image |
| 49 | + ↓ |
| 50 | + 3D RoPE Position Encoding (T=1, H=28, W=28) |
| 51 | + ↓ |
| 52 | + Transformer Encoder (24 layers) |
| 53 | + ↓ |
| 54 | + Output: [B, num_patches, hidden_size] |
| 55 | +``` |
| 56 | + |
| 57 | +**Key points:** |
| 58 | +- Images are internally treated as single-frame videos with `T=1` |
| 59 | +- Position encoding uses the same 3D RoPE with temporal dimension fixed at 1 |
| 60 | +- All patches are processed (no masking), resulting in `(H/patch_size) × (W/patch_size)` tokens |
| 61 | + |
| 62 | +### Video Chunk Sampling |
| 63 | + |
| 64 | +For video input with uniform frame sampling, the ViT processes 5D tensor format `[B, C, T, H, W]`: |
| 65 | + |
| 66 | +``` |
| 67 | +Input: [B, C, T, H, W] → e.g., [1, 3, 16, 224, 224] |
| 68 | + ↓ |
| 69 | + Patch Embedding (per-frame Conv2d) |
| 70 | + ↓ |
| 71 | + Flatten: [B, T × H_patches × W_patches, hidden_size] |
| 72 | + e.g., [1, 16 × 14 × 14, 1024] = [1, 3136, 1024] |
| 73 | + ↓ |
| 74 | + Build visible_indices for temporal mapping |
| 75 | + ↓ |
| 76 | + 3D RoPE Position Encoding with frame positions |
| 77 | + ↓ |
| 78 | + Transformer Encoder (24 layers) |
| 79 | + ↓ |
| 80 | + Output: [B, num_visible_patches, hidden_size] |
| 81 | +``` |
| 82 | + |
| 83 | +**The `visible_indices` mechanism:** |
| 84 | + |
| 85 | +The `visible_indices` tensor maps actual frame positions to a virtual temporal grid (e.g., 64 virtual frames), enabling proper temporal position encoding even with sparse frame sampling: |
| 86 | + |
| 87 | +```python |
| 88 | +# Example: 16 frames sampled from a video, mapped to 64 virtual frame positions |
| 89 | +num_frames = 16 # Actual number of sampled frames |
| 90 | +frame_tokens = 256 # Patches per frame (16×16 for 256×256 with patch_size=16) |
| 91 | +target_frames = 64 # Virtual temporal grid size (model's RoPE temporal dimension) |
| 92 | + |
| 93 | +# Map 16 actual frames to positions in the 64-frame virtual grid |
| 94 | +frame_pos = torch.linspace(0, target_frames - 1, num_frames).long() |
| 95 | +# frame_pos = [0, 4, 8, 12, 17, 21, 25, 29, 34, 38, 42, 46, 51, 55, 59, 63] |
| 96 | + |
| 97 | +# Build visible_indices: each frame's patches get position encoding based on frame_pos |
| 98 | +visible_indices = (frame_pos.unsqueeze(-1) * frame_tokens + |
| 99 | + torch.arange(frame_tokens)).reshape(1, -1) |
| 100 | +# Shape: [1, 4096] (16 frames × 256 patches) |
| 101 | +``` |
| 102 | + |
| 103 | +This enables the model to understand temporal relationships even when frames are not densely sampled. |
| 104 | + |
| 105 | +### Codec-Style Input |
| 106 | + |
| 107 | +Codec-style input is the most sophisticated processing mode, inspired by HEVC video compression. Instead of processing all patches from all frames, it selectively processes only temporally-salient patches identified through motion and residual analysis. |
| 108 | + |
| 109 | +``` |
| 110 | +Input Video: 64 frames |
| 111 | + ↓ |
| 112 | +┌───────────────────────────────────────────────┐ |
| 113 | +│ HEVC Feature Extraction │ |
| 114 | +│ ├── Motion Vectors (MV): quarter-pel motion │ |
| 115 | +│ └── Residuals: prediction error signals │ |
| 116 | +└───────────────────────────────────────────────┘ |
| 117 | + ↓ |
| 118 | +┌───────────────────────────────────────────────┐ |
| 119 | +│ Temporal Saliency Detection │ |
| 120 | +│ ├── MV Energy: camera-compensated motion mag │ |
| 121 | +│ ├── Residual Energy: prediction error mag │ |
| 122 | +│ └── Fused Energy: weighted combination │ |
| 123 | +└───────────────────────────────────────────────┘ |
| 124 | + ↓ |
| 125 | +┌───────────────────────────────────────────────┐ |
| 126 | +│ Top-K Patch Selection │ |
| 127 | +│ ├── Score each patch by fused energy │ |
| 128 | +│ ├── Select K most salient patches │ |
| 129 | +│ └── Build sparse visible_indices │ |
| 130 | +└───────────────────────────────────────────────┘ |
| 131 | + ↓ |
| 132 | +┌───────────────────────────────────────────────┐ |
| 133 | +│ ViT Processing with Sparse visible_indices │ |
| 134 | +│ ├── Input: [B, C, T, H, W] full video │ |
| 135 | +│ ├── visible_indices: [B, K] selected patches │ |
| 136 | +│ └── Output: [B, K, hidden_size] │ |
| 137 | +└───────────────────────────────────────────────┘ |
| 138 | +``` |
| 139 | + |
| 140 | +**Detailed Codec Processing Pipeline:** |
| 141 | + |
| 142 | +1. **Motion Vector Analysis**: Extract motion vectors from HEVC codec at quarter-pixel precision. Apply camera motion compensation (median, similarity, or affine model) to isolate object motion from camera movement. |
| 143 | + |
| 144 | +2. **Residual Analysis**: Extract prediction residuals that capture texture changes and fine-grained motion not captured by block-based motion compensation. |
| 145 | + |
| 146 | +3. **Energy Fusion**: Combine MV energy and residual energy with configurable weights to produce a unified saliency map. |
| 147 | + |
| 148 | +4. **Top-K Selection**: Rank all patches (across all frames) by their saliency scores and select the top K patches. This achieves 75-98% compression while retaining critical temporal dynamics. |
| 149 | + |
| 150 | +5. **Sparse Processing**: The selected patches are processed by the ViT with proper 3D position encodings, enabling the model to understand the spatiotemporal context of each selected patch. |
| 151 | + |
| 152 | +**Example codec-style inference:** |
| 153 | + |
| 154 | +```python |
| 155 | +# Codec-style: select 2048 most salient patches from 64 frames |
| 156 | +# (equivalent to 8 full frames worth of tokens) |
| 157 | +K_keep = 2048 # 256 patches/frame × 8 frames equivalent |
| 158 | + |
| 159 | +# visible_indices are computed by the codec saliency detection |
| 160 | +# Each index points to a specific (frame, h, w) position in the patch grid |
| 161 | +visible_indices = compute_codec_visible_indices( |
| 162 | + video_path, |
| 163 | + K=K_keep, |
| 164 | + mv_compensate="similarity", # Camera motion compensation |
| 165 | + patch_size=16 |
| 166 | +) |
| 167 | + |
| 168 | +# Process with the model |
| 169 | +outputs = model(video, visible_indices=visible_indices) |
| 170 | +# Output: [B, 2048, 1024] - features for 2048 selected patches |
| 171 | +``` |
| 172 | + |
| 173 | +### Comparison of Input Modes |
| 174 | + |
| 175 | +| Mode | Input Shape | visible_indices | Output Shape | Use Case | |
| 176 | +|------|-------------|-----------------|--------------|----------| |
| 177 | +| **Image** | `[B, 3, H, W]` | All patches | `[B, (H/16)×(W/16), 1024]` | Single image understanding | |
| 178 | +| **Video Chunk** | `[B, 3, T, H, W]` | Frame-mapped | `[B, T×(H/16)×(W/16), 1024]` | Uniform temporal sampling | |
| 179 | +| **Codec-Style** | `[B, 3, T, H, W]` | Top-K salient | `[B, K, 1024]` | Efficient dense temporal | |
| 180 | + |
| 181 | +### 3D RoPE Position Encoding |
| 182 | + |
| 183 | +All three input modes share the same 3D Rotary Position Embedding (RoPE) with a 4:6:6 split: |
| 184 | + |
| 185 | +- **Temporal (T)**: 4/16 of head dimension → captures frame ordering |
| 186 | +- **Height (H)**: 6/16 of head dimension → captures vertical position |
| 187 | +- **Width (W)**: 6/16 of head dimension → captures horizontal position |
| 188 | + |
| 189 | +```python |
| 190 | +# 3D position encoding construction |
| 191 | +head_dim = hidden_size // num_heads # 1024 // 16 = 64 |
| 192 | +half = head_dim // 2 # 32 |
| 193 | + |
| 194 | +# Split dimensions with 4:6:6 ratio (4+6+6 = 16 units total) |
| 195 | +unit = half // 16 # 32 // 16 = 2 |
| 196 | +t_size = 4 * unit # 4 * 2 = 8 dims for temporal |
| 197 | +h_size = 6 * unit # 6 * 2 = 12 dims for height |
| 198 | +w_size = 6 * unit # 6 * 2 = 12 dims for width |
| 199 | +# Total: 8 + 12 + 12 = 32 = half of head_dim |
| 200 | + |
| 201 | +# Compute frequencies for each dimension |
| 202 | +freqs = concat([ |
| 203 | + freq_temporal[t_ids], # Based on frame index |
| 204 | + freq_height[h_ids], # Based on patch row |
| 205 | + freq_width[w_ids] # Based on patch column |
| 206 | +]) |
| 207 | +``` |
| 208 | + |
| 209 | +This unified position encoding allows the model to maintain consistent spatial and temporal understanding across all input modalities. |
| 210 | + |
34 | 211 | ## Intended Use |
35 | 212 |
|
36 | 213 | ### Primary Use Cases |
|
0 commit comments