Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ to [Semantic Versioning]. The full commit history is available in the [commit lo

#### Added

- Add shared memory support for DDP to deduplicate `adata.X` across ranks on the same node,
reducing per-rank memory, {pr}`3754`.
- Add support for Python 3.14, {pr}`3563`.
- Add support for Pandas3, {pr}`3638`.

Expand Down
31 changes: 30 additions & 1 deletion docs/user_guide/use_case/multi_gpu_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ We can see the advantage with larger data, while for the small data, there's no

### 3. **Memory Utilization**
- **Larger Memory Pool:** When using multiple GPUs, each GPU can hold a part of the model and data, effectively creating a larger memory pool. This allows for larger batch sizes or more complex models.
- **Shared Memory Deduplication:** In subprocess-based DDP, each rank loads its own copy of `adata` into memory, which can be costly for large datasets. scvi-tools automatically uses POSIX shared memory to deduplicate `adata.X` across all ranks on the same node, so only one physical copy is kept in memory. For example, with 100K obs x 1000 genes (~400 MB), 8 GPUs go from 3.2 GB total to just 400 MB. This is auto-enabled when DDP is detected and works with both dense numpy and scipy sparse matrices.

## Using MultiGPU training in SCVI-Tools

Expand Down Expand Up @@ -69,7 +70,35 @@ model.train(
)
```

3. There are a few limitations with the current implementation:
3. **Shared memory for data deduplication:**

By default, shared memory is auto-enabled when DDP is detected (`share_memory=None`). You can explicitly control this behavior:

```python
# Explicitly enable shared memory
model.train(
...,
accelerator="gpu",
devices=-1,
strategy="ddp_find_unused_parameters_true",
datasplitter_kwargs={"share_memory": True},
)

# Explicitly disable shared memory
model.train(
...,
accelerator="gpu",
devices=-1,
strategy="ddp_find_unused_parameters_true",
datasplitter_kwargs={"share_memory": False},
)
```

:::{note}
Shared memory deduplication only applies to dense numpy or scipy sparse `adata.X`. If `adata` is backed (h5ad on disk) or uses dask arrays, shared memory is automatically skipped since those formats already avoid full in-memory copies.
:::

4. There are a few limitations with the current implementation:
- During an interactive session, like in a jupyter notebook, we can only train 1 model in multi GPU mode, per session.
It means that we can't train SCANVI model from SCVI model if the SCVI model was trained in the same notebook. Therefore, need to train and save the SCVI model in another session and load it in the other session. This is a torch lightning caveat.
- It can't run with early stopping right now (and some models, like totalvi, use early stopping by default), so we disable early stopping once running with DDP. the reason is that the validation loop should be running on one device only and not multiGPU.
Expand Down
42 changes: 42 additions & 0 deletions src/scvi/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import warnings
from math import ceil, floor

Expand All @@ -20,6 +21,8 @@
from scvi.model._utils import parse_device_args
from scvi.utils._docstrings import devices_dsp

logger = logging.getLogger(__name__)


def validate_data_split(
n_samples: int,
Expand Down Expand Up @@ -207,6 +210,11 @@ class DataSplitter(pl.LightningDataModule):
external_indexing
A list of data split indices in the order of training, validation, and test sets.
Validation and test set are not required and can be left empty.
share_memory
``EXPERIMENTAL`` If ``True``, uses POSIX shared memory to deduplicate ``adata.X``
across DDP ranks on the same node. If ``None`` (default), auto-enables when DDP
is detected. If ``False``, disables shared memory. Only applies to dense numpy
or scipy sparse ``adata.X``; backed and dask arrays are skipped.
**kwargs
Keyword args for data loader. If adata has labeled data, the data loader
class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`,
Expand All @@ -233,6 +241,7 @@ def __init__(
load_sparse_tensor: bool = False,
pin_memory: bool = False,
external_indexing: list[np.array, np.array, np.array] | None = None,
share_memory: bool | None = None,
**kwargs,
):
super().__init__()
Expand All @@ -246,6 +255,8 @@ def __init__(
self.data_loader_kwargs = kwargs
self.pin_memory = pin_memory
self.external_indexing = external_indexing
self.share_memory = share_memory
self._shm_registry = None

if self.external_indexing is not None:
self.n_train, self.n_val = validate_data_split_with_external_indexing(
Expand All @@ -264,6 +275,25 @@ def __init__(
self.train_size_is_none,
)

def _should_share_memory(self) -> bool:
"""Determine whether to use shared memory for adata.X."""
if self.share_memory is False:
return False

try:
import torch.distributed as dist

if not dist.is_initialized() or dist.get_world_size() <= 1:
return False
except ImportError:
return False

if self.share_memory is True:
return True

# share_memory is None (auto): enable for DDP
return True

def setup(self, stage: str | None = None):
"""Split indices in train/test/val sets."""
if self.external_indexing is not None:
Expand All @@ -286,6 +316,18 @@ def setup(self, stage: str | None = None):
self.train_idx = indices[n_val : (n_val + n_train)]
self.test_idx = indices[(n_val + n_train) :]

# Shared memory for DDP data deduplication
if self._should_share_memory():
from scvi.dataloaders._shared_memory import setup_shared_memory

self._shm_registry = setup_shared_memory(self.adata_manager)

def teardown(self, stage: str | None = None):
"""Clean up shared memory if used."""
if self._shm_registry is not None:
self._shm_registry.cleanup()
self._shm_registry = None

def train_dataloader(self):
"""Create a train data loader."""
return self.data_loader_cls(
Expand Down
Loading
Loading