Skip to content

Commit 3e63bf3

Browse files
remove elle, we'll bring it back later. bunch of other small cleanups
1 parent 4b6bb0e commit 3e63bf3

9 files changed

Lines changed: 5 additions & 175 deletions

File tree

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
- run: uv sync --locked --all-extras --group dev
2222
- run: uv run ruff check . --exclude "*.ipynb"
2323
- run: uv run ruff format --check . --exclude "*.ipynb"
24-
- run: uv run ty check claymodel training finetune trainer.py tests --exclude "*.ipynb"
24+
- run: uv run ty check claymodel tests --exclude "*.ipynb"
2525

2626
test:
2727
runs-on: ubuntu-latest
@@ -35,4 +35,4 @@ jobs:
3535
with:
3636
python-version: ${{ matrix.python-version }}
3737
- run: uv sync --locked --all-extras --group dev
38-
- run: uv run pytest tests/ -v --cov=claymodel --cov-report=term-missing
38+
- run: uv run pytest tests/ -v --cov=claymodel --cov-report=term-missing --cov-fail-under=90

claymodel/api.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,13 @@
22

33
__all__ = ["EmbeddingResult", "embed", "load_metadata", "load_model", "normalize"]
44

5-
import warnings
65
from dataclasses import dataclass, field
76
from importlib.resources import files
87
from pathlib import Path
98

109
import numpy as np
1110
import torch
1211

13-
from claymodel.inference.elle import ELLEProbe
1412
from claymodel.metadata import PlatformMetadata, load_metadata_yaml
1513
from claymodel.module import ClayMAEModule
1614

@@ -125,7 +123,6 @@ def embed( # noqa: PLR0913
125123
device: str = "cpu",
126124
time: torch.Tensor | None = None,
127125
latlon: torch.Tensor | None = None,
128-
quality: bool = False,
129126
metadata: dict[str, PlatformMetadata] | None = None,
130127
) -> EmbeddingResult:
131128
"""Embed pixels or a GeoTIFF with a Clay model."""
@@ -181,21 +178,9 @@ def embed( # noqa: PLR0913
181178
encoded, *_ = model.encoder(datacube)
182179
cls_embeddings = encoded[:, 0, :]
183180

184-
result = EmbeddingResult(
181+
return EmbeddingResult(
185182
embeddings=cls_embeddings,
186183
sensor=sensor,
187184
gsd=float(sensor_meta.gsd),
188185
metadata={"latlon": latlon, "time": time},
189186
)
190-
191-
if quality:
192-
try:
193-
probe = ELLEProbe.default()
194-
result.metadata["quality_score"] = probe.score(cls_embeddings)
195-
except FileNotFoundError:
196-
warnings.warn(
197-
"ELLE probe not available. Install or provide probe weights.",
198-
stacklevel=2,
199-
)
200-
201-
return result

claymodel/inference/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Inference utilities for Clay Foundation Model."""
22

33
from claymodel.inference.deterministic import DeterministicInference
4-
from claymodel.inference.elle import ELLEProbe
54
from claymodel.inference.masking import PatchAnalyzer
65

7-
__all__ = ["DeterministicInference", "ELLEProbe", "PatchAnalyzer"]
6+
__all__ = ["DeterministicInference", "PatchAnalyzer"]

claymodel/inference/elle.py

Lines changed: 0 additions & 86 deletions
This file was deleted.

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ known-first-party = ["claymodel", "training", "finetune"]
143143
"claymodel/__init__.py" = ["PLR0911"]
144144
"claymodel/api.py" = ["PLR0912", "PLR0915", "PLR2004"]
145145
"claymodel/module.py" = ["PLR2004"]
146-
"claymodel/inference/elle.py" = ["PLC0415"]
147146
"claymodel/inference/masking.py" = ["PLR2004"]
148147
"tests/*.py" = [
149148
"ANN001",

tests/test_api.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Test the high-level API."""
22

3-
import warnings
43
from importlib.resources import files
54

65
import numpy as np
@@ -176,19 +175,6 @@ def test_embed_sentinel1_db_conversion():
176175
assert not torch.isnan(result.embeddings).any()
177176

178177

179-
def test_embed_quality_warns_when_probe_missing():
180-
"""quality=True should warn, not error, when probe weights missing."""
181-
model = _make_tiny_module()
182-
pixels = torch.randn(1, 10, 64, 64)
183-
with warnings.catch_warnings(record=True) as w:
184-
warnings.simplefilter("always")
185-
result = embed(pixels, sensor="sentinel-2-l2a", model=model, quality=True)
186-
assert result.embeddings.shape == (1, 192)
187-
# Should have warned about missing probe
188-
probe_warned = any("ELLE" in str(w_.message) or "probe" in str(w_.message) for w_ in w)
189-
assert probe_warned
190-
191-
192178
def test_normalize_output_device_matches_input():
193179
"""Normalized output should be on the same device as input."""
194180
pixels = torch.randn(1, 10, 64, 64)

tests/test_elle.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

tests/test_imports.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
clay_mae_tiny,
1313
load_metadata,
1414
)
15-
from claymodel.inference import DeterministicInference, ELLEProbe, PatchAnalyzer
15+
from claymodel.inference import DeterministicInference, PatchAnalyzer
1616

1717

1818
def test_version_is_string():
@@ -54,5 +54,4 @@ def test_import_unknown_raises_attribute_error():
5454

5555
def test_inference_package_exports():
5656
assert callable(DeterministicInference)
57-
assert ELLEProbe is not None
5857
assert PatchAnalyzer is not None

trainer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,3 @@ def cli_main() -> LightningCLI:
2020

2121
if __name__ == "__main__":
2222
cli_main()
23-
24-
print("Done!")

0 commit comments

Comments
 (0)