Skip to content

Commit 31379c8

Browse files
use only in_nucleus
1 parent a6d7e9f commit 31379c8

1 file changed

Lines changed: 50 additions & 59 deletions

File tree

src/spatialdata_io/readers/visium_hd.py

Lines changed: 50 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def make_filtered_nucleus_adata(
7575
barcode_mappings = pq.read_table(barcode_mappings_parquet_path)
7676

7777
# Filter to only include valid cell IDs that are in both nucleus and cell
78-
barcode_mappings = barcode_mappings.filter((barcode_mappings['cell_id'].is_valid()) and barcode_mappings["in_nucleus"] and barcode_mappings["in_cell"])
78+
barcode_mappings = barcode_mappings.filter((barcode_mappings['cell_id'].is_valid()) and barcode_mappings["in_nucleus"])
7979

8080
# Filter the 2um adata to only include squares present in the barcode mappings
8181
valid_squares = barcode_mappings[bin_col_name].unique()
@@ -110,6 +110,42 @@ def make_filtered_nucleus_adata(
110110
adata_nucleus.var = adata_filtered.var
111111

112112
return adata_nucleus
113+
def extract_geometries_from_geojson(adata: anndata.AnnData, geojson_features_map: dict[str, Any]) -> GeoDataFrame:
114+
"""Extract geometries and create a GeoDataFrame from a GeoJSON features map.
115+
116+
Parameters
117+
----------
118+
cell_adata : anndata.AnnData
119+
AnnData object containing cell data.
120+
geojson_features_map : dict[str, Any]
121+
Dictionary mapping cell IDs to GeoJSON features.
122+
123+
Returns
124+
-------
125+
GeoDataFrame
126+
A GeoDataFrame containing cell IDs and their corresponding geometries.
127+
"""
128+
geometries = []
129+
cell_ids_ordered = []
130+
131+
for obs_index_str in adata.obs.index:
132+
feature = geojson_features_map.get(obs_index_str)
133+
if feature:
134+
polygon_coords = np.array(feature['geometry']['coordinates'][0])
135+
geometries.append(Polygon(polygon_coords))
136+
cell_ids_ordered.append(obs_index_str)
137+
else:
138+
geometries.append(None)
139+
cell_ids_ordered.append(obs_index_str)
140+
141+
valid_indices = [i for i, geom in enumerate(geometries) if geom is not None]
142+
geometries = [geometries[i] for i in valid_indices]
143+
cell_ids_ordered = [cell_ids_ordered[i] for i in valid_indices]
144+
145+
return GeoDataFrame({
146+
'cell_id': cell_ids_ordered,
147+
'geometry': geometries
148+
}, index=cell_ids_ordered)
113149
def make_shapes_transformation(scale_factors_path: Path, dataset_id: str) -> dict[str, Scale]:
114150
"""Load scale factors for lowres and hires images and create transformations.
115151
@@ -386,88 +422,43 @@ def _get_bins(path_bins: Path) -> list[str]:
386422
# Integrate the segmentation data (skipped if segmentation files are not found)
387423
if cell_segmentation_files_exist:
388424
print("Found segmentation data. Incorporating cell_segmentations.")
389-
adata_hd = sc.read_10x_h5(COUNT_MATRIX_PATH)
390-
adata_hd.var_names_make_unique()
425+
cell_adata_hd = sc.read_10x_h5(COUNT_MATRIX_PATH)
426+
cell_adata_hd.var_names_make_unique()
391427

392-
shapes_transformations_hd = make_shapes_transformation(scale_factors_path=SCALE_FACTORS_PATH, dataset_id=dataset_id)
393-
394-
geojson_features_map = make_geojson_features_map(CELL_GEOJSON_PATH)
395-
396-
geometries = []
397-
cell_ids_ordered = []
398-
399-
for obs_index_str in adata_hd.obs.index:
400-
feature = geojson_features_map.get(obs_index_str)
401-
if feature:
402-
polygon_coords = np.array(feature['geometry']['coordinates'][0])
403-
geometries.append(Polygon(polygon_coords))
404-
cell_ids_ordered.append(obs_index_str)
405-
else:
406-
geometries.append(None)
407-
cell_ids_ordered.append(obs_index_str)
408-
409-
valid_indices = [i for i, geom in enumerate(geometries) if geom is not None]
410-
geometries = [geometries[i] for i in valid_indices]
411-
cell_ids_ordered = [cell_ids_ordered[i] for i in valid_indices]
412-
413-
shapes_gdf = GeoDataFrame({
414-
'cell_id': cell_ids_ordered,
415-
'geometry': geometries
416-
}, index=cell_ids_ordered)
428+
shapes_transformations_hd = make_shapes_transformation(scale_factors_path=SCALE_FACTORS_PATH, dataset_id=dataset_id) # Used for both cell and nucleus segmentations
429+
cell_geojson_features_map = make_geojson_features_map(CELL_GEOJSON_PATH)
430+
cell_shapes_gdf = extract_geometries_from_geojson(cell_adata_hd, geojson_features_map=cell_geojson_features_map)
417431

418432
SHAPES_KEY_HD = f"{dataset_id}_{VisiumHDKeys.CELL_SEG_KEY_HD}"
419-
cell_adata_hd = adata_hd.copy()
420433
cell_adata_hd.obs['cell_id'] = cell_adata_hd.obs.index
421434
cell_adata_hd.obs['region'] = SHAPES_KEY_HD
422435
cell_adata_hd.obs['region'] = cell_adata_hd.obs['region'].astype('category')
423-
cell_adata_hd = cell_adata_hd[shapes_gdf.index].copy()
436+
cell_adata_hd = cell_adata_hd[cell_shapes_gdf.index].copy()
424437

425-
shapes[SHAPES_KEY_HD] = ShapesModel.parse(shapes_gdf, transformations=shapes_transformations_hd)
438+
shapes[SHAPES_KEY_HD] = ShapesModel.parse(cell_shapes_gdf, transformations=shapes_transformations_hd)
426439
tables[VisiumHDKeys.CELL_SEG_KEY_HD] = TableModel.parse(
427440
cell_adata_hd,
428441
region=SHAPES_KEY_HD,
429442
region_key='region',
430443
instance_key='cell_id'
431444
)
445+
432446
# load nucleus segmentations if available
433447
if nucleus_segmentation_files_exist and load_nucleus_segmentations:
434448
print("Found nucleus segmentation data. Incorporating nucleus_segmentations.")
435449

436450
nucleus_adata_hd = make_filtered_nucleus_adata(filtered_matrix_h5_path=FILTERED_MATRIX_2U_PATH,barcode_mappings_parquet_path=BARCODE_MAPPINGS_PATH)
437-
438-
with open(NUCLEUS_GEOJSON_PATH, 'r') as f:
439-
geojson_data = json.load(f)
440-
441-
geometries = []
442-
cell_ids_ordered = []
443-
444-
for obs_index_str in adata_hd.obs.index:
445-
feature = geojson_features_map.get(obs_index_str)
446-
if feature:
447-
polygon_coords = np.array(feature['geometry']['coordinates'][0])
448-
geometries.append(Polygon(polygon_coords))
449-
cell_ids_ordered.append(obs_index_str)
450-
else:
451-
geometries.append(None)
452-
cell_ids_ordered.append(obs_index_str)
453-
454-
valid_indices = [i for i, geom in enumerate(geometries) if geom is not None]
455-
geometries = [geometries[i] for i in valid_indices]
456-
cell_ids_ordered = [cell_ids_ordered[i] for i in valid_indices]
457-
458-
shapes_gdf = GeoDataFrame({
459-
'cell_id': cell_ids_ordered,
460-
'geometry': geometries
461-
}, index=cell_ids_ordered)
451+
geojson_features_map = make_geojson_features_map(NUCLEUS_GEOJSON_PATH)
452+
nucleus_shapes_gdf = extract_geometries_from_geojson(adata=nucleus_adata_hd, geojson_features_map=geojson_features_map)
462453

463454
SHAPES_KEY_HD = f"{dataset_id}_{VisiumHDKeys.NUCLEUS_SEG_KEY_HD}"
464455
nucleus_adata_hd.obs['cell_id'] = nucleus_adata_hd.obs.index
465456
nucleus_adata_hd.obs['region'] = SHAPES_KEY_HD
466457
nucleus_adata_hd.obs['region'] = nucleus_adata_hd.obs['region'].astype('category')
467-
nucleus_adata_hd = nucleus_adata_hd[shapes_gdf.index].copy()
458+
nucleus_adata_hd = nucleus_adata_hd[nucleus_shapes_gdf.index].copy()
468459

469-
shapes[SHAPES_KEY_HD] = ShapesModel.parse(shapes_gdf, transformations=shapes_transformations_hd)
470-
tables[VisiumHDKeys.CELL_SEG_KEY_HD] = TableModel.parse(
460+
shapes[SHAPES_KEY_HD] = ShapesModel.parse(nucleus_shapes_gdf, transformations=shapes_transformations_hd)
461+
tables[VisiumHDKeys.NUCLEUS_SEG_KEY_HD] = TableModel.parse(
471462
nucleus_adata_hd,
472463
region=SHAPES_KEY_HD,
473464
region_key='region',

0 commit comments

Comments
 (0)