Skip to content

Commit ea21a1b

Browse files
authored
Merge pull request #15 from dpeerlab/bugfix/threshold_convergence
Bugfix/threshold convergence
2 parents 9af8c25 + 32ed3d7 commit ea21a1b

3 files changed

Lines changed: 69 additions & 37 deletions

File tree

src/segger/data/utils/threshold.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from skimage.filters import threshold_li
2+
3+
def threshold_li_custom(arr, max_iter=100):
4+
"""Fallback to StopIteration if can't converge. Not implemented in threshold_li."""
5+
n_iter = 0
6+
def _callback(threshold):
7+
nonlocal n_iter
8+
n_iter += 1
9+
if n_iter > max_iter:
10+
raise StopIteration
11+
12+
return threshold_li(arr, iter_callback=_callback)

src/segger/data/writer.py

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
1-
from copyreg import pickle
2-
import os
1+
import gc
32
import logging
3+
import numpy as np
44
from lightning.pytorch.callbacks import BasePredictionWriter
5-
from skimage.filters import threshold_li, threshold_yen
5+
from skimage.filters import threshold_yen
6+
from .utils.threshold import threshold_li_custom
67
from lightning.pytorch import Trainer, LightningModule
78
from typing import Sequence, Any
89
from pathlib import Path
910
import polars as pl
1011
import torch
11-
1212
from ..io import TrainingTranscriptFields, TrainingBoundaryFields
1313
from . import ISTDataModule
1414

15-
# TODO: import datamodule, not trainer
16-
1715
class ISTSegmentationWriter(BasePredictionWriter):
1816
"""TODO: Description
1917
@@ -32,10 +30,11 @@ def __init__(self, output_directory: Path, debug: bool = False):
3230
# setup debugging
3331
self.debug = debug
3432
self.path_debug = None
33+
self.n_tx_predicted = 0
3534
if debug:
3635
logging.getLogger("segger").setLevel("DEBUG")
3736
self.path_debug = output_directory / "debug"
38-
self.path_debug.mkdir(exist_ok=True)
37+
self.path_debug.mkdir(exist_ok=True, parents=True)
3938

4039
def write_on_epoch_end(
4140
self,
@@ -148,56 +147,77 @@ def assign_transcripts_to_cells(
148147
)
149148

150149
# Per-gene thresholding (iterative to reduce memory usage)
151-
logger.debug("Calculating per-gene similarity thresholds...")
152-
feature_counts = (
150+
logger.debug(f"Calculating per-gene similarity thresholds, using {segmentation.shape[0]/1e6:.1f}M transcripts...")
151+
152+
segmentation_group = (
153153
segmentation
154154
.filter(pl.col('segger_cell_id').is_not_null())
155-
.select(tx_fields.feature)
156-
.to_series()
157-
.value_counts()
155+
.group_by(tx_fields.feature)
158156
)
159-
thresholds = []
157+
160158
n = 10_000_000
161-
for feature, count in feature_counts.iter_rows():
162-
similarities = (
163-
segmentation
164-
.filter(
165-
(pl.col(tx_fields.feature) == feature) &
166-
(pl.col('segger_cell_id').is_not_null())
167-
)
168-
.select('segger_similarity')
169-
)
170-
if count > n:
171-
similarities = similarities.sample(n=n, seed=0)
172-
similarities = similarities.to_series().to_numpy()
173-
threshold_value = min(
174-
threshold_li( similarities),
175-
threshold_yen(similarities),
176-
)
177-
thresholds.append({
178-
tx_fields.feature: feature,
179-
'similarity_threshold': threshold_value,
180-
})
181-
thresholds = pl.DataFrame(thresholds)
182-
159+
thresholds = []
160+
failed_to_converge = []
161+
n_groups = segmentation_group.len().height
162+
163+
for i, (feature, group) in enumerate(segmentation_group):
164+
165+
# log step
166+
if (i + 1) % 50 == 0:
167+
logger.debug(f"Processing feature {i+1}/{n_groups} (feature {feature[0]} | transcripts {group.shape[0]/1e3:.1f}K)...")
168+
169+
# sample if too many
170+
arr = group["segger_similarity"]
171+
if arr.shape[0] > n:
172+
arr = arr.sample(n=n, seed=0)
173+
arr = arr.to_numpy()
174+
175+
# threshold
176+
try:
177+
tye = threshold_yen(arr)
178+
tli = threshold_li_custom(arr, max_iter=250)
179+
threshold = min(tye, tli)
180+
except StopIteration:
181+
logger.debug(f"Failed to converge {feature[0]}. Will use 50% quantile of segger similarities of other genes as cutoff.")
182+
failed_to_converge.append(feature[0])
183+
continue
184+
185+
# append threshold
186+
thresholds.append({tx_fields.feature: feature[0], "similarity_threshold": threshold.item(), "converged": True})
187+
188+
# cleanup
189+
del arr
190+
gc.collect()
191+
192+
# backfill failed features in using the 80% quantile of thresholds
193+
global_threshold = np.quantile([t["similarity_threshold"] for t in thresholds], .5)
194+
for feature in failed_to_converge:
195+
thresholds.append({tx_fields.feature: feature, "similarity_threshold": global_threshold, "converged": False})
196+
logger.debug(f"Global Threshold: {global_threshold} | Used this to backfill {len(failed_to_converge)} features.")
197+
183198
# Join
184199
logger.debug("Joining thresholds with segmentation...")
200+
thresholds = pl.DataFrame(thresholds)
185201
segmentation = (
186202
segmentation
187203
.join(thresholds, on=tx_fields.feature, how='left')
188204
.drop(tx_fields.feature)
189205
)
206+
207+
logger.debug("Segmentation complete.")
190208
return segmentation
191209

192210

193211
# Debugging callbacks
194212
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
213+
mask = batch['tx']['predict_mask']
214+
self.n_tx_predicted += mask.sum().item()
195215
if not self.debug:
196216
return
197217
log_every = 50
198218
if batch_idx % log_every == 0:
199219
self.segger_logger.info(
200-
f"Finished prediction batch '{batch_idx}'."
220+
f"Finished prediction batch '{batch_idx}'. # TX so far {self.n_tx_predicted / 1e6:.1f}M"
201221
)
202222

203223
def on_fit_start(self, trainer, pl_module):

src/segger/debug/segmentation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def run_segmentation_only(
1818
SLURMEnvironment.detect = lambda: False
1919

2020
# Load data
21-
writer = ISTSegmentationWriter(path_outputs)
21+
writer = ISTSegmentationWriter(path_outputs, debug=True)
2222
adata = ad.read_h5ad(path_adata)
2323
with open(path_predictions, "rb") as f:
2424
predictions = pickle.load(f)

0 commit comments

Comments
 (0)