diff --git a/docs/releases.md b/docs/releases.md index ecb83f39..04fb8a18 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -42,6 +42,8 @@ By [Tom Nicholas](https://github.com/TomNicholas). - Fix `ZarrParser` not correctly parsing scalar variables from v2 native zarr stores ([#936](https://github.com/zarr-developers/VirtualiZarr/pull/936)). By [Julius Buseceke](https://github.com/jbusecke) +- Add `.shutdown()` method to custom executors (dask, lithops) preventing unbounded memory increase in the case of lithops ([#925](https://github.com/zarr-developers/VirtualiZarr/pull/925)). + By [Julius Buseceke](https://github.com/jbusecke) ### Documentation diff --git a/virtualizarr/parallel.py b/virtualizarr/parallel.py index 7341261d..2fcadd68 100644 --- a/virtualizarr/parallel.py +++ b/virtualizarr/parallel.py @@ -137,6 +137,9 @@ def map( """ return map(fn, *iterables) + def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: + self._futures.clear() + class DaskDelayedExecutor(Executor): """ @@ -230,6 +233,9 @@ def map( # Compute all tasks return iter(dask.compute(*delayed_tasks)) + def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: + self._futures.clear() + class LithopsEagerFunctionExecutor(Executor): """ @@ -270,6 +276,23 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: def __init__(self, **kwargs) -> None: import lithops # type: ignore[import-untyped] + # Fix for unbounded memory growth on repeated `open_virtual_mfdataset` calls + # see https://github.com/zarr-developers/VirtualiZarr/issues/926 + + # Users are encouraged to provide configs for lithops via file + # But just in case that someone imports this and configures it, they have to provide all + # details below explicitly as `config=` argument. + if "config" not in kwargs: + _config_file = lithops.config.load_config() + if _config_file["lithops"].get("backend") == "localhost": + # We currently only want to apply this fix for the localhost executor + kwargs["config"] = { + "lithops": { + "data_cleaner": False, # prevents atexit registration of `.lithops_client.clean` method + "backend": "localhost", # if this is not provided lithops will default to aws lambda + } + } + # Create Lithops client with optional configuration self.lithops_client = lithops.FunctionExecutor(**kwargs).__enter__() @@ -372,4 +395,16 @@ def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: wait Whether to wait for pending futures. """ + if wait: + # ensure all futures are completed before exiting + self.lithops_client.wait(show_progressbar=False) + + self._futures.clear() + + # Free output memory and clear lithops internal futures list + for f in self.lithops_client.futures: + f._call_output = None + self.lithops_client.futures.clear() + + # Exit context manager entered during __init__ self.lithops_client.__exit__(None, None, None) diff --git a/virtualizarr/tests/test_parallel.py b/virtualizarr/tests/test_parallel.py index 9aad3fdc..d274cb58 100644 --- a/virtualizarr/tests/test_parallel.py +++ b/virtualizarr/tests/test_parallel.py @@ -2,8 +2,13 @@ import pytest -from virtualizarr.parallel import LithopsEagerFunctionExecutor, get_executor -from virtualizarr.tests import requires_lithops +from virtualizarr.parallel import ( + DaskDelayedExecutor, + LithopsEagerFunctionExecutor, + SerialExecutor, + get_executor, +) +from virtualizarr.tests import requires_dask, requires_lithops @pytest.mark.flaky @@ -43,3 +48,55 @@ def test_get_executor_process_pool_mode(): assert ctx is not None, "Expected executor to have a multiprocessing context" assert ctx.get_start_method() == "forkserver" + + +def _make_executor(executor_cls): + """Create a pytest param for an executor class with appropriate marks.""" + marks = { + "DaskDelayedExecutor": [requires_dask], + "LithopsEagerFunctionExecutor": [requires_lithops], + } + return pytest.param( + executor_cls, + id=executor_cls.__name__, + marks=marks.get(executor_cls.__name__, []), + ) + + +ALL_CUSTOM_EXECUTORS = [ + _make_executor(SerialExecutor), + _make_executor(DaskDelayedExecutor), + _make_executor(LithopsEagerFunctionExecutor), +] + + +@pytest.mark.parametrize("executor_cls", ALL_CUSTOM_EXECUTORS) +class TestExecutorShutdown: + def test_shutdown_clears_futures(self, executor_cls): + """Internal _futures list should be empty after shutdown.""" + with executor_cls() as executor: + executor.submit(lambda x: x * 2, 1) + executor.submit(lambda x: x + 1, 2) + assert len(executor._futures) == 2 + if executor_cls is LithopsEagerFunctionExecutor: + # grab refs before they get cleared + lithops_futures = list(executor.lithops_client.futures) + assert len(lithops_futures) == 2 + + assert len(executor._futures) == 0 + + # Lithops-specific: verify lithops internal futures are also cleared + if executor_cls is LithopsEagerFunctionExecutor: + assert len(executor.lithops_client.futures) == 0 + assert all(f._call_output is None for f in lithops_futures) + + # Testing idempotency + executor.shutdown() + assert len(executor._futures) == 0 + + +@requires_lithops +def test_lithops_executor_data_cleaner_disabled(): + """Data_cleaner must be False to prevent atexit registration of lithops' clean method.""" + with LithopsEagerFunctionExecutor() as executor: + assert executor.lithops_client.data_cleaner is False