Skip to content

Commit 1431911

Browse files
Merge pull request #363 from scverse/performance/xenium-points-parsing
Faster xenium points parsing by precomputing categories
2 parents 785fa9a + 3b756da commit 1431911

1 file changed

Lines changed: 19 additions & 4 deletions

File tree

src/spatialdata_io/readers/xenium.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ def xenium(
352352
raise ValueError(invalid_format_msg) from e
353353
channels.append((channel_idx, ome_ch.name))
354354
channel_names = dict(sorted(channels))
355+
355356
# this reads the scale 0 for all the 1 or 4 channels (the other files are parsed automatically)
356357
# dask.image.imread will call tifffile.imread which will give a warning saying that reading multi-file
357358
# pyramids is not supported; since we are reading the full scale image and reconstructing the pyramid, we
@@ -550,10 +551,24 @@ def _get_cells_metadata_table_from_zarr(
550551

551552
def _get_points(path: Path, specs: dict[str, Any]) -> Table:
552553
table = read_parquet(path / XeniumKeys.TRANSCRIPTS_FILE)
553-
table["feature_name"] = table["feature_name"].apply(
554-
lambda x: x.decode("utf-8") if isinstance(x, bytes) else str(x),
555-
meta=("feature_name", "object"),
556-
)
554+
555+
# check if we need to decode bytes
556+
sample = table[XeniumKeys.FEATURE_NAME].head(1)
557+
needs_decode = isinstance(sample.iloc[0], bytes)
558+
559+
# get unique categories (fast)
560+
categories = table[XeniumKeys.FEATURE_NAME].drop_duplicates().compute()
561+
if needs_decode:
562+
categories = categories.str.decode("utf-8")
563+
cat_dtype = pd.CategoricalDtype(categories=categories)
564+
565+
# decode column if needed, then convert to categorical
566+
if needs_decode:
567+
table[XeniumKeys.FEATURE_NAME] = table[XeniumKeys.FEATURE_NAME].map_partitions(
568+
lambda s: s.str.decode("utf-8").astype(cat_dtype), meta=pd.Series(dtype=cat_dtype)
569+
)
570+
else:
571+
table[XeniumKeys.FEATURE_NAME] = table[XeniumKeys.FEATURE_NAME].astype(cat_dtype)
557572

558573
transform = Scale([1.0 / specs["pixel_size"], 1.0 / specs["pixel_size"]], axes=("x", "y"))
559574
points = PointsModel.parse(

0 commit comments

Comments
 (0)