|
2 | 2 |
|
3 | 3 | __all__ = ["EmbeddingResult", "embed", "load_metadata", "load_model", "normalize"] |
4 | 4 |
|
5 | | -import warnings |
6 | 5 | from dataclasses import dataclass, field |
7 | 6 | from importlib.resources import files |
8 | 7 | from pathlib import Path |
9 | 8 |
|
10 | 9 | import numpy as np |
11 | 10 | import torch |
12 | 11 |
|
13 | | -from claymodel.inference.elle import ELLEProbe |
14 | 12 | from claymodel.metadata import PlatformMetadata, load_metadata_yaml |
15 | 13 | from claymodel.module import ClayMAEModule |
16 | 14 |
|
@@ -125,7 +123,6 @@ def embed( # noqa: PLR0913 |
125 | 123 | device: str = "cpu", |
126 | 124 | time: torch.Tensor | None = None, |
127 | 125 | latlon: torch.Tensor | None = None, |
128 | | - quality: bool = False, |
129 | 126 | metadata: dict[str, PlatformMetadata] | None = None, |
130 | 127 | ) -> EmbeddingResult: |
131 | 128 | """Embed pixels or a GeoTIFF with a Clay model.""" |
@@ -181,21 +178,9 @@ def embed( # noqa: PLR0913 |
181 | 178 | encoded, *_ = model.encoder(datacube) |
182 | 179 | cls_embeddings = encoded[:, 0, :] |
183 | 180 |
|
184 | | - result = EmbeddingResult( |
| 181 | + return EmbeddingResult( |
185 | 182 | embeddings=cls_embeddings, |
186 | 183 | sensor=sensor, |
187 | 184 | gsd=float(sensor_meta.gsd), |
188 | 185 | metadata={"latlon": latlon, "time": time}, |
189 | 186 | ) |
190 | | - |
191 | | - if quality: |
192 | | - try: |
193 | | - probe = ELLEProbe.default() |
194 | | - result.metadata["quality_score"] = probe.score(cls_embeddings) |
195 | | - except FileNotFoundError: |
196 | | - warnings.warn( |
197 | | - "ELLE probe not available. Install or provide probe weights.", |
198 | | - stacklevel=2, |
199 | | - ) |
200 | | - |
201 | | - return result |
0 commit comments