Skip to content

Commit c11c1ab

Browse files
authored
[main] Integration/peerlab (#38)
Major: CUDA/Xenium install #21: CUDA 12.1 support; Xenium v1 + binary column datatype support; pixi/conda install updates Reference gene correlation #32: Add custom reference gene correlation support via `--gene-corr-reference-path` Memory optimisations #36: GPU/memory optimisations for large datasets; chunked subgraphing fallback Large dataset fixes #44: Fix torch array overflow and related scaling issues Tiling fixes #42: Fix tiling/bincount errors and transcript thresholding Minor: Logging improvements #31: Improve debug logging; support `SEGGER_LOG_LEVEL` Margin handling #34: Warn instead of fail when tiling margins are too large Misc cleanup: Refactors; assertions/debugging; remove unused code and comments --------- Co-authored-by: Nandula Co-authored-by: Kalfus
1 parent 32b1469 commit c11c1ab

26 files changed

Lines changed: 4496 additions & 150 deletions

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,4 +205,5 @@ __marimo__/
205205
# Custom
206206
.dev
207207
.dev/*
208-
*.pyc
208+
*.pyc
209+
*memray*

README.md

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,30 @@
11
# Installation
22

3-
## pip
3+
We recommend CUDA 12.1 with `cu*` packages version ≥24.2 and <26.0. Ensure your CUDA driver version matches or exceeds your toolkit version (≥12.1 for CUDA 12.1).
4+
Adjust package versions in the environment files below if your system requires a different package versions.
45

5-
Before installing **segger**, please install GPU-accelerated versions of PyTorch, RAPIDS, and related packages compatible with your system. *Please ensure all CUDA-enabled packages are compiled for the same CUDA version.*
6+
## Clone the repository
7+
```bash
8+
git clone https://github.com/dpeerlab/segger.git segger && cd segger
9+
```
10+
11+
## Using `conda`
12+
```bash
13+
conda env create -n segger -f environment_cuda121.yml
14+
```
15+
16+
Adjust `environment_cuda121.yml` for other CUDA versions (e.g., `environment_cuda118.yml` for CUDA 11.8).
17+
18+
## Using `pixi`
19+
```bash
20+
pixi install -e cuda121
21+
```
22+
23+
Adjust the environment name in `pixi.toml` as needed for other CUDA versions.
24+
25+
## `pip`
26+
27+
Install GPU-accelerated PyTorch and RAPIDS compatible with your CUDA version before installing **segger**. All CUDA-enabled packages must be compiled for the same CUDA version.
628

729
- **PyTorch & torchvision:** [Installation guide](https://pytorch.org/get-started/locally/)
830
- **torch_scatter:** [Installation guide](https://github.com/rusty1s/pytorch_scatter#installation)
@@ -28,7 +50,6 @@ pip install cupy-cuda12x
2850

2951
```bash
3052
# Clone segger repo and install locally
31-
git clone https://github.com/dpeerlab/segger.git segger && cd segger
3253
pip install -e .
3354
```
3455

environment_cuda121.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
channels:
2+
- pytorch
3+
- nvidia
4+
- rapidsai
5+
- conda-forge
6+
7+
dependencies:
8+
- python>=3.11,<3.12
9+
10+
# RAPIDS 24.10 — last release compatible with CUDA 12.1 runtime
11+
- cuspatial=24.10
12+
- cudf=24.10
13+
- cuml=24.10
14+
- cugraph=24.10
15+
- cupy
16+
17+
- pip
18+
- pip:
19+
- --extra-index-url https://pypi.nvidia.com
20+
- --extra-index-url https://download.pytorch.org/whl/cu121
21+
- --find-links https://data.pyg.org/whl/torch-2.5.0+cu121.html
22+
- torch==2.5.*
23+
- torchvision==0.20.*
24+
- lightning
25+
- torch-geometric
26+
- torch-scatter
27+
- torch-sparse
28+
- cupy-cuda12x>=12.2,<13.0
29+
- --editable .

pixi.lock

Lines changed: 3898 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pixi.toml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
[workspace]
2+
channels = ["pytorch", "nvidia", "rapidsai", "conda-forge"]
3+
platforms = ["linux-64"]
4+
5+
[dependencies]
6+
python = ">=3.11"
7+
8+
[pypi-dependencies]
9+
segger = { path = ".", editable = true }
10+
11+
[feature.cuda121.dependencies]
12+
python = ">=3.11,<3.12"
13+
14+
[feature.cuda121.system-requirements]
15+
cuda = "12"
16+
17+
[feature.cuda121.pypi-options]
18+
index-strategy = "unsafe-best-match"
19+
index-url = "https://pypi.org/simple"
20+
extra-index-urls = [
21+
"https://pypi.nvidia.com",
22+
"https://download.pytorch.org/whl/cu121"
23+
]
24+
find-links = [
25+
{ url = "https://data.pyg.org/whl/torch-2.5.0+cu121.html" }
26+
]
27+
28+
[feature.cuda121.pypi-dependencies]
29+
torch = "==2.5.*"
30+
torchvision = "==0.20.*"
31+
lightning = "*"
32+
torch_geometric = "*"
33+
torch-scatter = "*"
34+
torch-sparse = "*"
35+
cuspatial-cu12 = "==24.10.*"
36+
cudf-cu12 = "==24.10.*"
37+
cuml-cu12 = "==24.10.*"
38+
cugraph-cu12 = "==24.10.*"
39+
cupy-cuda12x = ">=12.2,<13.0"
40+
41+
[tool.uv.extra-build-dependencies]
42+
torch-scatter = ["torch"]
43+
44+
[environments]
45+
cuda121 = ["cuda121"]

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ dependencies = [
1515
"anndata",
1616
"cyclopts",
1717
"geopandas",
18-
"lightning",
1918
"numba",
2019
"numpy",
2120
"opencv-python",
@@ -27,8 +26,7 @@ dependencies = [
2726
"shapely",
2827
"scikit-image",
2928
"scikit-learn",
30-
"tifffile",
31-
"torch_geometric",
29+
"tifffile"
3230
]
3331

3432
[build-system]

src/segger/__init__.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from pathlib import Path
2+
import cupy as cp
3+
import torch
4+
import rmm
5+
from rmm.allocators.cupy import rmm_cupy_allocator
6+
from rmm.allocators.torch import rmm_torch_allocator
7+
from rmm.statistics import enable_statistics, get_statistics
8+
9+
# Single RMM pool shared by CuPy/cuDF/cuSpatial AND PyTorch. Must be set before
10+
# any CUDA tensor is created.
11+
rmm.reinitialize(pool_allocator=True, managed_memory=True)
12+
cp.cuda.set_allocator(rmm_cupy_allocator)
13+
torch.cuda.memory.change_current_allocator(rmm_torch_allocator)
14+
enable_statistics()
15+
16+
# Apply pytorch patches for issue pytorch/pytorch#51871 (CUDA nonzero INT_MAX limit).
17+
# Must run BEFORE any segger module imports HeteroData / bipartite_subgraph.
18+
from ._patches import apply as _apply_patches
19+
_apply_patches()
20+
21+
def free_mem_str() -> str:
22+
stats = get_statistics()
23+
return (
24+
f"GPU: {stats.current_bytes / 1e9:.2f} GB "
25+
f"(peak {stats.peak_bytes / 1e9:.2f} GB)"
26+
)
27+
28+
29+
def print_free_mem():
30+
print(free_mem_str())

src/segger/_patches.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Workaround for pytorch/pytorch#51871 (CUDA `nonzero` INT_MAX limit).
2+
3+
Patches `torch_geometric.utils.bipartite_subgraph` and the references already
4+
imported by `torch_geometric.data.hetero_data` / `._subgraph` so that
5+
`HeteroData.subgraph` falls back to a chunked-nonzero path when the edge
6+
tensor on CUDA exceeds INT_MAX (~2.15B) elements.
7+
8+
See: https://github.com/dpeerlab/segger/issues/44
9+
"""
10+
import torch
11+
import torch_geometric.utils._subgraph as _sg
12+
import torch_geometric.utils as _tgu
13+
import torch_geometric.data.hetero_data as _hd
14+
from torch_geometric.utils import index_to_mask
15+
from torch_geometric.utils.map import map_index
16+
17+
_INT_MAX = 2**31 - 1
18+
_pyg_bipartite = _sg.bipartite_subgraph
19+
20+
21+
def chunked_nonzero(mask: torch.Tensor, chunk: int = 2**30) -> torch.Tensor:
22+
"""Chunked version of `mask.nonzero()` that works on CUDA tensors with > INT_MAX elements."""
23+
if mask.numel() <= _INT_MAX or mask.device.type != "cuda":
24+
return mask.nonzero(as_tuple=False).flatten()
25+
parts = []
26+
for i, m in enumerate(mask.split(chunk)):
27+
idx = m.nonzero(as_tuple=False).flatten()
28+
if idx.numel():
29+
parts.append(idx + i * chunk)
30+
return torch.cat(parts)
31+
32+
33+
def bipartite_safe(subset, edge_index, edge_attr=None, relabel_nodes=False,
34+
size=None, return_edge_mask=False):
35+
"""
36+
Replacement for `torch_geometric.utils.bipartite_subgraph`.
37+
Falls back to a chunked subgraph version when the edge_index is too large for CUDA.
38+
"""
39+
# original
40+
if edge_index.numel() <= _INT_MAX or edge_index.device.type != "cuda":
41+
return _pyg_bipartite(subset, edge_index, edge_attr, relabel_nodes,
42+
size, return_edge_mask)
43+
44+
# same as source
45+
src_sub, dst_sub = subset
46+
src_mask = index_to_mask(src_sub, size=size[0])
47+
dst_mask = index_to_mask(dst_sub, size=size[1])
48+
edge_mask = src_mask[edge_index[0]] & dst_mask[edge_index[1]]
49+
50+
# replaced this
51+
idx = chunked_nonzero(edge_mask)
52+
53+
# same as source (but indices instead of mask)
54+
edge_index = edge_index[:, idx]
55+
edge_attr = edge_attr[edge_mask] if edge_attr is not None else None
56+
if relabel_nodes:
57+
src_index, _ = map_index(edge_index[0], src_sub, max_index=size[0], inclusive=True)
58+
dst_index, _ = map_index(edge_index[1], dst_sub, max_index=size[1], inclusive=True)
59+
edge_index = torch.stack([src_index, dst_index], dim=0)
60+
return (edge_index, edge_attr, edge_mask) if return_edge_mask else (edge_index, edge_attr)
61+
62+
63+
_patches_applied = False
64+
65+
66+
def apply():
67+
"""Apply the patches."""
68+
global _patches_applied
69+
if _patches_applied:
70+
return
71+
_sg.bipartite_subgraph = bipartite_safe
72+
_tgu.bipartite_subgraph = bipartite_safe
73+
_hd.bipartite_subgraph = bipartite_safe
74+
_patches_applied = True

src/segger/cli/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,3 @@
1010

1111
# Debugging utilities
1212
app.command(debug)
13-

src/segger/cli/segment.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import logging
33
from ..utils import setup_logging
4-
setup_logging(level=os.environ.get("LOG_LEVEL", "WARNING"))
54

65
from cyclopts import App, Parameter, Group, validators
76
from typing import Annotated, Literal
@@ -78,7 +77,6 @@ def segment(
7877
validator=validators.Path(exists=True, dir_okay=True),
7978
)] = registry.get_default("output_directory"),
8079

81-
8280
# Cell Representation
8381
node_representation_dim: Annotated[int, Parameter(
8482
help="Number of dimensions used to represent each node type.",
@@ -124,6 +122,20 @@ def segment(
124122
group=group_nodes,
125123
)] = registry.get_default("genes_clusters_resolution"),
126124

125+
gene_corr_reference_path: Annotated[Path | None, Parameter(
126+
help=(
127+
"Path to a reference AnnData .h5ad file used to compute a shared "
128+
"gene-gene correlation matrix."
129+
),
130+
group=group_nodes,
131+
)] = None,
132+
133+
134+
gene_missing_strategy: Annotated[Literal["error", "remove", "fill"], registry.get_parameter(
135+
"gene_missing_strategy",
136+
group=group_nodes,
137+
)] = registry.get_default("gene_missing_strategy"),
138+
127139

128140
# Transcript-Transcript Graph
129141
transcripts_max_k: Annotated[int, registry.get_parameter(
@@ -154,7 +166,7 @@ def segment(
154166
group=group_prediction,
155167
)] = registry.get_default("prediction_graph_max_k"),
156168

157-
prediction_expansion_ratio: Annotated[float | None, registry.get_parameter(
169+
prediction_graph_buffer_ratio: Annotated[float | None, registry.get_parameter(
158170
"prediction_graph_buffer_ratio",
159171
validator=validators.Number(gt=0),
160172
group=group_prediction,
@@ -289,6 +301,7 @@ def segment(
289301
group=group_loss,
290302
)] = registry.get_default("sg_weight_end"),
291303

304+
# Reference
292305
save_anndata: Annotated[bool, registry.get_parameter(
293306
"save_anndata",
294307
group=group_io,
@@ -301,8 +314,21 @@ def segment(
301314
"""Run cell segmentation on spatial transcriptomics data."""
302315

303316
# Setup logger and debug directory
317+
setup_logging(level=os.environ.get("LOG_LEVEL", "WARNING"), debug=debug)
304318
logger = logging.getLogger(__name__)
305319

320+
debug_dir = None
321+
if debug:
322+
import json
323+
debug_dir = Path(output_directory) / "debug"
324+
debug_dir.mkdir(exist_ok=True, parents=True)
325+
params = {k: (str(v) if not isinstance(v, (str, int, float, bool, type(None))) else v)
326+
for k, v in locals().items()
327+
if k not in {"logger", "debug_dir", "json"}}
328+
with open(debug_dir / "params.json", "w") as f:
329+
json.dump(params, f, indent=2, default=str)
330+
logger.info(f"Saved run params to {debug_dir / 'params.json'}")
331+
306332
# Remove SLURM environment autodetect
307333
from lightning.pytorch.plugins.environments import SLURMEnvironment
308334
SLURMEnvironment.detect = lambda: False
@@ -323,11 +349,14 @@ def segment(
323349
transcripts_graph_max_dist=transcripts_max_dist,
324350
prediction_graph_mode=prediction_mode,
325351
prediction_graph_max_k=prediction_max_k,
326-
prediction_graph_buffer_ratio=prediction_expansion_ratio,
352+
prediction_graph_buffer_ratio=prediction_graph_buffer_ratio,
327353
tiling_margin_training=tiling_margin_training,
328354
tiling_margin_prediction=tiling_margin_prediction,
329355
tiling_nodes_per_tile=max_nodes_per_tile,
330356
edges_per_batch=max_edges_per_batch,
357+
gene_corr_reference_path=gene_corr_reference_path,
358+
gene_missing_strategy=gene_missing_strategy,
359+
debug_dir=debug_dir,
331360
)
332361

333362
# Setup Lightning Model

0 commit comments

Comments
 (0)