Skip to content

Commit e6f2066

Browse files
mossishahiTobiaspk
authored andcommitted
Add save_anndata parameter to ISTSegmentationWriter and update segment function.
The user can now pass --save_anndata to save anndata out of segmentation output
1 parent ea21a1b commit e6f2066

1 file changed

Lines changed: 55 additions & 2 deletions

File tree

src/segger/data/writer.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
from ..io import TrainingTranscriptFields, TrainingBoundaryFields
1313
from . import ISTDataModule
14+
from .utils.anndata import anndata_from_transcripts
1415

1516
class ISTSegmentationWriter(BasePredictionWriter):
1617
"""TODO: Description
@@ -20,11 +21,16 @@ class ISTSegmentationWriter(BasePredictionWriter):
2021
output_directory : Path
2122
Path to write outputs.
2223
"""
23-
24-
def __init__(self, output_directory: Path, debug: bool = False):
24+
def __init__(
25+
self,
26+
output_directory: Path,
27+
save_anndata: bool = True,
28+
debug: bool = False
29+
):
2530
# "write" callback at the end of prediction epoch
2631
super().__init__(write_interval="epoch")
2732
self.output_directory = Path(output_directory)
33+
self.save_anndata = save_anndata
2834
self.segger_logger = logging.getLogger(__name__)
2935

3036
# setup debugging
@@ -71,6 +77,53 @@ def write_on_epoch_end(
7177
self.segger_logger.debug(f"Writing segmentation output to {self.output_directory}...")
7278
segmentation.write_parquet(self.output_directory / 'segger_segmentation.parquet')
7379

80+
# write anndata
81+
self.segger_logger.debug("Writing AnnData output...")
82+
if self.save_anndata:
83+
self.write_anndata(trainer, segmentation)
84+
85+
def write_anndata(
86+
self,
87+
trainer: Trainer,
88+
segmentation: pl.DataFrame
89+
):
90+
# Get fields
91+
tx_fields = TrainingTranscriptFields()
92+
93+
tx = trainer.datamodule.tx
94+
transcripts = (
95+
segmentation
96+
.join(
97+
tx.select([
98+
tx_fields.row_index,
99+
tx_fields.x,
100+
tx_fields.y,
101+
tx_fields.feature,
102+
]),
103+
on=tx_fields.row_index,
104+
how='left',
105+
)
106+
.rename({tx_fields.feature: "segger_gene"})
107+
.select([
108+
tx_fields.row_index,
109+
"segger_gene",
110+
"segger_cell_id",
111+
"segger_similarity",
112+
"similarity_threshold",
113+
tx_fields.x,
114+
tx_fields.y,
115+
])
116+
)
117+
118+
adata = anndata_from_transcripts(
119+
transcripts,
120+
feature_column="segger_gene",
121+
cell_id_column="segger_cell_id",
122+
score_column="segger_similarity",
123+
coordinate_columns=[tx_fields.x, tx_fields.y],
124+
)
125+
adata.write_h5ad(self.output_directory / 'segger_anndata.h5ad')
126+
74127
@classmethod
75128
def assign_transcripts_to_cells(
76129
cls,

0 commit comments

Comments
 (0)