Skip to content

refactor library#382

Draft
geospatial-jeff wants to merge 43 commits intomainfrom
refactor
Draft

refactor library#382
geospatial-jeff wants to merge 43 commits intomainfrom
refactor

Conversation

@geospatial-jeff
Copy link
Copy Markdown

No description provided.

@geospatial-jeff geospatial-jeff marked this pull request as draft April 26, 2026 13:44
Comment thread .github/workflows/ci.yml Fixed
Comment thread .github/workflows/ci.yml Fixed
Comment thread .github/workflows/ci.yml Fixed
@brunosan brunosan requested a review from Copilot April 26, 2026 16:40
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Refactors the repository into a more installable/usable Python package (claymodel) by introducing a typed metadata model, a high-level public API/CLI, reorganizing training vs inference code, and expanding automated tests and documentation to match the new interfaces.

Changes:

  • Introduces claymodel.api (load/normalize/embed), claymodel.cli, and claymodel.metadata (Pydantic) and updates docs/tutorials accordingly.
  • Refactors training + finetune code to use the new utilities (shared encoder weight loading, metadata loading) and new module paths.
  • Adds CI + Hatch-based dev workflows, Ruff configuration updates, and a broad new test suite.

Reviewed changes

Copilot reviewed 77 out of 88 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
utils/split_npz.py Removes a one-off data splitting script from utils.
utils/check_data_sanity.py Removes a one-off data deletion/sanity script from utils.
training/datamodule.py Switches metadata loading to Metadata and removes unused sampler/prints.
training/callbacks_wandb.py Adds new W&B callback module under training/.
trainer.py Ensures training defaults are configured and imports training datamodule from training.*.
tests/test_utils.py Adds tests for posemb utilities and shared encoder-weight loading.
tests/test_mrl.py Adds tests for MRL module/loss behavior.
tests/test_module.py Adds tests for ClayMAEModule training orchestration/logging/optimizer config.
tests/test_model.py Adds extensive shape/smoke tests for encoder/decoder/model factories and flags.
tests/test_metadata.py Adds tests for bundled + custom metadata loading/validation.
tests/test_masking.py Adds tests for PatchAnalyzer chip-quality helpers.
tests/test_imports.py Adds tests that package exports and inference exports are importable.
tests/test_elle.py Adds tests for ELLE probe behavior and serialization.
tests/test_deterministic.py Adds tests for deterministic inference context manager.
tests/test_cli.py Adds tests for CLI commands and expected output.
tests/test_api.py Adds end-to-end tests for load_model, normalize, and embed API.
tests/conftest.py Adds shared fixtures for lightweight model/datacube/metadata construction.
ruff.toml Updates per-file ignores to cover new modules and tests.
pyproject.toml Splits dependencies into optional extras, adds Hatch build/env config, and registers clay script.
finetune/segment/segment.py Updates imports to new finetune.* module paths.
finetune/segment/preprocess_data.py Adds a Chesapeake preprocessing script for segmentation chips.
finetune/segment/factory.py Refactors checkpoint loading to shared load_encoder_weights.
finetune/segment/chesapeake_model.py Updates imports to new finetune.* module paths.
finetune/segment/chesapeake_inference.ipynb Adds a segmentation inference notebook using new paths.
finetune/segment/chesapeake_datamodule.py Uses Metadata.from_yaml instead of Box/yaml.
finetune/regression/regression.py Updates imports to new finetune.* module paths.
finetune/regression/preprocess_data.py Adds a BioMasters preprocessing CLI for building npz cubes.
finetune/regression/factory.py Refactors checkpoint loading to shared load_encoder_weights.
finetune/regression/biomasters_model.py Updates imports to new finetune.* module paths.
finetune/regression/biomasters_inference.ipynb Adds a regression inference notebook using new paths.
finetune/regression/biomasters_datamodule.py Uses Metadata.from_yaml instead of Box/yaml.
finetune/embedder/factory.py Uses shared load_encoder_weights and simplifies weight loading.
finetune/classify/factory.py Uses shared load_encoder_weights and removes duplicated ckpt parsing.
finetune/classify/eurosat_model.py Updates imports to new finetune.* module paths.
finetune/classify/eurosat_datamodule.py Uses Metadata.from_yaml instead of Box/yaml.
finetune/classify/classify.py Updates imports to new finetune.* module paths.
finetune/init.py Adds a top-level finetune package docstring/module.
docs/tutorials/wall-to-wall.ipynb Updates tutorial to prefer installed package metadata/API while preserving dev path fallback.
docs/release-notes/specification.md Updates documented loss weighting and chip size details.
docs/references.bib Fixes formatting/line numbering for docs build.
docs/index.md Updates embeddings links and wording.
docs/getting-started/quickstart.md Updates quickstart to use load_model, load_metadata, and datacube dict input.
docs/getting-started/migration-guide.md Updates migration guide for new imports + new datacube encoder signature.
docs/getting-started/installation.md Updates installation docs to match new public API usage.
docs/getting-started/basic_use.md Updates basic usage docs to match new public API and datacube input format.
docs/concepts/how-clay-works.md Adds new conceptual documentation describing model behavior and input/output formats.
docs/clay-v0/tutorial_digital_earth_pacific_patch_level.ipynb Updates notebook imports to use training.datamodule for training-only components.
docs/clay-v0/patch_level_cloud_cover.ipynb Updates notebook imports to use training.datamodule for training-only components.
docs/clay-v0/partial-inputs.ipynb Updates notebook imports to use training.datamodule for training-only components.
docs/clay-v0/partial-inputs-flood-tutorial.ipynb Updates notebook imports to use training.datamodule for training-only components.
docs/clay-v0/clay-v0-reconstruction.ipynb Updates notebook imports to use training.datamodule for training-only components.
docs/clay-v0/clay-v0-location-embeddings.ipynb Updates notebook imports and fixes minor formatting.
docs/clay-v0/clay-v0-interpolation.ipynb Updates notebook imports to use training.datamodule for training-only components.
docs/_toc.yml Reorders/extends docs TOC, adds Concepts section and tutorial entry.
docs/_config.yml Updates docs execution exclusions to include inference tutorial.
configs/regression_biomasters.yaml Removes commented-out callback config.
configs/metadata.yaml Adds explicit comment about Sentinel-1 synthetic wavelengths convention.
configs/config.yaml Updates callback path to training.callbacks_wandb.*.
claymodel/utils.py Adds __all__, typing, and shared load_encoder_weights helper.
claymodel/mrl.py Adds docstring, typing, and clarifies MRL behavior.
claymodel/module.py Moves metadata handling to Metadata, adds encoder property, adds matryoshka flag wiring.
claymodel/model.py Splits layers/embedding imports, adds configure_training_defaults, restores MRL option, adds typing.
claymodel/metadata.py Adds Pydantic models and YAML loader for sensor metadata.
claymodel/layers.py Moves transformer code into dedicated module with typing and __all__.
claymodel/inference/masking.py Adds PatchAnalyzer utilities for quality filtering.
claymodel/inference/elle.py Adds ELLE probe implementation for quality scoring from embeddings.
claymodel/inference/deterministic.py Adds deterministic inference context manager.
claymodel/inference/init.py Adds inference subpackage exports.
claymodel/finetune/embedder/how-to-embed.ipynb Removes duplicated/legacy notebook under claymodel/finetune.
claymodel/finetune/init.py Removes legacy finetune subpackage exports (moved to top-level finetune/).
claymodel/embedding.py Adds __all__ + typing improvements for dynamic embedding module.
claymodel/configs/metadata.yaml Adds bundled metadata file under package resources.
claymodel/cli.py Adds clay CLI group with embed/info/benchmark commands.
claymodel/callbacks_wandb.py Removes legacy callback module (moved to training/).
claymodel/callbacks.py Removes legacy callbacks module.
claymodel/api.py Adds public API: metadata loading, normalization, model loading, embedding + export.
claymodel/init.py Updates public exports and switches version to importlib.metadata.version.
README.md Updates example imports to match new package surface.
CONTRIBUTING.md Adds contribution guide, dev workflow, and repo structure notes.
.ruff.toml Removes legacy Ruff config file (consolidated into ruff.toml).
.pre-commit-config.yaml Updates hooks to exclude notebooks from whitespace fixes and removes notebook ruff hook.
.github/workflows/ci.yml Adds CI workflow for linting (ruff) and tests (hatch + pytest).
Comments suppressed due to low confidence (4)

