From bbc73ba57de5e1ec9bbd08fe03571fecddeea88a Mon Sep 17 00:00:00 2001 From: Ty Tuff Date: Sat, 28 Mar 2026 11:01:07 -0600 Subject: [PATCH] Add regression test for no eager viz import in verbs --- src/cubedynamics/__init__.py | 23 ++++++++++++++++++++--- src/cubedynamics/verbs/__init__.py | 20 ++++++++++++++++---- tests/test_verbs_import_no_eager_viz.py | 22 ++++++++++++++++++++++ 3 files changed, 58 insertions(+), 7 deletions(-) create mode 100644 tests/test_verbs_import_no_eager_viz.py diff --git a/src/cubedynamics/__init__.py b/src/cubedynamics/__init__.py index a5db277..5e11e4e 100644 --- a/src/cubedynamics/__init__.py +++ b/src/cubedynamics/__init__.py @@ -31,7 +31,10 @@ from . import verbs from . import tubes from .demo_vase import demo -import xarray as xr +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import xarray as xr # Legacy, fully implemented APIs ------------------------------------------------- from .data.gridmet import load_gridmet_cube @@ -70,8 +73,6 @@ from .stats.spatial import mask_by_threshold, spatial_coarsen_mean, spatial_smooth_mean from .stats.tails import rolling_tail_dep_vs_center from .utils.chunking import coarsen_and_stride -from .viz.lexcube_viz import show_cube_lexcube -from .viz.qa_plots import plot_median_over_space from .ops import ( anomaly, month_filter, @@ -173,3 +174,19 @@ def plot( from .ops_fire.fired_api import fired_event __all__ += ["gridmet", "fired_event"] + + +def show_cube_lexcube(*args, **kwargs): + """Lazy wrapper to avoid importing visualization stack at package import time.""" + + from .viz.lexcube_viz import show_cube_lexcube as _show_cube_lexcube + + return _show_cube_lexcube(*args, **kwargs) + + +def plot_median_over_space(*args, **kwargs): + """Lazy wrapper to avoid importing visualization stack at package import time.""" + + from .viz.qa_plots import plot_median_over_space as _plot_median_over_space + + return _plot_median_over_space(*args, **kwargs) diff --git a/src/cubedynamics/verbs/__init__.py b/src/cubedynamics/verbs/__init__.py index b4dd60d..6ae14fa 100644 --- a/src/cubedynamics/verbs/__init__.py +++ b/src/cubedynamics/verbs/__init__.py @@ -22,10 +22,8 @@ from __future__ import annotations -import cubedynamics.viz as viz import matplotlib.pyplot as plt import numpy as np -import xarray as xr from IPython.display import display from ..config import TIME_DIM, X_DIM, Y_DIM @@ -57,9 +55,17 @@ from .stats import anomaly, mean, rolling_tail_dep_vs_center, variance, zscore +def _import_xarray(): + """Import xarray lazily to avoid import-time hard dependency failures.""" + + import xarray as xr + + return xr + + def _unwrap_dataarray( - obj: xr.DataArray | VirtualCube | None, -) -> tuple[xr.DataArray, xr.DataArray | VirtualCube]: + obj, +): """ Normalize a verb input to an (xarray.DataArray, original_obj) pair. @@ -73,6 +79,8 @@ def _unwrap_dataarray( if obj is None: raise ValueError("extract() requires an input cube/DataArray; got None.") + xr = _import_xarray() + if isinstance(obj, VirtualCube): base_da = obj.materialize() if not isinstance(base_da, xr.DataArray): @@ -154,6 +162,8 @@ def show_cube_lexcube(**kwargs): """ def _op(obj): + xr = _import_xarray() + # normalize to DataArray if needed (Dataset with 1 var) if isinstance(obj, xr.Dataset): if len(obj.data_vars) != 1: @@ -173,6 +183,8 @@ def _op(obj): ) da = da.transpose(TIME_DIM, Y_DIM, X_DIM) + import cubedynamics.viz as viz + widget = viz.show_cube_lexcube(da, **kwargs) display(widget) diff --git a/tests/test_verbs_import_no_eager_viz.py b/tests/test_verbs_import_no_eager_viz.py new file mode 100644 index 0000000..b10215c --- /dev/null +++ b/tests/test_verbs_import_no_eager_viz.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import importlib +import sys + + +def _clear_cubedynamics_modules() -> None: + for name in list(sys.modules): + if name == "cubedynamics" or name.startswith("cubedynamics."): + sys.modules.pop(name, None) + + +def test_verbs_import_does_not_eager_import_viz_modules(): + """Importing verbs for custom plotting should not eagerly load lexcube viz modules.""" + + _clear_cubedynamics_modules() + + verbs = importlib.import_module("cubedynamics.verbs") + + assert callable(verbs.plot) + assert "cubedynamics.viz" not in sys.modules + assert "cubedynamics.viz.lexcube_viz" not in sys.modules