1010import re
1111
1212import torch
13+ import torch .nn .functional as F
1314from einops import rearrange , repeat
1415from torch import nn
1516
@@ -37,7 +38,6 @@ def __init__( # noqa: PLR0913
3738 heads ,
3839 dim_head ,
3940 mlp_ratio ,
40- feature_maps ,
4141 ckpt_path = None ,
4242 ):
4343 super ().__init__ (
@@ -50,27 +50,6 @@ def __init__( # noqa: PLR0913
5050 dim_head ,
5151 mlp_ratio ,
5252 )
53- self .feature_maps = feature_maps
54-
55- # Define Feature Pyramid Network (FPN) layers
56- self .fpn1 = nn .Sequential (
57- nn .ConvTranspose2d (dim , dim , kernel_size = 2 , stride = 2 ),
58- nn .BatchNorm2d (dim ),
59- nn .GELU (),
60- nn .ConvTranspose2d (dim , dim , kernel_size = 2 , stride = 2 ),
61- )
62-
63- self .fpn2 = nn .Sequential (
64- nn .ConvTranspose2d (dim , dim , kernel_size = 2 , stride = 2 ),
65- )
66-
67- self .fpn3 = nn .Identity ()
68-
69- self .fpn4 = nn .Sequential (
70- nn .MaxPool2d (kernel_size = 2 , stride = 2 ),
71- )
72-
73- self .fpn5 = nn .Identity ()
7453
7554 # Set device
7655 self .device = (
@@ -143,25 +122,10 @@ def forward(self, datacube):
143122 cls_tokens = repeat (self .cls_token , "1 1 D -> B 1 D" , B = B ) # [B 1 D]
144123 patches = torch .cat ((cls_tokens , patches ), dim = 1 ) # [B (1 + L) D]
145124
146- features = []
147- for idx , (attn , ff ) in enumerate (self .transformer .layers ):
148- patches = attn (patches ) + patches
149- patches = ff (patches ) + patches
150- if idx in self .feature_maps :
151- _cube = rearrange (
152- patches [:, 1 :, :], "B (H W) D -> B D H W" , H = H // 8 , W = W // 8
153- )
154- features .append (_cube )
155- patches = self .transformer .norm (patches )
156- _cube = rearrange (patches [:, 1 :, :], "B (H W) D -> B D H W" , H = H // 8 , W = W // 8 )
157- features .append (_cube )
158-
159- # Apply FPN layers
160- ops = [self .fpn1 , self .fpn2 , self .fpn3 , self .fpn4 , self .fpn5 ]
161- for i in range (len (features )):
162- features [i ] = ops [i ](features [i ])
125+ patches = self .transformer (patches )
126+ patches = patches [:, 1 :, :] # [B L D]
163127
164- return features
128+ return patches
165129
166130
167131class Segmentor (nn .Module ):
@@ -175,7 +139,7 @@ class Segmentor(nn.Module):
175139 ckpt_path (str): Path to the checkpoint file.
176140 """
177141
178- def __init__ (self , num_classes , feature_maps , ckpt_path ):
142+ def __init__ (self , num_classes , ckpt_path ):
179143 super ().__init__ ()
180144 # Default values are for the clay mae base model.
181145 self .encoder = SegmentEncoder (
@@ -187,14 +151,26 @@ def __init__(self, num_classes, feature_maps, ckpt_path):
187151 heads = 16 ,
188152 dim_head = 64 ,
189153 mlp_ratio = 4.0 ,
190- feature_maps = feature_maps ,
191154 ckpt_path = ckpt_path ,
192155 )
193- self .upsamples = [nn .Upsample (scale_factor = 2 ** i ) for i in range (4 )] + [
194- nn .Upsample (scale_factor = 4 ),
195- ]
196- self .fusion = nn .Conv2d (self .encoder .dim * 5 , self .encoder .dim , kernel_size = 1 )
197- self .seg_head = nn .Conv2d (self .encoder .dim , num_classes , kernel_size = 1 )
156+
157+ # Freeze the encoder parameters
158+ for param in self .encoder .parameters ():
159+ param .requires_grad = False
160+
161+ # Define layers after the encoder
162+ D = self .encoder .dim # embedding dimension
163+ hidden_dim = 512
164+ C_out = 64
165+ r = self .encoder .patch_size # upscale factor (patch_size)
166+
167+ self .conv1 = nn .Conv2d (D , hidden_dim , kernel_size = 3 , padding = 1 )
168+ self .bn1 = nn .BatchNorm2d (hidden_dim )
169+ self .conv2 = nn .Conv2d (hidden_dim , hidden_dim , kernel_size = 3 , padding = 1 )
170+ self .bn2 = nn .BatchNorm2d (hidden_dim )
171+ self .conv_ps = nn .Conv2d (hidden_dim , C_out * r * r , kernel_size = 3 , padding = 1 )
172+ self .pixel_shuffle = nn .PixelShuffle (upscale_factor = r )
173+ self .conv_out = nn .Conv2d (C_out , num_classes , kernel_size = 3 , padding = 1 )
198174
199175 def forward (self , datacube ):
200176 """
@@ -207,12 +183,26 @@ def forward(self, datacube):
207183 Returns:
208184 torch.Tensor: The segmentation logits.
209185 """
210- features = self .encoder (datacube )
211- for i in range (len (features )):
212- features [i ] = self .upsamples [i ](features [i ])
186+ cube = datacube ["pixels" ] # [B C H_in W_in]
187+ B , C , H_in , W_in = cube .shape
188+
189+ # Get embeddings from the encoder
190+ patches = self .encoder (datacube ) # [B, L, D]
191+
192+ # Reshape embeddings to [B, D, H', W']
193+ H_patches = H_in // self .encoder .patch_size
194+ W_patches = W_in // self .encoder .patch_size
195+ x = rearrange (patches , "B (H W) D -> B D H W" , H = H_patches , W = W_patches )
196+
197+ # Pass through convolutional layers
198+ x = F .relu (self .bn1 (self .conv1 (x )))
199+ x = F .relu (self .bn2 (self .conv2 (x )))
200+ x = self .conv_ps (x ) # [B, C_out * r^2, H', W']
201+
202+ # Upsample using PixelShuffle
203+ x = self .pixel_shuffle (x ) # [B, C_out, H_in, W_in]
213204
214- fused = torch . cat ( features , dim = 1 )
215- fused = self .fusion ( fused )
205+ # Final convolution to get desired output channels
206+ x = self .conv_out ( x ) # [B, num_outputs, H_in, W_in]
216207
217- logits = self .seg_head (fused )
218- return logits
208+ return x
0 commit comments