Skip to content

Commit 50074dd

Browse files
committed
Simplify the head for segmentation
1 parent 05c167e commit 50074dd

3 files changed

Lines changed: 47 additions & 64 deletions

File tree

configs/segment_chesapeake.yaml

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,7 @@ data:
1111
platform: naip
1212
model:
1313
num_classes: 7
14-
feature_maps:
15-
- 5
16-
- 11
17-
- 15
18-
- 23
19-
ckpt_path: checkpoints/v1.5.0-no-mrl-dinov2/mae_v1.5.0_epoch-05_val-loss-0.1734.ckpt
14+
ckpt_path: checkpoints/clay_v1.5.ckpt
2015
lr: 1e-5
2116
wd: 0.05
2217
b1: 0.9
@@ -28,7 +23,7 @@ trainer:
2823
num_nodes: 1
2924
precision: bf16-mixed
3025
log_every_n_steps: 5
31-
max_epochs: 10
26+
max_epochs: 100
3227
accumulate_grad_batches: 1
3328
default_root_dir: checkpoints/segment
3429
fast_dev_run: False

finetune/segment/chesapeake_model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ class ChesapeakeSegmentor(L.LightningModule):
2828
def __init__( # # noqa: PLR0913
2929
self,
3030
num_classes,
31-
feature_maps,
3231
ckpt_path,
3332
lr,
3433
wd,
@@ -39,7 +38,6 @@ def __init__( # # noqa: PLR0913
3938
self.save_hyperparameters() # Save hyperparameters for checkpointing
4039
self.model = Segmentor(
4140
num_classes=num_classes,
42-
feature_maps=feature_maps,
4341
ckpt_path=ckpt_path,
4442
)
4543

@@ -101,7 +99,7 @@ def configure_optimizers(self):
10199
optimizer,
102100
T_0=100,
103101
T_mult=1,
104-
eta_min=self.hparams.lr * 10,
102+
eta_min=self.hparams.lr * 100,
105103
last_epoch=-1,
106104
)
107105
return {

finetune/segment/factory.py

Lines changed: 44 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import re
1111

1212
import torch
13+
import torch.nn.functional as F
1314
from einops import rearrange, repeat
1415
from torch import nn
1516

@@ -37,7 +38,6 @@ def __init__( # noqa: PLR0913
3738
heads,
3839
dim_head,
3940
mlp_ratio,
40-
feature_maps,
4141
ckpt_path=None,
4242
):
4343
super().__init__(
@@ -50,27 +50,6 @@ def __init__( # noqa: PLR0913
5050
dim_head,
5151
mlp_ratio,
5252
)
53-
self.feature_maps = feature_maps
54-
55-
# Define Feature Pyramid Network (FPN) layers
56-
self.fpn1 = nn.Sequential(
57-
nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2),
58-
nn.BatchNorm2d(dim),
59-
nn.GELU(),
60-
nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2),
61-
)
62-
63-
self.fpn2 = nn.Sequential(
64-
nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2),
65-
)
66-
67-
self.fpn3 = nn.Identity()
68-
69-
self.fpn4 = nn.Sequential(
70-
nn.MaxPool2d(kernel_size=2, stride=2),
71-
)
72-
73-
self.fpn5 = nn.Identity()
7453

7554
# Set device
7655
self.device = (
@@ -143,25 +122,10 @@ def forward(self, datacube):
143122
cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D]
144123
patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D]
145124

146-
features = []
147-
for idx, (attn, ff) in enumerate(self.transformer.layers):
148-
patches = attn(patches) + patches
149-
patches = ff(patches) + patches
150-
if idx in self.feature_maps:
151-
_cube = rearrange(
152-
patches[:, 1:, :], "B (H W) D -> B D H W", H=H // 8, W=W // 8
153-
)
154-
features.append(_cube)
155-
patches = self.transformer.norm(patches)
156-
_cube = rearrange(patches[:, 1:, :], "B (H W) D -> B D H W", H=H // 8, W=W // 8)
157-
features.append(_cube)
158-
159-
# Apply FPN layers
160-
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4, self.fpn5]
161-
for i in range(len(features)):
162-
features[i] = ops[i](features[i])
125+
patches = self.transformer(patches)
126+
patches = patches[:, 1:, :] # [B L D]
163127

