1111import torch
1212from ..io import TrainingTranscriptFields , TrainingBoundaryFields
1313from . import ISTDataModule
14+ from .utils .anndata import anndata_from_transcripts
1415
1516class ISTSegmentationWriter (BasePredictionWriter ):
1617 """TODO: Description
@@ -20,11 +21,16 @@ class ISTSegmentationWriter(BasePredictionWriter):
2021 output_directory : Path
2122 Path to write outputs.
2223 """
23-
24- def __init__ (self , output_directory : Path , debug : bool = False ):
24+ def __init__ (
25+ self ,
26+ output_directory : Path ,
27+ save_anndata : bool = True ,
28+ debug : bool = False
29+ ):
2530 # "write" callback at the end of prediction epoch
2631 super ().__init__ (write_interval = "epoch" )
2732 self .output_directory = Path (output_directory )
33+ self .save_anndata = save_anndata
2834 self .segger_logger = logging .getLogger (__name__ )
2935
3036 # setup debugging
@@ -71,6 +77,53 @@ def write_on_epoch_end(
7177 self .segger_logger .debug (f"Writing segmentation output to { self .output_directory } ..." )
7278 segmentation .write_parquet (self .output_directory / 'segger_segmentation.parquet' )
7379
80+ # write anndata
81+ self .segger_logger .debug ("Writing AnnData output..." )
82+ if self .save_anndata :
83+ self .write_anndata (trainer , segmentation )
84+
85+ def write_anndata (
86+ self ,
87+ trainer : Trainer ,
88+ segmentation : pl .DataFrame
89+ ):
90+ # Get fields
91+ tx_fields = TrainingTranscriptFields ()
92+
93+ tx = trainer .datamodule .tx
94+ transcripts = (
95+ segmentation
96+ .join (
97+ tx .select ([
98+ tx_fields .row_index ,
99+ tx_fields .x ,
100+ tx_fields .y ,
101+ tx_fields .feature ,
102+ ]),
103+ on = tx_fields .row_index ,
104+ how = 'left' ,
105+ )
106+ .rename ({tx_fields .feature : "segger_gene" })
107+ .select ([
108+ tx_fields .row_index ,
109+ "segger_gene" ,
110+ "segger_cell_id" ,
111+ "segger_similarity" ,
112+ "similarity_threshold" ,
113+ tx_fields .x ,
114+ tx_fields .y ,
115+ ])
116+ )
117+
118+ adata = anndata_from_transcripts (
119+ transcripts ,
120+ feature_column = "segger_gene" ,
121+ cell_id_column = "segger_cell_id" ,
122+ score_column = "segger_similarity" ,
123+ coordinate_columns = [tx_fields .x , tx_fields .y ],
124+ )
125+ adata .write_h5ad (self .output_directory / 'segger_anndata.h5ad' )
126+
74127 @classmethod
75128 def assign_transcripts_to_cells (
76129 cls ,
0 commit comments