Skip to content

Commit e8d9ca6

Browse files
srmsoumyayellowcap
authored andcommitted
Simplify head of Regressor and update for Clay v1.5
1 parent d175a95 commit e8d9ca6

4 files changed

Lines changed: 110 additions & 142 deletions

File tree

configs/regression_biomasters.yaml

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,18 @@
22
seed_everything: 42
33
data:
44
metadata_path: configs/metadata.yaml
5-
batch_size: 10
5+
batch_size: 25
66
num_workers: 8
77
train_chip_dir: data/biomasters/train_cube
88
train_label_dir: data/biomasters/train_agbm
99
val_chip_dir: data/biomasters/test_cube
1010
val_label_dir: data/biomasters/test_agbm
1111
model:
12-
ckpt_path: checkpoints/clay-v1-base.ckpt
13-
lr: 1e-3
12+
ckpt_path: checkpoints/clay_v1.5.ckpt
13+
lr: 1e-2
1414
wd: 0.05
1515
b1: 0.9
1616
b2: 0.95
17-
feature_maps:
18-
- 2
19-
- 5
20-
- 7
21-
- 9
22-
- 11
2317
trainer:
2418
accelerator: auto
2519
strategy: ddp
@@ -33,13 +27,14 @@ trainer:
3327
num_sanity_val_steps: 0
3428
# limit_train_batches: 0.25
3529
# limit_val_batches: 0.25
36-
accumulate_grad_batches: 4
30+
accumulate_grad_batches: 1
3731
logger:
3832
- class_path: lightning.pytorch.loggers.WandbLogger
3933
init_args:
4034
entity: developmentseed
4135
project: clay-regression
4236
log_model: false
37+
group: v1.5-test
4338
callbacks:
4439
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
4540
init_args:
@@ -55,9 +50,9 @@ trainer:
5550
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
5651
init_args:
5752
logging_interval: step
58-
- class_path: src.callbacks.LayerwiseFinetuning
59-
init_args:
60-
phase: 10
61-
train_bn: True
53+
# - class_path: src.callbacks.LayerwiseFinetuning
54+
# init_args:
55+
# phase: 10
56+
# train_bn: True
6257
plugins:
6358
- class_path: lightning.pytorch.plugins.io.AsyncCheckpointIO

finetune/regression/biomasters_inference.ipynb

Lines changed: 38 additions & 18 deletions
Large diffs are not rendered by default.

