Skip to content

Commit 3c1aaf5

Browse files
Copilotanxiangsir
andcommitted
Update README codec input section with patch_positions approach
Co-authored-by: anxiangsir <31175974+anxiangsir@users.noreply.github.com>
1 parent 8159be1 commit 3c1aaf5

1 file changed

Lines changed: 43 additions & 1 deletion

File tree

README.md

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,49 @@ preprocessor = AutoImageProcessor.from_pretrained(
217217

218218
### Codec Input
219219

220-
Add codec-style input documentation for temporal saliency-based patch selection.
220+
For codec-style temporal saliency-based patch selection, use `patch_positions` to specify the 3D coordinates (temporal, height, width) of selected patches:
221+
222+
```python
223+
import torch
224+
import math
225+
226+
# Assume we have selected patches from multiple frames
227+
num_frames = 16 # Number of sampled frames
228+
frame_tokens = 256 # Tokens per frame (e.g., 16x16 patches)
229+
target_frames = 64 # Target temporal dimension for RoPE
230+
patches_per_side = int(math.sqrt(frame_tokens)) # e.g., 16
231+
232+
device = "cuda" # or "cpu"
233+
234+
# Example: frame_indices indicates which frames were sampled (shape: [B, num_frames])
235+
frame_indices = torch.linspace(0, target_frames - 1, num_frames).long().to(device).unsqueeze(0) # [1, 16]
236+
237+
# Build patch_positions: [B, num_frames * frame_tokens, 3] where each position is (t, h, w)
238+
bs = 1
239+
per = torch.arange(frame_tokens, device=device)
240+
241+
# Temporal positions: frame index for each patch
242+
t_positions = frame_indices.unsqueeze(-1).expand(-1, -1, frame_tokens).reshape(bs, -1)
243+
244+
# Spatial positions: h and w within each frame's patch grid
245+
h_positions = (per // patches_per_side).unsqueeze(0).unsqueeze(0).expand(bs, num_frames, -1).reshape(bs, -1)
246+
w_positions = (per % patches_per_side).unsqueeze(0).unsqueeze(0).expand(bs, num_frames, -1).reshape(bs, -1)
247+
248+
# Stack to create patch_positions: [B, num_frames * frame_tokens, 3]
249+
patch_positions = torch.stack([t_positions, h_positions, w_positions], dim=-1)
250+
251+
# Video inference with patch_positions
252+
with torch.no_grad():
253+
outputs = model(video, patch_positions=patch_positions)
254+
# outputs.last_hidden_state: [B, num_patches, hidden_size]
255+
```
256+
257+
The `patch_positions` tensor has shape `[batch_size, num_patches, 3]` where each position contains:
258+
- `t`: Temporal frame index (0 to target_frames-1)
259+
- `h`: Height position in the patch grid (0 to patches_per_side-1)
260+
- `w`: Width position in the patch grid (0 to patches_per_side-1)
261+
262+
This enables flexible sparse patch selection for codec-style video understanding.
221263

222264
---
223265

0 commit comments

Comments
 (0)