finetune/segment/factory.py:60

  • self.device is set to CUDA when available, but the module parameters are still on CPU during __init__. Passing device=self.device into load_encoder_weights will load checkpoint tensors onto CUDA and then attempt to .copy_() them into CPU tensors, which will raise a device mismatch error. Load weights onto CPU (or the encoder’s current device), or move self to self.device before loading.
    finetune/regression/factory.py:54
  • Same device-mismatch issue as in the segmentation factory: the encoder is initialized on CPU but load_encoder_weights(..., device=self.device) may load checkpoint tensors onto CUDA, causing .copy_() into CPU tensors to fail. Load weights on CPU / encoder’s current device, or move the module to the target device before loading.
    finetune/classify/factory.py:72
  • self.clay_encoder is still on CPU when self.device is set to CUDA. Calling load_encoder_weights(..., device=self.device) will load checkpoint tensors onto CUDA and then try to copy them into CPU tensors, which will raise a device mismatch error. Use CPU map_location during weight loading (or move the encoder to self.device first) and consider deriving the map_location from next(self.clay_encoder.parameters()).device.
    claymodel/utils.py:52
  • posemb_sincos_2d_with_gsd declares gsd as torch.Tensor | float, but then calls gsd.to(...). If a float is passed (including the default 1.0), this will raise an AttributeError. Convert non-tensors via torch.as_tensor(gsd, device=...) (or branch on type) before calling .to().
