1- from copyreg import pickle
2- import os
1+ import gc
32import logging
3+ import numpy as np
44from lightning .pytorch .callbacks import BasePredictionWriter
5- from skimage .filters import threshold_li , threshold_yen
5+ from skimage .filters import threshold_yen
6+ from .utils .threshold import threshold_li_custom
67from lightning .pytorch import Trainer , LightningModule
78from typing import Sequence , Any
89from pathlib import Path
910import polars as pl
1011import torch
11-
1212from ..io import TrainingTranscriptFields , TrainingBoundaryFields
1313from . import ISTDataModule
1414
15- # TODO: import datamodule, not trainer
16-
1715class ISTSegmentationWriter (BasePredictionWriter ):
1816 """TODO: Description
1917
@@ -32,10 +30,11 @@ def __init__(self, output_directory: Path, debug: bool = False):
3230 # setup debugging
3331 self .debug = debug
3432 self .path_debug = None
33+ self .n_tx_predicted = 0
3534 if debug :
3635 logging .getLogger ("segger" ).setLevel ("DEBUG" )
3736 self .path_debug = output_directory / "debug"
38- self .path_debug .mkdir (exist_ok = True )
37+ self .path_debug .mkdir (exist_ok = True , parents = True )
3938
4039 def write_on_epoch_end (
4140 self ,
@@ -148,56 +147,77 @@ def assign_transcripts_to_cells(
148147 )
149148
150149 # Per-gene thresholding (iterative to reduce memory usage)
151- logger .debug ("Calculating per-gene similarity thresholds..." )
152- feature_counts = (
150+ logger .debug (f"Calculating per-gene similarity thresholds, using { segmentation .shape [0 ]/ 1e6 :.1f} M transcripts..." )
151+
152+ segmentation_group = (
153153 segmentation
154154 .filter (pl .col ('segger_cell_id' ).is_not_null ())
155- .select (tx_fields .feature )
156- .to_series ()
157- .value_counts ()
155+ .group_by (tx_fields .feature )
158156 )
159- thresholds = []
157+
160158 n = 10_000_000
161- for feature , count in feature_counts .iter_rows ():
162- similarities = (
163- segmentation
164- .filter (
165- (pl .col (tx_fields .feature ) == feature ) &
166- (pl .col ('segger_cell_id' ).is_not_null ())
167- )
168- .select ('segger_similarity' )
169- )
170- if count > n :
171- similarities = similarities .sample (n = n , seed = 0 )
172- similarities = similarities .to_series ().to_numpy ()
173- threshold_value = min (
174- threshold_li ( similarities ),
175- threshold_yen (similarities ),
176- )
177- thresholds .append ({
178- tx_fields .feature : feature ,
179- 'similarity_threshold' : threshold_value ,
180- })
181- thresholds = pl .DataFrame (thresholds )
182-
159+ thresholds = []
160+ failed_to_converge = []
161+ n_groups = segmentation_group .len ().height
162+
163+ for i , (feature , group ) in enumerate (segmentation_group ):
164+
165+ # log step
166+ if (i + 1 ) % 50 == 0 :
167+ logger .debug (f"Processing feature { i + 1 } /{ n_groups } (feature { feature [0 ]} | transcripts { group .shape [0 ]/ 1e3 :.1f} K)..." )
168+
169+ # sample if too many
170+ arr = group ["segger_similarity" ]
171+ if arr .shape [0 ] > n :
172+ arr = arr .sample (n = n , seed = 0 )
173+ arr = arr .to_numpy ()
174+
175+ # threshold
176+ try :
177+ tye = threshold_yen (arr )
178+ tli = threshold_li_custom (arr , max_iter = 250 )
179+ threshold = min (tye , tli )
180+ except StopIteration :
181+ logger .debug (f"Failed to converge { feature [0 ]} . Will use 50% quantile of segger similarities of other genes as cutoff." )
182+ failed_to_converge .append (feature [0 ])
183+ continue
184+
185+ # append threshold
186+ thresholds .append ({tx_fields .feature : feature [0 ], "similarity_threshold" : threshold .item (), "converged" : True })
187+
188+ # cleanup
189+ del arr
190+ gc .collect ()
191+
192+ # backfill failed features in using the 80% quantile of thresholds
193+ global_threshold = np .quantile ([t ["similarity_threshold" ] for t in thresholds ], .5 )
194+ for feature in failed_to_converge :
195+ thresholds .append ({tx_fields .feature : feature , "similarity_threshold" : global_threshold , "converged" : False })
196+ logger .debug (f"Global Threshold: { global_threshold } | Used this to backfill { len (failed_to_converge )} features." )
197+
183198 # Join
184199 logger .debug ("Joining thresholds with segmentation..." )
200+ thresholds = pl .DataFrame (thresholds )
185201 segmentation = (
186202 segmentation
187203 .join (thresholds , on = tx_fields .feature , how = 'left' )
188204 .drop (tx_fields .feature )
189205 )
206+
207+ logger .debug ("Segmentation complete." )
190208 return segmentation
191209
192210
193211 # Debugging callbacks
194212 def on_predict_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx = 0 ):
213+ mask = batch ['tx' ]['predict_mask' ]
214+ self .n_tx_predicted += mask .sum ().item ()
195215 if not self .debug :
196216 return
197217 log_every = 50
198218 if batch_idx % log_every == 0 :
199219 self .segger_logger .info (
200- f"Finished prediction batch '{ batch_idx } '."
220+ f"Finished prediction batch '{ batch_idx } '. # TX so far { self . n_tx_predicted / 1e6 :.1f } M "
201221 )
202222
203223 def on_fit_start (self , trainer , pl_module ):
0 commit comments