Skip to content

Commit dc264de

Browse files
authored
Update logit scaling method (#29)
1 parent c5d1244 commit dc264de

6 files changed

Lines changed: 73 additions & 51 deletions

File tree

mmlearn/modules/layers/logit_scaling.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ class LearnableLogitScaling(torch.nn.Module):
2222

2323
def __init__(
2424
self,
25-
logit_scale_init: float = 1 / 0.07,
26-
learnable: bool = True,
25+
init_logit_scale: float = 1 / 0.07,
2726
max_logit_scale: float = 100,
27+
learnable: bool = True,
2828
) -> None:
2929
super().__init__()
3030
self.max_logit_scale = max_logit_scale
31-
self.logit_scale_init = logit_scale_init
31+
self.init_logit_scale = init_logit_scale
3232
self.learnable = learnable
33-
log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
33+
log_logit_scale = torch.ones([]) * np.log(self.init_logit_scale)
3434
if learnable:
3535
self.log_logit_scale = torch.nn.Parameter(log_logit_scale)
3636
else:
@@ -49,6 +49,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4949
def extra_repr(self) -> str:
5050
"""Return the string representation of the layer."""
5151
return (
52-
f"logit_scale_init={self.logit_scale_init},learnable={self.learnable},"
52+
f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
5353
f" max_logit_scale={self.max_logit_scale}"
5454
)

mmlearn/modules/losses/contrastive.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def _get_logits(
8383
self,
8484
features_1: torch.Tensor,
8585
features_2: torch.Tensor,
86+
logit_scale: torch.Tensor,
8687
rank: int,
8788
world_size: int,
8889
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -93,7 +94,9 @@ def _get_logits(
9394
features_1 : torch.Tensor
9495
First feature tensor.
9596
features_2 : torch.Tensor
96-
Second feature tensor
97+
Second feature tensor.
98+
logit_scale : torch.Tensor
99+
Logit scale.
97100
rank : int
98101
Rank of the current process.
99102
world_size : int
@@ -114,19 +117,28 @@ def _get_logits(
114117
)
115118

116119
if self.local_loss:
117-
logits_per_feature_1 = _safe_matmul(features_1, all_features_2)
118-
logits_per_feature_2 = _safe_matmul(features_2, all_features_1)
120+
logits_per_feature_1 = logit_scale * _safe_matmul(
121+
features_1, all_features_2
122+
)
123+
logits_per_feature_2 = logit_scale * _safe_matmul(
124+
features_2, all_features_1
125+
)
119126
else:
120-
logits_per_feature_1 = _safe_matmul(all_features_1, all_features_2)
127+
logits_per_feature_1 = logit_scale * _safe_matmul(
128+
all_features_1, all_features_2
129+
)
121130
logits_per_feature_2 = logits_per_feature_1.T
122131
else:
123-
logits_per_feature_1 = _safe_matmul(features_1, features_2)
124-
logits_per_feature_2 = _safe_matmul(features_2, features_1)
132+
logits_per_feature_1 = logit_scale * _safe_matmul(features_1, features_2)
133+
logits_per_feature_2 = logit_scale * _safe_matmul(features_2, features_1)
125134

126135
return logits_per_feature_1, logits_per_feature_2
127136

128137
def forward(
129-
self, features_1: torch.Tensor, features_2: torch.Tensor
138+
self,
139+
features_1: torch.Tensor,
140+
features_2: torch.Tensor,
141+
logit_scale: torch.Tensor,
130142
) -> torch.Tensor:
131143
"""Calculate the CLIP-style loss between two sets of features.
132144
@@ -136,6 +148,8 @@ def forward(
136148
First set of features.
137149
features_2 : torch.Tensor
138150
Second set of features.
151+
logit_scale : torch.Tensor
152+
Logit scale.
139153
140154
Returns
141155
-------
@@ -150,7 +164,7 @@ def forward(
150164
features_2 = F.normalize(features_2, p=2, dim=-1)
151165

152166
logits_per_feat1, logits_per_feat2 = self._get_logits(
153-
features_1, features_2, rank=rank, world_size=world_size
167+
features_1, features_2, logit_scale, rank=rank, world_size=world_size
154168
)
155169
labels = self._get_ground_truth(
156170
features_1.device,

mmlearn/tasks/contrastive_pretraining.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
import inspect
44
import itertools
5+
import math
56
from dataclasses import dataclass
67
from functools import partial
78
from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union
89

910
import lightning as L # noqa: N812
11+
import numpy as np
1012
import torch
1113
import torch.distributed
1214
import torch.distributed.nn
@@ -151,6 +153,9 @@ def __init__( # noqa: PLR0912, PLR0915
151153
partial[torch.optim.lr_scheduler.LRScheduler],
152154
]
153155
] = None,
156+
init_logit_scale: float = 1 / 0.07,
157+
max_logit_scale: float = 100,
158+
learnable_logit_scale: bool = True,
154159
loss: Optional[CLIPLoss] = None,
155160
modality_loss_pairs: Optional[List[LossPairSpec]] = None,
156161
auxiliary_tasks: Optional[Dict[str, AuxiliaryTaskSpec]] = None,
@@ -259,6 +264,19 @@ def __init__( # noqa: PLR0912, PLR0915
259264
}
260265
)
261266

267+
# set up logit scaling
268+
log_logit_scale = torch.ones([]) * np.log(init_logit_scale)
269+
self.max_logit_scale = max_logit_scale
270+
self.learnable_logit_scale = learnable_logit_scale
271+
272+
if self.learnable_logit_scale:
273+
self.log_logit_scale = torch.nn.Parameter(
274+
log_logit_scale, requires_grad=True
275+
)
276+
else:
277+
self.register_buffer("log_logit_scale", log_logit_scale)
278+
279+
# set up contrastive loss pairs
262280
if modality_loss_pairs is None:
263281
modality_loss_pairs = [
264282
LossPairSpec(modalities=(m1.name, m2.name))
@@ -277,6 +295,7 @@ def __init__( # noqa: PLR0912, PLR0915
277295
)
278296
self.modality_loss_pairs = modality_loss_pairs
279297

298+
# set up auxiliary tasks
280299
self.aux_task_specs = auxiliary_tasks or {}
281300
self.auxiliary_tasks: Dict[str, L.LightningModule] = {}
282301
for task_name, task_spec in self.aux_task_specs.items():
@@ -313,10 +332,11 @@ def __init__( # noqa: PLR0912, PLR0915
313332
f"Expected {eval_task_spec.task} to be an instance of `EvaluationHooks` "
314333
f"but got {type(eval_task_spec.task)}."
315334
)
316-
317335
self.evaluation_tasks = evaluation_tasks
318336

319-
def encode(self, inputs: Dict[str, Any], modality: Modality) -> torch.Tensor:
337+
def encode(
338+
self, inputs: Dict[str, Any], modality: Modality, normalize: bool = False
339+
) -> torch.Tensor:
320340
"""Encode the input values for the given modality.
321341
322342
Parameters
@@ -325,6 +345,9 @@ def encode(self, inputs: Dict[str, Any], modality: Modality) -> torch.Tensor:
325345
Input values.
326346
modality : Modality
327347
The modality to encode.
348+
normalize : bool, optional, default=False
349+
Whether to apply L2 normalization to the output (after the head and
350+
postprocessor layers, if present).
328351
329352
Returns
330353
-------
@@ -339,6 +362,9 @@ def encode(self, inputs: Dict[str, Any], modality: Modality) -> torch.Tensor:
339362
if self.postprocessors and modality.name in self.postprocessors:
340363
output = self.postprocessors[modality.name](output)
341364

365+
if normalize:
366+
output = torch.nn.functional.normalize(output, p=2, dim=-1)
367+
342368
return output
343369

344370
def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
@@ -355,7 +381,7 @@ def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
355381
The encodings for each modality.
356382
"""
357383
outputs = {
358-
modality.embedding: self.encode(inputs, modality)
384+
modality.embedding: self.encode(inputs, modality, normalize=True)
359385
for modality in self._available_modalities
360386
}
361387

@@ -373,6 +399,16 @@ def _compute_loss(
373399
if self.loss_fn is None:
374400
return None
375401

402+
with torch.no_grad():
403+
self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale))
404+
self.log(
405+
"train/logit_scale",
406+
self.log_logit_scale.exp(),
407+
prog_bar=True,
408+
on_step=True,
409+
on_epoch=False,
410+
)
411+
376412
contrastive_losses: list[torch.Tensor] = []
377413
for loss_pair in self.modality_loss_pairs:
378414
modality_a = Modalities.get_modality(loss_pair.modalities[0])
@@ -389,6 +425,7 @@ def _compute_loss(
389425
self.loss_fn(
390426
outputs[modality_a.embedding][indices_a],
391427
outputs[modality_b.embedding][indices_b],
428+
self.log_logit_scale.exp(),
392429
)
393430
* loss_pair.weight
394431
)

projects/bioscan_clip/configs/experiment/bioscan_1m.yaml

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ defaults:
1212
- /modules/encoders@task.encoders.rgb: timm-vit-lora
1313
- /modules/encoders@task.encoders.dna: barcode-bert-lora
1414
- /modules/layers@task.heads.text: MLP # the other modalities have projection heads in their encoders
15-
- /modules/layers@task.postprocessors.norm_and_logit_scale.norm: L2Norm
16-
- /modules/layers@task.postprocessors.norm_and_logit_scale.logit_scale: LearnableLogitScaling
1715
- /modules/losses@task.loss: CLIPLoss
1816
- /modules/optimizers@task.optimizer: AdamW
1917
- /modules/lr_schedulers@task.lr_scheduler.scheduler: OneCycleLR
@@ -67,19 +65,6 @@ task:
6765
text:
6866
in_dim: 512
6967
out_dim: ${task.encoders.rgb.projection_dim}
70-
postprocessors:
71-
norm_and_logit_scale:
72-
norm:
73-
dim: -1
74-
logit_scale:
75-
learnable: True
76-
modality_module_mapping:
77-
text:
78-
postprocessor_key: norm_and_logit_scale
79-
rgb:
80-
postprocessor_key: norm_and_logit_scale
81-
dna:
82-
postprocessor_key: norm_and_logit_scale
8368
optimizer:
8469
lr: 1.0e-3
8570
eps: 1.0e-6

projects/med_benchmarking/configs/experiment/baseline.yaml

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ defaults:
1515
- /datasets/tokenizers@dataloader.val.collate_fn.batch_processors.text: HFCLIPTokenizer
1616
- /modules/encoders@task.encoders.text: HFCLIPTextEncoderWithProjection
1717
- /modules/encoders@task.encoders.rgb: HFCLIPVisionEncoderWithProjection
18-
- /modules/layers@task.postprocessors.norm_and_logit_scale.norm: L2Norm
19-
- /modules/layers@task.postprocessors.norm_and_logit_scale.logit_scale: LearnableLogitScaling
2018
- /modules/losses@task.loss: CLIPLoss
2119
- /modules/optimizers@task.optimizer: AdamW
2220
- /modules/lr_schedulers@task.lr_scheduler.scheduler: CosineAnnealingLR
@@ -47,17 +45,6 @@ dataloader:
4745
num_workers: 4
4846

4947
task:
50-
postprocessors:
51-
norm_and_logit_scale:
52-
norm:
53-
dim: -1
54-
logit_scale:
55-
learnable: True
56-
modality_module_mapping:
57-
text:
58-
postprocessor_key: norm_and_logit_scale
59-
rgb:
60-
postprocessor_key: norm_and_logit_scale
6148
optimizer:
6249
betas:
6350
- 0.9

projects/med_benchmarking/datasets/pad_ufes_20.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ def __init__(
4343
self.split = split
4444

4545
# Load cached data if available
46-
cache_path = f"cache/PadUfes20_{split}.pkl"
46+
cache_path = f".cache/PadUfes20_{split}.pkl"
4747
if os.path.exists(cache_path):
4848
print(f"!!! Using cached dataset for {split}")
4949
with open(cache_path, "rb") as f:
5050
self.metadata = pickle.load(f)
5151
else:
52-
os.makedirs("cache/", exist_ok=True)
52+
os.makedirs(".cache/", exist_ok=True)
5353
self.metadata = self._load_and_process_metadata()
5454
with open(cache_path, "wb") as f:
5555
pickle.dump(self.metadata.to_dict("records"), f)
@@ -68,14 +68,13 @@ def _load_and_process_metadata(self) -> pd.DataFrame:
6868
df["path"] = df["img_id"].apply(
6969
lambda imgid: os.path.join(self.root_dir, "Dataset", imgid)
7070
)
71-
df.drop(columns=["img_id", "diagnostic"], inplace=True).reset_index(
72-
drop=True, inplace=True
73-
)
71+
df.drop(columns=["img_id", "diagnostic"], inplace=True)
72+
df.reset_index(drop=True, inplace=True)
7473

7574
# Split into train and test
7675
dataset = {}
77-
dataset["test"] = df.sample(frac=0.2)
78-
dataset["train"] = df.drop(dataset["test"].index)
76+
dataset["test"] = df.sample(frac=0.2, ignore_index=True)
77+
dataset["train"] = df.drop(dataset["test"].index).reset_index(drop=True)
7978
return dataset[self.split]
8079

8180
def _build_label(self, str_label: str) -> int:

0 commit comments

Comments
 (0)