Skip to content

Commit e20e178

Browse files
Copilotanxiangsir
andcommitted
Fix: Update video inference in Using AutoModel section to use patch_positions instead of visible_indices
Co-authored-by: anxiangsir <31175974+anxiangsir@users.noreply.github.com>
1 parent 3c1aaf5 commit e20e178

1 file changed

Lines changed: 17 additions & 54 deletions

File tree

README.md

Lines changed: 17 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -171,26 +171,31 @@ with torch.no_grad():
171171
# outputs.last_hidden_state: [B, num_patches, hidden_size]
172172
# outputs.pooler_output: [B, hidden_size]
173173

174-
# Video inference: [B, C, T, H, W] with visible_indices
174+
# Video inference: [B, C, T, H, W] with patch_positions
175+
import math
175176
num_frames, frame_tokens, target_frames = 16, 256, 64
177+
patches_per_side = int(math.sqrt(frame_tokens)) # 16 for 256 tokens
176178
# Load video frames and preprocess each frame (replace with your video frame paths)
177179
frames = [Image.open(f"path/to/frame_{i}.jpg") for i in range(num_frames)]
178180
video_pixel_values = preprocessor(images=frames, return_tensors="pt")["pixel_values"]
179181
# Reshape from [T, C, H, W] to [B, C, T, H, W]
180182
video = video_pixel_values.unsqueeze(0).permute(0, 2, 1, 3, 4).to("cuda")
181183

182-
# Build visible_indices for temporal sampling
183-
frame_pos = torch.linspace(0, target_frames - 1, num_frames).long().cuda()
184-
visible_indices = (frame_pos.unsqueeze(-1) * frame_tokens + torch.arange(frame_tokens).cuda()).reshape(1, -1)
185-
# visible_indices example (with 256 tokens per frame):
186-
# Frame 0 (pos=0): indices [0, 1, 2, ..., 255]
187-
# Frame 1 (pos=4): indices [1024, 1025, 1026, ..., 1279]
188-
# Frame 2 (pos=8): indices [2048, 2049, 2050, ..., 2303]
189-
# ...
190-
# Frame 15 (pos=63): indices [16128, 16129, ..., 16383]
184+
# Build patch_positions for temporal sampling: [B, num_frames * frame_tokens, 3]
185+
# Each position is (t, h, w) where t is temporal index, h/w are spatial patch coordinates
186+
frame_pos = torch.linspace(0, target_frames - 1, num_frames).long().cuda() # [num_frames]
187+
per = torch.arange(frame_tokens).cuda() # [frame_tokens]
188+
189+
# Temporal positions: frame index for each patch
190+
t_positions = frame_pos.unsqueeze(-1).expand(-1, frame_tokens).reshape(1, -1) # [1, num_frames * frame_tokens]
191+
# Spatial positions: h and w within each frame's patch grid
192+
h_positions = (per // patches_per_side).unsqueeze(0).expand(num_frames, -1).reshape(1, -1)
193+
w_positions = (per % patches_per_side).unsqueeze(0).expand(num_frames, -1).reshape(1, -1)
194+
# Stack to create patch_positions: [B, num_frames * frame_tokens, 3]
195+
patch_positions = torch.stack([t_positions, h_positions, w_positions], dim=-1)
191196

192197
with torch.no_grad():
193-
outputs = model(video, visible_indices=visible_indices)
198+
outputs = model(video, patch_positions=patch_positions)
194199
```
195200

196201
### Loading from Source Code
@@ -217,49 +222,7 @@ preprocessor = AutoImageProcessor.from_pretrained(
217222

218223
### Codec Input
219224

220-
For codec-style temporal saliency-based patch selection, use `patch_positions` to specify the 3D coordinates (temporal, height, width) of selected patches:
221-
222-
```python
223-
import torch
224-
import math
225-
226-
# Assume we have selected patches from multiple frames
227-
num_frames = 16 # Number of sampled frames
228-
frame_tokens = 256 # Tokens per frame (e.g., 16x16 patches)
229-
target_frames = 64 # Target temporal dimension for RoPE
230-
patches_per_side = int(math.sqrt(frame_tokens)) # e.g., 16
231-
232-
device = "cuda" # or "cpu"
233-
234-
# Example: frame_indices indicates which frames were sampled (shape: [B, num_frames])
235-
frame_indices = torch.linspace(0, target_frames - 1, num_frames).long().to(device).unsqueeze(0) # [1, 16]
236-
237-
# Build patch_positions: [B, num_frames * frame_tokens, 3] where each position is (t, h, w)
238-
bs = 1
239-
per = torch.arange(frame_tokens, device=device)
240-
241-
# Temporal positions: frame index for each patch
242-
t_positions = frame_indices.unsqueeze(-1).expand(-1, -1, frame_tokens).reshape(bs, -1)
243-
244-
# Spatial positions: h and w within each frame's patch grid
245-
h_positions = (per // patches_per_side).unsqueeze(0).unsqueeze(0).expand(bs, num_frames, -1).reshape(bs, -1)
246-
w_positions = (per % patches_per_side).unsqueeze(0).unsqueeze(0).expand(bs, num_frames, -1).reshape(bs, -1)
247-
248-
# Stack to create patch_positions: [B, num_frames * frame_tokens, 3]
249-
patch_positions = torch.stack([t_positions, h_positions, w_positions], dim=-1)
250-
251-
# Video inference with patch_positions
252-
with torch.no_grad():
253-
outputs = model(video, patch_positions=patch_positions)
254-
# outputs.last_hidden_state: [B, num_patches, hidden_size]
255-
```
256-
257-
The `patch_positions` tensor has shape `[batch_size, num_patches, 3]` where each position contains:
258-
- `t`: Temporal frame index (0 to target_frames-1)
259-
- `h`: Height position in the patch grid (0 to patches_per_side-1)
260-
- `w`: Width position in the patch grid (0 to patches_per_side-1)
261-
262-
This enables flexible sparse patch selection for codec-style video understanding.
225+
Add codec-style input documentation for temporal saliency-based patch selection.
263226

264227
---
265228

0 commit comments

Comments
 (0)