Skip to content

Commit 7c93886

Browse files
Copilotanxiangsir
andcommitted
Address code review feedback: add validation for perfect square frame_tokens
Co-authored-by: anxiangsir <31175974+anxiangsir@users.noreply.github.com>
1 parent 6c3f060 commit 7c93886

3 files changed

Lines changed: 9 additions & 3 deletions

File tree

eval_encoder/attentive_probe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ def video_to_images(videos: torch.Tensor) -> torch.Tensor:
204204
# ===> Compute patch_positions for RoPE with temporal scaling to [0, 64) <===
205205
# Calculate spatial grid dimensions (assume square patches)
206206
patches_per_side = int(math.sqrt(frame_tokens)) # e.g., 14 for 196 tokens
207+
assert patches_per_side * patches_per_side == frame_tokens, (
208+
f"frame_tokens must be a perfect square, got {frame_tokens}"
209+
)
207210
seq_len = frame_indices.shape[1] # Number of frames sampled
208211

209212
# Temporal positions: use interpolated_indices (already in [0, target_frames-1])

eval_encoder/attentive_probe_codec.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,9 @@ def video_to_images(videos: torch.Tensor) -> torch.Tensor:
244244
# ===> Compute patch_positions for RoPE with temporal scaling to [0, 64) <===
245245
# Calculate spatial grid dimensions (assume square patches)
246246
patches_per_side = int(math.sqrt(frame_tokens)) # e.g., 14 for 196 tokens
247+
assert patches_per_side * patches_per_side == frame_tokens, (
248+
f"frame_tokens must be a perfect square, got {frame_tokens}"
249+
)
247250

248251
# Temporal positions: use interpolated_indices (already in [0, target_frames-1])
249252
# Shape: [bs, seq_len] -> expand to [bs, seq_len * frame_tokens]

tests/test_onevision_encoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,9 @@ def test_forward_from_positions_temporal_scaling(self):
159159
patches_per_frame = 16 # 4x4 spatial patches
160160
target_frames = 64
161161

162-
# Simulate interpolated indices in [0, 63] range
163-
# For 8 sampled frames from a video, spread across 64 target frames
164-
interpolated_t = torch.tensor([0, 9, 18, 27, 36, 45, 54, 63], device=device) # [num_frames]
162+
# Simulate interpolated indices in [0, target_frames-1] range
163+
# For num_frames sampled frames from a video, spread across target_frames
164+
interpolated_t = torch.linspace(0, target_frames - 1, num_frames, dtype=torch.long, device=device)
165165

166166
# Spatial positions for each frame (4x4 grid)
167167
h_ids = torch.arange(4, device=device).repeat_interleave(4) # [0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3]

0 commit comments

Comments
 (0)