Skip to content

Commit 4b6bb0e

Browse files
remove RegressionEncoder and SegmentEncoder, these were legitimately different in clayv1 with segment encoder including a feature pyramid network, intermediate feature extraction, and multi-scale output. This functionality was removed in v1.5, leaving these two classes nearly identical
1 parent b773a71 commit 4b6bb0e

2 files changed

Lines changed: 21 additions & 156 deletions

File tree

finetune/regression/factory.py

Lines changed: 10 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -7,85 +7,13 @@
77

88
import torch
99
import torch.nn.functional as F
10-
from einops import rearrange, repeat
10+
from einops import rearrange
1111
from torch import nn
1212

1313
from claymodel.model import Encoder
1414
from claymodel.utils import load_encoder_weights
1515

1616

17-
class RegressionEncoder(Encoder):
18-
"""
19-
Encoder class for regression tasks.
20-
21-
Attributes:
22-
ckpt_path (str): Path to the clay checkpoint file.
23-
"""
24-
25-
def __init__( # noqa: PLR0913
26-
self,
27-
mask_ratio: float,
28-
patch_size: int,
29-
shuffle: bool,
30-
dim: int,
31-
depth: int,
32-
heads: int,
33-
dim_head: int,
34-
mlp_ratio: float,
35-
ckpt_path: str | None = None,
36-
) -> None:
37-
super().__init__(
38-
mask_ratio,
39-
patch_size,
40-
shuffle,
41-
dim,
42-
depth,
43-
heads,
44-
dim_head,
45-
mlp_ratio,
46-
)
47-
# Set device
48-
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
49-
# Load model from checkpoint if provided
50-
if ckpt_path:
51-
load_encoder_weights(self, ckpt_path, device=str(self.device))
52-
53-
def forward(self, datacube: dict[str, torch.Tensor]) -> torch.Tensor: # ty: ignore[invalid-method-override]
54-
"""
55-
Forward pass of the RegressionEncoder.
56-
57-
Args:
58-
datacube (dict): A dictionary containing the input datacube and
59-
meta information like time, latlon, gsd & wavelenths.
60-
61-
Returns:
62-
torch.Tensor: The embeddings from the final layer.
63-
"""
64-
cube, time, latlon, gsd, waves = (
65-
datacube["pixels"], # [B C H W]
66-
datacube["time"], # [B 2]
67-
datacube["latlon"], # [B 2]
68-
datacube["gsd"], # 1
69-
datacube["waves"], # [N]
70-
)
71-
72-
B = cube.shape[0]
73-
74-
# Patchify and create embeddings per patch
75-
patches, _ = self.to_patch_embed(cube, waves) # [B L D]
76-
patches = self.add_encodings(patches, time, latlon, gsd) # [B L D]
77-
78-
# Add class tokens
79-
cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D]
80-
patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D]
81-
82-
# Transformer encoder
83-
patches = self.transformer(patches)
84-
85-
# Remove class token
86-
return patches[:, 1:, :] # [B, L, D]
87-
88-
8917
class Regressor(nn.Module):
9018
"""
9119
Clay Regressor class that combines the Encoder with PixelShuffle for regression.
@@ -98,7 +26,7 @@ class Regressor(nn.Module):
9826
def __init__(self, num_classes: int, ckpt_path: str | None) -> None:
9927
super().__init__()
10028
# Initialize the encoder
101-
self.encoder = RegressionEncoder(
29+
self.encoder = Encoder(
10230
mask_ratio=0.0,
10331
patch_size=8,
10432
shuffle=False,
@@ -107,9 +35,13 @@ def __init__(self, num_classes: int, ckpt_path: str | None) -> None:
10735
heads=16,
10836
dim_head=64,
10937
mlp_ratio=4.0,
110-
ckpt_path=ckpt_path,
11138
)
11239

40+
# Set device and load pretrained weights
41+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
42+
if ckpt_path:
43+
load_encoder_weights(self.encoder, ckpt_path, device=str(device))
44+
11345
# Freeze the encoder parameters
11446
for param in self.encoder.parameters():
11547
param.requires_grad = False
@@ -142,8 +74,9 @@ def forward(self, datacube: dict[str, torch.Tensor]) -> torch.Tensor:
14274
cube = datacube["pixels"] # [B C H_in W_in]
14375
_, _, H_in, W_in = cube.shape
14476

145-
# Get embeddings from the encoder
146-
patches = self.encoder(datacube) # [B, L, D]
77+
# Get embeddings from the encoder (strip CLS token)
78+
encoded, *_ = self.encoder(datacube)
79+
patches = encoded[:, 1:, :] # [B, L, D]
14780

14881
# Reshape embeddings to [B, D, H', W']
14982
H_patches = H_in // self.encoder.patch_size

finetune/segment/factory.py

Lines changed: 11 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -9,86 +9,13 @@
99

1010
import torch
1111
import torch.nn.functional as F
12-
from einops import rearrange, repeat
12+
from einops import rearrange
1313
from torch import nn
1414

1515
from claymodel.model import Encoder
1616
from claymodel.utils import load_encoder_weights
1717

1818

19-
class SegmentEncoder(Encoder):
20-
"""
21-
Encoder class for segmentation tasks, incorporating a feature pyramid
22-
network (FPN).
23-
24-
Attributes:
25-
feature_maps (list): Indices of layers to be used for generating
26-
feature maps.
27-
ckpt_path (str): Path to the clay checkpoint file.
28-
"""
29-
30-
def __init__( # noqa: PLR0913
31-
self,
32-
mask_ratio: float,
33-
patch_size: int,
34-
shuffle: bool,
35-
dim: int,
36-
depth: int,
37-
heads: int,
38-
dim_head: int,
39-
mlp_ratio: float,
40-
ckpt_path: str | None = None,
41-
) -> None:
42-
super().__init__(
43-
mask_ratio,
44-
patch_size,
45-
shuffle,
46-
dim,
47-
depth,
48-
heads,
49-
dim_head,
50-
mlp_ratio,
51-
)
52-
53-
# Set device
54-
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
55-
# Load model from checkpoint if provided
56-
if ckpt_path:
57-
load_encoder_weights(self, ckpt_path, device=str(self.device))
58-
59-
def forward(self, datacube: dict[str, torch.Tensor]) -> torch.Tensor: # ty: ignore[invalid-method-override]
60-
"""
61-
Forward pass of the SegmentEncoder.
62-
63-
Args:
64-
datacube (dict): A dictionary containing the input datacube and
65-
meta information like time, latlon, gsd & wavelenths.
66-
67-
Returns:
68-
list: A list of feature maps extracted from the datacube.
69-
"""
70-
cube, time, latlon, gsd, waves = (
71-
datacube["pixels"], # [B C H W]
72-
datacube["time"], # [B 2]
73-
datacube["latlon"], # [B 2]
74-
datacube["gsd"], # 1
75-
datacube["waves"], # [N]
76-
)
77-
78-
B = cube.shape[0]
79-
80-
# Patchify and create embeddings per patch
81-
patches, _ = self.to_patch_embed(cube, waves) # [B L D]
82-
patches = self.add_encodings(patches, time, latlon, gsd) # [B L D]
83-
84-
# Add class tokens
85-
cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D]
86-
patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D]
87-
88-
patches = self.transformer(patches)
89-
return patches[:, 1:, :] # [B L D]
90-
91-
9219
class Segmentor(nn.Module):
9320
"""
9421
Clay Segmentor class that combines the Encoder with FPN layers for semantic
@@ -102,8 +29,8 @@ class Segmentor(nn.Module):
10229

10330
def __init__(self, num_classes: int, ckpt_path: str | None) -> None:
10431
super().__init__()
105-
# Default values are for the clay mae base model.
106-
self.encoder = SegmentEncoder(
32+
# Default values are for the clay mae large model.
33+
self.encoder = Encoder(
10734
mask_ratio=0.0,
10835
patch_size=8,
10936
shuffle=False,
@@ -112,9 +39,13 @@ def __init__(self, num_classes: int, ckpt_path: str | None) -> None:
11239
heads=16,
11340
dim_head=64,
11441
mlp_ratio=4.0,
115-
ckpt_path=ckpt_path,
11642
)
11743

44+
# Set device and load pretrained weights
45+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
46+
if ckpt_path:
47+
load_encoder_weights(self.encoder, ckpt_path, device=str(device))
48+
11849
# Freeze the encoder parameters
11950
for param in self.encoder.parameters():
12051
param.requires_grad = False
@@ -147,8 +78,9 @@ def forward(self, datacube: dict[str, torch.Tensor]) -> torch.Tensor:
14778
cube = datacube["pixels"] # [B C H_in W_in]
14879
_, _, H_in, W_in = cube.shape
14980

150-
# Get embeddings from the encoder
151-
patches = self.encoder(datacube) # [B, L, D]
81+
# Get embeddings from the encoder (strip CLS token)
82+
encoded, *_ = self.encoder(datacube)
83+
patches = encoded[:, 1:, :] # [B, L, D]
15284

15385
# Reshape embeddings to [B, D, H', W']
15486
H_patches = H_in // self.encoder.patch_size

0 commit comments

Comments
 (0)