Skip to content

Commit 6c3f060

Browse files
Copilotanxiangsir
andcommitted
Add patch_positions support for chunk_wise_sampling RoPE with temporal scaling to [0, 64)
Co-authored-by: anxiangsir <31175974+anxiangsir@users.noreply.github.com>
1 parent 1f15bc0 commit 6c3f060

4 files changed

Lines changed: 178 additions & 11 deletions

File tree

eval_encoder/attentive_probe.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,29 @@ def video_to_images(videos: torch.Tensor) -> torch.Tensor:
201201
visible_index = (interpolated_indices.unsqueeze(-1) * frame_tokens + per).reshape(bs, -1)
202202
visible_index = visible_index.clamp_max(target_frames * frame_tokens - 1)
203203

204-
enc_out = model(videos, visible_index)
204+
# ===> Compute patch_positions for RoPE with temporal scaling to [0, 64) <===
205+
# Calculate spatial grid dimensions (assume square patches)
206+
patches_per_side = int(math.sqrt(frame_tokens)) # e.g., 14 for 196 tokens
207+
seq_len = frame_indices.shape[1] # Number of frames sampled
208+
209+
# Temporal positions: use interpolated_indices (already in [0, target_frames-1])
210+
# Shape: [bs, seq_len] -> expand to [bs, seq_len * frame_tokens]
211+
t_positions = interpolated_indices.unsqueeze(-1).expand(-1, -1, frame_tokens).reshape(bs, -1)
212+
213+
# Spatial positions: h and w within each frame
214+
# per is [0, 1, 2, ..., frame_tokens-1]
215+
h_per_patch = per // patches_per_side # [0,0,...,0,1,1,...,1,...,patches_per_side-1]
216+
w_per_patch = per % patches_per_side # [0,1,...,patches_per_side-1,0,1,...,patches_per_side-1,...]
217+
218+
# Expand spatial positions for all frames and batches
219+
# Shape: [frame_tokens] -> [bs, seq_len * frame_tokens]
220+
h_positions = h_per_patch.unsqueeze(0).unsqueeze(0).expand(bs, seq_len, -1).reshape(bs, -1)
221+
w_positions = w_per_patch.unsqueeze(0).unsqueeze(0).expand(bs, seq_len, -1).reshape(bs, -1)
222+
223+
# Stack to create patch_positions: [bs, seq_len * frame_tokens, 3]
224+
patch_positions = torch.stack([t_positions, h_positions, w_positions], dim=-1)
225+
226+
enc_out = model(videos, patch_positions, visible_index)
205227
if hasattr(enc_out, "last_hidden_state"):
206228
outputs = enc_out.last_hidden_state
207229
else:

eval_encoder/attentive_probe_codec.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,28 @@ def video_to_images(videos: torch.Tensor) -> torch.Tensor:
241241
visible_index = (interpolated_indices.unsqueeze(-1) * frame_tokens + per).reshape(bs, -1)
242242
visible_index = visible_index.clamp_max(target_frames * frame_tokens - 1)
243243

244-
enc_out = model(padded_videos, visible_index)
244+
# ===> Compute patch_positions for RoPE with temporal scaling to [0, 64) <===
245+
# Calculate spatial grid dimensions (assume square patches)
246+
patches_per_side = int(math.sqrt(frame_tokens)) # e.g., 14 for 196 tokens
247+
248+
# Temporal positions: use interpolated_indices (already in [0, target_frames-1])
249+
# Shape: [bs, seq_len] -> expand to [bs, seq_len * frame_tokens]
250+
t_positions = interpolated_indices.unsqueeze(-1).expand(-1, -1, frame_tokens).reshape(bs, -1)
251+
252+
# Spatial positions: h and w within each frame
253+
# per is [0, 1, 2, ..., frame_tokens-1]
254+
h_per_patch = per // patches_per_side # [0,0,...,0,1,1,...,1,...,patches_per_side-1]
255+
w_per_patch = per % patches_per_side # [0,1,...,patches_per_side-1,0,1,...,patches_per_side-1,...]
256+
257+
# Expand spatial positions for all frames and batches
258+
# Shape: [frame_tokens] -> [bs, seq_len * frame_tokens]
259+
h_positions = h_per_patch.unsqueeze(0).unsqueeze(0).expand(bs, seq_len, -1).reshape(bs, -1)
260+
w_positions = w_per_patch.unsqueeze(0).unsqueeze(0).expand(bs, seq_len, -1).reshape(bs, -1)
261+
262+
# Stack to create patch_positions: [bs, seq_len * frame_tokens, 3]
263+
patch_positions = torch.stack([t_positions, h_positions, w_positions], dim=-1)
264+
265+
enc_out = model(padded_videos, patch_positions, visible_index)
245266
if hasattr(enc_out, "last_hidden_state"):
246267
outputs = enc_out.last_hidden_state
247268
else:

