Skip to content

Commit 05c167e

Browse files
authored
Check non MRL loss for the model (#331)
- Use DINO v2 with linear projection - Don't load from base checkpoint of MRL
1 parent c638972 commit 05c167e

9 files changed

Lines changed: 63 additions & 32 deletions

File tree

configs/config.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ model:
2020
patch_size: 8
2121
shuffle: True
2222
metadata_path: configs/metadata.yaml
23-
teacher: samvit_base_patch16.sa1b
23+
teacher: vit_large_patch14_reg4_dinov2.lvd142m
2424
dolls: [16, 32, 64, 128, 256, 768, 1024]
2525
doll_weights: [1, 1, 1, 1, 1, 1, 1]
2626
lr: 5e-6
@@ -32,7 +32,7 @@ trainer:
3232
accelerator: gpu
3333
strategy: ddp
3434
devices: 8
35-
num_nodes: 20
35+
num_nodes: 48
3636
precision: bf16-mixed
3737
log_every_n_steps: 1
3838
max_epochs: 1000
@@ -48,9 +48,9 @@ trainer:
4848
init_args:
4949
entity: developmentseed
5050
project: clay
51-
group: v1.5
52-
# id: v8jh2pn9
53-
# resume: must
51+
group: v1.5-nomrl-dinov2
52+
id: 0uy3in7l
53+
resume: must
5454
log_model: false
5555
callbacks:
5656
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
@@ -70,4 +70,4 @@ trainer:
7070
- class_path: src.callbacks_wandb.LogIntermediatePredictions
7171
plugins:
7272
- class_path: lightning.pytorch.plugins.io.AsyncCheckpointIO
73-
ckpt_path: null
73+
ckpt_path: checkpoints/v1.5.0/last.ckpt

src/callbacks_wandb.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,7 @@ def on_validation_end(
257257

258258
for j in range(n_cols):
259259
# Plot actual images in rows 0 and 2
260-
axs[0, j].imshow(
261-
batch["pixels"][j][0], cmap="viridis"
262-
)
260+
axs[0, j].imshow(batch["pixels"][j][0], cmap="viridis")
263261
axs[0, j].set_title(f"Actual {j}")
264262
axs[0, j].axis("off")
265263

@@ -271,15 +269,11 @@ def on_validation_end(
271269
axs[2, j].axis("off")
272270

273271
# Plot predicted images in rows 1 and 3
274-
axs[1, j].imshow(
275-
pixels[j][0], cmap="viridis"
276-
)
272+
axs[1, j].imshow(pixels[j][0], cmap="viridis")
277273
axs[1, j].set_title(f"Pred {j}")
278274
axs[1, j].axis("off")
279275

280-
axs[3, j].imshow(
281-
pixels[j + n_cols][0], cmap="viridis"
282-
)
276+
axs[3, j].imshow(pixels[j + n_cols][0], cmap="viridis")
283277
axs[3, j].set_title(f"Pred {j+n_cols}")
284278
axs[3, j].axis("off")
285279

src/datamodule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def setup(self, stage: Literal["fit", "predict"] | None = None) -> None:
253253
# chips_path = list(dp.list_files_by_s3(masks="*.npz"))
254254
# else: # if self.data_dir is a local data path
255255
chips_path = sorted(list(Path(self.data_dir).glob("**/*.npz")))
256-
chips_platform = [chip.parent.parent.name for chip in chips_path]
256+
chips_platform = [chip.parent.name for chip in chips_path]
257257
# chips_platform = [chip.parent.parent.name for chip in chips_path]
258258
print(f"Total number of chips: {len(chips_path)}")
259259

src/model.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
import torch.nn.functional as F
88
from einops import rearrange, reduce, repeat
99
from torch import nn
10+
from torchvision.transforms import v2
1011

1112
from src.backbone import Transformer
1213
from src.factory import DynamicEmbedding
13-
from src.mrl import MRL, MRLLoss
1414
from src.utils import posemb_sincos_2d_with_gsd
1515

1616
torch.set_float32_matmul_precision("medium")
@@ -386,8 +386,13 @@ def __init__( # noqa: PLR0913
386386
self.shuffle = shuffle
387387
self.metadata = metadata
388388
self.teacher = timm.create_model(teacher, pretrained=True, num_classes=0)
389-
self.mrl = MRL(features=self.teacher.num_features, dolls=dolls)
390-
self.mrl_loss = MRLLoss(weights=doll_weights)
389+
self.teacher_chip_size = 518
390+
self.teacher_resize = v2.Resize(
391+
size=(self.teacher_chip_size, self.teacher_chip_size)
392+
)
393+
# self.mrl = MRL(features=self.teacher.num_features, dolls=dolls)
394+
# self.mrl_loss = MRLLoss(weights=doll_weights)
395+
self.proj = nn.Linear(dim, self.teacher.num_features)
391396

392397
self.encoder = Encoder(
393398
mask_ratio=mask_ratio,
@@ -516,8 +521,11 @@ def forward(self, datacube):
516521
if platform == "modis":
517522
reconstruction_loss /= 10
518523

519-
# MRL
520-
representations = self.mrl(encoded_unmasked_patches[:, 0, :]) # [(B D') ...]
524+
# # MRL
525+
# representations = self.mrl(encoded_unmasked_patches[:, 0, :]) # [(B D') ...]
526+
527+
# PROJ
528+
representations = self.proj(encoded_unmasked_patches[:, 0, :]) # [B D']
521529

522530
with torch.no_grad():
523531
if platform == "sentinel-1-rtc":
@@ -529,9 +537,12 @@ def forward(self, datacube):
529537
# Read RGB bands from the sensor to feed the teacher model
530538
indices = self.metadata[platform].rgb_indices
531539
rgb = datacube["pixels"][:, indices, :, :]
540+
rgb = self.teacher_resize(rgb)
532541
target = self.teacher(rgb)
542+
# target = self.teacher(rgb)
533543

534-
representation_loss = self.mrl_loss(representations, target)
544+
# representation_loss = self.mrl_loss(representations, target)
545+
representation_loss = 1.0 - F.cosine_similarity(representations, target).mean()
535546

536547
loss = 0.9 * reconstruction_loss + 0.1 * representation_loss
537548
return (loss, reconstruction_loss, representation_loss)

src/module.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__( # noqa: PLR0913
2727
embeddings_level: Literal["mean", "patch", "group"] = "mean",
2828
):
2929
super().__init__()
30+
# self.strict_loading = False # Allow partial loading to check if MRL was the bug
3031
self.save_hyperparameters(logger=True)
3132
self.metadata = Box(yaml.safe_load(open(metadata_path)))
3233
model_map = {
@@ -47,6 +48,26 @@ def __init__( # noqa: PLR0913
4748
"doll_weights": doll_weights,
4849
}
4950
self.model = model_map[model_size](**model_args)
51+
# checkpoint_path = 'mae_v1.5.0_epoch-76_val-loss-0.1612.ckpt'
52+
# checkpoint = torch.load(checkpoint_path, map_location="cpu")
53+
# # Extract the state dictionary
54+
# state_dict = checkpoint['state_dict']
55+
56+
# # Modify the state dictionary
57+
# new_state_dict = OrderedDict()
58+
# for k, v in state_dict.items():
59+
# # Remove 'model.' prefix if it exists
60+
# if k.startswith('model.'):
61+
# k = k[len('model.'):]
62+
# # Exclude keys related to the 'teacher'
63+
# if not (k.startswith('teacher') or k.startswith('mrl')):
64+
# new_state_dict[k] = v
65+
# with torch.no_grad():
66+
# # Load the modified state dictionary into your model
67+
# missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False)
68+
# # Optionally, print missing and unexpected keys
69+
# print(f"Missing keys: {missing_keys}")
70+
# print(f"Unexpected keys: {unexpected_keys}")
5071
else:
5172
raise ValueError(
5273
f"Invalid model size {model_size}. Expected one of {model_map.keys()}"

src/mrl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@ class MRL(nn.Module):
99
def __init__(self, features, dolls: list = [16, 32, 64, 128, 256, 768]) -> None:
1010
super().__init__()
1111
self.dolls = dolls
12+
self.layers = nn.ModuleDict()
1213
for doll in dolls:
13-
setattr(self, f"mrl_{doll}", nn.Linear(doll, features))
14+
self.layers[f"mrl_{doll}"] = nn.Linear(doll, features)
1415

1516
def forward(self, x):
1617
"x: (batch, features)"
17-
logits = [getattr(self, f"mrl_{doll}")(x[:, :doll]) for doll in self.dolls]
18+
logits = [self.layers[f"mrl_{doll}"](x[:, :doll]) for doll in self.dolls]
1819
return logits
1920

2021

train_clay_v2.sh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
#!/bin/bash
22

33
#SBATCH --job-name=clay-laucher
4-
#SBATCH --nodes=20
4+
#SBATCH --nodes=24
55
#SBATCH --ntasks-per-node=8 # EDIT if it's not 8-gpus per node
66
#SBATCH --cpus-per-task=12 # EDIT this to how many cpu cores the node has divided by num of gpus
77
#SBATCH --gres=gpu:8 # EDIT this if it's not 8-gpus per node
88
#SBATCH --time=0-00:00:00 # EDIT the desired runtime
99
#SBATCH --exclusive
1010
#SBATCH --partition=gpu # EDIT to the desired partition name
11+
#SBATCH --nodelist=gpu-dy-g6-[1-12],gpu-dy-g5-[1-12]
1112
#SBATCH --output=%x-%j-%N.out
1213

1314
echo "START TIME: $(date)"
@@ -31,8 +32,11 @@ LOG_PATH="main_log.txt"
3132
# PTL doesn't need a special launcher
3233
LAUNCHER="python -u"
3334

35+
# Capture the number of nodes allocated by Slurm
36+
NUM_NODES=$SLURM_JOB_NUM_NODES
37+
3438
# EDIT the path+name of the python script and whatever args it needs
35-
PROGRAM="trainer.py fit --config configs/config.yaml"
39+
PROGRAM="trainer.py fit --config configs/config.yaml --trainer.num_nodes=$NUM_NODES"
3640

3741
export CMD="$LAUNCHER $PROGRAM"
3842

train_environment.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
name: claymodel
22
channels:
3-
- pytorch
43
- conda-forge
4+
- nvidia
5+
- pytorch
56
dependencies:
67
- python=3.11
78
- pip
89
- pip:
10+
- --extra-index-url https://download.pytorch.org/whl/cu121
11+
- torch==2.4.0+cu121
12+
- torchvision==0.19.0+cu121
913
- einops~=0.7.0
1014
- geopandas
1115
- jsonargparse[signatures]>=4.27.7
1216
- lightning
1317
- matplotlib
1418
- python-box
15-
- torch
1619
- scikit-image
1720
- scikit-learn
1821
- timm
19-
- torchvision
2022
- vit-pytorch
2123
- wandb

trainer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ def cli_main():
2222
Command-line inteface to run ClayMAE with ClayDataModule.
2323
"""
2424
cli = LightningCLI(
25-
ClayMAEModule,
26-
ClayDataModule,
27-
save_config_kwargs={"overwrite": True}
25+
ClayMAEModule, ClayDataModule, save_config_kwargs={"overwrite": True}
2826
)
2927
return cli
3028

0 commit comments

Comments
 (0)