@@ -171,26 +171,31 @@ with torch.no_grad():
171171 # outputs.last_hidden_state: [B, num_patches, hidden_size]
172172 # outputs.pooler_output: [B, hidden_size]
173173
174- # Video inference: [B, C, T, H, W] with visible_indices
174+ # Video inference: [B, C, T, H, W] with patch_positions
175+ import math
175176num_frames, frame_tokens, target_frames = 16 , 256 , 64
177+ patches_per_side = int (math.sqrt(frame_tokens)) # 16 for 256 tokens
176178# Load video frames and preprocess each frame (replace with your video frame paths)
177179frames = [Image.open(f " path/to/frame_ { i} .jpg " ) for i in range (num_frames)]
178180video_pixel_values = preprocessor(images = frames, return_tensors = " pt" )[" pixel_values" ]
179181# Reshape from [T, C, H, W] to [B, C, T, H, W]
180182video = video_pixel_values.unsqueeze(0 ).permute(0 , 2 , 1 , 3 , 4 ).to(" cuda" )
181183
182- # Build visible_indices for temporal sampling
183- frame_pos = torch.linspace(0 , target_frames - 1 , num_frames).long().cuda()
184- visible_indices = (frame_pos.unsqueeze(- 1 ) * frame_tokens + torch.arange(frame_tokens).cuda()).reshape(1 , - 1 )
185- # visible_indices example (with 256 tokens per frame):
186- # Frame 0 (pos=0): indices [0, 1, 2, ..., 255]
187- # Frame 1 (pos=4): indices [1024, 1025, 1026, ..., 1279]
188- # Frame 2 (pos=8): indices [2048, 2049, 2050, ..., 2303]
189- # ...
190- # Frame 15 (pos=63): indices [16128, 16129, ..., 16383]
184+ # Build patch_positions for temporal sampling: [B, num_frames * frame_tokens, 3]
185+ # Each position is (t, h, w) where t is temporal index, h/w are spatial patch coordinates
186+ frame_pos = torch.linspace(0 , target_frames - 1 , num_frames).long().cuda() # [num_frames]
187+ per = torch.arange(frame_tokens).cuda() # [frame_tokens]
188+
189+ # Temporal positions: frame index for each patch
190+ t_positions = frame_pos.unsqueeze(- 1 ).expand(- 1 , frame_tokens).reshape(1 , - 1 ) # [1, num_frames * frame_tokens]
191+ # Spatial positions: h and w within each frame's patch grid
192+ h_positions = (per // patches_per_side).unsqueeze(0 ).expand(num_frames, - 1 ).reshape(1 , - 1 )
193+ w_positions = (per % patches_per_side).unsqueeze(0 ).expand(num_frames, - 1 ).reshape(1 , - 1 )
194+ # Stack to create patch_positions: [B, num_frames * frame_tokens, 3]
195+ patch_positions = torch.stack([t_positions, h_positions, w_positions], dim = - 1 )
191196
192197with torch.no_grad():
193- outputs = model(video, visible_indices = visible_indices )
198+ outputs = model(video, patch_positions = patch_positions )
194199```
195200
196201### Loading from Source Code
@@ -217,49 +222,7 @@ preprocessor = AutoImageProcessor.from_pretrained(
217222
218223### Codec Input
219224
220- For codec-style temporal saliency-based patch selection, use ` patch_positions ` to specify the 3D coordinates (temporal, height, width) of selected patches:
221-
222- ``` python
223- import torch
224- import math
225-
226- # Assume we have selected patches from multiple frames
227- num_frames = 16 # Number of sampled frames
228- frame_tokens = 256 # Tokens per frame (e.g., 16x16 patches)
229- target_frames = 64 # Target temporal dimension for RoPE
230- patches_per_side = int (math.sqrt(frame_tokens)) # e.g., 16
231-
232- device = " cuda" # or "cpu"
233-
234- # Example: frame_indices indicates which frames were sampled (shape: [B, num_frames])
235- frame_indices = torch.linspace(0 , target_frames - 1 , num_frames).long().to(device).unsqueeze(0 ) # [1, 16]
236-
237- # Build patch_positions: [B, num_frames * frame_tokens, 3] where each position is (t, h, w)
238- bs = 1
239- per = torch.arange(frame_tokens, device = device)
240-
241- # Temporal positions: frame index for each patch
242- t_positions = frame_indices.unsqueeze(- 1 ).expand(- 1 , - 1 , frame_tokens).reshape(bs, - 1 )
243-
244- # Spatial positions: h and w within each frame's patch grid
245- h_positions = (per // patches_per_side).unsqueeze(0 ).unsqueeze(0 ).expand(bs, num_frames, - 1 ).reshape(bs, - 1 )
246- w_positions = (per % patches_per_side).unsqueeze(0 ).unsqueeze(0 ).expand(bs, num_frames, - 1 ).reshape(bs, - 1 )
247-
248- # Stack to create patch_positions: [B, num_frames * frame_tokens, 3]
249- patch_positions = torch.stack([t_positions, h_positions, w_positions], dim = - 1 )
250-
251- # Video inference with patch_positions
252- with torch.no_grad():
253- outputs = model(video, patch_positions = patch_positions)
254- # outputs.last_hidden_state: [B, num_patches, hidden_size]
255- ```
256-
257- The ` patch_positions ` tensor has shape ` [batch_size, num_patches, 3] ` where each position contains:
258- - ` t ` : Temporal frame index (0 to target_frames-1)
259- - ` h ` : Height position in the patch grid (0 to patches_per_side-1)
260- - ` w ` : Width position in the patch grid (0 to patches_per_side-1)
261-
262- This enables flexible sparse patch selection for codec-style video understanding.
225+ Add codec-style input documentation for temporal saliency-based patch selection.
263226
264227---
265228
0 commit comments