refactor library#382
Conversation
… out of installable package, move wanddb into training
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
…ntain permissions' Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
…, fix more imports
There was a problem hiding this comment.
Pull request overview
Refactors the repository into a more installable/usable Python package (claymodel) by introducing a typed metadata model, a high-level public API/CLI, reorganizing training vs inference code, and expanding automated tests and documentation to match the new interfaces.
Changes:
- Introduces
claymodel.api(load/normalize/embed),claymodel.cli, andclaymodel.metadata(Pydantic) and updates docs/tutorials accordingly. - Refactors training + finetune code to use the new utilities (shared encoder weight loading, metadata loading) and new module paths.
- Adds CI + Hatch-based dev workflows, Ruff configuration updates, and a broad new test suite.
Reviewed changes
Copilot reviewed 77 out of 88 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| utils/split_npz.py | Removes a one-off data splitting script from utils. |
| utils/check_data_sanity.py | Removes a one-off data deletion/sanity script from utils. |
| training/datamodule.py | Switches metadata loading to Metadata and removes unused sampler/prints. |
| training/callbacks_wandb.py | Adds new W&B callback module under training/. |
| trainer.py | Ensures training defaults are configured and imports training datamodule from training.*. |
| tests/test_utils.py | Adds tests for posemb utilities and shared encoder-weight loading. |
| tests/test_mrl.py | Adds tests for MRL module/loss behavior. |
| tests/test_module.py | Adds tests for ClayMAEModule training orchestration/logging/optimizer config. |
| tests/test_model.py | Adds extensive shape/smoke tests for encoder/decoder/model factories and flags. |
| tests/test_metadata.py | Adds tests for bundled + custom metadata loading/validation. |
| tests/test_masking.py | Adds tests for PatchAnalyzer chip-quality helpers. |
| tests/test_imports.py | Adds tests that package exports and inference exports are importable. |
| tests/test_elle.py | Adds tests for ELLE probe behavior and serialization. |
| tests/test_deterministic.py | Adds tests for deterministic inference context manager. |
| tests/test_cli.py | Adds tests for CLI commands and expected output. |
| tests/test_api.py | Adds end-to-end tests for load_model, normalize, and embed API. |
| tests/conftest.py | Adds shared fixtures for lightweight model/datacube/metadata construction. |
| ruff.toml | Updates per-file ignores to cover new modules and tests. |
| pyproject.toml | Splits dependencies into optional extras, adds Hatch build/env config, and registers clay script. |
| finetune/segment/segment.py | Updates imports to new finetune.* module paths. |
| finetune/segment/preprocess_data.py | Adds a Chesapeake preprocessing script for segmentation chips. |
| finetune/segment/factory.py | Refactors checkpoint loading to shared load_encoder_weights. |
| finetune/segment/chesapeake_model.py | Updates imports to new finetune.* module paths. |
| finetune/segment/chesapeake_inference.ipynb | Adds a segmentation inference notebook using new paths. |
| finetune/segment/chesapeake_datamodule.py | Uses Metadata.from_yaml instead of Box/yaml. |
| finetune/regression/regression.py | Updates imports to new finetune.* module paths. |
| finetune/regression/preprocess_data.py | Adds a BioMasters preprocessing CLI for building npz cubes. |
| finetune/regression/factory.py | Refactors checkpoint loading to shared load_encoder_weights. |
| finetune/regression/biomasters_model.py | Updates imports to new finetune.* module paths. |
| finetune/regression/biomasters_inference.ipynb | Adds a regression inference notebook using new paths. |
| finetune/regression/biomasters_datamodule.py | Uses Metadata.from_yaml instead of Box/yaml. |
| finetune/embedder/factory.py | Uses shared load_encoder_weights and simplifies weight loading. |
| finetune/classify/factory.py | Uses shared load_encoder_weights and removes duplicated ckpt parsing. |
| finetune/classify/eurosat_model.py | Updates imports to new finetune.* module paths. |
| finetune/classify/eurosat_datamodule.py | Uses Metadata.from_yaml instead of Box/yaml. |
| finetune/classify/classify.py | Updates imports to new finetune.* module paths. |
| finetune/init.py | Adds a top-level finetune package docstring/module. |
| docs/tutorials/wall-to-wall.ipynb | Updates tutorial to prefer installed package metadata/API while preserving dev path fallback. |
| docs/release-notes/specification.md | Updates documented loss weighting and chip size details. |
| docs/references.bib | Fixes formatting/line numbering for docs build. |
| docs/index.md | Updates embeddings links and wording. |
| docs/getting-started/quickstart.md | Updates quickstart to use load_model, load_metadata, and datacube dict input. |
| docs/getting-started/migration-guide.md | Updates migration guide for new imports + new datacube encoder signature. |
| docs/getting-started/installation.md | Updates installation docs to match new public API usage. |
| docs/getting-started/basic_use.md | Updates basic usage docs to match new public API and datacube input format. |
| docs/concepts/how-clay-works.md | Adds new conceptual documentation describing model behavior and input/output formats. |
| docs/clay-v0/tutorial_digital_earth_pacific_patch_level.ipynb | Updates notebook imports to use training.datamodule for training-only components. |
| docs/clay-v0/patch_level_cloud_cover.ipynb | Updates notebook imports to use training.datamodule for training-only components. |
| docs/clay-v0/partial-inputs.ipynb | Updates notebook imports to use training.datamodule for training-only components. |
| docs/clay-v0/partial-inputs-flood-tutorial.ipynb | Updates notebook imports to use training.datamodule for training-only components. |
| docs/clay-v0/clay-v0-reconstruction.ipynb | Updates notebook imports to use training.datamodule for training-only components. |
| docs/clay-v0/clay-v0-location-embeddings.ipynb | Updates notebook imports and fixes minor formatting. |
| docs/clay-v0/clay-v0-interpolation.ipynb | Updates notebook imports to use training.datamodule for training-only components. |
| docs/_toc.yml | Reorders/extends docs TOC, adds Concepts section and tutorial entry. |
| docs/_config.yml | Updates docs execution exclusions to include inference tutorial. |
| configs/regression_biomasters.yaml | Removes commented-out callback config. |
| configs/metadata.yaml | Adds explicit comment about Sentinel-1 synthetic wavelengths convention. |
| configs/config.yaml | Updates callback path to training.callbacks_wandb.*. |
| claymodel/utils.py | Adds __all__, typing, and shared load_encoder_weights helper. |
| claymodel/mrl.py | Adds docstring, typing, and clarifies MRL behavior. |
| claymodel/module.py | Moves metadata handling to Metadata, adds encoder property, adds matryoshka flag wiring. |
| claymodel/model.py | Splits layers/embedding imports, adds configure_training_defaults, restores MRL option, adds typing. |
| claymodel/metadata.py | Adds Pydantic models and YAML loader for sensor metadata. |
| claymodel/layers.py | Moves transformer code into dedicated module with typing and __all__. |
| claymodel/inference/masking.py | Adds PatchAnalyzer utilities for quality filtering. |
| claymodel/inference/elle.py | Adds ELLE probe implementation for quality scoring from embeddings. |
| claymodel/inference/deterministic.py | Adds deterministic inference context manager. |
| claymodel/inference/init.py | Adds inference subpackage exports. |
| claymodel/finetune/embedder/how-to-embed.ipynb | Removes duplicated/legacy notebook under claymodel/finetune. |
| claymodel/finetune/init.py | Removes legacy finetune subpackage exports (moved to top-level finetune/). |
| claymodel/embedding.py | Adds __all__ + typing improvements for dynamic embedding module. |
| claymodel/configs/metadata.yaml | Adds bundled metadata file under package resources. |
| claymodel/cli.py | Adds clay CLI group with embed/info/benchmark commands. |
| claymodel/callbacks_wandb.py | Removes legacy callback module (moved to training/). |
| claymodel/callbacks.py | Removes legacy callbacks module. |
| claymodel/api.py | Adds public API: metadata loading, normalization, model loading, embedding + export. |
| claymodel/init.py | Updates public exports and switches version to importlib.metadata.version. |
| README.md | Updates example imports to match new package surface. |
| CONTRIBUTING.md | Adds contribution guide, dev workflow, and repo structure notes. |
| .ruff.toml | Removes legacy Ruff config file (consolidated into ruff.toml). |
| .pre-commit-config.yaml | Updates hooks to exclude notebooks from whitespace fixes and removes notebook ruff hook. |
| .github/workflows/ci.yml | Adds CI workflow for linting (ruff) and tests (hatch + pytest). |
Comments suppressed due to low confidence (4)
finetune/segment/factory.py:60
self.deviceis set to CUDA when available, but the module parameters are still on CPU during__init__. Passingdevice=self.deviceintoload_encoder_weightswill load checkpoint tensors onto CUDA and then attempt to.copy_()them into CPU tensors, which will raise a device mismatch error. Load weights onto CPU (or the encoder’s current device), or moveselftoself.devicebefore loading.
finetune/regression/factory.py:54- Same device-mismatch issue as in the segmentation factory: the encoder is initialized on CPU but
load_encoder_weights(..., device=self.device)may load checkpoint tensors onto CUDA, causing.copy_()into CPU tensors to fail. Load weights on CPU / encoder’s current device, or move the module to the target device before loading.
finetune/classify/factory.py:72 self.clay_encoderis still on CPU whenself.deviceis set to CUDA. Callingload_encoder_weights(..., device=self.device)will load checkpoint tensors onto CUDA and then try to copy them into CPU tensors, which will raise a device mismatch error. Use CPU map_location during weight loading (or move the encoder toself.devicefirst) and consider deriving the map_location fromnext(self.clay_encoder.parameters()).device.
claymodel/utils.py:52posemb_sincos_2d_with_gsddeclaresgsdastorch.Tensor | float, but then callsgsd.to(...). If a float is passed (including the default1.0), this will raise an AttributeError. Convert non-tensors viatorch.as_tensor(gsd, device=...)(or branch on type) before calling.to().
def posemb_sincos_2d_with_gsd(
h: int,
w: int,
dim: int,
gsd: torch.Tensor | float = 1.0,
temperature: int = 10000,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
gsd = gsd.to(x.device)
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** (2 * omega / dim)) * (gsd / 1.0) # Adjusted for g
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) | ||
| state_dict = ckpt.get("state_dict", ckpt) | ||
|
|
||
| # Extract encoder weights and strip the "model.encoder." prefix | ||
| encoder_state_dict = { | ||
| re.sub(r"^model\.encoder\.", "", name): param | ||
| for name, param in state_dict.items() | ||
| if name.startswith("model.encoder") | ||
| } | ||
|
|
||
| # Load matching weights into the encoder's state dict | ||
| model_state_dict = encoder.state_dict() | ||
| loaded_keys = [] | ||
| skipped_keys = [] | ||
|
|
||
| for name, param in encoder_state_dict.items(): | ||
| if name in model_state_dict and param.size() == model_state_dict[name].size(): | ||
| model_state_dict[name].copy_(param) | ||
| loaded_keys.append(name) |
There was a problem hiding this comment.
load_encoder_weights loads checkpoint tensors onto the provided device and then copies them directly into encoder.state_dict() tensors. If the encoder module is still on CPU (common during __init__) but device is CUDA, .copy_() will error due to device mismatch. Consider loading the checkpoint to the encoder’s current parameter device (e.g., next(encoder.parameters()).device) or always loading to CPU and letting the caller move the module afterward.
| val_dl = iter(trainer.val_dataloaders) | ||
| for i in range(6): | ||
| batch = next(val_dl) | ||
| platform = batch["platform"][0] |
There was a problem hiding this comment.
trainer.val_dataloaders is typically a list/collection of DataLoaders. Iterating over it yields DataLoader objects, not batches, so batch = next(val_dl) will not be a batch dict and will fail when indexing batch["platform"]. Use an actual validation DataLoader (e.g., trainer.val_dataloaders[0]) and then iterate over that, and consider handling multiple val dataloaders explicitly.
| fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, 8)) | ||
|
|
||
| for j in range(n_cols): | ||
| # Plot actual images in rows 0 and 2 | ||
| axs[0, j].imshow(batch["pixels"][j][0], cmap="viridis") | ||
| axs[0, j].set_title(f"Actual {j}") | ||
| axs[0, j].axis("off") | ||
|
|
||
| axs[2, j].imshow( | ||
| batch["pixels"][j + n_cols][0], | ||
| cmap="viridis", | ||
| ) | ||
| axs[2, j].set_title(f"Actual {j + n_cols}") | ||
| axs[2, j].axis("off") | ||
|
|
||
| # Plot predicted images in rows 1 and 3 | ||
| axs[1, j].imshow(pixels[j][0], cmap="viridis") | ||
| axs[1, j].set_title(f"Pred {j}") | ||
| axs[1, j].axis("off") | ||
|
|
||
| axs[3, j].imshow(pixels[j + n_cols][0], cmap="viridis") | ||
| axs[3, j].set_title(f"Pred {j + n_cols}") | ||
| axs[3, j].axis("off") | ||
|
|
||
| self.logger.experiment.log({f"{platform}": wandb.Image(fig)}) | ||
| plt.close(fig) |
There was a problem hiding this comment.
Figures are created inside the loop but only the last fig is closed after the loop. This can leak figure objects/memory over epochs. Close each figure inside the loop (after logging) or explicitly close all created figures.
| mean = torch.tensor(list(sensor_meta.bands.mean.values()), dtype=pixels.dtype).view( | ||
| 1, -1, 1, 1 | ||
| ) | ||
| std = torch.tensor(list(sensor_meta.bands.std.values()), dtype=pixels.dtype).view( | ||
| 1, -1, 1, 1 | ||
| ) |
There was a problem hiding this comment.
normalize() builds mean/std via sensor_meta.bands.mean.values() / .std.values(). This relies on dict insertion order matching the channel order of pixels, which is not guaranteed for custom metadata and is easy to get wrong. Use sensor_meta.band_order to build mean/std in the same order as the input channels (and do the same for wavelengths in embed()).
| mean = torch.tensor(list(sensor_meta.bands.mean.values()), dtype=pixels.dtype).view( | |
| 1, -1, 1, 1 | |
| ) | |
| std = torch.tensor(list(sensor_meta.bands.std.values()), dtype=pixels.dtype).view( | |
| 1, -1, 1, 1 | |
| ) | |
| mean = torch.tensor( | |
| [sensor_meta.bands.mean[band] for band in sensor_meta.band_order], | |
| dtype=pixels.dtype, | |
| ).view(1, -1, 1, 1) | |
| std = torch.tensor( | |
| [sensor_meta.bands.std[band] for band in sensor_meta.band_order], | |
| dtype=pixels.dtype, | |
| ).view(1, -1, 1, 1) |
| if "latlon" in self.metadata and self.metadata["latlon"] is not None: | ||
| lat = self.metadata["latlon"][i][0].item() | ||
| lon = self.metadata["latlon"][i][1].item() | ||
| record["geometry"] = Point(lon, lat) | ||
| else: |
There was a problem hiding this comment.
EmbeddingResult.to_geoparquet() assumes metadata["latlon"] contains (lat, lon) in degrees, but embed() accepts/constructs latlon in the model’s expected sin/cos encoding (shape [B,4]). Exporting those encoded values as a Point will produce incorrect geometries. Either (a) store raw lat/lon degrees separately (e.g., coords), (b) require degrees for export, or (c) provide helpers to convert degrees -> sin/cos for the model and keep degrees for metadata/export.
| [tool.hatch.envs.default.scripts] | ||
| test = "pytest tests/ -v --cov=claymodel --cov-report=term-missing" | ||
| lint = "pre-commit run --all" | ||
|
|
There was a problem hiding this comment.
The Hatch lint script runs pre-commit, but pre-commit is not listed in the dev extra dependencies. In a fresh hatch run lint environment this will fail with command not found. Add pre-commit to the dev optional dependencies (or change the lint script to use only tools that are installed).
| try: | ||
| import wandb | ||
| except ImportError: | ||
| wandb = None | ||
|
|
There was a problem hiding this comment.
wandb is set to None on ImportError but later used unconditionally (wandb.Image). This will raise an AttributeError when the callback runs in environments without wandb installed. Fail fast (e.g., raise a clear error in __init__) or guard the logging path when wandb is None.
* build: adopt uv ruff and ty tooling * ci: run checks on refactor PRs
| runs-on: ubuntu-latest | ||
| steps: | ||
| - uses: actions/checkout@v4 | ||
| - uses: astral-sh/setup-uv@v8.1.0 |
| python-version: ["3.11", "3.12", "3.13"] | ||
| steps: | ||
| - uses: actions/checkout@v4 | ||
| - uses: astral-sh/setup-uv@v8.1.0 |
… history that code is right and docs are wrong
…ack later with modern libs
…different in clayv1 with segment encoder including a feature pyramid network, intermediate feature extraction, and multi-scale output. This functionality was removed in v1.5, leaving these two classes nearly identical
…ore embed forward pass
No description provided.