Skip to content

Commit bbc73ba

Browse files
committed
Add regression test for no eager viz import in verbs
1 parent a1200cc commit bbc73ba

3 files changed

Lines changed: 58 additions & 7 deletions

File tree

src/cubedynamics/__init__.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@
3131
from . import verbs
3232
from . import tubes
3333
from .demo_vase import demo
34-
import xarray as xr
34+
from typing import TYPE_CHECKING
35+
36+
if TYPE_CHECKING:
37+
import xarray as xr
3538

3639
# Legacy, fully implemented APIs -------------------------------------------------
3740
from .data.gridmet import load_gridmet_cube
@@ -70,8 +73,6 @@
7073
from .stats.spatial import mask_by_threshold, spatial_coarsen_mean, spatial_smooth_mean
7174
from .stats.tails import rolling_tail_dep_vs_center
7275
from .utils.chunking import coarsen_and_stride
73-
from .viz.lexcube_viz import show_cube_lexcube
74-
from .viz.qa_plots import plot_median_over_space
7576
from .ops import (
7677
anomaly,
7778
month_filter,
@@ -173,3 +174,19 @@ def plot(
173174
from .ops_fire.fired_api import fired_event
174175

175176
__all__ += ["gridmet", "fired_event"]
177+
178+
179+
def show_cube_lexcube(*args, **kwargs):
180+
"""Lazy wrapper to avoid importing visualization stack at package import time."""
181+
182+
from .viz.lexcube_viz import show_cube_lexcube as _show_cube_lexcube
183+
184+
return _show_cube_lexcube(*args, **kwargs)
185+
186+
187+
def plot_median_over_space(*args, **kwargs):
188+
"""Lazy wrapper to avoid importing visualization stack at package import time."""
189+
190+
from .viz.qa_plots import plot_median_over_space as _plot_median_over_space
191+
192+
return _plot_median_over_space(*args, **kwargs)

src/cubedynamics/verbs/__init__.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@
2222

2323
from __future__ import annotations
2424

25-
import cubedynamics.viz as viz
2625
import matplotlib.pyplot as plt
2726
import numpy as np
28-
import xarray as xr
2927
from IPython.display import display
3028

3129
from ..config import TIME_DIM, X_DIM, Y_DIM
@@ -57,9 +55,17 @@
5755
from .stats import anomaly, mean, rolling_tail_dep_vs_center, variance, zscore
5856

5957

58+
def _import_xarray():
59+
"""Import xarray lazily to avoid import-time hard dependency failures."""
60+
61+
import xarray as xr
62+
63+
return xr
64+
65+
6066
def _unwrap_dataarray(
61-
obj: xr.DataArray | VirtualCube | None,
62-
) -> tuple[xr.DataArray, xr.DataArray | VirtualCube]:
67+
obj,
68+
):
6369
"""
6470
Normalize a verb input to an (xarray.DataArray, original_obj) pair.
6571
@@ -73,6 +79,8 @@ def _unwrap_dataarray(
7379
if obj is None:
7480
raise ValueError("extract() requires an input cube/DataArray; got None.")
7581

82+
xr = _import_xarray()
83+
7684
if isinstance(obj, VirtualCube):
7785
base_da = obj.materialize()
7886
if not isinstance(base_da, xr.DataArray):
@@ -154,6 +162,8 @@ def show_cube_lexcube(**kwargs):
154162
"""
155163

156164
def _op(obj):
165+
xr = _import_xarray()
166+
157167
# normalize to DataArray if needed (Dataset with 1 var)
158168
if isinstance(obj, xr.Dataset):
159169
if len(obj.data_vars) != 1:
@@ -173,6 +183,8 @@ def _op(obj):
173183
)
174184

175185
da = da.transpose(TIME_DIM, Y_DIM, X_DIM)
186+
import cubedynamics.viz as viz
187+
176188
widget = viz.show_cube_lexcube(da, **kwargs)
177189
display(widget)
178190

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from __future__ import annotations
2+
3+
import importlib
4+
import sys
5+
6+
7+
def _clear_cubedynamics_modules() -> None:
8+
for name in list(sys.modules):
9+
if name == "cubedynamics" or name.startswith("cubedynamics."):
10+
sys.modules.pop(name, None)
11+
12+
13+
def test_verbs_import_does_not_eager_import_viz_modules():
14+
"""Importing verbs for custom plotting should not eagerly load lexcube viz modules."""
15+
16+
_clear_cubedynamics_modules()
17+
18+
verbs = importlib.import_module("cubedynamics.verbs")
19+
20+
assert callable(verbs.plot)
21+
assert "cubedynamics.viz" not in sys.modules
22+
assert "cubedynamics.viz.lexcube_viz" not in sys.modules

0 commit comments

Comments
 (0)