@@ -88,6 +88,115 @@ def test_forward_from_positions_consistency(self):
8888 f"Max difference: { (freqs_forward - freqs_from_positions ).abs ().max ().item ()} "
8989 )
9090
91+ def test_forward_from_positions_batched (self ):
92+ """Test batched forward_from_positions with 3D input [batch_size, seq_len, 3].
93+
94+ This test verifies that forward_from_positions correctly handles batched inputs
95+ and produces the same result as calling it on each batch element separately.
96+ """
97+ config = OneVisionEncoderConfig (
98+ hidden_size = 128 ,
99+ num_hidden_layers = 1 ,
100+ num_attention_heads = 2 ,
101+ patch_size = 16 ,
102+ image_size = 64 ,
103+ )
104+ model = OneVisionEncoderModel (config )
105+
106+ batch_size = 2
107+ t , h , w = 2 , 4 , 4
108+ seq_len = t * h * w
109+
110+ device = model .video_rope .inv_freq_t .device
111+
112+ # Create patch positions for dense grid (same for both batches in this test)
113+ t_ids = torch .arange (t , device = device ).repeat_interleave (h * w )
114+ h_ids = torch .arange (h , device = device ).repeat_interleave (w ).repeat (t )
115+ w_ids = torch .arange (w , device = device ).repeat (h ).repeat (t )
116+ patch_positions_2d = torch .stack ([t_ids , h_ids , w_ids ], dim = - 1 ) # [seq_len, 3]
117+
118+ # Create batched input [batch_size, seq_len, 3]
119+ patch_positions_3d = patch_positions_2d .unsqueeze (0 ).expand (batch_size , - 1 , - 1 )
120+
121+ # Get frequencies using 2D input
122+ freqs_2d = model .video_rope .forward_from_positions (patch_positions_2d ) # [seq_len, half]
123+
124+ # Get frequencies using 3D batched input
125+ freqs_3d = model .video_rope .forward_from_positions (patch_positions_3d ) # [batch_size, seq_len, half]
126+
127+ # Check shapes
128+ assert freqs_2d .shape == (seq_len , model .video_rope .half ), (
129+ f"2D shape mismatch: expected ({ seq_len } , { model .video_rope .half } ), got { freqs_2d .shape } "
130+ )
131+ assert freqs_3d .shape == (batch_size , seq_len , model .video_rope .half ), (
132+ f"3D shape mismatch: expected ({ batch_size } , { seq_len } , { model .video_rope .half } ), got { freqs_3d .shape } "
133+ )
134+
135+ # Check that each batch element matches the 2D result
136+ for b in range (batch_size ):
137+ assert torch .allclose (freqs_2d , freqs_3d [b ], rtol = 1e-5 , atol = 1e-5 ), (
138+ f"Batch { b } value mismatch. Max diff: { (freqs_2d - freqs_3d [b ]).abs ().max ().item ()} "
139+ )
140+
141+ def test_forward_from_positions_temporal_scaling (self ):
142+ """Test that temporal positions in [0, 64) range produce valid RoPE frequencies.
143+
144+ This test simulates the chunk_wise_sampling use case where interpolated frame
145+ indices are scaled to the range [0, target_frames) where target_frames=64.
146+ """
147+ config = OneVisionEncoderConfig (
148+ hidden_size = 128 ,
149+ num_hidden_layers = 1 ,
150+ num_attention_heads = 2 ,
151+ patch_size = 16 ,
152+ image_size = 64 ,
153+ )
154+ model = OneVisionEncoderModel (config )
155+
156+ device = model .video_rope .inv_freq_t .device
157+ batch_size = 2
158+ num_frames = 8 # sampled frames
159+ patches_per_frame = 16 # 4x4 spatial patches
160+ target_frames = 64
161+
162+ # Simulate interpolated indices in [0, 63] range
163+ # For 8 sampled frames from a video, spread across 64 target frames
164+ interpolated_t = torch .tensor ([0 , 9 , 18 , 27 , 36 , 45 , 54 , 63 ], device = device ) # [num_frames]
165+
166+ # Spatial positions for each frame (4x4 grid)
167+ h_ids = torch .arange (4 , device = device ).repeat_interleave (4 ) # [0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3]
168+ w_ids = torch .arange (4 , device = device ).repeat (4 ) # [0,1,2,3,0,1,2,3,0,1,2,3,0,1,2,3]
169+
170+ # Build patch_positions [batch_size, seq_len, 3] for chunk_wise_sampling
171+ seq_len = num_frames * patches_per_frame
172+ t_positions = interpolated_t .unsqueeze (- 1 ).expand (- 1 , patches_per_frame ).reshape (- 1 ) # [seq_len]
173+ h_positions = h_ids .unsqueeze (0 ).expand (num_frames , - 1 ).reshape (- 1 ) # [seq_len]
174+ w_positions = w_ids .unsqueeze (0 ).expand (num_frames , - 1 ).reshape (- 1 ) # [seq_len]
175+
176+ patch_positions_2d = torch .stack ([t_positions , h_positions , w_positions ], dim = - 1 ) # [seq_len, 3]
177+ patch_positions_3d = patch_positions_2d .unsqueeze (0 ).expand (batch_size , - 1 , - 1 ) # [batch_size, seq_len, 3]
178+
179+ # Get frequencies
180+ freqs = model .video_rope .forward_from_positions (patch_positions_3d )
181+
182+ # Verify shape
183+ assert freqs .shape == (batch_size , seq_len , model .video_rope .half ), (
184+ f"Shape mismatch: expected ({ batch_size } , { seq_len } , { model .video_rope .half } ), got { freqs .shape } "
185+ )
186+
187+ # Verify that temporal positions scaled to [0, 64) don't cause any issues
188+ # (no NaN/Inf values)
189+ assert torch .isfinite (freqs ).all (), "RoPE frequencies contain NaN or Inf values"
190+
191+ # Verify temporal dimension contribution
192+ # For the same spatial position but different temporal positions,
193+ # the temporal part of freqs should differ
194+ frame_0_patch_0 = freqs [0 , 0 , :] # t=0
195+ frame_7_patch_0 = freqs [0 , 7 * patches_per_frame , :] # t=63
196+ assert not torch .allclose (frame_0_patch_0 , frame_7_patch_0 ), (
197+ "RoPE frequencies should differ for different temporal positions"
198+ )
199+
91200
92201if __name__ == "__main__" :
93202 pytest .main ([__file__ , "-v" ])
0 commit comments