onevision_encoder/modeling_onevision_encoder.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -165,25 +165,40 @@ def forward_from_positions(self, patch_positions: torch.Tensor) -> torch.Tensor:
165165
Compute rotary position embeddings from explicit patch positions.
166166
167167
Args:
168-
patch_positions: [seq_len, 3] tensor with [t, h, w] positions for each patch
168+
patch_positions: [seq_len, 3] or [batch_size, seq_len, 3] tensor with [t, h, w] positions for each patch
169169
170170
Returns:
171-
freqs: [seq_len, half] tensor of position frequencies
171+
freqs: [seq_len, half] or [batch_size, seq_len, half] tensor of position frequencies
172172
"""
173173
device = patch_positions.device
174174
inv_t = self.inv_freq_t.to(device=device)
175175
inv_h = self.inv_freq_h.to(device=device)
176176
inv_w = self.inv_freq_w.to(device=device)
177177

178-
t_pos = patch_positions[:, 0].float()
179-
h_pos = patch_positions[:, 1].float()
180-
w_pos = patch_positions[:, 2].float()
178+
# Handle both 2D [seq_len, 3] and 3D [batch_size, seq_len, 3] inputs
179+
if patch_positions.dim() == 2:
180+
# Original 2D case: [seq_len, 3]
181+
t_pos = patch_positions[:, 0].float()
182+
h_pos = patch_positions[:, 1].float()
183+
w_pos = patch_positions[:, 2].float()
181184

182-
ft = torch.outer(t_pos, inv_t)
183-
fh = torch.outer(h_pos, inv_h)
184-
fw = torch.outer(w_pos, inv_w)
185+
ft = torch.outer(t_pos, inv_t)
186+
fh = torch.outer(h_pos, inv_h)
187+
fw = torch.outer(w_pos, inv_w)
185188

186-
return torch.cat([ft, fh, fw], dim=-1)
189+
return torch.cat([ft, fh, fw], dim=-1)
190+
else:
191+
# Batched 3D case: [batch_size, seq_len, 3]
192+
t_pos = patch_positions[..., 0].float() # [batch_size, seq_len]
193+
h_pos = patch_positions[..., 1].float() # [batch_size, seq_len]
194+
w_pos = patch_positions[..., 2].float() # [batch_size, seq_len]
195+
196+
# Use einsum for batched outer product: [batch_size, seq_len] x [dim] -> [batch_size, seq_len, dim]
197+
ft = torch.einsum("bs,d->bsd", t_pos, inv_t)
198+
fh = torch.einsum("bs,d->bsd", h_pos, inv_h)
199+
fw = torch.einsum("bs,d->bsd", w_pos, inv_w)
200+
201+
return torch.cat([ft, fh, fw], dim=-1)
187202

188203

189204
class Siglip2MultiheadAttentionPoolingHead(nn.Module):

tests/test_onevision_encoder.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,115 @@ def test_forward_from_positions_consistency(self):
8888
f"Max difference: {(freqs_forward - freqs_from_positions).abs().max().item()}"
8989
)
9090

91+
def test_forward_from_positions_batched(self):
92+
"""Test batched forward_from_positions with 3D input [batch_size, seq_len, 3].
93+
94+
This test verifies that forward_from_positions correctly handles batched inputs
95+
and produces the same result as calling it on each batch element separately.
96+
"""
97+
config = OneVisionEncoderConfig(
98+
hidden_size=128,
99+
num_hidden_layers=1,
100+
num_attention_heads=2,
101+
patch_size=16,
102+
image_size=64,
103+
)
104+
model = OneVisionEncoderModel(config)
105+
106+
batch_size = 2
107+
t, h, w = 2, 4, 4
108+
seq_len = t * h * w
109+
110+
device = model.video_rope.inv_freq_t.device
111+
112+
# Create patch positions for dense grid (same for both batches in this test)
113+
t_ids = torch.arange(t, device=device).repeat_interleave(h * w)
114+
h_ids = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
115+
w_ids = torch.arange(w, device=device).repeat(h).repeat(t)
116+
patch_positions_2d = torch.stack([t_ids, h_ids, w_ids], dim=-1) # [seq_len, 3]
117+
118+
# Create batched input [batch_size, seq_len, 3]
119+
patch_positions_3d = patch_positions_2d.unsqueeze(0).expand(batch_size, -1, -1)
120+
121+
# Get frequencies using 2D input
122+
freqs_2d = model.video_rope.forward_from_positions(patch_positions_2d) # [seq_len, half]
123+
124+
# Get frequencies using 3D batched input
125+
freqs_3d = model.video_rope.forward_from_positions(patch_positions_3d) # [batch_size, seq_len, half]
126+
127+
# Check shapes
128+
assert freqs_2d.shape == (seq_len, model.video_rope.half), (
129+
f"2D shape mismatch: expected ({seq_len}, {model.video_rope.half}), got {freqs_2d.shape}"
130+
)
131+
assert freqs_3d.shape == (batch_size, seq_len, model.video_rope.half), (
132+
f"3D shape mismatch: expected ({batch_size}, {seq_len}, {model.video_rope.half}), got {freqs_3d.shape}"
133+
)
134+
135+
# Check that each batch element matches the 2D result
136+
for b in range(batch_size):
137+
assert torch.allclose(freqs_2d, freqs_3d[b], rtol=1e-5, atol=1e-5), (
138+
f"Batch {b} value mismatch. Max diff: {(freqs_2d - freqs_3d[b]).abs().max().item()}"
139+
)
140+
141+
def test_forward_from_positions_temporal_scaling(self):
142+
"""Test that temporal positions in [0, 64) range produce valid RoPE frequencies.
143+
144+
This test simulates the chunk_wise_sampling use case where interpolated frame
145+
indices are scaled to the range [0, target_frames) where target_frames=64.
146+
"""
147+
config = OneVisionEncoderConfig(
148+
hidden_size=128,
149+
num_hidden_layers=1,
150+
num_attention_heads=2,
151+
patch_size=16,
152+
image_size=64,
153+
)
154+
model = OneVisionEncoderModel(config)
155+
156+
device = model.video_rope.inv_freq_t.device
157+
batch_size = 2
158+
num_frames = 8 # sampled frames
159+
patches_per_frame = 16 # 4x4 spatial patches
160+
target_frames = 64
161+
162+
# Simulate interpolated indices in [0, 63] range
163+
# For 8 sampled frames from a video, spread across 64 target frames
164+
interpolated_t = torch.tensor([0, 9, 18, 27, 36, 45, 54, 63], device=device) # [num_frames]
165+
166+
# Spatial positions for each frame (4x4 grid)
167+
h_ids = torch.arange(4, device=device).repeat_interleave(4) # [0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3]
168+
w_ids = torch.arange(4, device=device).repeat(4) # [0,1,2,3,0,1,2,3,0,1,2,3,0,1,2,3]
169+
170+
# Build patch_positions [batch_size, seq_len, 3] for chunk_wise_sampling
171+
seq_len = num_frames * patches_per_frame
172+
t_positions = interpolated_t.unsqueeze(-1).expand(-1, patches_per_frame).reshape(-1) # [seq_len]
173+
h_positions = h_ids.unsqueeze(0).expand(num_frames, -1).reshape(-1) # [seq_len]
174+
w_positions = w_ids.unsqueeze(0).expand(num_frames, -1).reshape(-1) # [seq_len]
175+
176+
patch_positions_2d = torch.stack([t_positions, h_positions, w_positions], dim=-1) # [seq_len, 3]
177+
patch_positions_3d = patch_positions_2d.unsqueeze(0).expand(batch_size, -1, -1) # [batch_size, seq_len, 3]
178+
179+
# Get frequencies
180+
freqs = model.video_rope.forward_from_positions(patch_positions_3d)
181+
182+
# Verify shape
183+
assert freqs.shape == (batch_size, seq_len, model.video_rope.half), (
184+
f"Shape mismatch: expected ({batch_size}, {seq_len}, {model.video_rope.half}), got {freqs.shape}"
185+
)
186+
187+
# Verify that temporal positions scaled to [0, 64) don't cause any issues
188+
# (no NaN/Inf values)
189+
assert torch.isfinite(freqs).all(), "RoPE frequencies contain NaN or Inf values"
190+
191+
# Verify temporal dimension contribution
192+
# For the same spatial position but different temporal positions,
193+
# the temporal part of freqs should differ
194+
frame_0_patch_0 = freqs[0, 0, :] # t=0
195+
frame_7_patch_0 = freqs[0, 7 * patches_per_frame, :] # t=63
196+
assert not torch.allclose(frame_0_patch_0, frame_7_patch_0), (
197+
"RoPE frequencies should differ for different temporal positions"
198+
)
199+
91200

92201
if __name__ == "__main__":
93202
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)