77import torch .nn .functional as F
88from einops import rearrange , reduce , repeat
99from torch import nn
10+ from torchvision .transforms import v2
1011
1112from src .backbone import Transformer
1213from src .factory import DynamicEmbedding
13- from src .mrl import MRL , MRLLoss
1414from src .utils import posemb_sincos_2d_with_gsd
1515
1616torch .set_float32_matmul_precision ("medium" )
@@ -386,8 +386,13 @@ def __init__( # noqa: PLR0913
386386 self .shuffle = shuffle
387387 self .metadata = metadata
388388 self .teacher = timm .create_model (teacher , pretrained = True , num_classes = 0 )
389- self .mrl = MRL (features = self .teacher .num_features , dolls = dolls )
390- self .mrl_loss = MRLLoss (weights = doll_weights )
389+ self .teacher_chip_size = 518
390+ self .teacher_resize = v2 .Resize (
391+ size = (self .teacher_chip_size , self .teacher_chip_size )
392+ )
393+ # self.mrl = MRL(features=self.teacher.num_features, dolls=dolls)
394+ # self.mrl_loss = MRLLoss(weights=doll_weights)
395+ self .proj = nn .Linear (dim , self .teacher .num_features )
391396
392397 self .encoder = Encoder (
393398 mask_ratio = mask_ratio ,
@@ -516,8 +521,11 @@ def forward(self, datacube):
516521 if platform == "modis" :
517522 reconstruction_loss /= 10
518523
519- # MRL
520- representations = self .mrl (encoded_unmasked_patches [:, 0 , :]) # [(B D') ...]
524+ # # MRL
525+ # representations = self.mrl(encoded_unmasked_patches[:, 0, :]) # [(B D') ...]
526+
527+ # PROJ
528+ representations = self .proj (encoded_unmasked_patches [:, 0 , :]) # [B D']
521529
522530 with torch .no_grad ():
523531 if platform == "sentinel-1-rtc" :
@@ -529,9 +537,12 @@ def forward(self, datacube):
529537 # Read RGB bands from the sensor to feed the teacher model
530538 indices = self .metadata [platform ].rgb_indices
531539 rgb = datacube ["pixels" ][:, indices , :, :]
540+ rgb = self .teacher_resize (rgb )
532541 target = self .teacher (rgb )
542+ # target = self.teacher(rgb)
533543
534- representation_loss = self .mrl_loss (representations , target )
544+ # representation_loss = self.mrl_loss(representations, target)
545+ representation_loss = 1.0 - F .cosine_similarity (representations , target ).mean ()
535546
536547 loss = 0.9 * reconstruction_loss + 0.1 * representation_loss
537548 return (loss , reconstruction_loss , representation_loss )
0 commit comments