Skip to content

Commit f0c4a0c

Browse files
committed
Merge branch 'main' of https://github.com/VectorInstitute/mmlearn into main
2 parents bdaec8f + c6b07e0 commit f0c4a0c

11 files changed

Lines changed: 761 additions & 193 deletions

File tree

mmlearn/datasets/processors/masking.py

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -237,31 +237,38 @@ def apply_masks(
237237
Parameters
238238
----------
239239
x : torch.Tensor
240-
Input tensor of shape (B, N, D), where B is the batch size, N is the number
241-
of patches, and D is the feature dimension.
240+
Input tensor of shape (B, N, D).
242241
masks : Union[torch.Tensor, List[torch.Tensor]]
243-
A list of tensors containing the indices of patches to keep for each sample.
244-
Each mask tensor has shape (B, N), where B is the batch size and N is the number
245-
of patches.
242+
A list of mask tensors of shape (N,), (1, N), or (B, N).
246243
247244
Returns
248245
-------
249246
torch.Tensor
250247
The masked tensor where only the patches indicated by the masks are kept.
251-
The output tensor has shape (B', N', D), where B' is the new batch size
252-
(which may be different due to concatenation) and N' is the
253-
reduced number of patches.
254-
255-
Notes
256-
-----
257-
- The masks should indicate which patches to keep (1 for keep, 0 for discard).
258-
- The function uses `torch.gather` to select the patches specified by the masks.
248+
The output tensor has shape (B * num_masks, N', D),
249+
where N' is the number of patches kept.
259250
"""
260251
all_x = []
261-
for m in masks:
262-
# Expand the mask to match the feature dimension and gather the relevant patches
263-
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
264-
all_x.append(torch.gather(x, dim=1, index=mask_keep))
252+
batch_size = x.size(0)
253+
for m_ in masks:
254+
m = m_.to(x.device)
255+
256+
# Ensure mask is at least 2D
257+
if m.dim() == 1:
258+
m = m.unsqueeze(0) # Shape: (1, N)
259+
260+
# Expand mask to match the batch size if needed
261+
if m.size(0) == 1 and batch_size > 1:
262+
m = m.expand(batch_size, -1) # Shape: (B, N)
263+
264+
# Expand mask to match x's dimensions
265+
m_expanded = (
266+
m.unsqueeze(-1).expand(-1, -1, x.size(-1)).bool()
267+
) # Shape: (B, N, D)
268+
269+
# Use boolean indexing
270+
selected_patches = x[m_expanded].view(batch_size, -1, x.size(-1))
271+
all_x.append(selected_patches)
265272

266273
# Concatenate along the batch dimension
267274
return torch.cat(all_x, dim=0)
@@ -271,40 +278,39 @@ def apply_masks(
271278
class IJEPAMaskGenerator:
272279
"""Generates encoder and predictor masks for preprocessing.
273280
274-
This class generates masks dynamically for individual examples and can be passed to
275-
a data loader as a preprocessing step.
281+
This class generates masks dynamically for batches of examples.
276282
277283
Parameters
278284
----------
279285
input_size : tuple[int, int], default=(224, 224)
280286
Input image size.
281287
patch_size : int, default=16
282288
Size of each patch.
283-
min_keep : int, default=4
289+
min_keep : int, default=10
284290
Minimum number of patches to keep.
285291
allow_overlap : bool, default=False
286292
Whether to allow overlap between encoder and predictor masks.
287-
enc_mask_scale : tuple[float, float], default=(0.2, 0.8)
293+
enc_mask_scale : tuple[float, float], default=(0.85, 1.0)
288294
Scale range for encoder mask.
289-
pred_mask_scale : tuple[float, float], default=(0.2, 0.8)
295+
pred_mask_scale : tuple[float, float], default=(0.15, 0.2)
290296
Scale range for predictor mask.
291-
aspect_ratio : tuple[float, float], default=(0.3, 3.0)
297+
aspect_ratio : tuple[float, float], default=(0.75, 1.0)
292298
Aspect ratio range for mask blocks.
293299
nenc : int, default=1
294300
Number of encoder masks to generate.
295-
npred : int, default=2
301+
npred : int, default=4
296302
Number of predictor masks to generate.
297303
"""
298304

299305
input_size: Tuple[int, int] = (224, 224)
300306
patch_size: int = 16
301-
min_keep: int = 4
307+
min_keep: int = 10
302308
allow_overlap: bool = False
303-
enc_mask_scale: Tuple[float, float] = (0.2, 0.8)
304-
pred_mask_scale: Tuple[float, float] = (0.2, 0.8)
305-
aspect_ratio: Tuple[float, float] = (0.3, 3.0)
309+
enc_mask_scale: Tuple[float, float] = (0.85, 1.0)
310+
pred_mask_scale: Tuple[float, float] = (0.15, 0.2)
311+
aspect_ratio: Tuple[float, float] = (0.75, 1.0)
306312
nenc: int = 1
307-
npred: int = 2
313+
npred: int = 4
308314

309315
def __post_init__(self) -> None:
310316
"""Initialize the mask generator."""
@@ -353,8 +359,14 @@ def _sample_block_mask(
353359

354360
def __call__(
355361
self,
362+
batch_size: int = 1,
356363
) -> Dict[str, Any]:
357-
"""Generate encoder and predictor masks for a single example.
364+
"""Generate encoder and predictor masks for a batch of examples.
365+
366+
Parameters
367+
----------
368+
batch_size : int, default=1
369+
The batch size for which to generate masks.
358370
359371
Returns
360372
-------
@@ -378,14 +390,18 @@ def __call__(
378390
masks_pred, masks_enc = [], []
379391
for _ in range(self.npred):
380392
mask_p, _ = self._sample_block_mask(p_size)
393+
# Expand mask to match batch size
394+
mask_p = mask_p.unsqueeze(0).expand(batch_size, -1)
381395
masks_pred.append(mask_p)
382396

383397
# Generate encoder masks
384398
for _ in range(self.nenc):
385399
mask_e, _ = self._sample_block_mask(e_size)
400+
# Expand mask to match batch size
401+
mask_e = mask_e.unsqueeze(0).expand(batch_size, -1)
386402
masks_enc.append(mask_e)
387403

388404
return {
389-
"encoder_masks": torch.stack(masks_enc),
390-
"predictor_masks": torch.stack(masks_pred),
405+
"encoder_masks": masks_enc, # List of tensors of shape (batch_size, N)
406+
"predictor_masks": masks_pred, # List of tensors of shape (batch_size, N)
391407
}

mmlearn/modules/ema.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,52 @@ def __init__(
5252
self.ema_end_decay = ema_end_decay
5353
self.ema_anneal_end_step = ema_anneal_end_step
5454

55+
@staticmethod
56+
def deepcopy_model(model: torch.nn.Module) -> torch.nn.Module:
57+
"""Deep copy the model."""
58+
try:
59+
return copy.deepcopy(model)
60+
except RuntimeError as e:
61+
raise RuntimeError("Unable to copy the model ", e) from e
62+
63+
@staticmethod
64+
def get_annealed_rate(
65+
start: float,
66+
end: float,
67+
curr_step: int,
68+
total_steps: int,
69+
) -> float:
70+
"""Calculate EMA annealing rate."""
71+
r = end - start
72+
pct_remaining = 1 - curr_step / total_steps
73+
return end - r * pct_remaining
74+
75+
def step(self, new_model: torch.nn.Module) -> None:
76+
"""Perform single EMA update step."""
77+
self._update_weights(new_model)
78+
self._update_ema_decay()
79+
80+
def restore(self, model: torch.nn.Module) -> torch.nn.Module:
81+
"""Reassign weights from another model.
82+
83+
Parameters
84+
----------
85+
model : nn.Module
86+
Model to load weights from.
87+
88+
Returns
89+
-------
90+
nn.Module
91+
model with new weights
92+
"""
93+
d = self.model.state_dict()
94+
model.load_state_dict(d, strict=False)
95+
return model
96+
97+
def state_dict(self) -> dict[str, Any]:
98+
"""Return the state dict of the model."""
99+
return self.model.state_dict() # type: ignore[no-any-return]
100+
55101
@torch.no_grad() # type: ignore[misc]
56102
def _update_weights(self, new_model: torch.nn.Module) -> None:
57103
if self.decay < 1:
@@ -98,49 +144,3 @@ def _update_ema_decay(self) -> None:
98144
self.ema_anneal_end_step,
99145
)
100146
self.decay = decay
101-
102-
def step(self, new_model: torch.nn.Module) -> None:
103-
"""Perform single EMA update step."""
104-
self._update_weights(new_model)
105-
self._update_ema_decay()
106-
107-
@staticmethod
108-
def deepcopy_model(model: torch.nn.Module) -> torch.nn.Module:
109-
"""Deep copy the model."""
110-
try:
111-
return copy.deepcopy(model)
112-
except RuntimeError as e:
113-
raise RuntimeError("Unable to copy the model ", e) from e
114-
115-
def restore(self, model: torch.nn.Module) -> torch.nn.Module:
116-
"""Reassign weights from another model.
117-
118-
Parameters
119-
----------
120-
model : nn.Module
121-
Model to load weights from.
122-
123-
Returns
124-
-------
125-
nn.Module
126-
model with new weights
127-
"""
128-
d = self.model.state_dict()
129-
model.load_state_dict(d, strict=False)
130-
return model
131-
132-
def state_dict(self) -> dict[str, Any]:
133-
"""Return the state dict of the model."""
134-
return self.model.state_dict() # type: ignore[no-any-return]
135-
136-
@staticmethod
137-
def get_annealed_rate(
138-
start: float,
139-
end: float,
140-
curr_step: int,
141-
total_steps: int,
142-
) -> float:
143-
"""Calculate EMA annealing rate."""
144-
r = end - start
145-
pct_remaining = 1 - curr_step / total_steps
146-
return end - r * pct_remaining

0 commit comments

Comments
 (0)