@@ -294,3 +294,175 @@ Training on a mixed dataset of 740K samples from LLaVA-OneVision and 800K sample
294294## Contact
295295
296296For questions and issues, please open an issue on the [ GitHub repository] ( https://github.com/Evolvinglmms-lab/OneVision-Encoder ) .
297+
298+ ## Multi-Modal Training Strategy
299+
300+ OneVision Encoder uses a unified training approach that simultaneously processes images, video codec-style patches, video frame sampling, and video collage within the same batch. This multi-modal training enables the model to learn robust representations across different input modalities.
301+
302+ ### Training Batch Composition
303+
304+ Within each training batch, samples are divided into different processing modes:
305+
306+ ```
307+ Training Batch (bs=16)
308+ ┌─────────────────────────────────────────────────────────────────────┐
309+ │ │
310+ │ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │
311+ │ │ Image Head │ │ Video Head │ │ OCR Head │ │
312+ │ │ (origin) │ │ (decord_residual│ │ (ocr) │ │
313+ │ │ │ │ │ │ │ │
314+ │ │ [B, 3, H, W] │ │ Split by mode: │ │ [B, 3, H, W] │ │
315+ │ │ │ │ • Codec 50% │ │ │ │
316+ │ │ │ │ • Sampling 37.5│ │ │ │
317+ │ │ │ │ • Collage 12.5%│ │ │ │
318+ │ └─────────────────┘ └─────────────────┘ └─────────────────┘ │
319+ │ │
320+ └─────────────────────────────────────────────────────────────────────┘
321+ ```
322+
323+ ### Video Processing Modes
324+
325+ For video inputs, the batch is further split into three processing modes:
326+
327+ | Mode | Batch % | Description | Input → Output |
328+ | ------| ---------| -------------| ----------------|
329+ | ** Codec-Style** | 50% | Select top-K salient patches based on HEVC residual | ` [n, 3, 64, 224, 224] ` → ` [n, 3, 8, 224, 224] ` |
330+ | ** Frame Sampling** | 37.5% | Uniform temporal sampling, 1 frame per bin | ` [n, 3, 64, 224, 224] ` → ` [n, 3, 8, 224, 224] ` |
331+ | ** Collage** | 12.5% | 8 frames concatenated into tall image | ` [n, 3, 64, 224, 224] ` → ` [n, 3, 1792, 224] ` |
332+
333+ ### Processing Pipeline
334+
335+ ```
336+ Video Input: [bs, 3, 64, 224, 224]
337+ │
338+ ┌──────────────────┼──────────────────┐
339+ │ │ │
340+ ▼ ▼ ▼
341+ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐
342+ │ Codec-Style │ │Frame Sampling │ │ Collage │
343+ │ (50%) │ │ (37.5%) │ │ (12.5%) │
344+ └───────┬───────┘ └───────┬───────┘ └───────┬───────┘
345+ │ │ │
346+ ▼ ▼ ▼
347+ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐
348+ │ Patchify │ │ Sample frames │ │ Sample frames │
349+ │ [n,3,16384,p²]│ │ from 8 bins │ │ from 8 bins │
350+ └───────┬───────┘ └───────┬───────┘ └───────┬───────┘
351+ │ │ │
352+ ▼ ▼ ▼
353+ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐
354+ │ Select top-K │ │ Build indices │ │ Concat frames │
355+ │ by vis_idx │ │ for 8 frames │ │ vertically │
356+ └───────┬───────┘ └───────┬───────┘ └───────┬───────┘
357+ │ │ │
358+ ▼ ▼ ▼
359+ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐
360+ │ Unpatchify │ │ │ │ │
361+ │[n,3,8,224,224]│ │[n,3,8,224,224]│ │[n,3,1792,224] │
362+ └───────┬───────┘ └───────┬───────┘ └───────┬───────┘
363+ │ │ │
364+ └──────────────────┼──────────────────┘
365+ │
366+ ▼
367+ ┌───────────────┐
368+ │ ViT Backbone │
369+ │ with RoPE │
370+ └───────┬───────┘
371+ │
372+ ▼
373+ [bs, hidden_size]
374+ ```
375+
376+ ### 1. Codec-Style Processing (50% of batch)
377+
378+ This mode uses HEVC-extracted saliency information to select the most informative patches:
379+
380+ ``` python
381+ # Example: bs=16, first 8 samples use codec-style
382+ # visible_indices contains pre-computed salient patch indices from HEVC analysis
383+
384+ # Step 1: Use pre-computed visible_indices (sorted by saliency)
385+ out[mask_residual] = visible_indices[mask_residual, :target_num] # [8, 2048]
386+
387+ # Step 2: Patchify full video
388+ # [8, 3, 64, 224, 224] → [8, 3, 16384, 14, 14] (64 frames × 256 patches/frame)
389+ patches = video.view(n, C, T, Hp, patch_size, Wp, patch_size)
390+ .permute(0 , 1 , 2 , 3 , 5 , 4 , 6 )
391+ .reshape(n, C, T * Hp * Wp, patch_size, patch_size)
392+
393+ # Step 3: Select top-K patches by visible_indices
394+ selected = torch.gather(patches, 2 , idx) # [8, 3, 2048, 14, 14]
395+
396+ # Step 4: Unpatchify back to video format
397+ # 2048 patches = 8 frames × 256 patches/frame
398+ combined_head_input = selected.view(n, C, 8 , Hp, Wp, patch_size, patch_size)
399+ .permute(0 , 1 , 2 , 3 , 5 , 4 , 6 )
400+ .reshape(n, C, 8 , H, W) # [8, 3, 8, 224, 224]
401+ ```
402+
403+ ### 2. Frame Sampling Processing (37.5% of batch)
404+
405+ This mode uniformly samples frames from temporal bins:
406+
407+ ``` python
408+ # Example: samples 8-13 in batch use frame sampling
409+ # Divide 64 frames into 8 bins of 8 frames each, sample 1 from each bin
410+
411+ # Step 1: Sample frame indices
412+ # bins: [0-7], [8-15], [16-23], [24-31], [32-39], [40-47], [48-55], [56-63]
413+ frames = torch.arange(8 ) * 8 + torch.randint(8 , (nB, 8 )) # [6, 8]
414+
415+ # Step 2: Build patch indices for all patches in selected frames
416+ # Each frame has 256 patches
417+ out[mask_frame_sampling] = (frames.unsqueeze(- 1 ) * 256 +
418+ torch.arange(256 )).reshape(nB, - 1 ) # [6, 2048]
419+
420+ # Step 3: Same patchify → select → unpatchify as codec-style
421+ # Result: [6, 3, 8, 224, 224]
422+ ```
423+
424+ ### 3. Collage Processing (12.5% of batch)
425+
426+ This mode concatenates sampled frames into a single tall image:
427+
428+ ``` python
429+ # Example: samples 14-15 in batch use collage
430+ # Sample 8 frames and concatenate vertically
431+
432+ # Step 1: Sample 8 frames (same bin-based sampling)
433+ frames_idx = base + offsets # [2, 8], values in [0, 63]
434+
435+ # Step 2: Gather selected frames
436+ sel_frames = torch.gather(video, 2 , idx_expand) # [2, 3, 8, 224, 224]
437+
438+ # Step 3: Concatenate frames vertically
439+ sel_frames = sel_frames.permute(0 , 2 , 1 , 3 , 4 ) # [2, 8, 3, 224, 224]
440+ grid = torch.cat([sel_frames[:, i] for i in range (8 )], dim = - 2 ) # [2, 3, 1792, 224]
441+
442+ # Result: Processed as a tall image (1792 = 224 × 8)
443+ ```
444+
445+ ### Benefits of Multi-Modal Training
446+
447+ 1 . ** Unified Architecture** : Same ViT backbone handles all modalities through different input preprocessing
448+ 2 . ** Complementary Learning** :
449+ - Codec-style: Learns to focus on temporally salient regions
450+ - Frame sampling: Learns uniform temporal understanding
451+ - Collage: Learns spatial arrangement of temporal information
452+ 3 . ** Robust Representations** : Exposure to diverse input formats improves generalization
453+ 4 . ** Efficient Training** : Single forward pass processes all modalities together
454+
455+ ### Position Encoding Consistency
456+
457+ All video modes use the same 3D RoPE position encoding:
458+
459+ ``` python
460+ # visible_indices maps selected patches to positions in a 64-frame virtual grid
461+ # This enables consistent temporal position encoding across all modes
462+
463+ # Codec-style: patches scattered across 64 frames
464+ # Frame sampling: 8 complete frames with gaps
465+ # Collage: treated as single image (T=1)
466+
467+ # The model learns to handle all patterns through the unified RoPE mechanism
468+ ```
0 commit comments