22
33import inspect
44import itertools
5+ import math
56from dataclasses import dataclass
67from functools import partial
78from typing import Any , Dict , List , Literal , Mapping , Optional , Tuple , Union
89
910import lightning as L # noqa: N812
11+ import numpy as np
1012import torch
1113import torch .distributed
1214import torch .distributed .nn
@@ -151,6 +153,9 @@ def __init__( # noqa: PLR0912, PLR0915
151153 partial [torch .optim .lr_scheduler .LRScheduler ],
152154 ]
153155 ] = None ,
156+ init_logit_scale : float = 1 / 0.07 ,
157+ max_logit_scale : float = 100 ,
158+ learnable_logit_scale : bool = True ,
154159 loss : Optional [CLIPLoss ] = None ,
155160 modality_loss_pairs : Optional [List [LossPairSpec ]] = None ,
156161 auxiliary_tasks : Optional [Dict [str , AuxiliaryTaskSpec ]] = None ,
@@ -259,6 +264,19 @@ def __init__( # noqa: PLR0912, PLR0915
259264 }
260265 )
261266
267+ # set up logit scaling
268+ log_logit_scale = torch .ones ([]) * np .log (init_logit_scale )
269+ self .max_logit_scale = max_logit_scale
270+ self .learnable_logit_scale = learnable_logit_scale
271+
272+ if self .learnable_logit_scale :
273+ self .log_logit_scale = torch .nn .Parameter (
274+ log_logit_scale , requires_grad = True
275+ )
276+ else :
277+ self .register_buffer ("log_logit_scale" , log_logit_scale )
278+
279+ # set up contrastive loss pairs
262280 if modality_loss_pairs is None :
263281 modality_loss_pairs = [
264282 LossPairSpec (modalities = (m1 .name , m2 .name ))
@@ -277,6 +295,7 @@ def __init__( # noqa: PLR0912, PLR0915
277295 )
278296 self .modality_loss_pairs = modality_loss_pairs
279297
298+ # set up auxiliary tasks
280299 self .aux_task_specs = auxiliary_tasks or {}
281300 self .auxiliary_tasks : Dict [str , L .LightningModule ] = {}
282301 for task_name , task_spec in self .aux_task_specs .items ():
@@ -313,10 +332,11 @@ def __init__( # noqa: PLR0912, PLR0915
313332 f"Expected { eval_task_spec .task } to be an instance of `EvaluationHooks` "
314333 f"but got { type (eval_task_spec .task )} ."
315334 )
316-
317335 self .evaluation_tasks = evaluation_tasks
318336
319- def encode (self , inputs : Dict [str , Any ], modality : Modality ) -> torch .Tensor :
337+ def encode (
338+ self , inputs : Dict [str , Any ], modality : Modality , normalize : bool = False
339+ ) -> torch .Tensor :
320340 """Encode the input values for the given modality.
321341
322342 Parameters
@@ -325,6 +345,9 @@ def encode(self, inputs: Dict[str, Any], modality: Modality) -> torch.Tensor:
325345 Input values.
326346 modality : Modality
327347 The modality to encode.
348+ normalize : bool, optional, default=False
349+ Whether to apply L2 normalization to the output (after the head and
350+ postprocessor layers, if present).
328351
329352 Returns
330353 -------
@@ -339,6 +362,9 @@ def encode(self, inputs: Dict[str, Any], modality: Modality) -> torch.Tensor:
339362 if self .postprocessors and modality .name in self .postprocessors :
340363 output = self .postprocessors [modality .name ](output )
341364
365+ if normalize :
366+ output = torch .nn .functional .normalize (output , p = 2 , dim = - 1 )
367+
342368 return output
343369
344370 def forward (self , inputs : Dict [str , Any ]) -> Dict [str , torch .Tensor ]:
@@ -355,7 +381,7 @@ def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
355381 The encodings for each modality.
356382 """
357383 outputs = {
358- modality .embedding : self .encode (inputs , modality )
384+ modality .embedding : self .encode (inputs , modality , normalize = True )
359385 for modality in self ._available_modalities
360386 }
361387
@@ -373,6 +399,16 @@ def _compute_loss(
373399 if self .loss_fn is None :
374400 return None
375401
402+ with torch .no_grad ():
403+ self .log_logit_scale .clamp_ (0 , math .log (self .max_logit_scale ))
404+ self .log (
405+ "train/logit_scale" ,
406+ self .log_logit_scale .exp (),
407+ prog_bar = True ,
408+ on_step = True ,
409+ on_epoch = False ,
410+ )
411+
376412 contrastive_losses : list [torch .Tensor ] = []
377413 for loss_pair in self .modality_loss_pairs :
378414 modality_a = Modalities .get_modality (loss_pair .modalities [0 ])
@@ -389,6 +425,7 @@ def _compute_loss(
389425 self .loss_fn (
390426 outputs [modality_a .embedding ][indices_a ],
391427 outputs [modality_b .embedding ][indices_b ],
428+ self .log_logit_scale .exp (),
392429 )
393430 * loss_pair .weight
394431 )
0 commit comments