Skip to content

Commit a6d7e9f

Browse files
make nuc segmentation and cut dup code
1 parent b1e6d06 commit a6d7e9f

2 files changed

Lines changed: 172 additions & 32 deletions

File tree

src/spatialdata_io/_constants/_constants.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,10 @@ class VisiumHDKeys(ModeEnum):
359359
FILTERED_COUNTS_FILE = "filtered_feature_bc_matrix.h5"
360360
RAW_COUNTS_FILE = "raw_feature_bc_matrix.h5"
361361
TISSUE_POSITIONS_FILE = "tissue_positions.parquet"
362+
BARCODE_MAPPINGS_FILE = "barcode_mappings.parquet"
362363
FILTERED_CELL_COUNTS_FILE = "filtered_feature_cell_matrix.h5"
363-
SEGMENTATION_GEOJSON_PATH = "cell_segmentations.geojson"
364+
CELL_SEGMENTATION_GEOJSON_PATH = "cell_segmentations.geojson"
365+
NUCLEUS_SEGMENTATION_GEOJSON_PATH = "nucleus_segmentations.geojson"
364366

365367
# images
366368
IMAGE_HIRES_FILE = "tissue_hires_image.png"
@@ -405,3 +407,4 @@ class VisiumHDKeys(ModeEnum):
405407

406408
# Cell Segmentation keys
407409
CELL_SEG_KEY_HD = 'cell_segmentations'
410+
NUCLEUS_SEG_KEY_HD = 'nucleus_segmentations'

src/spatialdata_io/readers/visium_hd.py

Lines changed: 168 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from types import MappingProxyType
88
from typing import TYPE_CHECKING, Any
99
from shapely.geometry import Polygon
10-
10+
import pyarrow.parquet as pq
11+
import anndata
12+
from scipy.sparse import csc_matrix
1113
import h5py
1214
import numpy as np
1315
import pandas as pd
@@ -39,12 +41,113 @@
3941

4042
RNG = default_rng(0)
4143

