Skip to content
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -206,4 +206,5 @@ __marimo__/
.dev
.dev/*
*.pyc
*memray*
*memray*
nvidia/
5 changes: 0 additions & 5 deletions src/segger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@
torch.cuda.memory.change_current_allocator(rmm_torch_allocator)
enable_statistics()

# Apply pytorch patches for issue pytorch/pytorch#51871 (CUDA nonzero INT_MAX limit).
# Must run BEFORE any segger module imports HeteroData / bipartite_subgraph.
from ._patches import apply as _apply_patches
_apply_patches()

def free_mem_str() -> str:
stats = get_statistics()
return (
Expand Down
74 changes: 0 additions & 74 deletions src/segger/_patches.py

This file was deleted.

2 changes: 2 additions & 0 deletions src/segger/cli/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ..debug.segmentation import run_segmentation_only
from ..debug.prediction import run_prediction_only
from ..utils import setup_logging

debug = App(name="debug", help="Utilities for debugging and testing individual components.")

Expand Down Expand Up @@ -41,6 +42,7 @@ def predict_only_cli(
)],
):
"""Run prediction only."""
setup_logging(level="DEBUG", debug=True)
run_prediction_only(
path_checkpoint=path_checkpoint,
path_outputs=path_outputs,
Expand Down
15 changes: 14 additions & 1 deletion src/segger/cli/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def segment(

from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from ..data import ISTSegmentationWriter

csvlogger = CSVLogger(output_directory)
Expand All @@ -397,11 +398,23 @@ def segment(
save_anndata=save_anndata,
debug=debug,
)
callbacks = [writer]

if debug:
checkpoint_callback = ModelCheckpoint(
dirpath=Path(output_directory) / "checkpoints",
filename="epoch={epoch:02d}",
save_top_k=-1,
every_n_epochs=1,
)
callbacks.append(checkpoint_callback)


trainer = Trainer(
logger=csvlogger,
max_epochs=n_epochs,
reload_dataloaders_every_n_epochs=1,
callbacks=[writer],
callbacks=callbacks,
)

# Training
Expand Down
8 changes: 7 additions & 1 deletion src/segger/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,18 @@ def load(self):
)

# Tile graph dataset
self.logger.debug("Tiling graph dataset...")
node_positions = torch.vstack([
self.data['tx']['pos'],
self.data['bd']['pos'],
])
self.logger.debug(
f"Tiling graph: {len(node_positions)} positions, "
f"mode='{self.tiling_mode}'"
)
if self.tiling_mode == "adaptive":
self.logger.debug(
f" → QuadTreeTiling (max_tile_size={self.tiling_nodes_per_tile})"
)
self.tiling = QuadTreeTiling(
positions=node_positions,
max_tile_size=self.tiling_nodes_per_tile,
Expand Down
Loading