77from types import MappingProxyType
88from typing import TYPE_CHECKING , Any
99from shapely .geometry import Polygon
10-
10+ import pyarrow .parquet as pq
11+ import anndata
12+ from scipy .sparse import csc_matrix
1113import h5py
1214import numpy as np
1315import pandas as pd
3941
4042RNG = default_rng (0 )
4143
44+ def make_filtered_nucleus_adata (
45+ filtered_matrix_h5_path : str ,
46+ barcode_mappings_parquet_path : str ,
47+ bin_col_name : str = 'square_002um' ,
48+ aggregate_col_name : str = 'cell_id'
49+ ) -> anndata .AnnData :
50+ """Generate a filtered AnnData object by aggregating 2um binned data
51+ based on nucleus segmentation.
52+ Uses a 2um filtered_feature_bc_matrix.h5 file and a barcode_mappings.parquet file containing
53+ barcode mappings, filters the data to include only valid nucleus mappings,
54+ and aggregates the data based on specified bin into cell IDs which only contain
55+ the 2um square data under segmented nuclei.
56+
57+ Parameters:
58+ -----------
59+ filtered_matrix_h5_path : str
60+ Path to the 10x Genomics HDF5 matrix file.
61+ barcode_mappings_parquet_path : str
62+ Path to the Parquet file containing barcode mappings.
63+ bin_col_name : str, optional
64+ Column name in the barcode mappings that specifies the spatial bin (default is 'square_002um').
65+ aggregate_col_name : str, optional
66+ Column name in the barcode mappings that specifies the aggregate cell ID (default is 'cell_id').
67+ Returns:
68+ --------
69+ anndata.AnnData
70+ An AnnData object where the observations correspond to filtered cell IDs
71+ and the variables correspond to the original features from the input data.
72+ """
73+ # Read in the necessary files
74+ adata_2um = sc .read_10x_h5 (filtered_matrix_h5_path )
75+ barcode_mappings = pq .read_table (barcode_mappings_parquet_path )
76+
77+ # Filter to only include valid cell IDs that are in both nucleus and cell
78+ barcode_mappings = barcode_mappings .filter ((barcode_mappings ['cell_id' ].is_valid ()) and barcode_mappings ["in_nucleus" ] and barcode_mappings ["in_cell" ])
79+
80+ # Filter the 2um adata to only include squares present in the barcode mappings
81+ valid_squares = barcode_mappings [bin_col_name ].unique ()
82+ squares_to_keep = np .intersect1d (adata_2um .obs_names , valid_squares )
83+ adata_filtered = adata_2um [squares_to_keep , :].copy ()
84+
85+ # Map each square to its corresponding cell ID
86+ square_to_cell_map = dict (zip (
87+ barcode_mappings [bin_col_name ].to_pylist (),
88+ barcode_mappings [aggregate_col_name ].to_pylist ()
89+
90+ ))
91+ ordered_cell_ids = [square_to_cell_map [square ] for square in adata_filtered .obs_names ]
92+ unique_cells = list (dict .fromkeys (ordered_cell_ids ).keys ())
93+ cell_to_idx = {cell : i for i , cell in enumerate (unique_cells )}
94+
95+ # Make the aggregation matrix
96+ col_indices = [cell_to_idx [cell ] for cell in ordered_cell_ids ]
97+ row_indices = np .arange (len (ordered_cell_ids ))
98+ data = np .ones_like (row_indices )
99+
100+ aggregation_matrix = csc_matrix (
101+ (data , (row_indices , col_indices )),
102+ shape = (adata_filtered .n_obs , len (unique_cells ))
103+ )
104+
105+ # Make the final AnnData object where cell IDs are filtered
106+ # to the data under the segmented nuclei
107+ nucleus_matrix_sparse = adata_filtered .X .T .dot (aggregation_matrix )
108+ adata_nucleus = sc .AnnData (nucleus_matrix_sparse .T )
109+ adata_nucleus .obs_names = unique_cells
110+ adata_nucleus .var = adata_filtered .var
111+
112+ return adata_nucleus
113+ def make_shapes_transformation (scale_factors_path : Path , dataset_id : str ) -> dict [str , Scale ]:
114+ """Load scale factors for lowres and hires images and create transformations.
115+
116+ Parameters
117+ ----------
118+ scale_factors_path : Path
119+ Path to the scale factors JSON file.
120+ dataset_id : str
121+ Unique identifier of the dataset.
122+
123+ Returns
124+ -------
125+ dict[str, Scale]
126+ A dictionary containing the transformations for lowres and hires images.
127+ """
128+ with open (scale_factors_path , 'r' ) as f :
129+ scale_data_hd = json .load (f )
130+ lowres_scale_factor_hd = scale_data_hd ['tissue_lowres_scalef' ]
131+ hires_scale_factor_hd = scale_data_hd ['tissue_hires_scalef' ]
132+
133+ return {
134+ f"{ dataset_id } _downscaled_lowres" : Scale (np .array ([lowres_scale_factor_hd , lowres_scale_factor_hd ]), axes = ("x" , "y" )),
135+ f"{ dataset_id } _downscaled_hires" : Scale (np .array ([hires_scale_factor_hd , hires_scale_factor_hd ]), axes = ("x" , "y" ))
136+ }
137+ def make_geojson_features_map (geojson_path : Path ) -> dict [str , Any ]:
138+ with open (geojson_path , 'r' ) as f :
139+ geojson_data = json .load (f )
140+ return {
141+ f"cellid_{ feature ['properties' ]['cell_id' ]:09d} -1" : feature
142+ for feature in geojson_data ['features' ]
143+ }
42144@inject_docs (vx = VisiumHDKeys )
43145def visium_hd (
44146 path : str | Path ,
45147 dataset_id : str | None = None ,
46148 filtered_counts_file : bool = True ,
47149 load_segmentations_only : bool = True ,
150+ load_nucleus_segmentations : bool = False ,
48151 bin_size : int | list [int ] | None = None ,
49152 bins_as_squares : bool = True ,
50153 annotate_table_by_labels : bool = False ,
@@ -108,10 +211,14 @@ def visium_hd(
108211 # Check for segmentation files
109212 SEGMENTED_OUTPUTS_PATH = path / VisiumHDKeys .SEGMENTATION_OUTPUTS
110213 COUNT_MATRIX_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys .FILTERED_CELL_COUNTS_FILE
111- GEOJSON_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys .SEGMENTATION_GEOJSON_PATH
214+ CELL_GEOJSON_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys .CELL_SEGMENTATION_GEOJSON_PATH
215+ NUCLEUS_GEOJSON_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys .NUCLEUS_SEGMENTATION_GEOJSON_PATH
112216 SCALE_FACTORS_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys .SPATIAL / VisiumHDKeys .SCALEFACTORS_FILE
113- segmentation_files_exist = COUNT_MATRIX_PATH .exists () and GEOJSON_PATH .exists () and SCALE_FACTORS_PATH .exists ()
114-
217+ BARCODE_MAPPINGS_PATH = path / VisiumHDKeys .BARCODE_MAPPINGS_FILE
218+ FILTERED_MATRIX_2U_PATH = path / VisiumHDKeys .BINNED_OUTPUTS / f"{ VisiumHDKeys .BIN_PREFIX } 002um" / VisiumHDKeys .FILTERED_COUNTS_FILE
219+ cell_segmentation_files_exist = COUNT_MATRIX_PATH .exists () and CELL_GEOJSON_PATH .exists () and SCALE_FACTORS_PATH .exists ()
220+ nucleus_segmentation_files_exist = NUCLEUS_GEOJSON_PATH .exists () and BARCODE_MAPPINGS_PATH .exists () and FILTERED_MATRIX_2U_PATH .exists ()
221+
115222 if dataset_id is None :
116223 dataset_id = _infer_dataset_id (path )
117224
@@ -277,29 +384,14 @@ def _get_bins(path_bins: Path) -> list[str]:
277384 tables [bin_size_str ].var_names_make_unique ()
278385
279386 # Integrate the segmentation data (skipped if segmentation files are not found)
280- if segmentation_files_exist :
387+ if cell_segmentation_files_exist :
281388 print ("Found segmentation data. Incorporating cell_segmentations." )
282389 adata_hd = sc .read_10x_h5 (COUNT_MATRIX_PATH )
283390 adata_hd .var_names_make_unique ()
284-
285- with open (SCALE_FACTORS_PATH , 'r' ) as f :
286- scale_data_hd = json .load (f )
287-
288- with open (GEOJSON_PATH , 'r' ) as f :
289- geojson_data = json .load (f )
290-
291- lowres_scale_factor_hd = scale_data_hd ['tissue_lowres_scalef' ]
292- hires_scale_factor_hd = scale_data_hd ['tissue_hires_scalef' ]
293-
294- shapes_transformations_hd = {
295- f"{ dataset_id } _downscaled_lowres" : Scale (np .array ([lowres_scale_factor_hd , lowres_scale_factor_hd ]), axes = ("x" , "y" )),
296- f"{ dataset_id } _downscaled_hires" : Scale (np .array ([hires_scale_factor_hd , hires_scale_factor_hd ]), axes = ("x" , "y" ))
297- }
298-
299- geojson_features_map = {
300- f"cellid_{ feature ['properties' ]['cell_id' ]:09d} -1" : feature
301- for feature in geojson_data ['features' ]
302- }
391+
392+ shapes_transformations_hd = make_shapes_transformation (scale_factors_path = SCALE_FACTORS_PATH , dataset_id = dataset_id )
393+
394+ geojson_features_map = make_geojson_features_map (CELL_GEOJSON_PATH )
303395
304396 geometries = []
305397 cell_ids_ordered = []
@@ -324,18 +416,63 @@ def _get_bins(path_bins: Path) -> list[str]:
324416 }, index = cell_ids_ordered )
325417
326418 SHAPES_KEY_HD = f"{ dataset_id } _{ VisiumHDKeys .CELL_SEG_KEY_HD } "
327- adata_hd .obs ['cell_id' ] = adata_hd .obs .index
328- adata_hd .obs ['region' ] = SHAPES_KEY_HD
329- adata_hd .obs ['region' ] = adata_hd .obs ['region' ].astype ('category' )
330- adata_hd = adata_hd [shapes_gdf .index ].copy ()
419+ cell_adata_hd = adata_hd .copy ()
420+ cell_adata_hd .obs ['cell_id' ] = cell_adata_hd .obs .index
421+ cell_adata_hd .obs ['region' ] = SHAPES_KEY_HD
422+ cell_adata_hd .obs ['region' ] = cell_adata_hd .obs ['region' ].astype ('category' )
423+ cell_adata_hd = cell_adata_hd [shapes_gdf .index ].copy ()
331424
332425 shapes [SHAPES_KEY_HD ] = ShapesModel .parse (shapes_gdf , transformations = shapes_transformations_hd )
333426 tables [VisiumHDKeys .CELL_SEG_KEY_HD ] = TableModel .parse (
334- adata_hd ,
427+ cell_adata_hd ,
335428 region = SHAPES_KEY_HD ,
336429 region_key = 'region' ,
337430 instance_key = 'cell_id'
338431 )
432+ # load nucleus segmentations if available
433+ if nucleus_segmentation_files_exist and load_nucleus_segmentations :
434+ print ("Found nucleus segmentation data. Incorporating nucleus_segmentations." )
435+
436+ nucleus_adata_hd = make_filtered_nucleus_adata (filtered_matrix_h5_path = FILTERED_MATRIX_2U_PATH ,barcode_mappings_parquet_path = BARCODE_MAPPINGS_PATH )
437+
438+ with open (NUCLEUS_GEOJSON_PATH , 'r' ) as f :
439+ geojson_data = json .load (f )
440+
441+ geometries = []
442+ cell_ids_ordered = []
443+
444+ for obs_index_str in adata_hd .obs .index :
445+ feature = geojson_features_map .get (obs_index_str )
446+ if feature :
447+ polygon_coords = np .array (feature ['geometry' ]['coordinates' ][0 ])
448+ geometries .append (Polygon (polygon_coords ))
449+ cell_ids_ordered .append (obs_index_str )
450+ else :
451+ geometries .append (None )
452+ cell_ids_ordered .append (obs_index_str )
453+
454+ valid_indices = [i for i , geom in enumerate (geometries ) if geom is not None ]
455+ geometries = [geometries [i ] for i in valid_indices ]
456+ cell_ids_ordered = [cell_ids_ordered [i ] for i in valid_indices ]
457+
458+ shapes_gdf = GeoDataFrame ({
459+ 'cell_id' : cell_ids_ordered ,
460+ 'geometry' : geometries
461+ }, index = cell_ids_ordered )
462+
463+ SHAPES_KEY_HD = f"{ dataset_id } _{ VisiumHDKeys .NUCLEUS_SEG_KEY_HD } "
464+ nucleus_adata_hd .obs ['cell_id' ] = nucleus_adata_hd .obs .index
465+ nucleus_adata_hd .obs ['region' ] = SHAPES_KEY_HD
466+ nucleus_adata_hd .obs ['region' ] = nucleus_adata_hd .obs ['region' ].astype ('category' )
467+ nucleus_adata_hd = nucleus_adata_hd [shapes_gdf .index ].copy ()
468+
469+ shapes [SHAPES_KEY_HD ] = ShapesModel .parse (shapes_gdf , transformations = shapes_transformations_hd )
470+ tables [VisiumHDKeys .CELL_SEG_KEY_HD ] = TableModel .parse (
471+ nucleus_adata_hd ,
472+ region = SHAPES_KEY_HD ,
473+ region_key = 'region' ,
474+ instance_key = 'cell_id'
475+ )
339476
340477 # Read all images and add transformations for both binning and segmentation
341478 fullres_image_file_paths = []
@@ -394,7 +531,7 @@ def _get_bins(path_bins: Path) -> list[str]:
394531 },
395532 set_all = True ,
396533 )
397- if segmentation_files_exist :
534+ if cell_segmentation_files_exist :
398535 set_transformation (
399536 images [dataset_id + "_hires_image" ],
400537 {f"{ dataset_id } _downscaled_hires" : Identity ()},
@@ -422,7 +559,7 @@ def _get_bins(path_bins: Path) -> list[str]:
422559 },
423560 set_all = True ,
424561 )
425- if segmentation_files_exist :
562+ if cell_segmentation_files_exist :
426563 set_transformation (
427564 images [dataset_id + "_lowres_image" ],
428565 {f"{ dataset_id } _downscaled_lowres" : Identity ()},
0 commit comments