Skip to content

Commit fa67083

Browse files
BREAKING move teacher to train to avoid downloading teacher on inference
1 parent 798d5b6 commit fa67083

16 files changed

Lines changed: 220 additions & 399 deletions

README.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,9 @@ The easiest way to install Clay Foundation Model is via `uv`:
3030
This will install the `claymodel` package and all its dependencies. You can then import and use it in your Python code:
3131

3232
```python
33-
from claymodel import ClayMAEModule
33+
from claymodel import load_model, embed
3434
```
3535

36-
If you want the `clay` CLI, install the `cli` extra:
37-
38-
uv pip install "claymodel[cli]"
39-
4036
### Development Installation
4137

4238
For development or advanced usage, clone the repository and install with dev extras:

claymodel/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44

55
from claymodel.api import EmbeddingResult, embed, load_metadata, load_model, normalize
66
from claymodel.metadata import PlatformMetadata
7-
from claymodel.model import clay_mae_base, clay_mae_large, clay_mae_small, clay_mae_tiny
8-
from claymodel.module import ClayMAEModule
7+
from claymodel.model import Encoder, clay_mae_base, clay_mae_large, clay_mae_small, clay_mae_tiny
98

109
__version__: str = version("claymodel")
1110

1211

1312
__all__ = [
14-
"ClayMAEModule",
1513
"EmbeddingResult",
14+
"Encoder",
1615
"PlatformMetadata",
1716
"clay_mae_base",
1817
"clay_mae_large",

claymodel/api.py

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@
1010
import torch
1111

1212
from claymodel.metadata import PlatformMetadata, load_metadata_yaml
13-
from claymodel.module import ClayMAEModule
13+
from claymodel.model import Encoder
14+
from claymodel.utils import load_encoder_weights
15+
16+
_ENCODER_CONFIGS: dict[str, dict[str, int]] = {
17+
"tiny": {"dim": 192, "depth": 6, "heads": 4, "dim_head": 48, "mlp_ratio": 2},
18+
"small": {"dim": 384, "depth": 6, "heads": 6, "dim_head": 64, "mlp_ratio": 2},
19+
"base": {"dim": 768, "depth": 12, "heads": 12, "dim_head": 64, "mlp_ratio": 4},
20+
"large": {"dim": 1024, "depth": 24, "heads": 16, "dim_head": 64, "mlp_ratio": 4},
21+
}
1422

1523

1624
def load_metadata(
@@ -48,57 +56,40 @@ def load_model(
4856
size: str = "large",
4957
ckpt_path: str | None = None,
5058
device: str = "cpu",
51-
metadata_path: str | Path | None = None,
52-
) -> ClayMAEModule:
53-
"""Load a Clay MAE model ready for inference.
54-
55-
Creates a ClayMAEModule and optionally loads weights from a checkpoint.
56-
The model is returned in eval mode with mask_ratio=0 and shuffle=False
57-
for deterministic inference.
59+
) -> Encoder:
60+
"""Load a Clay encoder ready for inference.
5861
59-
Note: The model includes a DINOv2 teacher (~300MB) that is downloaded
60-
on first use. The teacher is frozen and not needed for embedding
61-
extraction, but is part of the architecture.
62+
Creates an Encoder and optionally loads weights from a checkpoint.
63+
The encoder is returned in eval mode with mask_ratio=0 and shuffle=False
64+
for deterministic inference. No teacher model is downloaded.
6265
6366
Args:
6467
size: Model size - "tiny", "small", "base", or "large".
65-
ckpt_path: Path to checkpoint file. If None, creates model with
68+
ckpt_path: Path to checkpoint file. If None, creates encoder with
6669
random weights (useful for testing).
6770
device: Device to load model onto ("cpu", "cuda", etc.).
68-
metadata_path: Path to a custom metadata YAML file. If None,
69-
uses the bundled metadata with common public sensors.
7071
7172
Returns:
72-
ClayMAEModule instance in eval mode.
73+
Encoder instance in eval mode.
7374
7475
Example:
75-
>>> model = load_model("large", ckpt_path="clay-v1.5.ckpt")
76-
>>> model = load_model("large", metadata_path="my_sensors.yaml")
76+
>>> encoder = load_model("large", ckpt_path="clay-v1.5.ckpt")
7777
"""
78-
resolved_path = (
79-
str(metadata_path)
80-
if metadata_path
81-
else str(files("claymodel").joinpath("configs/metadata.yaml"))
78+
if size not in _ENCODER_CONFIGS:
79+
raise ValueError(f"Invalid size {size!r}. Expected one of {list(_ENCODER_CONFIGS.keys())}")
80+
81+
encoder = Encoder(
82+
mask_ratio=0.0,
83+
patch_size=8,
84+
shuffle=False,
85+
**_ENCODER_CONFIGS[size],
8286
)
8387

8488
if ckpt_path is not None:
85-
model = ClayMAEModule.load_from_checkpoint(
86-
ckpt_path,
87-
metadata_path=resolved_path,
88-
map_location=device,
89-
)
90-
else:
91-
model = ClayMAEModule(
92-
model_size=size,
93-
mask_ratio=0.0,
94-
shuffle=False,
95-
metadata_path=resolved_path,
96-
)
89+
load_encoder_weights(encoder, ckpt_path, device=device, freeze=False)
9790

98-
model.model.encoder.mask_ratio = 0.0
99-
model.model.encoder.shuffle = False
100-
model.eval()
101-
return model.to(device)
91+
encoder.eval()
92+
return encoder.to(device)
10293

10394

10495
@dataclass
@@ -118,7 +109,7 @@ def shape(self) -> torch.Size:
118109
def embed( # noqa: PLR0913
119110
input_data: torch.Tensor | np.ndarray,
120111
sensor: str,
121-
model: ClayMAEModule | None = None,
112+
model: Encoder | None = None,
122113
ckpt_path: str | None = None,
123114
device: str = "cpu",
124115
time: torch.Tensor | None = None,
@@ -188,7 +179,7 @@ def embed( # noqa: PLR0913
188179
model = load_model(ckpt_path=ckpt_path, device=device)
189180

190181
with torch.no_grad():
191-
encoded, *_ = model.encoder(datacube)
182+
encoded, *_ = model(datacube)
192183
cls_embeddings = encoded[:, 0, :]
193184

194185
return EmbeddingResult(

claymodel/model.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,15 @@
1010
]
1111

1212
import math
13-
from typing import TYPE_CHECKING, Any, TypedDict, cast
13+
from typing import TYPE_CHECKING, Any, TypedDict
1414

15-
import timm
1615
import torch
1716
import torch.nn.functional as F
1817
from einops import rearrange, reduce, repeat
1918
from torch import nn
20-
from torchvision.transforms import v2
2119

2220
from claymodel.embedding import DynamicEmbedding
2321
from claymodel.layers import Transformer
24-
from claymodel.mrl import MRL, MRLLoss
2522
from claymodel.utils import posemb_sincos_2d_with_gsd
2623

2724
if TYPE_CHECKING:
@@ -375,14 +372,17 @@ def forward( # noqa: PLR0913
375372

376373

377374
class ClayMAE(nn.Module):
375+
"""Clay Masked Autoencoder: encoder + decoder.
376+
377+
Does not include the teacher model or representation loss components,
378+
which live in ClayMAEModule (the training wrapper).
379+
"""
380+
378381
mask_ratio: float
379382
patch_size: int
380383
norm_pix_loss: bool
381384
shuffle: bool
382385
metadata: dict[str, "PlatformMetadata"]
383-
teacher: nn.Module
384-
teacher_chip_size: int
385-
matryoshka: bool
386386
encoder: Encoder
387387
decoder: Decoder
388388

@@ -393,9 +393,6 @@ def __init__( # noqa: PLR0913
393393
norm_pix_loss: bool,
394394
shuffle: bool,
395395
metadata: dict[str, "PlatformMetadata"],
396-
teacher: str,
397-
dolls: list[int],
398-
doll_weights: list[float],
399396
# ENCODER
400397
dim: int,
401398
depth: int,
@@ -408,7 +405,6 @@ def __init__( # noqa: PLR0913
408405
decoder_heads: int,
409406
decoder_dim_head: int,
410407
decoder_mlp_ratio: float,
411-
matryoshka: bool = False,
412408
**kwargs: object,
413409
) -> None:
414410
super().__init__()
@@ -417,16 +413,6 @@ def __init__( # noqa: PLR0913
417413
self.norm_pix_loss = norm_pix_loss
418414
self.shuffle = shuffle
419415
self.metadata = metadata
420-
self.teacher = timm.create_model(teacher, pretrained=True, num_classes=0)
421-
teacher_features = cast("int", self.teacher.num_features)
422-
self.teacher_chip_size = 518
423-
self.teacher_resize = v2.Resize(size=(self.teacher_chip_size, self.teacher_chip_size))
424-
self.matryoshka = matryoshka
425-
if matryoshka:
426-
self.mrl = MRL(features=teacher_features, dolls=dolls)
427-
self.mrl_loss = MRLLoss(weights=doll_weights)
428-
else:
429-
self.proj = nn.Linear(dim, teacher_features)
430416

431417
self.encoder = Encoder(
432418
mask_ratio=mask_ratio,
@@ -450,13 +436,6 @@ def __init__( # noqa: PLR0913
450436
mlp_ratio=decoder_mlp_ratio,
451437
)
452438

453-
self.freeze_teacher()
454-
455-
def freeze_teacher(self) -> None:
456-
for param in self.teacher.parameters():
457-
param.requires_grad = False
458-
self.teacher.eval()
459-
460439
def per_pixel_loss(
461440
self, cube: torch.Tensor, pixels: torch.Tensor, masked_matrix: torch.Tensor
462441
) -> torch.Tensor:

claymodel/module.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
import random
44
from collections.abc import Mapping
55
from importlib.resources import files
6-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, Any, cast
77

88
import lightning as L
9+
import timm
910
import torch
1011
import torch.nn.functional as F
1112
from lightning.pytorch.utilities.types import OptimizerLRScheduler
13+
from torch import nn
14+
from torchvision.transforms import v2
1215

1316
from claymodel.metadata import load_metadata_yaml
1417
from claymodel.model import (
@@ -19,6 +22,7 @@
1922
clay_mae_small,
2023
clay_mae_tiny,
2124
)
25+
from claymodel.mrl import MRL, MRLLoss
2226

2327
if TYPE_CHECKING:
2428
from claymodel.metadata import PlatformMetadata
@@ -65,19 +69,57 @@ def __init__( # noqa: PLR0913
6569
"norm_pix_loss": norm_pix_loss,
6670
"shuffle": shuffle,
6771
"metadata": self.metadata,
68-
"teacher": teacher,
69-
"dolls": dolls,
70-
"doll_weights": doll_weights,
71-
"matryoshka": matryoshka,
7272
}
7373
self.model = model_map[model_size](**model_args)
7474
else:
7575
raise ValueError(
7676
f"Invalid model size {model_size}. Expected one of {list(model_map.keys())}"
7777
)
7878

79+
# Teacher model and representation loss components (training only)
80+
self.teacher = timm.create_model(teacher, pretrained=True, num_classes=0)
81+
teacher_features = cast("int", self.teacher.num_features)
82+
self.teacher_chip_size = 518
83+
self.teacher_resize = v2.Resize(size=(self.teacher_chip_size, self.teacher_chip_size))
84+
self.matryoshka = matryoshka
85+
if matryoshka:
86+
self.mrl = MRL(features=teacher_features, dolls=dolls)
87+
self.mrl_loss = MRLLoss(weights=doll_weights)
88+
else:
89+
self.proj = nn.Linear(self.model.encoder.dim, teacher_features)
90+
91+
self._freeze_teacher()
92+
93+
def _freeze_teacher(self) -> None:
94+
for param in self.teacher.parameters():
95+
param.requires_grad = False
96+
self.teacher.eval()
97+
98+
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
99+
"""Remap old checkpoint keys where teacher lived under model.*."""
100+
state_dict = checkpoint.get("state_dict", checkpoint)
101+
prefixes = (
102+
"model.teacher.",
103+
"model.proj.",
104+
"model.mrl.",
105+
"model.mrl_loss.",
106+
"model.teacher_resize.",
107+
)
108+
remapped = {}
109+
keys_to_remove = []
110+
for key in state_dict:
111+
for prefix in prefixes:
112+
if key.startswith(prefix):
113+
new_key = key.replace("model.", "", 1)
114+
remapped[new_key] = state_dict[key]
115+
keys_to_remove.append(key)
116+
break
117+
for key in keys_to_remove:
118+
del state_dict[key]
119+
state_dict.update(remapped)
120+
79121
def on_train_epoch_start(self) -> None:
80-
self.model.teacher.eval()
122+
self.teacher.eval()
81123

82124
@property
83125
def encoder(self) -> Encoder:
@@ -165,19 +207,19 @@ def _teacher_target(
165207
else:
166208
indices = self.metadata[platform].rgb_indices
167209
rgb = pixels[:, indices, :, :]
168-
rgb = self.model.teacher_resize(rgb)
169-
return self.model.teacher(rgb)
210+
rgb = self.teacher_resize(rgb)
211+
return self.teacher(rgb)
170212

171213
def _representation_loss(
172214
self,
173215
cls_token: torch.Tensor,
174216
target: torch.Tensor,
175217
) -> torch.Tensor:
176218
"""Compute representation loss (proj or MRL)."""
177-
if self.model.matryoshka:
178-
representations = self.model.mrl(cls_token)
179-
return self.model.mrl_loss(representations, target)
180-
representations = self.model.proj(cls_token)
219+
if self.matryoshka:
220+
representations = self.mrl(cls_token)
221+
return self.mrl_loss(representations, target)
222+
representations = self.proj(cls_token)
181223
return 1.0 - F.cosine_similarity(representations, target).mean()
182224

183225
def _log_losses(

0 commit comments

Comments
 (0)