diff --git a/CHANGELOG.md b/CHANGELOG.md index 25316318e2..6140e135ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ to [Semantic Versioning]. The full commit history is available in the [commit lo #### Added +- Add shared memory support for DDP to deduplicate `adata.X` across ranks on the same node, + reducing per-rank memory, {pr}`3754`. - Add support for Python 3.14, {pr}`3563`. - Add support for Pandas3, {pr}`3638`. diff --git a/docs/user_guide/use_case/multi_gpu_training.md b/docs/user_guide/use_case/multi_gpu_training.md index f38e4d3254..16d0853d3b 100644 --- a/docs/user_guide/use_case/multi_gpu_training.md +++ b/docs/user_guide/use_case/multi_gpu_training.md @@ -38,6 +38,7 @@ We can see the advantage with larger data, while for the small data, there's no ### 3. **Memory Utilization** - **Larger Memory Pool:** When using multiple GPUs, each GPU can hold a part of the model and data, effectively creating a larger memory pool. This allows for larger batch sizes or more complex models. + - **Shared Memory Deduplication:** In subprocess-based DDP, each rank loads its own copy of `adata` into memory, which can be costly for large datasets. scvi-tools automatically uses POSIX shared memory to deduplicate `adata.X` across all ranks on the same node, so only one physical copy is kept in memory. For example, with 100K obs x 1000 genes (~400 MB), 8 GPUs go from 3.2 GB total to just 400 MB. This is auto-enabled when DDP is detected and works with both dense numpy and scipy sparse matrices. ## Using MultiGPU training in SCVI-Tools @@ -69,7 +70,35 @@ model.train( ) ``` -3. There are a few limitations with the current implementation: +3. **Shared memory for data deduplication:** + + By default, shared memory is auto-enabled when DDP is detected (`share_memory=None`). You can explicitly control this behavior: + + ```python + # Explicitly enable shared memory + model.train( + ..., + accelerator="gpu", + devices=-1, + strategy="ddp_find_unused_parameters_true", + datasplitter_kwargs={"share_memory": True}, + ) + + # Explicitly disable shared memory + model.train( + ..., + accelerator="gpu", + devices=-1, + strategy="ddp_find_unused_parameters_true", + datasplitter_kwargs={"share_memory": False}, + ) + ``` + + :::{note} + Shared memory deduplication only applies to dense numpy or scipy sparse `adata.X`. If `adata` is backed (h5ad on disk) or uses dask arrays, shared memory is automatically skipped since those formats already avoid full in-memory copies. + ::: + +4. There are a few limitations with the current implementation: - During an interactive session, like in a jupyter notebook, we can only train 1 model in multi GPU mode, per session. It means that we can't train SCANVI model from SCVI model if the SCVI model was trained in the same notebook. Therefore, need to train and save the SCVI model in another session and load it in the other session. This is a torch lightning caveat. - It can't run with early stopping right now (and some models, like totalvi, use early stopping by default), so we disable early stopping once running with DDP. the reason is that the validation loop should be running on one device only and not multiGPU. diff --git a/src/scvi/dataloaders/_data_splitting.py b/src/scvi/dataloaders/_data_splitting.py index 3abdaf3518..4a3a0b7423 100644 --- a/src/scvi/dataloaders/_data_splitting.py +++ b/src/scvi/dataloaders/_data_splitting.py @@ -1,3 +1,4 @@ +import logging import warnings from math import ceil, floor @@ -20,6 +21,8 @@ from scvi.model._utils import parse_device_args from scvi.utils._docstrings import devices_dsp +logger = logging.getLogger(__name__) + def validate_data_split( n_samples: int, @@ -207,6 +210,11 @@ class DataSplitter(pl.LightningDataModule): external_indexing A list of data split indices in the order of training, validation, and test sets. Validation and test set are not required and can be left empty. + share_memory + ``EXPERIMENTAL`` If ``True``, uses POSIX shared memory to deduplicate ``adata.X`` + across DDP ranks on the same node. If ``None`` (default), auto-enables when DDP + is detected. If ``False``, disables shared memory. Only applies to dense numpy + or scipy sparse ``adata.X``; backed and dask arrays are skipped. **kwargs Keyword args for data loader. If adata has labeled data, the data loader class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`, @@ -233,6 +241,7 @@ def __init__( load_sparse_tensor: bool = False, pin_memory: bool = False, external_indexing: list[np.array, np.array, np.array] | None = None, + share_memory: bool | None = None, **kwargs, ): super().__init__() @@ -246,6 +255,8 @@ def __init__( self.data_loader_kwargs = kwargs self.pin_memory = pin_memory self.external_indexing = external_indexing + self.share_memory = share_memory + self._shm_registry = None if self.external_indexing is not None: self.n_train, self.n_val = validate_data_split_with_external_indexing( @@ -264,6 +275,25 @@ def __init__( self.train_size_is_none, ) + def _should_share_memory(self) -> bool: + """Determine whether to use shared memory for adata.X.""" + if self.share_memory is False: + return False + + try: + import torch.distributed as dist + + if not dist.is_initialized() or dist.get_world_size() <= 1: + return False + except ImportError: + return False + + if self.share_memory is True: + return True + + # share_memory is None (auto): enable for DDP + return True + def setup(self, stage: str | None = None): """Split indices in train/test/val sets.""" if self.external_indexing is not None: @@ -286,6 +316,18 @@ def setup(self, stage: str | None = None): self.train_idx = indices[n_val : (n_val + n_train)] self.test_idx = indices[(n_val + n_train) :] + # Shared memory for DDP data deduplication + if self._should_share_memory(): + from scvi.dataloaders._shared_memory import setup_shared_memory + + self._shm_registry = setup_shared_memory(self.adata_manager) + + def teardown(self, stage: str | None = None): + """Clean up shared memory if used.""" + if self._shm_registry is not None: + self._shm_registry.cleanup() + self._shm_registry = None + def train_dataloader(self): """Create a train data loader.""" return self.data_loader_cls( diff --git a/src/scvi/dataloaders/_shared_memory.py b/src/scvi/dataloaders/_shared_memory.py new file mode 100644 index 0000000000..09427ee04e --- /dev/null +++ b/src/scvi/dataloaders/_shared_memory.py @@ -0,0 +1,343 @@ +"""Shared memory utilities for DDP data deduplication. + +In subprocess-based DDP, each GPU rank re-executes the full script and loads +its own copy of ``adata`` into memory. This module provides utilities to share +a single physical copy of ``adata.X`` across all ranks on the same node using +POSIX shared memory. +""" + +from __future__ import annotations + +import atexit +import gc +import logging +import os +import platform +from multiprocessing import shared_memory + +import numpy as np +from scipy.sparse import csr_matrix, issparse + +logger = logging.getLogger(__name__) + + +def share_dense_array( + name: str, + arr: np.ndarray, +) -> tuple[shared_memory.SharedMemory, np.ndarray]: + """Create shared memory from a dense numpy array (rank 0). + + Parameters + ---------- + name + Name for the shared memory block. + arr + Dense numpy array to share. + + Returns + ------- + Tuple of (SharedMemory handle, numpy view into shared memory). + """ + shm = shared_memory.SharedMemory(name=name, create=True, size=arr.nbytes) + shared_arr = np.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf) + np.copyto(shared_arr, arr) + return shm, shared_arr + + +def attach_dense_array( + name: str, + shape: tuple[int, ...], + dtype: np.dtype, +) -> tuple[shared_memory.SharedMemory, np.ndarray]: + """Attach to existing shared memory as a dense numpy array (rank 1+). + + Parameters + ---------- + name + Name of the existing shared memory block. + shape + Shape of the array. + dtype + Data type of the array. + + Returns + ------- + Tuple of (SharedMemory handle, numpy view into shared memory). + """ + shm = shared_memory.SharedMemory(name=name, create=False) + shared_arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf) + return shm, shared_arr + + +def share_sparse_csr( + name: str, + mat: csr_matrix, +) -> tuple[dict, list[shared_memory.SharedMemory]]: + """Create shared memory from a sparse CSR matrix (rank 0). + + Shares 3 arrays (data, indices, indptr) as separate shared memory blocks. + + Parameters + ---------- + name + Base name for the shared memory blocks. + mat + Sparse CSR matrix to share. + + Returns + ------- + Tuple of (metadata dict for broadcast, list of SharedMemory handles). + """ + mat = csr_matrix(mat) # ensure CSR format + + shm_data, shared_data = share_dense_array(f"{name}_data", mat.data) + shm_indices, shared_indices = share_dense_array(f"{name}_indices", mat.indices) + shm_indptr, shared_indptr = share_dense_array(f"{name}_indptr", mat.indptr) + + metadata = { + "shape": mat.shape, + "data_dtype": str(mat.data.dtype), + "data_shape": mat.data.shape, + "indices_dtype": str(mat.indices.dtype), + "indices_shape": mat.indices.shape, + "indptr_dtype": str(mat.indptr.dtype), + "indptr_shape": mat.indptr.shape, + } + return metadata, [shm_data, shm_indices, shm_indptr] + + +def attach_sparse_csr( + name: str, + metadata: dict, +) -> tuple[csr_matrix, list[shared_memory.SharedMemory]]: + """Attach to existing shared memory and reconstruct a sparse CSR matrix (rank 1+). + + Parameters + ---------- + name + Base name for the shared memory blocks. + metadata + Metadata dict from :func:`share_sparse_csr` (broadcast from rank 0). + + Returns + ------- + Tuple of (CSR matrix backed by shared memory, list of SharedMemory handles). + """ + shm_data, shared_data = attach_dense_array( + f"{name}_data", + metadata["data_shape"], + np.dtype(metadata["data_dtype"]), + ) + shm_indices, shared_indices = attach_dense_array( + f"{name}_indices", + metadata["indices_shape"], + np.dtype(metadata["indices_dtype"]), + ) + shm_indptr, shared_indptr = attach_dense_array( + f"{name}_indptr", + metadata["indptr_shape"], + np.dtype(metadata["indptr_dtype"]), + ) + + mat = csr_matrix( + (shared_data, shared_indices, shared_indptr), + shape=metadata["shape"], + copy=False, + ) + return mat, [shm_data, shm_indices, shm_indptr] + + +class SharedMemoryRegistry: + """Tracks shared memory blocks and handles cleanup. + + Parameters + ---------- + is_rank0 + Whether this process is rank 0 (responsible for unlinking). + """ + + def __init__(self, is_rank0: bool = False): + self.is_rank0 = is_rank0 + self._handles: list[shared_memory.SharedMemory] = [] + self._cleaned_up = False + + def register(self, shm: shared_memory.SharedMemory | list[shared_memory.SharedMemory]): + """Register shared memory handle(s) for cleanup.""" + if isinstance(shm, list): + self._handles.extend(shm) + else: + self._handles.append(shm) + + def cleanup(self): + """Close all handles and unlink on rank 0.""" + if self._cleaned_up: + return + self._cleaned_up = True + + for shm in self._handles: + try: + shm.close() + except OSError: + pass + if self.is_rank0: + try: + shm.unlink() + except OSError: + pass + self._handles.clear() + + +def _malloc_trim(): + """Call malloc_trim on Linux to return freed pages to the OS.""" + if platform.system() == "Linux": + try: + import ctypes + + libc = ctypes.CDLL("libc.so.6") + libc.malloc_trim(0) + except OSError: + pass + + +def _is_shareable(X) -> bool: + """Check if the data matrix X is shareable (dense numpy or scipy sparse). + + Returns False for backed (h5py) or dask arrays. + """ + import h5py + + if isinstance(X, h5py.Dataset): + return False + if isinstance(X, np.ndarray): + return True + if issparse(X): + return True + try: + import dask.array as da + + if isinstance(da.Array, type) and isinstance(X, da.Array): + return False + except ImportError: + pass + return False + + +def setup_shared_memory( + adata_manager, + registry: SharedMemoryRegistry | None = None, +) -> SharedMemoryRegistry | None: + """Orchestrate shared memory setup for DDP. + + Called from ``DataSplitter.setup()`` after index splitting. Rank 0 copies + ``adata.X`` into shared memory; all ranks replace ``adata.X`` with a view + into the shared block. + + Parameters + ---------- + adata_manager + The AnnDataManager whose adata.X should be shared. + registry + Existing registry, or None to create a new one. + + Returns + ------- + SharedMemoryRegistry if shared memory was set up, None otherwise. + """ + import torch.distributed as dist + + if not dist.is_initialized(): + return None + + rank = dist.get_rank() + world_size = dist.get_world_size() + if world_size <= 1: + return None + + adata = adata_manager.adata + X = adata.X + + if not _is_shareable(X): + logger.info("adata.X is not shareable (backed or dask); skipping shared memory.") + return None + + if registry is None: + registry = SharedMemoryRegistry(is_rank0=(rank == 0)) + + # Use rank 0's PID for unique naming across training sessions + pid_list = [os.getpid() if rank == 0 else 0] + dist.broadcast_object_list(pid_list, src=0) + base_name = f"scvi_{pid_list[0]}_X" + + is_sparse = issparse(X) + + if rank == 0: + # Clean up stale shared memory with the same name + _cleanup_stale_shm(base_name, is_sparse) + + if is_sparse: + metadata, shm_handles = share_sparse_csr(base_name, X) + registry.register(shm_handles) + broadcast_data = [{"sparse": True, "metadata": metadata}] + else: + shm, shared_arr = share_dense_array(base_name, X) + registry.register(shm) + broadcast_data = [ + { + "sparse": False, + "shape": X.shape, + "dtype": str(X.dtype), + } + ] + logger.info( + f"[rank 0] Created shared memory for adata.X " + f"({X.nbytes / 1024**2:.1f} MB, sparse={is_sparse})" + ) + else: + broadcast_data = [None] + + dist.broadcast_object_list(broadcast_data, src=0) + info = broadcast_data[0] + + if rank != 0: + if info["sparse"]: + shared_X, shm_handles = attach_sparse_csr(base_name, info["metadata"]) + registry.register(shm_handles) + else: + shm, shared_X = attach_dense_array(base_name, info["shape"], np.dtype(info["dtype"])) + registry.register(shm) + logger.info(f"[rank {rank}] Attached to shared memory for adata.X") + else: + if is_sparse: + # Rank 0: reconstruct from shared memory too so we can free the original + shared_X, attach_handles = attach_sparse_csr(base_name, info["metadata"]) + registry.register(attach_handles) + else: + shared_X = shared_arr + + # Replace adata.X with the shared view + del X + adata.X = shared_X + gc.collect() + _malloc_trim() + + # Register atexit cleanup as a safety net + atexit.register(registry.cleanup) + + dist.barrier() + return registry + + +def _cleanup_stale_shm(base_name: str, is_sparse: bool): + """Try to clean up stale shared memory from a previous run.""" + names = ( + [base_name] + if not is_sparse + else [f"{base_name}_data", f"{base_name}_indices", f"{base_name}_indptr"] + ) + for name in names: + try: + old = shared_memory.SharedMemory(name=name, create=False) + old.close() + old.unlink() + except FileNotFoundError: + pass diff --git a/src/scvi/train/_trainrunner.py b/src/scvi/train/_trainrunner.py index 50030028fe..84e236ebc4 100644 --- a/src/scvi/train/_trainrunner.py +++ b/src/scvi/train/_trainrunner.py @@ -143,6 +143,9 @@ def _run_training_core(self): self.trainer.fit(self.training_plan, self.data_splitter, ckpt_path=self.ckpt_path) except BaseException as e: self._update_history() + # In DDP, exit non-rank-0 workers immediately to prevent zombies + if not self.trainer.is_global_zero: + self._exit_non_rank0(exit_code=1) print("Exception raised during training.", NameError, e) gc.collect() @@ -158,6 +161,12 @@ def _run_training_core(self): raise self._update_history() + # In DDP, non-rank-0 worker subprocesses must exit after training. + # Without this, they become zombies holding GPU memory at 100% utilization + # because the subprocess continues executing the rest of the user's script. + if not self.trainer.is_global_zero: + self._exit_non_rank0() + # data splitter only gets these attrs after fit self.model.train_indices = getattr(self.data_splitter, "train_idx", None) self.model.test_indices = getattr(self.data_splitter, "test_idx", None) @@ -170,6 +179,33 @@ def _run_training_core(self): return + def _exit_non_rank0(self, exit_code: int = 0): + """Clean up and terminate non-rank-0 DDP worker processes. + + In subprocess-based DDP, Lightning re-launches the user's script for + each rank via ``subprocess.Popen()``. After ``trainer.fit()`` returns, + every rank continues executing the rest of the script. Non-rank-0 + processes have no purpose post-training and, if left alive, become + zombie processes that hold GPU memory and spin at 100% GPU utilization. + + This method releases CUDA resources, destroys the distributed process + group, and terminates the process. ``DataSplitter.teardown()`` has + already been called by Lightning before ``fit()`` returns, so shared + memory is already cleaned up. + """ + logger.debug("Non-rank-0 DDP worker: cleaning up and exiting.") + gc.collect() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + import torch.distributed as dist + + if dist.is_initialized(): + dist.destroy_process_group() + + os._exit(exit_code) + def __call__(self): """Run training.""" if hasattr(self.data_splitter, "n_train"): diff --git a/tests/dataloaders/test_shared_memory.py b/tests/dataloaders/test_shared_memory.py new file mode 100644 index 0000000000..2faaff58d1 --- /dev/null +++ b/tests/dataloaders/test_shared_memory.py @@ -0,0 +1,199 @@ +"""Tests for shared memory utilities.""" + +import tempfile + +import h5py +import numpy as np +import pytest +from scipy.sparse import csr_matrix +from scipy.sparse import random as sparse_random + +from scvi.dataloaders._shared_memory import ( + SharedMemoryRegistry, + _is_shareable, + attach_dense_array, + attach_sparse_csr, + share_dense_array, + share_sparse_csr, +) + + +@pytest.fixture +def dense_array(): + rng = np.random.default_rng(42) + return rng.standard_normal((100, 50)).astype(np.float32) + + +@pytest.fixture +def sparse_csr(): + return sparse_random(100, 50, density=0.1, format="csr", dtype=np.float32, random_state=42) + + +class TestShareDenseArray: + def test_round_trip(self, dense_array): + """Test creating and attaching shared memory for a dense array.""" + name = "test_dense_rt" + shm, shared_arr = share_dense_array(name, dense_array) + try: + np.testing.assert_array_equal(shared_arr, dense_array) + + shm2, attached_arr = attach_dense_array(name, dense_array.shape, dense_array.dtype) + np.testing.assert_array_equal(attached_arr, dense_array) + + # Verify they share the same memory + attached_arr[0, 0] = 999.0 + assert shared_arr[0, 0] == 999.0 + + shm2.close() + finally: + shm.close() + shm.unlink() + + def test_different_dtypes(self): + """Test with various numpy dtypes.""" + for dtype in [np.float32, np.float64, np.int32, np.int64]: + name = f"test_dtype_{dtype.__name__}" + arr = np.arange(20, dtype=dtype).reshape(4, 5) + shm, _ = share_dense_array(name, arr) + try: + shm2, attached = attach_dense_array(name, arr.shape, arr.dtype) + np.testing.assert_array_equal(attached, arr) + assert attached.dtype == dtype + shm2.close() + finally: + shm.close() + shm.unlink() + + +class TestShareSparseCsr: + def test_round_trip(self, sparse_csr): + """Test creating and attaching shared memory for a sparse CSR matrix.""" + name = "test_sparse_rt" + metadata, shm_handles = share_sparse_csr(name, sparse_csr) + try: + mat, shm_handles2 = attach_sparse_csr(name, metadata) + + np.testing.assert_array_almost_equal(mat.toarray(), sparse_csr.toarray()) + assert mat.shape == sparse_csr.shape + assert mat.nnz == sparse_csr.nnz + + for h in shm_handles2: + h.close() + finally: + for h in shm_handles: + h.close() + h.unlink() + + def test_shared_data_modification(self, sparse_csr): + """Test that modifying the attached sparse matrix affects the shared one.""" + name = "test_sparse_mod" + metadata, shm_handles = share_sparse_csr(name, sparse_csr) + try: + mat, shm_handles2 = attach_sparse_csr(name, metadata) + + if mat.nnz > 0: + mat.data[0] = 12345.0 + + mat2, shm_handles3 = attach_sparse_csr(name, metadata) + assert mat2.data[0] == 12345.0 + + for h in shm_handles3: + h.close() + + for h in shm_handles2: + h.close() + finally: + for h in shm_handles: + h.close() + h.unlink() + + def test_metadata_fields(self, sparse_csr): + """Test that metadata contains the expected fields.""" + name = "test_sparse_meta" + metadata, shm_handles = share_sparse_csr(name, sparse_csr) + try: + assert "shape" in metadata + assert "data_dtype" in metadata + assert "indices_dtype" in metadata + assert "indptr_dtype" in metadata + assert metadata["shape"] == sparse_csr.shape + finally: + for h in shm_handles: + h.close() + h.unlink() + + +class TestSharedMemoryRegistry: + def test_register_and_cleanup(self): + """Test registering handles and cleaning them up.""" + registry = SharedMemoryRegistry(is_rank0=True) + + arr = np.zeros(100, dtype=np.float32) + shm1, _ = share_dense_array("test_reg_1", arr) + shm2, _ = share_dense_array("test_reg_2", arr) + + registry.register(shm1) + registry.register(shm2) + + registry.cleanup() + + with pytest.raises(FileNotFoundError): + attach_dense_array("test_reg_1", (100,), np.float32) + with pytest.raises(FileNotFoundError): + attach_dense_array("test_reg_2", (100,), np.float32) + + def test_register_list(self): + """Test registering a list of handles at once.""" + registry = SharedMemoryRegistry(is_rank0=True) + + arr = np.zeros(100, dtype=np.float32) + shm1, _ = share_dense_array("test_reglist_1", arr) + shm2, _ = share_dense_array("test_reglist_2", arr) + + registry.register([shm1, shm2]) + registry.cleanup() + + with pytest.raises(FileNotFoundError): + attach_dense_array("test_reglist_1", (100,), np.float32) + + def test_non_rank0_does_not_unlink(self): + """Test that non-rank-0 registry closes but doesn't unlink.""" + arr = np.zeros(100, dtype=np.float32) + shm_creator, _ = share_dense_array("test_noreg_unlink", arr) + + shm_attacher, _ = attach_dense_array("test_noreg_unlink", (100,), np.float32) + registry = SharedMemoryRegistry(is_rank0=False) + registry.register(shm_attacher) + registry.cleanup() + + # Should still be accessible (not unlinked) + shm_check, _ = attach_dense_array("test_noreg_unlink", (100,), np.float32) + shm_check.close() + + shm_creator.close() + shm_creator.unlink() + + def test_double_cleanup_is_safe(self): + """Test that calling cleanup twice doesn't raise.""" + registry = SharedMemoryRegistry(is_rank0=True) + arr = np.zeros(100, dtype=np.float32) + shm, _ = share_dense_array("test_double_cleanup", arr) + registry.register(shm) + + registry.cleanup() + registry.cleanup() # should be a no-op + + +class TestIsShareable: + def test_dense_numpy(self): + assert _is_shareable(np.zeros((10, 5))) + + def test_sparse_scipy(self): + assert _is_shareable(csr_matrix(np.zeros((10, 5)))) + + def test_h5py_dataset(self): + with tempfile.NamedTemporaryFile(suffix=".h5") as f: + with h5py.File(f.name, "w") as hf: + hf.create_dataset("X", data=np.zeros((10, 5))) + with h5py.File(f.name, "r") as hf: + assert not _is_shareable(hf["X"]) diff --git a/tests/model/test_multigpu.py b/tests/model/test_multigpu.py index 91ae836d4a..f09efcbfe1 100644 --- a/tests/model/test_multigpu.py +++ b/tests/model/test_multigpu.py @@ -225,6 +225,27 @@ def test_linearcvi_multigpu(): assert model.is_trained +@pytest.mark.multigpu +def test_scvi_shared_memory_multigpu(): + """Test SCVI training with shared memory enabled in DDP.""" + from scvi.model import SCVI + + adata = scvi.data.synthetic_iid() + SCVI.setup_anndata(adata) + + model = SCVI(adata) + model.train( + max_epochs=1, + check_val_every_n_epoch=1, + accelerator="gpu", + devices=-1, + strategy="ddp_find_unused_parameters_true", + datasplitter_kwargs={"share_memory": True}, + ) + assert len(model.history["elbo_train"]) == 1 + assert model.is_trained + + @pytest.mark.multigpu def test_scvi_train_ddp(save_path: str): training_code = """