Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions src/cubedynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
from . import verbs
from . import tubes
from .demo_vase import demo
import xarray as xr
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import xarray as xr

# Legacy, fully implemented APIs -------------------------------------------------
from .data.gridmet import load_gridmet_cube
Expand Down Expand Up @@ -70,8 +73,6 @@
from .stats.spatial import mask_by_threshold, spatial_coarsen_mean, spatial_smooth_mean
from .stats.tails import rolling_tail_dep_vs_center
from .utils.chunking import coarsen_and_stride
from .viz.lexcube_viz import show_cube_lexcube
from .viz.qa_plots import plot_median_over_space
from .ops import (
anomaly,
month_filter,
Expand Down Expand Up @@ -173,3 +174,19 @@ def plot(
from .ops_fire.fired_api import fired_event

__all__ += ["gridmet", "fired_event"]


def show_cube_lexcube(*args, **kwargs):
"""Lazy wrapper to avoid importing visualization stack at package import time."""

from .viz.lexcube_viz import show_cube_lexcube as _show_cube_lexcube

return _show_cube_lexcube(*args, **kwargs)


def plot_median_over_space(*args, **kwargs):
"""Lazy wrapper to avoid importing visualization stack at package import time."""

from .viz.qa_plots import plot_median_over_space as _plot_median_over_space

return _plot_median_over_space(*args, **kwargs)
20 changes: 16 additions & 4 deletions src/cubedynamics/verbs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@

from __future__ import annotations

import cubedynamics.viz as viz
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from IPython.display import display

from ..config import TIME_DIM, X_DIM, Y_DIM
Expand Down Expand Up @@ -57,9 +55,17 @@
from .stats import anomaly, mean, rolling_tail_dep_vs_center, variance, zscore


def _import_xarray():
"""Import xarray lazily to avoid import-time hard dependency failures."""

import xarray as xr

return xr


def _unwrap_dataarray(
obj: xr.DataArray | VirtualCube | None,
) -> tuple[xr.DataArray, xr.DataArray | VirtualCube]:
obj,
):
"""
Normalize a verb input to an (xarray.DataArray, original_obj) pair.

Expand All @@ -73,6 +79,8 @@ def _unwrap_dataarray(
if obj is None:
raise ValueError("extract() requires an input cube/DataArray; got None.")

xr = _import_xarray()

if isinstance(obj, VirtualCube):
base_da = obj.materialize()
if not isinstance(base_da, xr.DataArray):
Expand Down Expand Up @@ -154,6 +162,8 @@ def show_cube_lexcube(**kwargs):
"""

def _op(obj):
xr = _import_xarray()

# normalize to DataArray if needed (Dataset with 1 var)
if isinstance(obj, xr.Dataset):
if len(obj.data_vars) != 1:
Expand All @@ -173,6 +183,8 @@ def _op(obj):
)

da = da.transpose(TIME_DIM, Y_DIM, X_DIM)
import cubedynamics.viz as viz

widget = viz.show_cube_lexcube(da, **kwargs)
display(widget)

Expand Down
22 changes: 22 additions & 0 deletions tests/test_verbs_import_no_eager_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations

import importlib
import sys


def _clear_cubedynamics_modules() -> None:
for name in list(sys.modules):
if name == "cubedynamics" or name.startswith("cubedynamics."):
sys.modules.pop(name, None)


def test_verbs_import_does_not_eager_import_viz_modules():
"""Importing verbs for custom plotting should not eagerly load lexcube viz modules."""

_clear_cubedynamics_modules()

verbs = importlib.import_module("cubedynamics.verbs")

assert callable(verbs.plot)
assert "cubedynamics.viz" not in sys.modules
assert "cubedynamics.viz.lexcube_viz" not in sys.modules
Loading