99
1010import torch
1111import torch .nn .functional as F
12- from einops import rearrange , repeat
12+ from einops import rearrange
1313from torch import nn
1414
1515from claymodel .model import Encoder
1616from claymodel .utils import load_encoder_weights
1717
1818
19- class SegmentEncoder (Encoder ):
20- """
21- Encoder class for segmentation tasks, incorporating a feature pyramid
22- network (FPN).
23-
24- Attributes:
25- feature_maps (list): Indices of layers to be used for generating
26- feature maps.
27- ckpt_path (str): Path to the clay checkpoint file.
28- """
29-
30- def __init__ ( # noqa: PLR0913
31- self ,
32- mask_ratio : float ,
33- patch_size : int ,
34- shuffle : bool ,
35- dim : int ,
36- depth : int ,
37- heads : int ,
38- dim_head : int ,
39- mlp_ratio : float ,
40- ckpt_path : str | None = None ,
41- ) -> None :
42- super ().__init__ (
43- mask_ratio ,
44- patch_size ,
45- shuffle ,
46- dim ,
47- depth ,
48- heads ,
49- dim_head ,
50- mlp_ratio ,
51- )
52-
53- # Set device
54- self .device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
55- # Load model from checkpoint if provided
56- if ckpt_path :
57- load_encoder_weights (self , ckpt_path , device = str (self .device ))
58-
59- def forward (self , datacube : dict [str , torch .Tensor ]) -> torch .Tensor : # ty: ignore[invalid-method-override]
60- """
61- Forward pass of the SegmentEncoder.
62-
63- Args:
64- datacube (dict): A dictionary containing the input datacube and
65- meta information like time, latlon, gsd & wavelenths.
66-
67- Returns:
68- list: A list of feature maps extracted from the datacube.
69- """
70- cube , time , latlon , gsd , waves = (
71- datacube ["pixels" ], # [B C H W]
72- datacube ["time" ], # [B 2]
73- datacube ["latlon" ], # [B 2]
74- datacube ["gsd" ], # 1
75- datacube ["waves" ], # [N]
76- )
77-
78- B = cube .shape [0 ]
79-
80- # Patchify and create embeddings per patch
81- patches , _ = self .to_patch_embed (cube , waves ) # [B L D]
82- patches = self .add_encodings (patches , time , latlon , gsd ) # [B L D]
83-
84- # Add class tokens
85- cls_tokens = repeat (self .cls_token , "1 1 D -> B 1 D" , B = B ) # [B 1 D]
86- patches = torch .cat ((cls_tokens , patches ), dim = 1 ) # [B (1 + L) D]
87-
88- patches = self .transformer (patches )
89- return patches [:, 1 :, :] # [B L D]
90-
91-
9219class Segmentor (nn .Module ):
9320 """
9421 Clay Segmentor class that combines the Encoder with FPN layers for semantic
@@ -102,8 +29,8 @@ class Segmentor(nn.Module):
10229
10330 def __init__ (self , num_classes : int , ckpt_path : str | None ) -> None :
10431 super ().__init__ ()
105- # Default values are for the clay mae base model.
106- self .encoder = SegmentEncoder (
32+ # Default values are for the clay mae large model.
33+ self .encoder = Encoder (
10734 mask_ratio = 0.0 ,
10835 patch_size = 8 ,
10936 shuffle = False ,
@@ -112,9 +39,13 @@ def __init__(self, num_classes: int, ckpt_path: str | None) -> None:
11239 heads = 16 ,
11340 dim_head = 64 ,
11441 mlp_ratio = 4.0 ,
115- ckpt_path = ckpt_path ,
11642 )
11743
44+ # Set device and load pretrained weights
45+ device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
46+ if ckpt_path :
47+ load_encoder_weights (self .encoder , ckpt_path , device = str (device ))
48+
11849 # Freeze the encoder parameters
11950 for param in self .encoder .parameters ():
12051 param .requires_grad = False
@@ -147,8 +78,9 @@ def forward(self, datacube: dict[str, torch.Tensor]) -> torch.Tensor:
14778 cube = datacube ["pixels" ] # [B C H_in W_in]
14879 _ , _ , H_in , W_in = cube .shape
14980
150- # Get embeddings from the encoder
151- patches = self .encoder (datacube ) # [B, L, D]
81+ # Get embeddings from the encoder (strip CLS token)
82+ encoded , * _ = self .encoder (datacube )
83+ patches = encoded [:, 1 :, :] # [B, L, D]
15284
15385 # Reshape embeddings to [B, D, H', W']
15486 H_patches = H_in // self .encoder .patch_size
0 commit comments