1010import torch
1111
1212from claymodel .metadata import PlatformMetadata , load_metadata_yaml
13- from claymodel .module import ClayMAEModule
13+ from claymodel .model import Encoder
14+ from claymodel .utils import load_encoder_weights
15+
16+ _ENCODER_CONFIGS : dict [str , dict [str , int ]] = {
17+ "tiny" : {"dim" : 192 , "depth" : 6 , "heads" : 4 , "dim_head" : 48 , "mlp_ratio" : 2 },
18+ "small" : {"dim" : 384 , "depth" : 6 , "heads" : 6 , "dim_head" : 64 , "mlp_ratio" : 2 },
19+ "base" : {"dim" : 768 , "depth" : 12 , "heads" : 12 , "dim_head" : 64 , "mlp_ratio" : 4 },
20+ "large" : {"dim" : 1024 , "depth" : 24 , "heads" : 16 , "dim_head" : 64 , "mlp_ratio" : 4 },
21+ }
1422
1523
1624def load_metadata (
@@ -48,57 +56,40 @@ def load_model(
4856 size : str = "large" ,
4957 ckpt_path : str | None = None ,
5058 device : str = "cpu" ,
51- metadata_path : str | Path | None = None ,
52- ) -> ClayMAEModule :
53- """Load a Clay MAE model ready for inference.
54-
55- Creates a ClayMAEModule and optionally loads weights from a checkpoint.
56- The model is returned in eval mode with mask_ratio=0 and shuffle=False
57- for deterministic inference.
59+ ) -> Encoder :
60+ """Load a Clay encoder ready for inference.
5861
59- Note: The model includes a DINOv2 teacher (~300MB) that is downloaded
60- on first use. The teacher is frozen and not needed for embedding
61- extraction, but is part of the architecture .
62+ Creates an Encoder and optionally loads weights from a checkpoint.
63+ The encoder is returned in eval mode with mask_ratio=0 and shuffle=False
64+ for deterministic inference. No teacher model is downloaded .
6265
6366 Args:
6467 size: Model size - "tiny", "small", "base", or "large".
65- ckpt_path: Path to checkpoint file. If None, creates model with
68+ ckpt_path: Path to checkpoint file. If None, creates encoder with
6669 random weights (useful for testing).
6770 device: Device to load model onto ("cpu", "cuda", etc.).
68- metadata_path: Path to a custom metadata YAML file. If None,
69- uses the bundled metadata with common public sensors.
7071
7172 Returns:
72- ClayMAEModule instance in eval mode.
73+ Encoder instance in eval mode.
7374
7475 Example:
75- >>> model = load_model("large", ckpt_path="clay-v1.5.ckpt")
76- >>> model = load_model("large", metadata_path="my_sensors.yaml")
76+ >>> encoder = load_model("large", ckpt_path="clay-v1.5.ckpt")
7777 """
78- resolved_path = (
79- str (metadata_path )
80- if metadata_path
81- else str (files ("claymodel" ).joinpath ("configs/metadata.yaml" ))
78+ if size not in _ENCODER_CONFIGS :
79+ raise ValueError (f"Invalid size { size !r} . Expected one of { list (_ENCODER_CONFIGS .keys ())} " )
80+
81+ encoder = Encoder (
82+ mask_ratio = 0.0 ,
83+ patch_size = 8 ,
84+ shuffle = False ,
85+ ** _ENCODER_CONFIGS [size ],
8286 )
8387
8488 if ckpt_path is not None :
85- model = ClayMAEModule .load_from_checkpoint (
86- ckpt_path ,
87- metadata_path = resolved_path ,
88- map_location = device ,
89- )
90- else :
91- model = ClayMAEModule (
92- model_size = size ,
93- mask_ratio = 0.0 ,
94- shuffle = False ,
95- metadata_path = resolved_path ,
96- )
89+ load_encoder_weights (encoder , ckpt_path , device = device , freeze = False )
9790
98- model .model .encoder .mask_ratio = 0.0
99- model .model .encoder .shuffle = False
100- model .eval ()
101- return model .to (device )
91+ encoder .eval ()
92+ return encoder .to (device )
10293
10394
10495@dataclass
@@ -118,7 +109,7 @@ def shape(self) -> torch.Size:
118109def embed ( # noqa: PLR0913
119110 input_data : torch .Tensor | np .ndarray ,
120111 sensor : str ,
121- model : ClayMAEModule | None = None ,
112+ model : Encoder | None = None ,
122113 ckpt_path : str | None = None ,
123114 device : str = "cpu" ,
124115 time : torch .Tensor | None = None ,
@@ -188,7 +179,7 @@ def embed( # noqa: PLR0913
188179 model = load_model (ckpt_path = ckpt_path , device = device )
189180
190181 with torch .no_grad ():
191- encoded , * _ = model . encoder (datacube )
182+ encoded , * _ = model (datacube )
192183 cls_embeddings = encoded [:, 0 , :]
193184
194185 return EmbeddingResult (
0 commit comments