def posemb_sincos_2d_with_gsd(
    h: int,
    w: int,
    dim: int,
    gsd: torch.Tensor | float = 1.0,
    temperature: int = 10000,
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"

    gsd = gsd.to(x.device)
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** (2 * omega / dim)) * (gsd / 1.0)  # Adjusted for g

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread claymodel/utils.py
Comment on lines +85 to +103
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
state_dict = ckpt.get("state_dict", ckpt)

# Extract encoder weights and strip the "model.encoder." prefix
encoder_state_dict = {
re.sub(r"^model\.encoder\.", "", name): param
for name, param in state_dict.items()
if name.startswith("model.encoder")
}

# Load matching weights into the encoder's state dict
model_state_dict = encoder.state_dict()
loaded_keys = []
skipped_keys = []

for name, param in encoder_state_dict.items():
if name in model_state_dict and param.size() == model_state_dict[name].size():
model_state_dict[name].copy_(param)
loaded_keys.append(name)
Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_encoder_weights loads checkpoint tensors onto the provided device and then copies them directly into encoder.state_dict() tensors. If the encoder module is still on CPU (common during __init__) but device is CUDA, .copy_() will error due to device mismatch. Consider loading the checkpoint to the encoder’s current parameter device (e.g., next(encoder.parameters()).device) or always loading to CPU and letting the caller move the module afterward.

Copilot uses AI. Check for mistakes.
Comment thread training/callbacks_wandb.py Outdated
Comment on lines +62 to +65
val_dl = iter(trainer.val_dataloaders)
for i in range(6):
batch = next(val_dl)
platform = batch["platform"][0]
Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trainer.val_dataloaders is typically a list/collection of DataLoaders. Iterating over it yields DataLoader objects, not batches, so batch = next(val_dl) will not be a batch dict and will fail when indexing batch["platform"]. Use an actual validation DataLoader (e.g., trainer.val_dataloaders[0]) and then iterate over that, and consider handling multiple val dataloaders explicitly.

Copilot uses AI. Check for mistakes.
Comment on lines +123 to +148
fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, 8))

for j in range(n_cols):
# Plot actual images in rows 0 and 2
axs[0, j].imshow(batch["pixels"][j][0], cmap="viridis")
axs[0, j].set_title(f"Actual {j}")
axs[0, j].axis("off")

