@@ -237,31 +237,38 @@ def apply_masks(
237237 Parameters
238238 ----------
239239 x : torch.Tensor
240- Input tensor of shape (B, N, D), where B is the batch size, N is the number
241- of patches, and D is the feature dimension.
240+ Input tensor of shape (B, N, D).
242241 masks : Union[torch.Tensor, List[torch.Tensor]]
243- A list of tensors containing the indices of patches to keep for each sample.
244- Each mask tensor has shape (B, N), where B is the batch size and N is the number
245- of patches.
242+ A list of mask tensors of shape (N,), (1, N), or (B, N).
246243
247244 Returns
248245 -------
249246 torch.Tensor
250247 The masked tensor where only the patches indicated by the masks are kept.
251- The output tensor has shape (B', N', D), where B' is the new batch size
252- (which may be different due to concatenation) and N' is the
253- reduced number of patches.
254-
255- Notes
256- -----
257- - The masks should indicate which patches to keep (1 for keep, 0 for discard).
258- - The function uses `torch.gather` to select the patches specified by the masks.
248+ The output tensor has shape (B * num_masks, N', D),
249+ where N' is the number of patches kept.
259250 """
260251 all_x = []
261- for m in masks :
262- # Expand the mask to match the feature dimension and gather the relevant patches
263- mask_keep = m .unsqueeze (- 1 ).repeat (1 , 1 , x .size (- 1 ))
264- all_x .append (torch .gather (x , dim = 1 , index = mask_keep ))
252+ batch_size = x .size (0 )
253+ for m_ in masks :
254+ m = m_ .to (x .device )
255+
256+ # Ensure mask is at least 2D
257+ if m .dim () == 1 :
258+ m = m .unsqueeze (0 ) # Shape: (1, N)
259+
260+ # Expand mask to match the batch size if needed
261+ if m .size (0 ) == 1 and batch_size > 1 :
262+ m = m .expand (batch_size , - 1 ) # Shape: (B, N)
263+
264+ # Expand mask to match x's dimensions
265+ m_expanded = (
266+ m .unsqueeze (- 1 ).expand (- 1 , - 1 , x .size (- 1 )).bool ()
267+ ) # Shape: (B, N, D)
268+
269+ # Use boolean indexing
270+ selected_patches = x [m_expanded ].view (batch_size , - 1 , x .size (- 1 ))
271+ all_x .append (selected_patches )
265272
266273 # Concatenate along the batch dimension
267274 return torch .cat (all_x , dim = 0 )
@@ -271,40 +278,39 @@ def apply_masks(
271278class IJEPAMaskGenerator :
272279 """Generates encoder and predictor masks for preprocessing.
273280
274- This class generates masks dynamically for individual examples and can be passed to
275- a data loader as a preprocessing step.
281+ This class generates masks dynamically for batches of examples.
276282
277283 Parameters
278284 ----------
279285 input_size : tuple[int, int], default=(224, 224)
280286 Input image size.
281287 patch_size : int, default=16
282288 Size of each patch.
283- min_keep : int, default=4
289+ min_keep : int, default=10
284290 Minimum number of patches to keep.
285291 allow_overlap : bool, default=False
286292 Whether to allow overlap between encoder and predictor masks.
287- enc_mask_scale : tuple[float, float], default=(0.2, 0.8 )
293+ enc_mask_scale : tuple[float, float], default=(0.85, 1.0 )
288294 Scale range for encoder mask.
289- pred_mask_scale : tuple[float, float], default=(0.2 , 0.8 )
295+ pred_mask_scale : tuple[float, float], default=(0.15 , 0.2 )
290296 Scale range for predictor mask.
291- aspect_ratio : tuple[float, float], default=(0.3, 3 .0)
297+ aspect_ratio : tuple[float, float], default=(0.75, 1 .0)
292298 Aspect ratio range for mask blocks.
293299 nenc : int, default=1
294300 Number of encoder masks to generate.
295- npred : int, default=2
301+ npred : int, default=4
296302 Number of predictor masks to generate.
297303 """
298304
299305 input_size : Tuple [int , int ] = (224 , 224 )
300306 patch_size : int = 16
301- min_keep : int = 4
307+ min_keep : int = 10
302308 allow_overlap : bool = False
303- enc_mask_scale : Tuple [float , float ] = (0.2 , 0.8 )
304- pred_mask_scale : Tuple [float , float ] = (0.2 , 0.8 )
305- aspect_ratio : Tuple [float , float ] = (0.3 , 3 .0 )
309+ enc_mask_scale : Tuple [float , float ] = (0.85 , 1.0 )
310+ pred_mask_scale : Tuple [float , float ] = (0.15 , 0.2 )
311+ aspect_ratio : Tuple [float , float ] = (0.75 , 1 .0 )
306312 nenc : int = 1
307- npred : int = 2
313+ npred : int = 4
308314
309315 def __post_init__ (self ) -> None :
310316 """Initialize the mask generator."""
@@ -353,8 +359,14 @@ def _sample_block_mask(
353359
354360 def __call__ (
355361 self ,
362+ batch_size : int = 1 ,
356363 ) -> Dict [str , Any ]:
357- """Generate encoder and predictor masks for a single example.
364+ """Generate encoder and predictor masks for a batch of examples.
365+
366+ Parameters
367+ ----------
368+ batch_size : int, default=1
369+ The batch size for which to generate masks.
358370
359371 Returns
360372 -------
@@ -378,14 +390,18 @@ def __call__(
378390 masks_pred , masks_enc = [], []
379391 for _ in range (self .npred ):
380392 mask_p , _ = self ._sample_block_mask (p_size )
393+ # Expand mask to match batch size
394+ mask_p = mask_p .unsqueeze (0 ).expand (batch_size , - 1 )
381395 masks_pred .append (mask_p )
382396
383397 # Generate encoder masks
384398 for _ in range (self .nenc ):
385399 mask_e , _ = self ._sample_block_mask (e_size )
400+ # Expand mask to match batch size
401+ mask_e = mask_e .unsqueeze (0 ).expand (batch_size , - 1 )
386402 masks_enc .append (mask_e )
387403
388404 return {
389- "encoder_masks" : torch . stack ( masks_enc ),
390- "predictor_masks" : torch . stack ( masks_pred ),
405+ "encoder_masks" : masks_enc , # List of tensors of shape (batch_size, N)
406+ "predictor_masks" : masks_pred , # List of tensors of shape (batch_size, N)
391407 }
0 commit comments