finetune/regression/biomasters_model.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,11 @@ class BioMastersClassifier(L.LightningModule):
4242
b2 (float): Beta2 parameter for the Adam optimizer.
4343
"""
4444

45-
def __init__(self, ckpt_path, feature_maps, lr, wd, b1, b2): # noqa: PLR0913
45+
def __init__(self, ckpt_path, lr, wd, b1, b2): # noqa: PLR0913
4646
super().__init__()
4747
self.save_hyperparameters()
4848
# self.model = Classifier(num_classes=1, ckpt_path=ckpt_path)
49-
self.model = Regressor(
50-
num_classes=1, feature_maps=feature_maps, ckpt_path=ckpt_path
51-
)
49+
self.model = Regressor(num_classes=1, ckpt_path=ckpt_path)
5250
self.loss_fn = NoNaNRMSE()
5351
self.score_fn = MeanSquaredError()
5452

@@ -110,7 +108,7 @@ def configure_optimizers(self):
110108
weight_decay=self.hparams.wd,
111109
betas=(self.hparams.b1, self.hparams.b2),
112110
)
113-
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.5)
111+
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
114112
return {
115113
"optimizer": optimizer,
116114
"lr_scheduler": {

finetune/regression/factory.py

Lines changed: 60 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
"""
2-
Clay Segmentor for semantic segmentation tasks.
2+
Clay Regressor for semantic regression tasks using PixelShuffle.
33
44
Attribution:
5-
Decoder from Segformer: Simple and Efficient Design for Semantic Segmentation
6-
with Transformers
7-
Paper URL: https://arxiv.org/abs/2105.15203
5+
Decoder inspired by PixelShuffle-based upsampling.
86
"""
97

108
import re
@@ -17,18 +15,15 @@
1715
from src.model import Encoder
1816

1917

20-
class SegmentEncoder(Encoder):
18+
class RegressionEncoder(Encoder):
2119
"""
22-
Encoder class for segmentation tasks, incorporating a feature pyramid
23-
network (FPN).
20+
Encoder class for regression tasks.
2421
2522
Attributes:
26-
feature_maps (list): Indices of layers to be used for generating
27-
feature maps.
2823
ckpt_path (str): Path to the clay checkpoint file.
2924
"""
3025

31-
def __init__( # noqa: PLR0913
26+
def __init__(
3227
self,
3328
mask_ratio,
3429
patch_size,
@@ -38,7 +33,6 @@ def __init__( # noqa: PLR0913
3833
heads,
3934
dim_head,
4035
mlp_ratio,
41-
feature_maps,
4236
ckpt_path=None,
4337
):
4438
super().__init__(
@@ -51,30 +45,6 @@ def __init__( # noqa: PLR0913
5145
dim_head,
5246
mlp_ratio,
5347
)
54-
self.feature_maps = feature_maps
55-
56-
# Define Feature Pyramid Network (FPN) layers
57-
self.fpn1 = nn.Sequential(
58-
nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2),
59-
nn.BatchNorm2d(dim),
60-
nn.GELU(),
61-
nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2),
62-
)
63-
64-
self.fpn2 = nn.Sequential(
65-
nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2),
66-
)
67-
68-
self.fpn3 = nn.Identity()
69-
70-
self.fpn4 = nn.Sequential(
71-
nn.MaxPool2d(kernel_size=2, stride=2),
72-
)
73-
74-
self.fpn5 = nn.Sequential(
75-
nn.MaxPool2d(kernel_size=4, stride=4),
76-
)
77-
7848
# Set device
7949
self.device = (
8050
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
@@ -119,14 +89,14 @@ def load_from_ckpt(self, ckpt_path):
11989

12090
def forward(self, datacube):
12191
"""
122-
Forward pass of the SegmentEncoder.
92+
Forward pass of the RegressionEncoder.
12393
12494
Args:
12595
datacube (dict): A dictionary containing the input datacube and
12696
meta information like time, latlon, gsd & wavelenths.
12797
12898
Returns:
129-
list: A list of feature maps extracted from the datacube.
99+
torch.Tensor: The embeddings from the final layer.
130100
"""
131101
cube, time, latlon, gsd, waves = (
132102
datacube["pixels"], # [B C H W]
@@ -146,84 +116,56 @@ def forward(self, datacube):
146116
cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D]
147117
patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D]
148118

149-
features = []
150-
for idx, (attn, ff) in enumerate(self.transformer.layers):
151-
patches = attn(patches) + patches
152-
patches = ff(patches) + patches
153-
if idx in self.feature_maps:
154-
_cube = rearrange(
155-
patches[:, 1:, :], "B (H W) D -> B D H W", H=H // 8, W=W // 8
156-
)
157-
features.append(_cube)
158-
# patches = self.transformer.norm(patches)
159-
# _cube = rearrange(patches[:, 1:, :], "B (H W) D -> B D H W", H=H//8, W=W//8)
160-
# features.append(_cube)
161-
162-
# Apply FPN layers
163-
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4, self.fpn5]
164-
for i in range(len(features)):
165-
features[i] = ops[i](features[i])
166-
167-
return features
168-
169-
170-
class FusionBlock(nn.Module):
171-
def __init__(self, input_dim, output_dim):
172-
super().__init__()
173-
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1)
174-
self.bn = nn.BatchNorm2d(output_dim)
175-
176-
def forward(self, x):
177-
x = F.relu(self.bn(self.conv(x)))
178-
return x
179-
119+
# Transformer encoder
120+
patches = self.transformer(patches)
180121

181-
class SegmentationHead(nn.Module):
182-
def __init__(self, input_dim, num_classes):
183-
super().__init__()
184-
self.conv1 = nn.Conv2d(input_dim, input_dim // 2, kernel_size=3, padding=1)
185-
self.conv2 = nn.Conv2d(
186-
input_dim // 2, num_classes, kernel_size=1
187-
) # final conv to num_classes
188-
self.bn1 = nn.BatchNorm2d(input_dim // 2)
122+
# Remove class token
123+
patches = patches[:, 1:, :] # [B, L, D]
189124

190-
def forward(self, x):
191-
x = F.relu(self.bn1(self.conv1(x)))
192-
x = self.conv2(x) # No activation before final layer
193-
return x
125+
return patches
194126

195127

196128
class Regressor(nn.Module):
197129
"""
198-
Clay Regressor class that combines the Encoder with FPN layers for semantic
199-
regression.
130+
Clay Regressor class that combines the Encoder with PixelShuffle for regression.
200131
201132
Attributes:
202-
num_classes (int): Number of output classes for segmentation.
203-
feature_maps (list): Indices of layers to be used for generating feature maps.
133+
num_classes (int): Number of output classes for regression.
204134
ckpt_path (str): Path to the checkpoint file.
205135
"""
206136

207-
def __init__(self, num_classes, feature_maps, ckpt_path):
137+
def __init__(self, num_classes, ckpt_path):
208138
super().__init__()
209-
# Default values are for the clay mae base model.
210-
self.encoder = SegmentEncoder(
139+
# Initialize the encoder
140+
self.encoder = RegressionEncoder(
211141
mask_ratio=0.0,
212142
patch_size=8,
213143
shuffle=False,
214-
dim=768,
215-
depth=12,
216-
heads=12,
144+
dim=1024,
145+
depth=24,
146+
heads=16,
217147
dim_head=64,
218148
mlp_ratio=4.0,
219-
feature_maps=feature_maps,
220149
ckpt_path=ckpt_path,
221150
)
222-
self.upsamples = [nn.Upsample(scale_factor=2**i) for i in range(5)]
223-
self.fusion = FusionBlock(self.encoder.dim, self.encoder.dim // 4)
224-
self.seg_head = nn.Conv2d(
225-
self.encoder.dim // 4, num_classes, kernel_size=3, padding=1
226-
)
151+
152+
# Freeze the encoder parameters
153+
for param in self.encoder.parameters():
154+
param.requires_grad = False
155+
156+
# Define layers after the encoder
157+
D = self.encoder.dim # embedding dimension
158+
hidden_dim = 512
159+
C_out = 64
160+
r = self.encoder.patch_size # upscale factor (patch_size)
161+
162+
self.conv1 = nn.Conv2d(D, hidden_dim, kernel_size=3, padding=1)
163+
self.bn1 = nn.BatchNorm2d(hidden_dim)
164+
self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
165+
self.bn2 = nn.BatchNorm2d(hidden_dim)
166+
self.conv_ps = nn.Conv2d(hidden_dim, C_out * r * r, kernel_size=3, padding=1)
167+
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=r)
168+
self.conv_out = nn.Conv2d(C_out, num_classes, kernel_size=3, padding=1)
227169

228170
def forward(self, datacube):
229171
"""
@@ -234,15 +176,28 @@ def forward(self, datacube):
234176
meta information like time, latlon, gsd & wavelenths.
235177
236178
Returns:
237-
torch.Tensor: The segmentation logits.
179+
torch.Tensor: The regression output.
238180
"""
239-
features = self.encoder(datacube)
240-
for i in range(len(features)):
241-
features[i] = self.upsamples[i](features[i])
181+
cube = datacube["pixels"] # [B C H_in W_in]
182+
B, C, H_in, W_in = cube.shape
242183

243-
# fused = torch.cat(features, dim=1)
244-
fused = torch.sum(torch.stack(features), dim=0)
245-
fused = self.fusion(fused)
184+
# Get embeddings from the encoder
185+
patches = self.encoder(datacube) # [B, L, D]
246186

247-
logits = self.seg_head(fused)
248-
return logits
187+
# Reshape embeddings to [B, D, H', W']
188+
H_patches = H_in // self.encoder.patch_size
189+
W_patches = W_in // self.encoder.patch_size
190+
x = rearrange(patches, "B (H W) D -> B D H W", H=H_patches, W=W_patches)
191+
192+
# Pass through convolutional layers
193+
x = F.relu(self.bn1(self.conv1(x)))
194+
x = F.relu(self.bn2(self.conv2(x)))
195+
x = self.conv_ps(x) # [B, C_out * r^2, H', W']
196+
197+
# Upsample using PixelShuffle
198+
x = self.pixel_shuffle(x) # [B, C_out, H_in, W_in]
199+
200+
# Final convolution to get desired output channels
201+
x = self.conv_out(x) # [B, num_outputs, H_in, W_in]
202+
203+
return x

0 commit comments

Comments
 (0)