44+
def make_filtered_nucleus_adata(
45+
filtered_matrix_h5_path: str,
46+
barcode_mappings_parquet_path: str,
47+
bin_col_name: str = 'square_002um',
48+
aggregate_col_name: str = 'cell_id'
49+
) -> anndata.AnnData:
50+
"""Generate a filtered AnnData object by aggregating 2um binned data
51+
based on nucleus segmentation.
52+
Uses a 2um filtered_feature_bc_matrix.h5 file and a barcode_mappings.parquet file containing
53+
barcode mappings, filters the data to include only valid nucleus mappings,
54+
and aggregates the data based on specified bin into cell IDs which only contain
55+
the 2um square data under segmented nuclei.
56+
57+
Parameters:
58+
-----------
59+
filtered_matrix_h5_path : str
60+
Path to the 10x Genomics HDF5 matrix file.
61+
barcode_mappings_parquet_path : str
62+
Path to the Parquet file containing barcode mappings.
63+
bin_col_name : str, optional
64+
Column name in the barcode mappings that specifies the spatial bin (default is 'square_002um').
65+
aggregate_col_name : str, optional
66+
Column name in the barcode mappings that specifies the aggregate cell ID (default is 'cell_id').
67+
Returns:
68+
--------
69+
anndata.AnnData
70+
An AnnData object where the observations correspond to filtered cell IDs
71+
and the variables correspond to the original features from the input data.
72+
"""
73+
# Read in the necessary files
74+
adata_2um = sc.read_10x_h5(filtered_matrix_h5_path)
75+
barcode_mappings = pq.read_table(barcode_mappings_parquet_path)
76+
77+
# Filter to only include valid cell IDs that are in both nucleus and cell
78+
barcode_mappings = barcode_mappings.filter((barcode_mappings['cell_id'].is_valid()) and barcode_mappings["in_nucleus"] and barcode_mappings["in_cell"])
79+
80+
# Filter the 2um adata to only include squares present in the barcode mappings
81+
valid_squares = barcode_mappings[bin_col_name].unique()
82+
squares_to_keep = np.intersect1d(adata_2um.obs_names, valid_squares)
83+
adata_filtered = adata_2um[squares_to_keep, :].copy()
84+
85+
# Map each square to its corresponding cell ID
86+
square_to_cell_map = dict(zip(
87+
barcode_mappings[bin_col_name].to_pylist(),
88+
barcode_mappings[aggregate_col_name].to_pylist()
89+
90+
))
91+
ordered_cell_ids = [square_to_cell_map[square] for square in adata_filtered.obs_names]
92+
unique_cells = list(dict.fromkeys(ordered_cell_ids).keys())
93+
cell_to_idx = {cell: i for i, cell in enumerate(unique_cells)}
94+
95+
# Make the aggregation matrix
96+
col_indices = [cell_to_idx[cell] for cell in ordered_cell_ids]
97+
row_indices = np.arange(len(ordered_cell_ids))
98+
data = np.ones_like(row_indices)
99+
100+
aggregation_matrix = csc_matrix(
101+
(data, (row_indices, col_indices)),
102+
shape=(adata_filtered.n_obs, len(unique_cells))
103+
)
104+
105+
# Make the final AnnData object where cell IDs are filtered
106+
# to the data under the segmented nuclei
107+
nucleus_matrix_sparse = adata_filtered.X.T.dot(aggregation_matrix)
108+
adata_nucleus = sc.AnnData(nucleus_matrix_sparse.T)
109+
adata_nucleus.obs_names = unique_cells
110+
adata_nucleus.var = adata_filtered.var
111+
112+
return adata_nucleus
113+
def make_shapes_transformation(scale_factors_path: Path, dataset_id: str) -> dict[str, Scale]:
114+
"""Load scale factors for lowres and hires images and create transformations.
115+
116+
Parameters
117+
----------
118+
scale_factors_path : Path
119+
Path to the scale factors JSON file.
120+
dataset_id : str
121+
Unique identifier of the dataset.
122+
123+
Returns
124+
-------
125+
dict[str, Scale]
126+
A dictionary containing the transformations for lowres and hires images.
127+
"""
128+
with open(scale_factors_path, 'r') as f:
129+
scale_data_hd = json.load(f)
130+
lowres_scale_factor_hd = scale_data_hd['tissue_lowres_scalef']
131+
hires_scale_factor_hd = scale_data_hd['tissue_hires_scalef']
132+
133+
return {
134+
f"{dataset_id}_downscaled_lowres": Scale(np.array([lowres_scale_factor_hd, lowres_scale_factor_hd]), axes=("x", "y")),
135+
f"{dataset_id}_downscaled_hires": Scale(np.array([hires_scale_factor_hd, hires_scale_factor_hd]), axes=("x", "y"))
136+
}
137+
def make_geojson_features_map(geojson_path: Path) -> dict[str, Any]:
138+
with open(geojson_path, 'r') as f:
139+
geojson_data = json.load(f)
140+
return {
141+
f"cellid_{feature['properties']['cell_id']:09d}-1": feature
142+
for feature in geojson_data['features']
143+
}
42144
@inject_docs(vx=VisiumHDKeys)
43145
def visium_hd(
44146
path: str | Path,
45147
dataset_id: str | None = None,
46148
filtered_counts_file: bool = True,
47149
load_segmentations_only: bool = True,
150+
load_nucleus_segmentations: bool = False,
48151
bin_size: int | list[int] | None = None,
49152
bins_as_squares: bool = True,
50153
annotate_table_by_labels: bool = False,
@@ -108,10 +211,14 @@ def visium_hd(
108211
# Check for segmentation files
109212
SEGMENTED_OUTPUTS_PATH = path / VisiumHDKeys.SEGMENTATION_OUTPUTS
110213
COUNT_MATRIX_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys.FILTERED_CELL_COUNTS_FILE
111-
GEOJSON_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys.SEGMENTATION_GEOJSON_PATH
214+
CELL_GEOJSON_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys.CELL_SEGMENTATION_GEOJSON_PATH
215+
NUCLEUS_GEOJSON_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys.NUCLEUS_SEGMENTATION_GEOJSON_PATH
112216
SCALE_FACTORS_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys.SPATIAL / VisiumHDKeys.SCALEFACTORS_FILE
113-
segmentation_files_exist = COUNT_MATRIX_PATH.exists() and GEOJSON_PATH.exists() and SCALE_FACTORS_PATH.exists()
114-
217+
BARCODE_MAPPINGS_PATH = path / VisiumHDKeys.BARCODE_MAPPINGS_FILE
218+
FILTERED_MATRIX_2U_PATH = path / VisiumHDKeys.BINNED_OUTPUTS / f"{VisiumHDKeys.BIN_PREFIX}002um" / VisiumHDKeys.FILTERED_COUNTS_FILE
219+
cell_segmentation_files_exist = COUNT_MATRIX_PATH.exists() and CELL_GEOJSON_PATH.exists() and SCALE_FACTORS_PATH.exists()
220+
nucleus_segmentation_files_exist = NUCLEUS_GEOJSON_PATH.exists() and BARCODE_MAPPINGS_PATH.exists() and FILTERED_MATRIX_2U_PATH.exists()
221+
115222
if dataset_id is None:
116223
dataset_id = _infer_dataset_id(path)
117224

@@ -277,29 +384,14 @@ def _get_bins(path_bins: Path) -> list[str]:
277384
tables[bin_size_str].var_names_make_unique()
278385

279386
# Integrate the segmentation data (skipped if segmentation files are not found)
280-
if segmentation_files_exist:
387+
if cell_segmentation_files_exist:
281388
print("Found segmentation data. Incorporating cell_segmentations.")
282389
adata_hd = sc.read_10x_h5(COUNT_MATRIX_PATH)
283390
adata_hd.var_names_make_unique()
284-
285-
with open(SCALE_FACTORS_PATH, 'r') as f:
286-
scale_data_hd = json.load(f)
287-
288-
with open(GEOJSON_PATH, 'r') as f:
289-
geojson_data = json.load(f)
290-
291-
lowres_scale_factor_hd = scale_data_hd['tissue_lowres_scalef']
292-
hires_scale_factor_hd = scale_data_hd['tissue_hires_scalef']
293-
294-
shapes_transformations_hd = {
295-
f"{dataset_id}_downscaled_lowres": Scale(np.array([lowres_scale_factor_hd, lowres_scale_factor_hd]), axes=("x", "y")),
296-
f"{dataset_id}_downscaled_hires": Scale(np.array([hires_scale_factor_hd, hires_scale_factor_hd]), axes=("x", "y"))
297-
}
298-
299-
geojson_features_map = {
300-
f"cellid_{feature['properties']['cell_id']:09d}-1": feature
301-
for feature in geojson_data['features']
302-
}
391+
392+
shapes_transformations_hd = make_shapes_transformation(scale_factors_path=SCALE_FACTORS_PATH, dataset_id=dataset_id)
393+
394+
geojson_features_map = make_geojson_features_map(CELL_GEOJSON_PATH)
303395

304396
geometries = []
305397
cell_ids_ordered = []
@@ -324,18 +416,63 @@ def _get_bins(path_bins: Path) -> list[str]:
324416
}, index=cell_ids_ordered)
325417

