|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import os |
| 4 | +import re |
| 5 | +from collections.abc import Mapping |
| 6 | +from pathlib import Path |
| 7 | +from types import MappingProxyType |
| 8 | +from typing import Any |
| 9 | + |
| 10 | +import anndata as ad |
| 11 | +import numpy as np |
| 12 | +import pandas as pd |
| 13 | +from dask_image.imread import imread |
| 14 | +from spatialdata import SpatialData |
| 15 | +from spatialdata.models import ( |
| 16 | + Image2DModel, |
| 17 | + Labels2DModel, |
| 18 | + PointsModel, |
| 19 | + ShapesModel, |
| 20 | + TableModel, |
| 21 | +) |
| 22 | +from spatialdata.transformations import Identity |
| 23 | + |
| 24 | +from spatialdata_io._constants._constants import SeqfishKeys as SK |
| 25 | +from spatialdata_io._docs import inject_docs |
| 26 | + |
| 27 | +__all__ = ["seqfish"] |
| 28 | + |
| 29 | + |
| 30 | +@inject_docs(vx=SK) |
| 31 | +def seqfish( |
| 32 | + path: str | Path, |
| 33 | + load_images: bool = True, |
| 34 | + load_labels: bool = True, |
| 35 | + load_points: bool = True, |
| 36 | + sections: list[int] | None = None, |
| 37 | + imread_kwargs: Mapping[str, Any] = MappingProxyType({}), |
| 38 | +) -> SpatialData: |
| 39 | + """ |
| 40 | + Read *seqfish* formatted dataset. |
| 41 | +
|
| 42 | + This function reads the following files: |
| 43 | +
|
| 44 | + - ```{vx.COUNTS_FILE!r}{vx.SECTION!r}{vx.CSV_FILE!r}```: Counts and metadata file. |
| 45 | + - ```{vx.CELL_COORDINATES!r}{vx.SECTION!r}{vx.CSV_FILE!r}```: Cell coordinates file. |
| 46 | + - ```{vx.DAPI!r}{vx.SECTION!r}{vx.OME_TIFF_FILE!r}```: High resolution tiff image. |
| 47 | + - ```{vx.CELL_MASK_FILE!r}{vx.SECTION!r}{vx.TIFF_FILE!r}```: Cell mask file. |
| 48 | + - ```{vx.TRANSCRIPT_COORDINATES!r}{vx.SECTION!r}{vx.CSV_FILE!r}```: Transcript coordinates file. |
| 49 | +
|
| 50 | + .. seealso:: |
| 51 | +
|
| 52 | + - `seqfish output <https://spatialgenomics.com/data/>`_. |
| 53 | +
|
| 54 | + Parameters |
| 55 | + ---------- |
| 56 | + path |
| 57 | + Path to the directory containing the data. |
| 58 | + load_images |
| 59 | + Whether to load the images. |
| 60 | + load_labels |
| 61 | + Whether to load the labels. |
| 62 | + load_points |
| 63 | + Whether to load the points. |
| 64 | + sections |
| 65 | + Which sections (specified as integers) to load. By default, all sections are loaded. |
| 66 | + imread_kwargs |
| 67 | + Keyword arguments to pass to :func:`dask_image.imread.imread`. |
| 68 | +
|
| 69 | + Returns |
| 70 | + ------- |
| 71 | + :class:`spatialdata.SpatialData` |
| 72 | + """ |
| 73 | + path = Path(path) |
| 74 | + count_file_pattern = re.compile(rf"(.*?)_{SK.CELL_COORDINATES}_{SK.SECTION}[0-9]+" + re.escape(SK.CSV_FILE)) |
| 75 | + count_files = [i for i in os.listdir(path) if count_file_pattern.match(i)] |
| 76 | + if not count_files: |
| 77 | + # no file matching tbe pattern found |
| 78 | + raise ValueError( |
| 79 | + f"No files matching the pattern {count_file_pattern} were found. Cannot infer the naming scheme." |
| 80 | + ) |
| 81 | + matched = count_file_pattern.match(count_files[0]) |
| 82 | + if matched is None: |
| 83 | + raise ValueError(f"File {count_files[0]} does not match the pattern {count_file_pattern}") |
| 84 | + prefix = matched.group(1) |
| 85 | + |
| 86 | + n = len(count_files) |
| 87 | + all_sections = list(range(1, n + 1)) |
| 88 | + if sections is None: |
| 89 | + sections = all_sections |
| 90 | + else: |
| 91 | + for section in sections: |
| 92 | + if section not in all_sections: |
| 93 | + raise ValueError(f"Section {section} not found in the data.") |
| 94 | + sections_str = [f"{SK.SECTION}{x}" for x in sections] |
| 95 | + |
| 96 | + def get_cell_file(section: str) -> str: |
| 97 | + return f"{prefix}_{SK.CELL_COORDINATES}_{section}{SK.CSV_FILE}" |
| 98 | + |
| 99 | + def get_count_file(section: str) -> str: |
| 100 | + return f"{prefix}_{SK.COUNTS_FILE}_{section}{SK.CSV_FILE}" |
| 101 | + |
| 102 | + def get_dapi_file(section: str) -> str: |
| 103 | + return f"{prefix}_{SK.DAPI}_{section}{SK.OME_TIFF_FILE}" |
| 104 | + |
| 105 | + def get_cell_mask_file(section: str) -> str: |
| 106 | + return f"{prefix}_{SK.CELL_MASK_FILE}_{section}{SK.TIFF_FILE}" |
| 107 | + |
| 108 | + def get_transcript_file(section: str) -> str: |
| 109 | + return f"{prefix}_{SK.TRANSCRIPT_COORDINATES}_{section}{SK.CSV_FILE}" |
| 110 | + |
| 111 | + adatas: dict[str, ad.AnnData] = {} |
| 112 | + for section in sections_str: # type: ignore[assignment] |
| 113 | + assert isinstance(section, str) |
| 114 | + cell_file = get_cell_file(section) |
| 115 | + count_matrix = get_count_file(section) |
| 116 | + adata = ad.read_csv(path / count_matrix, delimiter=",") |
| 117 | + cell_info = pd.read_csv(path / cell_file, delimiter=",") |
| 118 | + adata.obsm[SK.SPATIAL_KEY] = cell_info[[SK.CELL_X, SK.CELL_Y]].to_numpy() |
| 119 | + adata.obs[SK.AREA] = np.reshape(cell_info[SK.AREA].to_numpy(), (-1, 1)) |
| 120 | + region = f"cells_{section}" |
| 121 | + adata.obs[SK.REGION_KEY] = region |
| 122 | + adata.obs[SK.INSTANCE_KEY_TABLE] = adata.obs.index.astype(int) |
| 123 | + adatas[section] = adata |
| 124 | + |
| 125 | + scale_factors = [2, 2, 2, 2] |
| 126 | + |
| 127 | + if load_images: |
| 128 | + images = { |
| 129 | + f"image_{x}": Image2DModel.parse( |
| 130 | + imread(path / get_dapi_file(x), **imread_kwargs), |
| 131 | + dims=("c", "y", "x"), |
| 132 | + scale_factors=scale_factors, |
| 133 | + transformations={x: Identity()}, |
| 134 | + ) |
| 135 | + for x in sections_str |
| 136 | + } |
| 137 | + else: |
| 138 | + images = {} |
| 139 | + |
| 140 | + if load_labels: |
| 141 | + labels = { |
| 142 | + f"labels_{x}": Labels2DModel.parse( |
| 143 | + imread(path / get_cell_mask_file(x), **imread_kwargs).squeeze(), |
| 144 | + dims=("y", "x"), |
| 145 | + scale_factors=scale_factors, |
| 146 | + transformations={x: Identity()}, |
| 147 | + ) |
| 148 | + for x in sections_str |
| 149 | + } |
| 150 | + else: |
| 151 | + labels = {} |
| 152 | + |
| 153 | + if load_points: |
| 154 | + points = { |
| 155 | + f"transcripts_{x}": PointsModel.parse( |
| 156 | + pd.read_csv(path / get_transcript_file(x), delimiter=","), |
| 157 | + coordinates={"x": SK.TRANSCRIPTS_X, "y": SK.TRANSCRIPTS_Y}, |
| 158 | + feature_key=SK.FEATURE_KEY.value, |
| 159 | + instance_key=SK.INSTANCE_KEY_POINTS.value, |
| 160 | + transformations={x: Identity()}, |
| 161 | + ) |
| 162 | + for x in sections_str |
| 163 | + } |
| 164 | + else: |
| 165 | + points = {} |
| 166 | + |
| 167 | + adata = ad.concat(adatas.values()) |
| 168 | + adata.obs[SK.REGION_KEY] = adata.obs[SK.REGION_KEY].astype("category") |
| 169 | + adata.obs = adata.obs.reset_index(drop=True) |
| 170 | + table = TableModel.parse( |
| 171 | + adata, |
| 172 | + region=[f"cells_{x}" for x in sections_str], |
| 173 | + region_key=SK.REGION_KEY.value, |
| 174 | + instance_key=SK.INSTANCE_KEY_TABLE.value, |
| 175 | + ) |
| 176 | + |
| 177 | + shapes = { |
| 178 | + f"cells_{x}": ShapesModel.parse( |
| 179 | + adata.obsm[SK.SPATIAL_KEY], |
| 180 | + geometry=0, |
| 181 | + radius=np.sqrt(adata.obs[SK.AREA].to_numpy() / np.pi), |
| 182 | + index=adata.obs[SK.INSTANCE_KEY_TABLE].copy(), |
| 183 | + transformations={x: Identity()}, |
| 184 | + ) |
| 185 | + for x, adata in adatas.items() |
| 186 | + } |
| 187 | + |
| 188 | + sdata = SpatialData(images=images, labels=labels, points=points, table=table, shapes=shapes) |
| 189 | + |
| 190 | + return sdata |
0 commit comments