Skip to content

Commit 34ac722

Browse files
make local functions
1 parent 31379c8 commit 34ac722

1 file changed

Lines changed: 147 additions & 143 deletions

File tree

src/spatialdata_io/readers/visium_hd.py

Lines changed: 147 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
rasterize_bins,
2626
rasterize_bins_link_table_to_labels,
2727
)
28-
from spatialdata.models import Image2DModel, ShapesModel, TableModel
28+
from spatialdata.models import Image2DModel, ShapesModel, TableModel, PointsModel
2929
from spatialdata.transformations import Affine, Identity, Scale, set_transformation
3030
from xarray import DataArray
3131

@@ -41,142 +41,6 @@
4141

4242
RNG = default_rng(0)
4343

44-
def make_filtered_nucleus_adata(
45-
filtered_matrix_h5_path: str,
46-
barcode_mappings_parquet_path: str,
47-
bin_col_name: str = 'square_002um',
48-
aggregate_col_name: str = 'cell_id'
49-
) -> anndata.AnnData:
50-
"""Generate a filtered AnnData object by aggregating 2um binned data
51-
based on nucleus segmentation.
52-
Uses a 2um filtered_feature_bc_matrix.h5 file and a barcode_mappings.parquet file containing
53-
barcode mappings, filters the data to include only valid nucleus mappings,
54-
and aggregates the data based on specified bin into cell IDs which only contain
55-
the 2um square data under segmented nuclei.
56-
57-
Parameters:
58-
-----------
59-
filtered_matrix_h5_path : str
60-
Path to the 10x Genomics HDF5 matrix file.
61-
barcode_mappings_parquet_path : str
62-
Path to the Parquet file containing barcode mappings.
63-
bin_col_name : str, optional
64-
Column name in the barcode mappings that specifies the spatial bin (default is 'square_002um').
65-
aggregate_col_name : str, optional
66-
Column name in the barcode mappings that specifies the aggregate cell ID (default is 'cell_id').
67-
Returns:
68-
--------
69-
anndata.AnnData
70-
An AnnData object where the observations correspond to filtered cell IDs
71-
and the variables correspond to the original features from the input data.
72-
"""
73-
# Read in the necessary files
74-
adata_2um = sc.read_10x_h5(filtered_matrix_h5_path)
75-
barcode_mappings = pq.read_table(barcode_mappings_parquet_path)
76-
77-
# Filter to only include valid cell IDs that are in both nucleus and cell
78-
barcode_mappings = barcode_mappings.filter((barcode_mappings['cell_id'].is_valid()) and barcode_mappings["in_nucleus"])
79-
80-
# Filter the 2um adata to only include squares present in the barcode mappings
81-
valid_squares = barcode_mappings[bin_col_name].unique()
82-
squares_to_keep = np.intersect1d(adata_2um.obs_names, valid_squares)
83-
adata_filtered = adata_2um[squares_to_keep, :].copy()
84-
85-
# Map each square to its corresponding cell ID
86-
square_to_cell_map = dict(zip(
87-
barcode_mappings[bin_col_name].to_pylist(),
88-
barcode_mappings[aggregate_col_name].to_pylist()
89-
90-
))
91-
ordered_cell_ids = [square_to_cell_map[square] for square in adata_filtered.obs_names]
92-
unique_cells = list(dict.fromkeys(ordered_cell_ids).keys())
93-
cell_to_idx = {cell: i for i, cell in enumerate(unique_cells)}
94-
95-
# Make the aggregation matrix
96-
col_indices = [cell_to_idx[cell] for cell in ordered_cell_ids]
97-
row_indices = np.arange(len(ordered_cell_ids))
98-
data = np.ones_like(row_indices)
99-
100-
aggregation_matrix = csc_matrix(
101-
(data, (row_indices, col_indices)),
102-
shape=(adata_filtered.n_obs, len(unique_cells))
103-
)
104-
105-
# Make the final AnnData object where cell IDs are filtered
106-
# to the data under the segmented nuclei
107-
nucleus_matrix_sparse = adata_filtered.X.T.dot(aggregation_matrix)
108-
adata_nucleus = sc.AnnData(nucleus_matrix_sparse.T)
109-
adata_nucleus.obs_names = unique_cells
110-
adata_nucleus.var = adata_filtered.var
111-
112-
return adata_nucleus
113-
def extract_geometries_from_geojson(adata: anndata.AnnData, geojson_features_map: dict[str, Any]) -> GeoDataFrame:
114-
"""Extract geometries and create a GeoDataFrame from a GeoJSON features map.
115-
116-
Parameters
117-
----------
118-
cell_adata : anndata.AnnData
119-
AnnData object containing cell data.
120-
geojson_features_map : dict[str, Any]
121-
Dictionary mapping cell IDs to GeoJSON features.
122-
123-
Returns
124-
-------
125-
GeoDataFrame
126-
A GeoDataFrame containing cell IDs and their corresponding geometries.
127-
"""
128-
geometries = []
129-
cell_ids_ordered = []
130-
131-
for obs_index_str in adata.obs.index:
132-
feature = geojson_features_map.get(obs_index_str)
133-
if feature:
134-
polygon_coords = np.array(feature['geometry']['coordinates'][0])
135-
geometries.append(Polygon(polygon_coords))
136-
cell_ids_ordered.append(obs_index_str)
137-
else:
138-
geometries.append(None)
139-
cell_ids_ordered.append(obs_index_str)
140-
141-
valid_indices = [i for i, geom in enumerate(geometries) if geom is not None]
142-
geometries = [geometries[i] for i in valid_indices]
143-
cell_ids_ordered = [cell_ids_ordered[i] for i in valid_indices]
144-
145-
return GeoDataFrame({
146-
'cell_id': cell_ids_ordered,
147-
'geometry': geometries
148-
}, index=cell_ids_ordered)
149-
def make_shapes_transformation(scale_factors_path: Path, dataset_id: str) -> dict[str, Scale]:
150-
"""Load scale factors for lowres and hires images and create transformations.
151-
152-
Parameters
153-
----------
154-
scale_factors_path : Path
155-
Path to the scale factors JSON file.
156-
dataset_id : str
157-
Unique identifier of the dataset.
158-
159-
Returns
160-
-------
161-
dict[str, Scale]
162-
A dictionary containing the transformations for lowres and hires images.
163-
"""
164-
with open(scale_factors_path, 'r') as f:
165-
scale_data_hd = json.load(f)
166-
lowres_scale_factor_hd = scale_data_hd['tissue_lowres_scalef']
167-
hires_scale_factor_hd = scale_data_hd['tissue_hires_scalef']
168-
169-
return {
170-
f"{dataset_id}_downscaled_lowres": Scale(np.array([lowres_scale_factor_hd, lowres_scale_factor_hd]), axes=("x", "y")),
171-
f"{dataset_id}_downscaled_hires": Scale(np.array([hires_scale_factor_hd, hires_scale_factor_hd]), axes=("x", "y"))
172-
}
173-
def make_geojson_features_map(geojson_path: Path) -> dict[str, Any]:
174-
with open(geojson_path, 'r') as f:
175-
geojson_data = json.load(f)
176-
return {
177-
f"cellid_{feature['properties']['cell_id']:09d}-1": feature
178-
for feature in geojson_data['features']
179-
}
18044
@inject_docs(vx=VisiumHDKeys)
18145
def visium_hd(
18246
path: str | Path,
@@ -425,9 +289,9 @@ def _get_bins(path_bins: Path) -> list[str]:
425289
cell_adata_hd = sc.read_10x_h5(COUNT_MATRIX_PATH)
426290
cell_adata_hd.var_names_make_unique()
427291

428-
shapes_transformations_hd = make_shapes_transformation(scale_factors_path=SCALE_FACTORS_PATH, dataset_id=dataset_id) # Used for both cell and nucleus segmentations
429-
cell_geojson_features_map = make_geojson_features_map(CELL_GEOJSON_PATH)
430-
cell_shapes_gdf = extract_geometries_from_geojson(cell_adata_hd, geojson_features_map=cell_geojson_features_map)
292+
shapes_transformations_hd = _make_shapes_transformation(scale_factors_path=SCALE_FACTORS_PATH, dataset_id=dataset_id) # Used for both cell and nucleus segmentations
293+
cell_geojson_features_map = _make_geojson_features_map(CELL_GEOJSON_PATH)
294+
cell_shapes_gdf = _extract_geometries_from_geojson(cell_adata_hd, geojson_features_map=cell_geojson_features_map)
431295

432296
SHAPES_KEY_HD = f"{dataset_id}_{VisiumHDKeys.CELL_SEG_KEY_HD}"
433297
cell_adata_hd.obs['cell_id'] = cell_adata_hd.obs.index
@@ -447,9 +311,9 @@ def _get_bins(path_bins: Path) -> list[str]:
447311
if nucleus_segmentation_files_exist and load_nucleus_segmentations:
448312
print("Found nucleus segmentation data. Incorporating nucleus_segmentations.")
449313

450-
nucleus_adata_hd = make_filtered_nucleus_adata(filtered_matrix_h5_path=FILTERED_MATRIX_2U_PATH,barcode_mappings_parquet_path=BARCODE_MAPPINGS_PATH)
451-
geojson_features_map = make_geojson_features_map(NUCLEUS_GEOJSON_PATH)
452-
nucleus_shapes_gdf = extract_geometries_from_geojson(adata=nucleus_adata_hd, geojson_features_map=geojson_features_map)
314+
nucleus_adata_hd = _make_filtered_nucleus_adata(filtered_matrix_h5_path=FILTERED_MATRIX_2U_PATH,barcode_mappings_parquet_path=BARCODE_MAPPINGS_PATH)
315+
geojson_features_map = _make_geojson_features_map(NUCLEUS_GEOJSON_PATH)
316+
nucleus_shapes_gdf = _extract_geometries_from_geojson(adata=nucleus_adata_hd, geojson_features_map=geojson_features_map)
453317

454318
SHAPES_KEY_HD = f"{dataset_id}_{VisiumHDKeys.NUCLEUS_SEG_KEY_HD}"
455319
nucleus_adata_hd.obs['cell_id'] = nucleus_adata_hd.obs.index
@@ -779,3 +643,143 @@ def _get_transform_matrices(metadata: dict[str, Any], hd_layout: dict[str, Any])
779643
transform_matrices[key] = np.array(coefficients).reshape(3, 3)
780644

781645
return transform_matrices
646+
647+
def _make_filtered_nucleus_adata(
648+
filtered_matrix_h5_path: str,
649+
barcode_mappings_parquet_path: str,
650+
bin_col_name: str = 'square_002um',
651+
aggregate_col_name: str = 'cell_id'
652+
) -> anndata.AnnData:
653+
"""Generate a filtered AnnData object by aggregating 2um binned data
654+
based on nucleus segmentation.
655+
Uses a 2um filtered_feature_bc_matrix.h5 file and a barcode_mappings.parquet file containing
656+
barcode mappings, filters the data to include only valid nucleus mappings,
657+
and aggregates the data based on specified bin into cell IDs which only contain
658+
the 2um square data under segmented nuclei.
659+
660+
Parameters:
661+
-----------
662+
filtered_matrix_h5_path : str
663+
Path to the 10x Genomics HDF5 matrix file.
664+
barcode_mappings_parquet_path : str
665+
Path to the Parquet file containing barcode mappings.
666+
bin_col_name : str, optional
667+
Column name in the barcode mappings that specifies the spatial bin (default is 'square_002um').
668+
aggregate_col_name : str, optional
669+
Column name in the barcode mappings that specifies the aggregate cell ID (default is 'cell_id').
670+
Returns:
671+
--------
672+
anndata.AnnData
673+
An AnnData object where the observations correspond to filtered cell IDs
674+
and the variables correspond to the original features from the input data.
675+
"""
676+
# Read in the necessary files
677+
adata_2um = sc.read_10x_h5(filtered_matrix_h5_path)
678+
barcode_mappings = pq.read_table(barcode_mappings_parquet_path)
679+
680+
# Filter to only include valid cell IDs that are in both nucleus and cell
681+
barcode_mappings = barcode_mappings.filter((barcode_mappings['cell_id'].is_valid()) and barcode_mappings["in_nucleus"])
682+
683+
# Filter the 2um adata to only include squares present in the barcode mappings
684+
valid_squares = barcode_mappings[bin_col_name].unique()
685+
squares_to_keep = np.intersect1d(adata_2um.obs_names, valid_squares)
686+
adata_filtered = adata_2um[squares_to_keep, :].copy()
687+
688+
# Map each square to its corresponding cell ID
689+
square_to_cell_map = dict(zip(
690+
barcode_mappings[bin_col_name].to_pylist(),
691+
barcode_mappings[aggregate_col_name].to_pylist()
692+
693+
))
694+
ordered_cell_ids = [square_to_cell_map[square] for square in adata_filtered.obs_names]
695+
unique_cells = list(dict.fromkeys(ordered_cell_ids).keys())
696+
cell_to_idx = {cell: i for i, cell in enumerate(unique_cells)}
697+
698+
# Make the aggregation matrix
699+
col_indices = [cell_to_idx[cell] for cell in ordered_cell_ids]
700+
row_indices = np.arange(len(ordered_cell_ids))
701+
data = np.ones_like(row_indices)
702+
703+
aggregation_matrix = csc_matrix(
704+
(data, (row_indices, col_indices)),
705+
shape=(adata_filtered.n_obs, len(unique_cells))
706+
)
707+
708+
# Make the final AnnData object where cell IDs are filtered
709+
# to the data under the segmented nuclei
710+
nucleus_matrix_sparse = adata_filtered.X.T.dot(aggregation_matrix)
711+
adata_nucleus = sc.AnnData(nucleus_matrix_sparse.T)
712+
adata_nucleus.obs_names = unique_cells
713+
adata_nucleus.var = adata_filtered.var
714+
715+
return adata_nucleus
716+
717+
def _extract_geometries_from_geojson(adata: anndata.AnnData, geojson_features_map: dict[str, Any]) -> GeoDataFrame:
718+
"""Extract geometries and create a GeoDataFrame from a GeoJSON features map.
719+
720+
Parameters
721+
----------
722+
cell_adata : anndata.AnnData
723+
AnnData object containing cell data.
724+
geojson_features_map : dict[str, Any]
725+
Dictionary mapping cell IDs to GeoJSON features.
726+
727+
Returns
728+
-------
729+
GeoDataFrame
730+
A GeoDataFrame containing cell IDs and their corresponding geometries.
731+
"""
732+
geometries = []
733+
cell_ids_ordered = []
734+
735+
for obs_index_str in adata.obs.index:
736+
feature = geojson_features_map.get(obs_index_str)
737+
if feature:
738+
polygon_coords = np.array(feature['geometry']['coordinates'][0])
739+
geometries.append(Polygon(polygon_coords))
740+
cell_ids_ordered.append(obs_index_str)
741+
else:
742+
geometries.append(None)
743+
cell_ids_ordered.append(obs_index_str)
744+
745+
valid_indices = [i for i, geom in enumerate(geometries) if geom is not None]
746+
geometries = [geometries[i] for i in valid_indices]
747+
cell_ids_ordered = [cell_ids_ordered[i] for i in valid_indices]
748+
749+
return GeoDataFrame({
750+
'cell_id': cell_ids_ordered,
751+
'geometry': geometries
752+
}, index=cell_ids_ordered)
753+
754+
def _make_shapes_transformation(scale_factors_path: Path, dataset_id: str) -> dict[str, Scale]:
755+
"""Load scale factors for lowres and hires images and create transformations.
756+
757+
Parameters
758+
----------
759+
scale_factors_path : Path
760+
Path to the scale factors JSON file.
761+
dataset_id : str
762+
Unique identifier of the dataset.
763+
764+
Returns
765+
-------
766+
dict[str, Scale]
767+
A dictionary containing the transformations for lowres and hires images.
768+
"""
769+
with open(scale_factors_path, 'r') as f:
770+
scale_data_hd = json.load(f)
771+
lowres_scale_factor_hd = scale_data_hd['tissue_lowres_scalef']
772+
hires_scale_factor_hd = scale_data_hd['tissue_hires_scalef']
773+
774+
return {
775+
f"{dataset_id}_downscaled_lowres": Scale(np.array([lowres_scale_factor_hd, lowres_scale_factor_hd]), axes=("x", "y")),
776+
f"{dataset_id}_downscaled_hires": Scale(np.array([hires_scale_factor_hd, hires_scale_factor_hd]), axes=("x", "y"))
777+
}
778+
779+
def _make_geojson_features_map(geojson_path: Path) -> dict[str, Any]:
780+
with open(geojson_path, 'r') as f:
781+
geojson_data = json.load(f)
782+
return {
783+
f"cellid_{feature['properties']['cell_id']:09d}-1": feature
784+
for feature in geojson_data['features']
785+
}

0 commit comments

Comments
 (0)