axs[2, j].imshow(
batch["pixels"][j + n_cols][0],
cmap="viridis",
)
axs[2, j].set_title(f"Actual {j + n_cols}")
axs[2, j].axis("off")

# Plot predicted images in rows 1 and 3
axs[1, j].imshow(pixels[j][0], cmap="viridis")
axs[1, j].set_title(f"Pred {j}")
axs[1, j].axis("off")

axs[3, j].imshow(pixels[j + n_cols][0], cmap="viridis")
axs[3, j].set_title(f"Pred {j + n_cols}")
axs[3, j].axis("off")

self.logger.experiment.log({f"{platform}": wandb.Image(fig)})
plt.close(fig)
Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Figures are created inside the loop but only the last fig is closed after the loop. This can leak figure objects/memory over epochs. Close each figure inside the loop (after logging) or explicitly close all created figures.

Copilot uses AI. Check for mistakes.
Comment thread claymodel/api.py Outdated
Comment on lines +88 to +93
mean = torch.tensor(list(sensor_meta.bands.mean.values()), dtype=pixels.dtype).view(
1, -1, 1, 1
)
std = torch.tensor(list(sensor_meta.bands.std.values()), dtype=pixels.dtype).view(
1, -1, 1, 1
)
Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normalize() builds mean/std via sensor_meta.bands.mean.values() / .std.values(). This relies on dict insertion order matching the channel order of pixels, which is not guaranteed for custom metadata and is easy to get wrong. Use sensor_meta.band_order to build mean/std in the same order as the input channels (and do the same for wavelengths in embed()).

Suggested change
mean = torch.tensor(list(sensor_meta.bands.mean.values()), dtype=pixels.dtype).view(
1, -1, 1, 1
)
std = torch.tensor(list(sensor_meta.bands.std.values()), dtype=pixels.dtype).view(
1, -1, 1, 1
)
mean = torch.tensor(
[sensor_meta.bands.mean[band] for band in sensor_meta.band_order],
dtype=pixels.dtype,
).view(1, -1, 1, 1)
std = torch.tensor(
[sensor_meta.bands.std[band] for band in sensor_meta.band_order],
dtype=pixels.dtype,
).view(1, -1, 1, 1)

Copilot uses AI. Check for mistakes.
Comment thread claymodel/api.py Outdated
Comment on lines +201 to +205
if "latlon" in self.metadata and self.metadata["latlon"] is not None:
lat = self.metadata["latlon"][i][0].item()
lon = self.metadata["latlon"][i][1].item()
record["geometry"] = Point(lon, lat)
else:
Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EmbeddingResult.to_geoparquet() assumes metadata["latlon"] contains (lat, lon) in degrees, but embed() accepts/constructs latlon in the model’s expected sin/cos encoding (shape [B,4]). Exporting those encoded values as a Point will produce incorrect geometries. Either (a) store raw lat/lon degrees separately (e.g., coords), (b) require degrees for export, or (c) provide helpers to convert degrees -> sin/cos for the model and keep degrees for metadata/export.

Copilot uses AI. Check for mistakes.
Comment thread pyproject.toml Outdated
Comment on lines +78 to +81
[tool.hatch.envs.default.scripts]
test = "pytest tests/ -v --cov=claymodel --cov-report=term-missing"
lint = "pre-commit run --all"

Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Hatch lint script runs pre-commit, but pre-commit is not listed in the dev extra dependencies. In a fresh hatch run lint environment this will fail with command not found. Add pre-commit to the dev optional dependencies (or change the lint script to use only tools that are installed).

Copilot uses AI. Check for mistakes.
Comment on lines +13 to +17
try:
import wandb
except ImportError:
wandb = None

Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wandb is set to None on ImportError but later used unconditionally (wandb.Image). This will raise an AttributeError when the callback runs in environments without wandb installed. Fail fast (e.g., raise a clear error in __init__) or guard the logging path when wandb is None.

Copilot uses AI. Check for mistakes.
Comment thread .github/workflows/ci.yml
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v8.1.0
Comment thread .github/workflows/ci.yml
python-version: ["3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v8.1.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants