11"""
2- Clay Segmentor for semantic segmentation tasks.
2+ Clay Regressor for semantic regression tasks using PixelShuffle .
33
44Attribution:
5- Decoder from Segformer: Simple and Efficient Design for Semantic Segmentation
6- with Transformers
7- Paper URL: https://arxiv.org/abs/2105.15203
5+ Decoder inspired by PixelShuffle-based upsampling.
86"""
97
108import re
1715from src .model import Encoder
1816
1917
20- class SegmentEncoder (Encoder ):
18+ class RegressionEncoder (Encoder ):
2119 """
22- Encoder class for segmentation tasks, incorporating a feature pyramid
23- network (FPN).
20+ Encoder class for regression tasks.
2421
2522 Attributes:
26- feature_maps (list): Indices of layers to be used for generating
27- feature maps.
2823 ckpt_path (str): Path to the clay checkpoint file.
2924 """
3025
31- def __init__ ( # noqa: PLR0913
26+ def __init__ (
3227 self ,
3328 mask_ratio ,
3429 patch_size ,
@@ -38,7 +33,6 @@ def __init__( # noqa: PLR0913
3833 heads ,
3934 dim_head ,
4035 mlp_ratio ,
41- feature_maps ,
4236 ckpt_path = None ,
4337 ):
4438 super ().__init__ (
@@ -51,30 +45,6 @@ def __init__( # noqa: PLR0913
5145 dim_head ,
5246 mlp_ratio ,
5347 )
54- self .feature_maps = feature_maps
55-
56- # Define Feature Pyramid Network (FPN) layers
57- self .fpn1 = nn .Sequential (
58- nn .ConvTranspose2d (dim , dim , kernel_size = 2 , stride = 2 ),
59- nn .BatchNorm2d (dim ),
60- nn .GELU (),
61- nn .ConvTranspose2d (dim , dim , kernel_size = 2 , stride = 2 ),
62- )
63-
64- self .fpn2 = nn .Sequential (
65- nn .ConvTranspose2d (dim , dim , kernel_size = 2 , stride = 2 ),
66- )
67-
68- self .fpn3 = nn .Identity ()
69-
70- self .fpn4 = nn .Sequential (
71- nn .MaxPool2d (kernel_size = 2 , stride = 2 ),
72- )
73-
74- self .fpn5 = nn .Sequential (
75- nn .MaxPool2d (kernel_size = 4 , stride = 4 ),
76- )
77-
7848 # Set device
7949 self .device = (
8050 torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
@@ -119,14 +89,14 @@ def load_from_ckpt(self, ckpt_path):
11989
12090 def forward (self , datacube ):
12191 """
122- Forward pass of the SegmentEncoder .
92+ Forward pass of the RegressionEncoder .
12393
12494 Args:
12595 datacube (dict): A dictionary containing the input datacube and
12696 meta information like time, latlon, gsd & wavelenths.
12797
12898 Returns:
129- list: A list of feature maps extracted from the datacube .
99+ torch.Tensor: The embeddings from the final layer .
130100 """
131101 cube , time , latlon , gsd , waves = (
132102 datacube ["pixels" ], # [B C H W]
@@ -146,84 +116,56 @@ def forward(self, datacube):
146116 cls_tokens = repeat (self .cls_token , "1 1 D -> B 1 D" , B = B ) # [B 1 D]
147117 patches = torch .cat ((cls_tokens , patches ), dim = 1 ) # [B (1 + L) D]
148118
149- features = []
150- for idx , (attn , ff ) in enumerate (self .transformer .layers ):
151- patches = attn (patches ) + patches
152- patches = ff (patches ) + patches
153- if idx in self .feature_maps :
154- _cube = rearrange (
155- patches [:, 1 :, :], "B (H W) D -> B D H W" , H = H // 8 , W = W // 8
156- )
157- features .append (_cube )
158- # patches = self.transformer.norm(patches)
159- # _cube = rearrange(patches[:, 1:, :], "B (H W) D -> B D H W", H=H//8, W=W//8)
160- # features.append(_cube)
161-
162- # Apply FPN layers
163- ops = [self .fpn1 , self .fpn2 , self .fpn3 , self .fpn4 , self .fpn5 ]
164- for i in range (len (features )):
165- features [i ] = ops [i ](features [i ])
166-
167- return features
168-
169-
170- class FusionBlock (nn .Module ):
171- def __init__ (self , input_dim , output_dim ):
172- super ().__init__ ()
173- self .conv = nn .Conv2d (input_dim , output_dim , kernel_size = 3 , padding = 1 )
174- self .bn = nn .BatchNorm2d (output_dim )
175-
176- def forward (self , x ):
177- x = F .relu (self .bn (self .conv (x )))
178- return x
179-
119+ # Transformer encoder
120+ patches = self .transformer (patches )
180121
181- class SegmentationHead (nn .Module ):
182- def __init__ (self , input_dim , num_classes ):
183- super ().__init__ ()
184- self .conv1 = nn .Conv2d (input_dim , input_dim // 2 , kernel_size = 3 , padding = 1 )
185- self .conv2 = nn .Conv2d (
186- input_dim // 2 , num_classes , kernel_size = 1
187- ) # final conv to num_classes
188- self .bn1 = nn .BatchNorm2d (input_dim // 2 )
122+ # Remove class token
123+ patches = patches [:, 1 :, :] # [B, L, D]
189124
190- def forward (self , x ):
191- x = F .relu (self .bn1 (self .conv1 (x )))
192- x = self .conv2 (x ) # No activation before final layer
193- return x
125+ return patches
194126
195127
196128class Regressor (nn .Module ):
197129 """
198- Clay Regressor class that combines the Encoder with FPN layers for semantic
199- regression.
130+ Clay Regressor class that combines the Encoder with PixelShuffle for regression.
200131
201132 Attributes:
202- num_classes (int): Number of output classes for segmentation.
203- feature_maps (list): Indices of layers to be used for generating feature maps.
133+ num_classes (int): Number of output classes for regression.
204134 ckpt_path (str): Path to the checkpoint file.
205135 """
206136
207- def __init__ (self , num_classes , feature_maps , ckpt_path ):
137+ def __init__ (self , num_classes , ckpt_path ):
208138 super ().__init__ ()
209- # Default values are for the clay mae base model.
210- self .encoder = SegmentEncoder (
139+ # Initialize the encoder
140+ self .encoder = RegressionEncoder (
211141 mask_ratio = 0.0 ,
212142 patch_size = 8 ,
213143 shuffle = False ,
214- dim = 768 ,
215- depth = 12 ,
216- heads = 12 ,
144+ dim = 1024 ,
145+ depth = 24 ,
146+ heads = 16 ,
217147 dim_head = 64 ,
218148 mlp_ratio = 4.0 ,
219- feature_maps = feature_maps ,
220149 ckpt_path = ckpt_path ,
221150 )
222- self .upsamples = [nn .Upsample (scale_factor = 2 ** i ) for i in range (5 )]
223- self .fusion = FusionBlock (self .encoder .dim , self .encoder .dim // 4 )
224- self .seg_head = nn .Conv2d (
225- self .encoder .dim // 4 , num_classes , kernel_size = 3 , padding = 1
226- )
151+
152+ # Freeze the encoder parameters
153+ for param in self .encoder .parameters ():
154+ param .requires_grad = False
155+
156+ # Define layers after the encoder
157+ D = self .encoder .dim # embedding dimension
158+ hidden_dim = 512
159+ C_out = 64
160+ r = self .encoder .patch_size # upscale factor (patch_size)
161+
162+ self .conv1 = nn .Conv2d (D , hidden_dim , kernel_size = 3 , padding = 1 )
163+ self .bn1 = nn .BatchNorm2d (hidden_dim )
164+ self .conv2 = nn .Conv2d (hidden_dim , hidden_dim , kernel_size = 3 , padding = 1 )
165+ self .bn2 = nn .BatchNorm2d (hidden_dim )
166+ self .conv_ps = nn .Conv2d (hidden_dim , C_out * r * r , kernel_size = 3 , padding = 1 )
167+ self .pixel_shuffle = nn .PixelShuffle (upscale_factor = r )
168+ self .conv_out = nn .Conv2d (C_out , num_classes , kernel_size = 3 , padding = 1 )
227169
228170 def forward (self , datacube ):
229171 """
@@ -234,15 +176,28 @@ def forward(self, datacube):
234176 meta information like time, latlon, gsd & wavelenths.
235177
236178 Returns:
237- torch.Tensor: The segmentation logits .
179+ torch.Tensor: The regression output .
238180 """
239- features = self .encoder (datacube )
240- for i in range (len (features )):
241- features [i ] = self .upsamples [i ](features [i ])
181+ cube = datacube ["pixels" ] # [B C H_in W_in]
182+ B , C , H_in , W_in = cube .shape
242183
243- # fused = torch.cat(features, dim=1)
244- fused = torch .sum (torch .stack (features ), dim = 0 )
245- fused = self .fusion (fused )
184+ # Get embeddings from the encoder
185+ patches = self .encoder (datacube ) # [B, L, D]
246186
247- logits = self .seg_head (fused )
248- return logits
187+ # Reshape embeddings to [B, D, H', W']
188+ H_patches = H_in // self .encoder .patch_size
189+ W_patches = W_in // self .encoder .patch_size
190+ x = rearrange (patches , "B (H W) D -> B D H W" , H = H_patches , W = W_patches )
191+
192+ # Pass through convolutional layers
193+ x = F .relu (self .bn1 (self .conv1 (x )))
194+ x = F .relu (self .bn2 (self .conv2 (x )))
195+ x = self .conv_ps (x ) # [B, C_out * r^2, H', W']
196+
197+ # Upsample using PixelShuffle
198+ x = self .pixel_shuffle (x ) # [B, C_out, H_in, W_in]
199+
200+ # Final convolution to get desired output channels
201+ x = self .conv_out (x ) # [B, num_outputs, H_in, W_in]
202+
203+ return x
0 commit comments