326418
SHAPES_KEY_HD = f"{dataset_id}_{VisiumHDKeys.CELL_SEG_KEY_HD}"
327-
adata_hd.obs['cell_id'] = adata_hd.obs.index
328-
adata_hd.obs['region'] = SHAPES_KEY_HD
329-
adata_hd.obs['region'] = adata_hd.obs['region'].astype('category')
330-
adata_hd = adata_hd[shapes_gdf.index].copy()
419+
cell_adata_hd = adata_hd.copy()
420+
cell_adata_hd.obs['cell_id'] = cell_adata_hd.obs.index
421+
cell_adata_hd.obs['region'] = SHAPES_KEY_HD
422+
cell_adata_hd.obs['region'] = cell_adata_hd.obs['region'].astype('category')
423+
cell_adata_hd = cell_adata_hd[shapes_gdf.index].copy()
331424

332425
shapes[SHAPES_KEY_HD] = ShapesModel.parse(shapes_gdf, transformations=shapes_transformations_hd)
333426
tables[VisiumHDKeys.CELL_SEG_KEY_HD] = TableModel.parse(
334-
adata_hd,
427+
cell_adata_hd,
335428
region=SHAPES_KEY_HD,
336429
region_key='region',
337430
instance_key='cell_id'
338431
)
432+
# load nucleus segmentations if available
433+
if nucleus_segmentation_files_exist and load_nucleus_segmentations:
434+
print("Found nucleus segmentation data. Incorporating nucleus_segmentations.")
435+
436+
nucleus_adata_hd = make_filtered_nucleus_adata(filtered_matrix_h5_path=FILTERED_MATRIX_2U_PATH,barcode_mappings_parquet_path=BARCODE_MAPPINGS_PATH)
437+
438+
with open(NUCLEUS_GEOJSON_PATH, 'r') as f:
439+
geojson_data = json.load(f)
440+
441+
geometries = []
442+
cell_ids_ordered = []
443+
444+
for obs_index_str in adata_hd.obs.index:
445+
feature = geojson_features_map.get(obs_index_str)
446+
if feature:
447+
polygon_coords = np.array(feature['geometry']['coordinates'][0])
448+
geometries.append(Polygon(polygon_coords))
449+
cell_ids_ordered.append(obs_index_str)
450+
else:
451+
geometries.append(None)
452+
cell_ids_ordered.append(obs_index_str)
453+
454+
valid_indices = [i for i, geom in enumerate(geometries) if geom is not None]
455+
geometries = [geometries[i] for i in valid_indices]
456+
cell_ids_ordered = [cell_ids_ordered[i] for i in valid_indices]
457+
458+
shapes_gdf = GeoDataFrame({
459+
'cell_id': cell_ids_ordered,
460+
'geometry': geometries
461+
}, index=cell_ids_ordered)
462+
463+
SHAPES_KEY_HD = f"{dataset_id}_{VisiumHDKeys.NUCLEUS_SEG_KEY_HD}"
464+
nucleus_adata_hd.obs['cell_id'] = nucleus_adata_hd.obs.index
465+
nucleus_adata_hd.obs['region'] = SHAPES_KEY_HD
466+
nucleus_adata_hd.obs['region'] = nucleus_adata_hd.obs['region'].astype('category')
467+
nucleus_adata_hd = nucleus_adata_hd[shapes_gdf.index].copy()
468+
469+
shapes[SHAPES_KEY_HD] = ShapesModel.parse(shapes_gdf, transformations=shapes_transformations_hd)
470+
tables[VisiumHDKeys.CELL_SEG_KEY_HD] = TableModel.parse(
471+
nucleus_adata_hd,
472+
region=SHAPES_KEY_HD,
473+
region_key='region',
474+
instance_key='cell_id'
475+
)
339476

340477
# Read all images and add transformations for both binning and segmentation
341478
fullres_image_file_paths = []
@@ -394,7 +531,7 @@ def _get_bins(path_bins: Path) -> list[str]:
394531
},
395532
set_all=True,
396533
)
397-
if segmentation_files_exist:
534+
if cell_segmentation_files_exist:
398535
set_transformation(
399536
images[dataset_id + "_hires_image"],
400537
{f"{dataset_id}_downscaled_hires": Identity()},
@@ -422,7 +559,7 @@ def _get_bins(path_bins: Path) -> list[str]:
422559
},
423560
set_all=True,
424561
)
425-
if segmentation_files_exist:
562+
if cell_segmentation_files_exist:
426563
set_transformation(
427564
images[dataset_id + "_lowres_image"],
428565
{f"{dataset_id}_downscaled_lowres": Identity()},

0 commit comments

Comments
 (0)