Skip to content

Commit 757485d

Browse files
Copilotanxiangsir
andcommitted
Add Multi-Modal Training Strategy section to model_card.md
Co-authored-by: anxiangsir <31175974+anxiangsir@users.noreply.github.com>
1 parent 7d6adbe commit 757485d

1 file changed

Lines changed: 172 additions & 0 deletions

File tree

docs/model_card.md

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,3 +294,175 @@ Training on a mixed dataset of 740K samples from LLaVA-OneVision and 800K sample
294294
## Contact
295295

296296
For questions and issues, please open an issue on the [GitHub repository](https://github.com/Evolvinglmms-lab/OneVision-Encoder).
297+
298+
## Multi-Modal Training Strategy
299+
300+
OneVision Encoder uses a unified training approach that simultaneously processes images, video codec-style patches, video frame sampling, and video collage within the same batch. This multi-modal training enables the model to learn robust representations across different input modalities.
301+
302+
### Training Batch Composition
303+
304+
Within each training batch, samples are divided into different processing modes:
305+
306+
```
307+
Training Batch (bs=16)
308+
┌─────────────────────────────────────────────────────────────────────┐
309+
│ │
310+
│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │
311+
│ │ Image Head │ │ Video Head │ │ OCR Head │ │
312+
│ │ (origin) │ │ (decord_residual│ │ (ocr) │ │
313+
│ │ │ │ │ │ │ │
314+
│ │ [B, 3, H, W] │ │ Split by mode: │ │ [B, 3, H, W] │ │
315+
│ │ │ │ • Codec 50% │ │ │ │
316+
│ │ │ │ • Sampling 37.5│ │ │ │
317+
│ │ │ │ • Collage 12.5%│ │ │ │
318+
│ └─────────────────┘ └─────────────────┘ └─────────────────┘ │
319+
│ │
320+
└─────────────────────────────────────────────────────────────────────┘
321+
```
322+
323+
### Video Processing Modes
324+
325+
For video inputs, the batch is further split into three processing modes:
326+
327+
| Mode | Batch % | Description | Input → Output |
328+
|------|---------|-------------|----------------|
329+
| **Codec-Style** | 50% | Select top-K salient patches based on HEVC residual | `[n, 3, 64, 224, 224]``[n, 3, 8, 224, 224]` |
330+
| **Frame Sampling** | 37.5% | Uniform temporal sampling, 1 frame per bin | `[n, 3, 64, 224, 224]``[n, 3, 8, 224, 224]` |
331+
| **Collage** | 12.5% | 8 frames concatenated into tall image | `[n, 3, 64, 224, 224]``[n, 3, 1792, 224]` |
332+
333+
### Processing Pipeline
334+
335+
```
336+
Video Input: [bs, 3, 64, 224, 224]
337+
338+
┌──────────────────┼──────────────────┐
339+
│ │ │
340+
▼ ▼ ▼
341+
┌───────────────┐ ┌───────────────┐ ┌───────────────┐
342+
│ Codec-Style │ │Frame Sampling │ │ Collage │
343+
│ (50%) │ │ (37.5%) │ │ (12.5%) │
344+
└───────┬───────┘ └───────┬───────┘ └───────┬───────┘
345+
│ │ │
346+
▼ ▼ ▼
347+
┌───────────────┐ ┌───────────────┐ ┌───────────────┐
348+
│ Patchify │ │ Sample frames │ │ Sample frames │
349+
│ [n,3,16384,p²]│ │ from 8 bins │ │ from 8 bins │
350+
└───────┬───────┘ └───────┬───────┘ └───────┬───────┘
351+
│ │ │
352+
▼ ▼ ▼
353+
┌───────────────┐ ┌───────────────┐ ┌───────────────┐
354+
│ Select top-K │ │ Build indices │ │ Concat frames │
355+
│ by vis_idx │ │ for 8 frames │ │ vertically │
356+
└───────┬───────┘ └───────┬───────┘ └───────┬───────┘
357+
│ │ │
358+
▼ ▼ ▼
359+
┌───────────────┐ ┌───────────────┐ ┌───────────────┐
360+
│ Unpatchify │ │ │ │ │
361+
│[n,3,8,224,224]│ │[n,3,8,224,224]│ │[n,3,1792,224] │
362+
└───────┬───────┘ └───────┬───────┘ └───────┬───────┘
363+
│ │ │
364+
└──────────────────┼──────────────────┘
365+
366+
367+
┌───────────────┐
368+
│ ViT Backbone │
369+
│ with RoPE │
370+
└───────┬───────┘
371+
372+
373+
[bs, hidden_size]
374+
```
375+
376+
### 1. Codec-Style Processing (50% of batch)
377+
378+
This mode uses HEVC-extracted saliency information to select the most informative patches:
379+
380+
```python
381+
# Example: bs=16, first 8 samples use codec-style
382+
# visible_indices contains pre-computed salient patch indices from HEVC analysis
383+
384+
# Step 1: Use pre-computed visible_indices (sorted by saliency)
385+
out[mask_residual] = visible_indices[mask_residual, :target_num] # [8, 2048]
386+
387+
# Step 2: Patchify full video
388+
# [8, 3, 64, 224, 224] → [8, 3, 16384, 14, 14] (64 frames × 256 patches/frame)
389+
patches = video.view(n, C, T, Hp, patch_size, Wp, patch_size)
390+
.permute(0, 1, 2, 3, 5, 4, 6)
391+
.reshape(n, C, T * Hp * Wp, patch_size, patch_size)
392+
393+
# Step 3: Select top-K patches by visible_indices
394+
selected = torch.gather(patches, 2, idx) # [8, 3, 2048, 14, 14]
395+
396+
# Step 4: Unpatchify back to video format
397+
# 2048 patches = 8 frames × 256 patches/frame
398+
combined_head_input = selected.view(n, C, 8, Hp, Wp, patch_size, patch_size)
399+
.permute(0, 1, 2, 3, 5, 4, 6)
400+
.reshape(n, C, 8, H, W) # [8, 3, 8, 224, 224]
401+
```
402+
403+
### 2. Frame Sampling Processing (37.5% of batch)
404+
405+
This mode uniformly samples frames from temporal bins:
406+
407+
```python
408+
# Example: samples 8-13 in batch use frame sampling
409+
# Divide 64 frames into 8 bins of 8 frames each, sample 1 from each bin
410+
411+
# Step 1: Sample frame indices
412+
# bins: [0-7], [8-15], [16-23], [24-31], [32-39], [40-47], [48-55], [56-63]
413+
frames = torch.arange(8) * 8 + torch.randint(8, (nB, 8)) # [6, 8]
414+
415+
# Step 2: Build patch indices for all patches in selected frames
416+
# Each frame has 256 patches
417+
out[mask_frame_sampling] = (frames.unsqueeze(-1) * 256 +
418+
torch.arange(256)).reshape(nB, -1) # [6, 2048]
419+
420+
# Step 3: Same patchify → select → unpatchify as codec-style
421+
# Result: [6, 3, 8, 224, 224]
422+
```
423+
424+
### 3. Collage Processing (12.5% of batch)
425+
426+
This mode concatenates sampled frames into a single tall image:
427+
428+
```python
429+
# Example: samples 14-15 in batch use collage
430+
# Sample 8 frames and concatenate vertically
431+
432+
# Step 1: Sample 8 frames (same bin-based sampling)
433+
frames_idx = base + offsets # [2, 8], values in [0, 63]
434+
435+
# Step 2: Gather selected frames
436+
sel_frames = torch.gather(video, 2, idx_expand) # [2, 3, 8, 224, 224]
437+
438+
# Step 3: Concatenate frames vertically
439+
sel_frames = sel_frames.permute(0, 2, 1, 3, 4) # [2, 8, 3, 224, 224]
440+
grid = torch.cat([sel_frames[:, i] for i in range(8)], dim=-2) # [2, 3, 1792, 224]
441+
442+
# Result: Processed as a tall image (1792 = 224 × 8)
443+
```
444+
445+
### Benefits of Multi-Modal Training
446+
447+
1. **Unified Architecture**: Same ViT backbone handles all modalities through different input preprocessing
448+
2. **Complementary Learning**:
449+
- Codec-style: Learns to focus on temporally salient regions
450+
- Frame sampling: Learns uniform temporal understanding
451+
- Collage: Learns spatial arrangement of temporal information
452+
3. **Robust Representations**: Exposure to diverse input formats improves generalization
453+
4. **Efficient Training**: Single forward pass processes all modalities together
454+
455+
### Position Encoding Consistency
456+
457+
All video modes use the same 3D RoPE position encoding:
458+
459+
```python
460+
# visible_indices maps selected patches to positions in a 64-frame virtual grid
461+
# This enables consistent temporal position encoding across all modes
462+
463+
# Codec-style: patches scattered across 64 frames
464+
# Frame sampling: 8 complete frames with gaps
465+
# Collage: treated as single image (T=1)
466+
467+
# The model learns to handle all patterns through the unified RoPE mechanism
468+
```

0 commit comments

Comments
 (0)