164-
return features
128+
return patches
165129

166130

167131
class Segmentor(nn.Module):
@@ -175,7 +139,7 @@ class Segmentor(nn.Module):
175139
ckpt_path (str): Path to the checkpoint file.
176140
"""
177141

178-
def __init__(self, num_classes, feature_maps, ckpt_path):
142+
def __init__(self, num_classes, ckpt_path):
179143
super().__init__()
180144
# Default values are for the clay mae base model.
181145
self.encoder = SegmentEncoder(
@@ -187,14 +151,26 @@ def __init__(self, num_classes, feature_maps, ckpt_path):
187151
heads=16,
188152
dim_head=64,
189153
mlp_ratio=4.0,
190-
feature_maps=feature_maps,
191154
ckpt_path=ckpt_path,
192155
)
193-
self.upsamples = [nn.Upsample(scale_factor=2**i) for i in range(4)] + [
194-
nn.Upsample(scale_factor=4),
195-
]
196-
self.fusion = nn.Conv2d(self.encoder.dim * 5, self.encoder.dim, kernel_size=1)
197-
self.seg_head = nn.Conv2d(self.encoder.dim, num_classes, kernel_size=1)
156+
157+
# Freeze the encoder parameters
158+
for param in self.encoder.parameters():
159+
param.requires_grad = False
160+
161+
# Define layers after the encoder
162+
D = self.encoder.dim # embedding dimension
163+
hidden_dim = 512
164+
C_out = 64
165+
r = self.encoder.patch_size # upscale factor (patch_size)
166+
167+
self.conv1 = nn.Conv2d(D, hidden_dim, kernel_size=3, padding=1)
168+
self.bn1 = nn.BatchNorm2d(hidden_dim)
169+
self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
170+
self.bn2 = nn.BatchNorm2d(hidden_dim)
171+
self.conv_ps = nn.Conv2d(hidden_dim, C_out * r * r, kernel_size=3, padding=1)
172+
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=r)
173+
self.conv_out = nn.Conv2d(C_out, num_classes, kernel_size=3, padding=1)
198174

199175
def forward(self, datacube):
200176
"""
@@ -207,12 +183,26 @@ def forward(self, datacube):
207183
Returns:
208184
torch.Tensor: The segmentation logits.
209185
"""
210-
features = self.encoder(datacube)
211-
for i in range(len(features)):
212-
features[i] = self.upsamples[i](features[i])
186+
cube = datacube["pixels"] # [B C H_in W_in]
187+
B, C, H_in, W_in = cube.shape
188+
189+
# Get embeddings from the encoder
190+
patches = self.encoder(datacube) # [B, L, D]
191+
192+
# Reshape embeddings to [B, D, H', W']
193+
H_patches = H_in // self.encoder.patch_size
194+
W_patches = W_in // self.encoder.patch_size
195+
x = rearrange(patches, "B (H W) D -> B D H W", H=H_patches, W=W_patches)
196+
197+
# Pass through convolutional layers
198+
x = F.relu(self.bn1(self.conv1(x)))
199+
x = F.relu(self.bn2(self.conv2(x)))
200+
x = self.conv_ps(x) # [B, C_out * r^2, H', W']
201+
202+
# Upsample using PixelShuffle
203+
x = self.pixel_shuffle(x) # [B, C_out, H_in, W_in]
213204

214-
fused = torch.cat(features, dim=1)
215-
fused = self.fusion(fused)
205+
# Final convolution to get desired output channels
206+
x = self.conv_out(x) # [B, num_outputs, H_in, W_in]
216207

217-
logits = self.seg_head(fused)
218-
return logits
208+
return x

0 commit comments

Comments
 (0)