From 61f37e345753ca67d38d65762d62f91f5859e625 Mon Sep 17 00:00:00 2001 From: tobiaspk Date: Tue, 26 May 2026 14:43:46 -0400 Subject: [PATCH 1/9] Improved logging --- .gitignore | 3 +- src/segger/cli/debug.py | 2 ++ src/segger/debug/prediction.py | 66 ++++++++++++++++++++++++++++++++++ src/segger/utils.py | 3 +- 4 files changed, 72 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index a9f0f1e..b6f56f5 100644 --- a/.gitignore +++ b/.gitignore @@ -206,4 +206,5 @@ __marimo__/ .dev .dev/* *.pyc -*memray* \ No newline at end of file +*memray* +nvidia/ \ No newline at end of file diff --git a/src/segger/cli/debug.py b/src/segger/cli/debug.py index 41808a9..cd49eed 100644 --- a/src/segger/cli/debug.py +++ b/src/segger/cli/debug.py @@ -9,6 +9,7 @@ from ..debug.segmentation import run_segmentation_only from ..debug.prediction import run_prediction_only +from ..utils import setup_logging debug = App(name="debug", help="Utilities for debugging and testing individual components.") @@ -41,6 +42,7 @@ def predict_only_cli( )], ): """Run prediction only.""" + setup_logging(level="DEBUG", debug=True) run_prediction_only( path_checkpoint=path_checkpoint, path_outputs=path_outputs, diff --git a/src/segger/debug/prediction.py b/src/segger/debug/prediction.py index 910bc66..94344f3 100644 --- a/src/segger/debug/prediction.py +++ b/src/segger/debug/prediction.py @@ -1,6 +1,69 @@ """Run only prediction, followed by segmentation.""" import os +import logging +import pickle +from pathlib import Path +from types import SimpleNamespace + + +def _patch_load_from_cache(): + """Monkey-patch ISTDataModule.load to restore from `debug_dir` cache. + + The original `load()` always re-runs setup_anndata + setup_heterodata + tiling + (~2 h on cervical). For predict-only that's wasted: when --debug was set on + the original run, those outputs are already on disk in `debug_dir` as + `data.pt`, `tiles.pkl`, and `adata_debug.h5ad`. Restore from them instead. + """ + import scanpy as sc + import polars as pl + import torch + from segger.data import ISTDataModule + from segger.io.fields import StandardTranscriptFields + from segger.io import get_preprocessor + + logger = logging.getLogger(__name__) + original_load = ISTDataModule.load + + def cached_load(self): + d = Path(self.debug_dir) if self.debug_dir is not None else None + cached = { + "data": d / "data.pt" if d else None, + "tiles": d / "tiles.pkl" if d else None, + "adata": d / "adata_debug.h5ad" if d else None, + } + if d is None or not all(p.exists() for p in cached.values()): + logger.info("Cached artifacts not found; falling back to full rebuild.") + return original_load(self) + + logger.info(f"Restoring cached datamodule state from {d}") + tx_fields = StandardTranscriptFields() + + # Raw transcripts/boundaries — only used by writer.write_anndata; cheap to re-read. + pp = get_preprocessor(self.input_directory) + self.tx = pp.transcripts + self.bd = pp.boundaries + + # Cached artifacts + self.ad = sc.read_h5ad(cached["adata"]) + self.data = torch.load(cached["data"], weights_only=False) + with open(cached["tiles"], "rb") as f: + tiles = pickle.load(f) + # Predict only accesses `self.tiling.tiles[idx]`; a SimpleNamespace shell suffices. + self.tiling = SimpleNamespace(tiles=tiles) + + # Model-side embeddings/similarities — rebuilt from adata + self.tx_embedding = ( + pl.from_numpy(self.ad.varm['X_corr']) + .cast(pl.Float32) + .with_columns(pl.Series(self.ad.var.index).alias(tx_fields.feature)) + ) + self.tx_similarity = torch.tensor(self.ad.uns['gene_cluster_similarities']) + self.bd_similarity = torch.tensor(self.ad.uns['cell_cluster_similarities']) + logger.debug("Data loading is complete (cache).") + + ISTDataModule.load = cached_load + def run_prediction_only( path_checkpoint, @@ -18,6 +81,9 @@ def run_prediction_only( os.makedirs(path_outputs, exist_ok=True) + # Skip the full setup_anndata/setup_heterodata/tiling rebuild + _patch_load_from_cache() + # load objects (analogous to segment.py) csvlogger = CSVLogger(path_outputs) writer = ISTSegmentationWriter(path_outputs, debug=True) diff --git a/src/segger/utils.py b/src/segger/utils.py index f523983..2eff9e2 100644 --- a/src/segger/utils.py +++ b/src/segger/utils.py @@ -24,8 +24,9 @@ def setup_logging(level: str = "WARNING", log_file: str = None, debug: bool = Fa for handler in handlers: handler.addFilter(MemFilter()) + root_level = "WARNING" if debug else level logging.basicConfig( - level=getattr(logging, level.upper()), + level=getattr(logging, root_level.upper()), format=fmt, datefmt=datefmt, handlers=handlers, From 91851127d7cb45b37bd8689adc9152b32e457c51 Mon Sep 17 00:00:00 2001 From: tobiaspk Date: Tue, 26 May 2026 14:44:08 -0400 Subject: [PATCH 2/9] Clean slate of chunked 0 fix. Prepare for more native CSR-like indexing fix. --- src/segger/__init__.py | 5 --- src/segger/_patches.py | 74 ------------------------------------------ 2 files changed, 79 deletions(-) delete mode 100644 src/segger/_patches.py diff --git a/src/segger/__init__.py b/src/segger/__init__.py index 2f6eadc..f1914d1 100644 --- a/src/segger/__init__.py +++ b/src/segger/__init__.py @@ -13,11 +13,6 @@ torch.cuda.memory.change_current_allocator(rmm_torch_allocator) enable_statistics() -# Apply pytorch patches for issue pytorch/pytorch#51871 (CUDA nonzero INT_MAX limit). -# Must run BEFORE any segger module imports HeteroData / bipartite_subgraph. -from ._patches import apply as _apply_patches -_apply_patches() - def free_mem_str() -> str: stats = get_statistics() return ( diff --git a/src/segger/_patches.py b/src/segger/_patches.py deleted file mode 100644 index 51360b0..0000000 --- a/src/segger/_patches.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Workaround for pytorch/pytorch#51871 (CUDA `nonzero` INT_MAX limit). - -Patches `torch_geometric.utils.bipartite_subgraph` and the references already -imported by `torch_geometric.data.hetero_data` / `._subgraph` so that -`HeteroData.subgraph` falls back to a chunked-nonzero path when the edge -tensor on CUDA exceeds INT_MAX (~2.15B) elements. - -See: https://github.com/dpeerlab/segger/issues/44 -""" -import torch -import torch_geometric.utils._subgraph as _sg -import torch_geometric.utils as _tgu -import torch_geometric.data.hetero_data as _hd -from torch_geometric.utils import index_to_mask -from torch_geometric.utils.map import map_index - -_INT_MAX = 2**31 - 1 -_pyg_bipartite = _sg.bipartite_subgraph - - -def chunked_nonzero(mask: torch.Tensor, chunk: int = 2**30) -> torch.Tensor: - """Chunked version of `mask.nonzero()` that works on CUDA tensors with > INT_MAX elements.""" - if mask.numel() <= _INT_MAX or mask.device.type != "cuda": - return mask.nonzero(as_tuple=False).flatten() - parts = [] - for i, m in enumerate(mask.split(chunk)): - idx = m.nonzero(as_tuple=False).flatten() - if idx.numel(): - parts.append(idx + i * chunk) - return torch.cat(parts) - - -def bipartite_safe(subset, edge_index, edge_attr=None, relabel_nodes=False, - size=None, return_edge_mask=False): - """ - Replacement for `torch_geometric.utils.bipartite_subgraph`. - Falls back to a chunked subgraph version when the edge_index is too large for CUDA. - """ - # original - if edge_index.numel() <= _INT_MAX or edge_index.device.type != "cuda": - return _pyg_bipartite(subset, edge_index, edge_attr, relabel_nodes, - size, return_edge_mask) - - # same as source - src_sub, dst_sub = subset - src_mask = index_to_mask(src_sub, size=size[0]) - dst_mask = index_to_mask(dst_sub, size=size[1]) - edge_mask = src_mask[edge_index[0]] & dst_mask[edge_index[1]] - - # replaced this - idx = chunked_nonzero(edge_mask) - - # same as source (but indices instead of mask) - edge_index = edge_index[:, idx] - edge_attr = edge_attr[edge_mask] if edge_attr is not None else None - if relabel_nodes: - src_index, _ = map_index(edge_index[0], src_sub, max_index=size[0], inclusive=True) - dst_index, _ = map_index(edge_index[1], dst_sub, max_index=size[1], inclusive=True) - edge_index = torch.stack([src_index, dst_index], dim=0) - return (edge_index, edge_attr, edge_mask) if return_edge_mask else (edge_index, edge_attr) - - -_patches_applied = False - - -def apply(): - """Apply the patches.""" - global _patches_applied - if _patches_applied: - return - _sg.bipartite_subgraph = bipartite_safe - _tgu.bipartite_subgraph = bipartite_safe - _hd.bipartite_subgraph = bipartite_safe - _patches_applied = True From 11a66f6c11e9067523f2bf49cb648316eda63447 Mon Sep 17 00:00:00 2001 From: tobiaspk Date: Tue, 26 May 2026 16:35:53 -0400 Subject: [PATCH 3/9] Sort edges by src. Helps downstream with faster indexing. --- src/segger/data/utils/heterodata.py | 44 ++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/src/segger/data/utils/heterodata.py b/src/segger/data/utils/heterodata.py index e20a9bf..d323059 100644 --- a/src/segger/data/utils/heterodata.py +++ b/src/segger/data/utils/heterodata.py @@ -15,6 +15,20 @@ logger = logging.getLogger(__name__) + +def _sort_by_src(edge_index: torch.Tensor) -> torch.Tensor: + """Sort edges by source. This helps with faster subgraph extraction in `_patches.py`.""" + if edge_index.device.type == "cuda": + # torch on CUDA can run into INT_MAX issues for large graphs (2**31 edges). Cupy can handle larger arrays. + import cupy as cp + perm = torch.as_tensor( + cp.argsort(cp.asarray(edge_index[0]), kind='stable'), + device=edge_index.device, + ).long() + else: + perm = torch.argsort(edge_index[0], stable=True) + return edge_index[:, perm].contiguous() + def setup_heterodata( transcripts: pl.DataFrame, boundaries: gpd.GeoDataFrame, @@ -135,29 +149,31 @@ def setup_heterodata( # Transcript neighbors graph logger.debug("Setting up transcript neighbors graph") - data['tx', 'neighbors', 'tx'].edge_index = setup_transcripts_graph( - transcripts, - max_k=transcripts_graph_max_k, - max_dist=transcripts_graph_max_dist, + data['tx', 'neighbors', 'tx'].edge_index = _sort_by_src( + setup_transcripts_graph(transcripts, max_k=transcripts_graph_max_k, max_dist=transcripts_graph_max_dist) ) logger.info(f" tx-neighbors-tx edges: {data['tx', 'neighbors', 'tx'].edge_index.shape[1]:,}") # Reference segmentation graph logger.debug("Setting up segmentation graph") - data['tx', 'belongs', 'bd'].edge_index = setup_segmentation_graph( - transcripts, - segmentation_mask=segmentation_mask, + data['tx', 'belongs', 'bd'].edge_index = _sort_by_src( + setup_segmentation_graph(transcripts, segmentation_mask=segmentation_mask) ) logger.info(f" tx-belongs-bd edges: {data['tx', 'belongs', 'bd'].edge_index.shape[1]:,}") # Transcript-cell graph for prediction - logger.debug("Setting up prediction graph") - data['tx', 'neighbors', 'bd'].edge_index = setup_prediction_graph( - transcripts, - boundaries, - max_k=prediction_graph_max_k, - buffer_ratio=prediction_graph_buffer_ratio, - mode=prediction_graph_mode, + logger.debug( + f"Prediction graph: {len(transcripts)} tx vs {len(boundaries)} bd " + f"(mode='{prediction_graph_mode}') → quadtree on tx" + ) + data['tx', 'neighbors', 'bd'].edge_index = _sort_by_src( + setup_prediction_graph( + transcripts, + boundaries, + max_k=prediction_graph_max_k, + buffer_ratio=prediction_graph_buffer_ratio, + mode=prediction_graph_mode, + ) ) logger.info(f" tx-neighbors-bd edges: {data['tx', 'neighbors', 'bd'].edge_index.shape[1]:,}") From 58530c4861a947d02e90923d3f21405463ee5120 Mon Sep 17 00:00:00 2001 From: tobiaspk Date: Wed, 27 May 2026 10:06:19 -0400 Subject: [PATCH 4/9] Implement CSR tile and edge querying logic --- src/segger/data/tile_dataset.py | 248 ++++++++++++++++++++++++-------- 1 file changed, 192 insertions(+), 56 deletions(-) diff --git a/src/segger/data/tile_dataset.py b/src/segger/data/tile_dataset.py index 470c2b4..bf3a93e 100644 --- a/src/segger/data/tile_dataset.py +++ b/src/segger/data/tile_dataset.py @@ -2,13 +2,43 @@ from torch_geometric.data.storage import NodeStorage from torch_geometric.data import Data, HeteroData from torch.utils.data import Dataset +from torch_geometric.utils.map import map_index +from torch_geometric.index import index2ptr +import logging import shapely import torch from .partition import PartitionDataset from .tiling import Tiling -from .._patches import chunked_nonzero as _chunked_nonzero + +logger = logging.getLogger(__name__) + + +def query_ptr(csr, query) -> torch.Tensor: + """Gather values for bucket(s) `query` from a `(ptr, values)` CSR. + + `query` may be a scalar (one bucket) or a 1-D tensor (concatenated in + the given order). + """ + ptr, values = csr + + # single value + if not (torch.is_tensor(query) and query.dim() > 0): + q = int(query) + return values[ptr[q]:ptr[q + 1]] + + # tensor of values + starts = ptr[query] + ends = ptr[query + 1] + counts = ends - starts + total = int(counts.sum()) + if total == 0: + return values.new_empty(0) + base = torch.repeat_interleave(starts, counts) + within = (torch.arange(total, device=values.device) - torch.repeat_interleave(counts.cumsum(0) - counts, counts)) + return values[base + within] + class TileFitDataset(PartitionDataset): """ @@ -115,32 +145,42 @@ def _get_partition(self, data: Data | HeteroData) -> torch.Tensor: """ Generates partition labels for all nodes using the tiling object. """ + n_tiles = len(self.tiling.tiles) if isinstance(data, HeteroData): partition = dict() for node_type in data.node_types: - partition[node_type] = self.tiling.label( - data[node_type][self.geometry_key] + geom = data[node_type][self.geometry_key] + logger.debug( + f"TileFit label '{node_type}': {len(geom)} geoms vs {n_tiles} tiles → quadtree" ) + partition[node_type] = self.tiling.label(geom) return partition else: # isinstance(data, Data) - return self.tiling.label(data[self.geometry_key]) - + geom = data[self.geometry_key] + logger.debug(f"TileFit label: {len(geom)} geoms vs {n_tiles} tiles → quadtree") + return self.tiling.label(geom) + def _mask_data(self, data: Data | HeteroData) -> Data | HeteroData: """ Adds a boolean 'mask' attribute to each node indicating whether it is within a specified margin of a tile's boundary. """ + n_tiles = len(self.tiling.tiles) if isinstance(data, HeteroData): for node_type in data.node_types: - data[node_type]['mask'] = self.tiling.mask( - data[node_type][self.geometry_key], - self.margin, + geom = data[node_type][self.geometry_key] + logger.debug( + f"TileFit mask '{node_type}': {len(geom)} geoms vs {n_tiles} tiles " + f"(margin={self.margin}) → quadtree" ) + data[node_type]['mask'] = self.tiling.mask(geom, self.margin) else: # isinstance(data, Data) - data['mask'] = self.tiling.mask( - data[self.geometry_key], - self.margin + geom = data[self.geometry_key] + logger.debug( + f"TileFit mask: {len(geom)} geoms vs {n_tiles} tiles " + f"(margin={self.margin}) → quadtree" ) + data['mask'] = self.tiling.mask(geom, self.margin) return data def _drop_geometry(self, data: Data | HeteroData) -> Data | HeteroData: @@ -197,6 +237,40 @@ def __init__( elif 'pos' not in self.data.node_attrs(): raise ValueError("Graph must contain 'pos' attribute.") + # Precompute CSRs for fast per-tile subsetting (one-time cost). + if self._is_hetero: + logger.debug("Building tile/edge pointers for fast subsetting...") + self._tile_ptr_inner = self._build_tile_ptr(margin=0.0) + self._tile_ptr_outer = self._build_tile_ptr(margin=self.margin) + self._edges_ptr = self._build_edge_ptr() + + def _build_tile_ptr(self, margin: float) -> dict: + """Builds CSR-like structure of {node_type: (ptr[tile_id], node_id)} for fast subsetting.""" + n_tiles = len(getattr(self.tiling, 'tiles', self.tiling)) + out = {} + for nt in self.data.node_types: + pairs = self._get_tiles_to_nodes_edges(nt, margin=margin) + ptr = index2ptr(pairs[0], size=n_tiles) + out[nt] = (ptr, pairs[1]) + return out + + def _build_edge_ptr(self) -> dict: + """Builds CSR-like structure of {edge_type: (ptr[src-node], dst-node)} for edges in each tile. + + Assumes `edge_index` is sorted by src; values are the identity range + so `query_ptr` returns original-column positions in `edge_index`. + """ + out = {} + for et in self.data.edge_types: + ei = self.data[et].edge_index + assert (ei[0][1:] >= ei[0][:-1]).all(), f"edge_index[0] for {et} not sorted by src" + n_src = self.data[et[0]]["pos"].shape[0] + out[et] = ( + index2ptr(ei[0], size=n_src), + torch.arange(ei.shape[1], device=ei.device), + ) + return out + def __len__(self) -> int: """Number of tiles in the dataset.""" return len(self.tiling.tiles) @@ -213,59 +287,121 @@ def __getitem__(self, idx: int) -> Data | HeteroData: f"Requested {idx}, but tiling only contains {len(self)} tiles." ) geometry = self.tiling.tiles[idx] - return self._subset(geometry) + return self._subset_new(idx) - def _subset(self, bounds: shapely.Polygon) -> Data | HeteroData: - """Slices all node attributes within bounds. + def _get_tiles_to_nodes_edges(self, node_type: str, margin: float) -> torch.Tensor: + """ + Create edges `(tile_id, node_id)` for nodes in each tile's margined bbox. - TODO: Long Description. + Return tuples, sorted by `tile_id`. """ - inner = bounds.bounds - outer = bounds.buffer(self.margin).bounds - - if self._is_hetero: - subset = dict() - p_mask = dict() - for node_type in self.data.node_types: - pos: torch.Tensor = self.data[node_type]['pos'] - # Row indices of masked elements inside tile w/ margin - subset[node_type] = _chunked_nonzero( - (pos[:, 0] >= outer[0]) & - (pos[:, 0] < outer[2]) & - (pos[:, 1] >= outer[1]) & - (pos[:, 1] < outer[3]) - ) - p_mask[node_type] = ( - (pos[subset[node_type], 0] >= inner[0]) & - (pos[subset[node_type], 0] <= inner[2]) & - (pos[subset[node_type], 1] >= inner[1]) & - (pos[subset[node_type], 1] <= inner[3]) + pos: torch.Tensor = self.data[node_type]['pos'].to(torch.float32) + tiles_geom = getattr(self.tiling, 'tiles', self.tiling) + bounds = tiles_geom.bounds.to_numpy().astype("float32") + bounds = torch.from_numpy(bounds).to(pos.device) + bounds[:, :2] -= margin + bounds[:, 2:] += margin + + # Chunk tiles & nodes, cap at ~128 MB intermediate + K, N = bounds.shape[0], pos.shape[0] + budget = 2 ** 27 # ~128M bools = 128MB + chunk_K = max(8, min(256, K)) + chunk_N = max(1, min(N, budget // (8 * max(chunk_K, 1)))) + + tile_ids, node_ids = [], [] + + # for each batch of tiles + for s_t in range(0, K, chunk_K): + ch = bounds[s_t:min(s_t + chunk_K, K)] + + # for each batch of nodes + for s_n in range(0, N, chunk_N): + px = pos[s_n:min(s_n + chunk_N, N), 0] + py = pos[s_n:min(s_n + chunk_N, N), 1] + + # create boundary mask. results in a (chunked) binary matrix of (chunk_k, chunk_n) where "True" indicates assignment + m = ( + (ch[:, None, 0] <= px[None, :]) & (ch[:, None, 2] > px[None, :]) & + (ch[:, None, 1] <= py[None, :]) & (ch[:, None, 3] > py[None, :]) ) - sample = self.data.subgraph(subset) - sample.set_value_dict('predict_mask', p_mask) - return sample - - else: # is homogenous Data - pos: torch.Tensor = self.data['pos'] - subset = ( - (pos[:, 0] >= outer[0]) & - (pos[:, 0] < outer[2]) & - (pos[:, 1] >= outer[1]) & - (pos[:, 1] < outer[3]) - ) - subset = _chunked_nonzero(subset) - sample = self.data.subgraph(subset) - sample['predict_mask'] = ( - (pos[subset, 0] >= inner[0]) & - (pos[subset, 0] <= inner[2]) & - (pos[subset, 1] >= inner[1]) & - (pos[subset, 1] <= inner[3]) - ) - return sample + # extract pairs + ki, ni = torch.nonzero(m, as_tuple=True) + tile_ids.append(ki + s_t) + node_ids.append(ni + s_n) + + tile_ids = torch.cat(tile_ids) + node_ids = torch.cat(node_ids) + + # sort by tile_id (and preserve node order) + perm = torch.argsort(tile_ids, stable=True) + return torch.stack([tile_ids[perm], node_ids[perm]], 0) + + def _subset_new(self, idx) -> Data | HeteroData: + """Subset the Heterograph to nodes and edges within tile `idx`. + + Uses CSRs precomputed in `__init__` (`_tile_ptr_outer`, `_tile_ptr_inner`, `_edges_ptr`). + """ + subset = HeteroData() + + # create nodes + for node_type in self.data.node_types: + + # get list of nodes in tile + nodes_subset_idx = query_ptr(self._tile_ptr_outer[node_type], idx) + + # populate metadata for these nodes + for key, value in self.data[node_type].items(): + if key == 'num_nodes': + subset[node_type].num_nodes = len(nodes_subset_idx) + elif self.data[node_type].is_node_attr(key): + subset[node_type][key] = value[nodes_subset_idx] + else: + subset[node_type][key] = value + + # get mask (mask for nodes within margined tiles) + nodes_margin_idx = query_ptr(self._tile_ptr_inner[node_type], idx) + subset[node_type]['predict_mask'] = torch.isin(nodes_subset_idx, nodes_margin_idx) + + + # create edges + for edge_type in self.data.edge_types: + + # get src (source) and dst (destination) nodes that fall within the tile + src, _, dst = edge_type + src_subset = query_ptr(self._tile_ptr_outer[src], idx) + dst_subset = query_ptr(self._tile_ptr_outer[dst], idx) + + # get edges where src node is in tile + edge_src_subset_idx = query_ptr(self._edges_ptr[edge_type], src_subset) + candidate_edges = self.data[edge_type].edge_index[:, edge_src_subset_idx] + + # get edges where also dst node is in tile + edge_dst_subset_idx = torch.isin(candidate_edges[1], dst_subset) + + # store mask + kept_orig = edge_src_subset_idx[edge_dst_subset_idx] + edge_index_new = candidate_edges[:, edge_dst_subset_idx] + + # map indices to new subset + src_index, _ = map_index(edge_index_new[0], src_subset, max_index=self.data[src]["pos"].shape[0]) + dst_index, _ = map_index(edge_index_new[1], dst_subset, max_index=self.data[dst]["pos"].shape[0]) + edge_index_mapped = torch.stack([src_index, dst_index], dim=0) + + # populate heterodata + for key, value in self.data[edge_type].items(): + if key == 'edge_index': + subset[edge_type].edge_index = edge_index_mapped + elif self.data[edge_type].is_edge_attr(key): + subset[edge_type][key] = value[kept_orig] + else: + subset[edge_type][key] = value + + return subset class DynamicBatchSamplerPatch(DynamicBatchSampler): """TODO: Description """ def __len__(self): return len(self.dataset) # ceiling on dataset length + From 0ba0cb328e2e2994c0a09c6707c77107d7b6449d Mon Sep 17 00:00:00 2001 From: tobiaspk Date: Wed, 20 May 2026 10:27:40 -0400 Subject: [PATCH 5/9] Add caller-level debug logs for quadtree builds --- src/segger/data/data_module.py | 8 +++++++- src/segger/data/tile_dataset.py | 2 ++ src/segger/geometry/query.py | 4 ++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/segger/data/data_module.py b/src/segger/data/data_module.py index ffb6d3d..72794eb 100644 --- a/src/segger/data/data_module.py +++ b/src/segger/data/data_module.py @@ -240,12 +240,18 @@ def load(self): ) # Tile graph dataset - self.logger.debug("Tiling graph dataset...") node_positions = torch.vstack([ self.data['tx']['pos'], self.data['bd']['pos'], ]) + self.logger.debug( + f"Tiling graph: {len(node_positions)} positions, " + f"mode='{self.tiling_mode}'" + ) if self.tiling_mode == "adaptive": + self.logger.debug( + f" → QuadTreeTiling (max_tile_size={self.tiling_nodes_per_tile})" + ) self.tiling = QuadTreeTiling( positions=node_positions, max_tile_size=self.tiling_nodes_per_tile, diff --git a/src/segger/data/tile_dataset.py b/src/segger/data/tile_dataset.py index bf3a93e..8dd9cef 100644 --- a/src/segger/data/tile_dataset.py +++ b/src/segger/data/tile_dataset.py @@ -40,6 +40,8 @@ def query_ptr(csr, query) -> torch.Tensor: return values[base + within] +logger = logging.getLogger(__name__) + class TileFitDataset(PartitionDataset): """ Partitions a PyG graph based on a geometric tiling of its nodes. diff --git a/src/segger/geometry/query.py b/src/segger/geometry/query.py index dd4a0c3..6573060 100644 --- a/src/segger/geometry/query.py +++ b/src/segger/geometry/query.py @@ -147,6 +147,10 @@ def _points_in_polygons_intersects( pts_ixn = points.iloc[idx_missing] ply_ixn = polygons_to_geoseries(polygons, backend='geopandas') if len(pts_ixn) >= max_unassigned_points: + logger.debug( + f"intersects buffer-filter: {len(pts_ixn)} unassigned pts vs " + f"{len(ply_ixn)} buffered polys → 2nd quadtree" + ) ply_buf = polygons_to_geoseries( ply_ixn.buffer(boundary_buffer), backend='cuspatial', From f70d6dfe4891dc26f9ce76e77aadcbe1bb003906 Mon Sep 17 00:00:00 2001 From: tobiaspk Date: Fri, 22 May 2026 14:13:12 -0400 Subject: [PATCH 6/9] Update quadtree retry logic --- src/segger/geometry/quadtree.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/segger/geometry/quadtree.py b/src/segger/geometry/quadtree.py index a55c9e1..19ce142 100644 --- a/src/segger/geometry/quadtree.py +++ b/src/segger/geometry/quadtree.py @@ -143,7 +143,7 @@ def get_quadtree_index( points: cuspatial.GeoSeries, max_size: int, with_bounds: bool = True, - max_retries: int = 5, + max_retries: int = 8, ) -> tuple[cudf.Series, cudf.DataFrame, dict]: """Build a cuSpatial quadtree from 2D point data. @@ -174,6 +174,7 @@ def get_quadtree_index( y_max = kwargs['y_max'] scale = kwargs['scale'] max_depth = kwargs['max_depth'] + max_size_input = max_size logger.debug(f"Building quadtree on {len(points)} points with max_size={max_size}, max_depth={max_depth}") @@ -193,8 +194,8 @@ def get_quadtree_index( # check if valid (see segger issue #40) if is_quadtree_valid(quadtree, len(points)): break - logger.warning(f"Invalid quadtree generated with max_size={max_size}. Retry with max_size={max_size + 10000}.") - max_size += 10000 + logger.warning(f"Invalid quadtree generated with max_size={max_size}. Retry with max_size={max_size + max_size_input}.") + max_size += max_size_input else: raise RuntimeError( f"cuSpatial returned an invalid quadtree after {max_retries + 1} " From 527d684727d35e55fead3fd7724ab1ee63a4430f Mon Sep 17 00:00:00 2001 From: tobiaspk Date: Fri, 22 May 2026 14:14:46 -0400 Subject: [PATCH 7/9] Store all checkpoints if --debug --- src/segger/cli/segment.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/segger/cli/segment.py b/src/segger/cli/segment.py index 111b567..bb2b5b6 100644 --- a/src/segger/cli/segment.py +++ b/src/segger/cli/segment.py @@ -389,6 +389,7 @@ def segment( from lightning.pytorch.loggers import CSVLogger from lightning.pytorch import Trainer + from lightning.pytorch.callbacks import ModelCheckpoint from ..data import ISTSegmentationWriter csvlogger = CSVLogger(output_directory) @@ -397,11 +398,23 @@ def segment( save_anndata=save_anndata, debug=debug, ) + callbacks = [writer] + + if debug: + checkpoint_callback = ModelCheckpoint( + dirpath=Path(output_directory) / "checkpoints", + filename="epoch={epoch:02d}", + save_top_k=-1, + every_n_epochs=1, + ) + callbacks.append(checkpoint_callback) + + trainer = Trainer( logger=csvlogger, max_epochs=n_epochs, reload_dataloaders_every_n_epochs=1, - callbacks=[writer], + callbacks=callbacks, ) # Training From d5b2d4aaa79d3092dad61753957e9bb35a695af0 Mon Sep 17 00:00:00 2001 From: tobiaspk Date: Fri, 22 May 2026 11:03:52 -0400 Subject: [PATCH 8/9] Move gc.collect outside the loop --- src/segger/data/writer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/segger/data/writer.py b/src/segger/data/writer.py index d0d68bc..a622498 100644 --- a/src/segger/data/writer.py +++ b/src/segger/data/writer.py @@ -244,7 +244,9 @@ def assign_transcripts_to_cells( # cleanup del arr - gc.collect() + + # move from loop to after: ctx_hp 55.4s → 2.3s, pancreas 83.0s → 1.4s) + gc.collect() # backfill failed features in using the 80% quantile of thresholds global_threshold = np.quantile([t["similarity_threshold"] for t in thresholds], .5) From 0e3f94b1851d2618dd053c38eee1f4c9ae1387cf Mon Sep 17 00:00:00 2001 From: tobiaspk Date: Wed, 27 May 2026 11:27:43 -0400 Subject: [PATCH 9/9] Attribute prior contributors to #38 The squash-merge of #38 dropped co-author trailers. This empty commit credits them so their contributions are reflected in the repo's contributor graph. Co-authored-by: Ananya Nandula <225912921+ananya-nandula@users.noreply.github.com> Co-authored-by: Mostafa Shahhosseini <33085339+mossishahi@users.noreply.github.com> Co-authored-by: Sam Rose <4626152+srose89@users.noreply.github.com>