Skip to content

Commit 5a5e8a1

Browse files
clean up inference pipeline
1 parent 0926539 commit 5a5e8a1

4 files changed

Lines changed: 90 additions & 93 deletions

File tree

claymodel/cli.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,14 @@
11
"""Clay model CLI."""
22

3-
import importlib.util
4-
53
import click
64

75
from claymodel.api import load_metadata, load_model
8-
9-
10-
def _check_inference_deps() -> None:
11-
"""Check that inference optional dependencies are installed."""
12-
missing = [
13-
pkg
14-
for pkg in ("rustac", "lazycogs", "geopandas", "shapely")
15-
if importlib.util.find_spec(pkg) is None
16-
]
17-
if missing:
18-
raise click.ClickException(
19-
f"Missing inference dependencies: {', '.join(missing)}\n"
20-
"Install them with: pip install claymodel[inference]"
21-
)
6+
from claymodel.inference.pipeline import (
7+
generate_embeddings,
8+
load_scene,
9+
save_embeddings_geoparquet,
10+
search_scene,
11+
)
2212

2313

2414
@click.group()
@@ -97,17 +87,7 @@ def embed(
9787
Searches Earth Search, loads COG bands via lazycogs, chips into patches,
9888
runs the Clay encoder, and saves embeddings as GeoParquet.
9989
100-
Requires: pip install claymodel[inference]
10190
"""
102-
_check_inference_deps()
103-
104-
from claymodel.inference.pipeline import (
105-
generate_embeddings,
106-
load_scene,
107-
save_embeddings_geoparquet,
108-
search_scene,
109-
)
110-
11191
click.echo(f"Scene: {scene_id}")
11292
click.echo(f"Model: {size} (checkpoint: {ckpt})")
11393
click.echo(f"Device: {device}")

claymodel/inference/pipeline.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
Provides functions to search Earth Search STAC, load Sentinel-2 COGs via
44
lazycogs, chip into model-ready patches, and generate embeddings with
55
spatial metadata. Usable both programmatically and via the CLI.
6-
7-
Requires optional dependencies: ``pip install claymodel[inference]``
86
"""
97

108
__all__ = [
@@ -25,8 +23,12 @@
2523
from datetime import datetime
2624
from pathlib import Path
2725

26+
import geopandas as gpd
27+
import lazycogs
2828
import numpy as np
29+
import rustac
2930
import torch
31+
from shapely.geometry import box
3032

3133
from claymodel.api import load_metadata, load_model, normalize
3234
from claymodel.metadata import PlatformMetadata
@@ -130,8 +132,6 @@ async def _search_scene_async(
130132
Returns:
131133
Number of items found.
132134
"""
133-
import rustac
134-
135135
return await rustac.search_to(
136136
str(parquet_path),
137137
EARTH_SEARCH_URL,
@@ -158,14 +158,6 @@ def search_scene(
158158
ValueError: If the scene is not found.
159159
ImportError: If rustac is not installed.
160160
"""
161-
try:
162-
import rustac # noqa: F401
163-
except ImportError as exc:
164-
raise ImportError(
165-
"rustac is required for the inference pipeline. "
166-
"Install it with: pip install claymodel[inference]"
167-
) from exc
168-
169161
if output_dir is None:
170162
output_dir = Path(tempfile.mkdtemp(prefix="clay_"))
171163
else:
@@ -191,8 +183,6 @@ def search_scene(
191183
)
192184

193185
# Read back the parquet to extract metadata
194-
import geopandas as gpd
195-
196186
gdf = gpd.read_parquet(parquet_path)
197187
row = gdf.iloc[0]
198188

@@ -238,14 +228,6 @@ def load_scene(
238228
Raises:
239229
ImportError: If lazycogs is not installed.
240230
"""
241-
try:
242-
import lazycogs
243-
except ImportError as exc:
244-
raise ImportError(
245-
"lazycogs is required for the inference pipeline. "
246-
"Install it with: pip install claymodel[inference]"
247-
) from exc
248-
249231
parquet_str = str(search_result.parquet_path)
250232
store = lazycogs.store_for(parquet_str, skip_signature=True)
251233

@@ -469,9 +451,6 @@ def save_embeddings_geoparquet(
469451
Raises:
470452
ImportError: If geopandas is not installed.
471453
"""
472-
import geopandas as gpd
473-
from shapely.geometry import box
474-
475454
output_path = Path(output_path)
476455
output_path.parent.mkdir(parents=True, exist_ok=True)
477456

pyproject.toml

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,14 @@ license-files = ["LICENSE"]
2929
dependencies = [
3030
"click>=8.1",
3131
"einops~=0.7.0",
32+
"geopandas>=0.14",
33+
"lazycogs",
3234
"lightning>=2.0.0",
3335
"pydantic>=2.0",
36+
"pyarrow",
3437
"pyyaml",
38+
"rustac",
39+
"shapely>=2.0",
3540
"timm>=0.6.0",
3641
"torch>=2.4.0",
3742
"torchvision>=0.19.0",
@@ -51,14 +56,8 @@ finetune = [
5156
"tifffile",
5257
"rasterio",
5358
]
54-
inference = [
55-
"geopandas>=0.14",
56-
"lazycogs",
57-
"rustac",
58-
"shapely>=2.0",
59-
]
6059
all = [
61-
"claymodel[train,finetune,inference]",
60+
"claymodel[train,finetune]",
6261
]
6362

6463
[project.scripts]
@@ -149,8 +148,8 @@ known-first-party = ["claymodel", "training", "finetune"]
149148
"claymodel/api.py" = ["PLR0912", "PLR0915", "PLR2004"]
150149
"claymodel/module.py" = ["PLR2004"]
151150
"claymodel/inference/masking.py" = ["PLR2004"]
152-
"claymodel/inference/pipeline.py" = ["PLC0415", "PLR0913", "PLR2004"]
153-
"claymodel/cli.py" = ["PLC0415", "PLR0913"]
151+
"claymodel/inference/pipeline.py" = ["PLR0913", "PLR2004"]
152+
"claymodel/cli.py" = ["PLR0913"]
154153
"tests/*.py" = [
155154
"ANN001",
156155
"ANN003",

0 commit comments

Comments
 (0)