Skip to content

Commit 5f61dd5

Browse files
committed
init
1 parent 95d8b18 commit 5f61dd5

8 files changed

Lines changed: 97 additions & 25 deletions

File tree

src/squidpy/_constants/_pkg_constants.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,16 @@ def _sort_haystack(
198198
class obsp:
199199
@classmethod
200200
def spatial_dist(cls, value: str | None = None) -> str:
201-
return f"{Key.obsm.spatial}_distances" if value is None else f"{value}_distances"
201+
if value is None:
202+
return f"{Key.obsm.spatial}_distances"
203+
if value.endswith("_distances"):
204+
return value
205+
return f"{value}_distances"
202206

203207
@classmethod
204208
def spatial_conn(cls, value: str | None = None) -> str:
205-
return f"{Key.obsm.spatial}_connectivities" if value is None else f"{value}_connectivities"
209+
if value is None:
210+
return f"{Key.obsm.spatial}_connectivities"
211+
if value.endswith("_connectivities"):
212+
return value
213+
return f"{value}_connectivities"

src/squidpy/gr/_ligrec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
_check_tuple_needles,
2828
_genesymbols,
2929
_save_data,
30+
extract_table_if_spatialdata,
3031
)
3132

3233
__all__ = ["ligrec", "PermutationTest"]
@@ -630,6 +631,7 @@ def prepare(
630631

631632

632633
@d.dedent
634+
@extract_table_if_spatialdata
633635
def ligrec(
634636
adata: AnnData | SpatialData,
635637
cluster_key: str,
@@ -659,8 +661,6 @@ def ligrec(
659661
-------
660662
%(ligrec_test_returns)s
661663
""" # noqa: D400
662-
if isinstance(adata, SpatialData):
663-
adata = adata.table
664664
with _genesymbols(adata, key=gene_symbols, use_raw=use_raw, make_unique=False):
665665
return ( # type: ignore[no-any-return]
666666
PermutationTest(adata, use_raw=use_raw)

src/squidpy/gr/_nhood.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
_assert_positive,
2828
_save_data,
2929
_shuffle_group,
30+
extract_table_if_spatialdata,
3031
)
3132

3233
__all__ = ["nhood_enrichment", "centrality_scores", "interaction_matrix"]
@@ -134,6 +135,7 @@ def _create_function(n_cls: int, parallel: bool = False) -> Callable[[NDArrayA,
134135

135136
@d.get_sections(base="nhood_ench", sections=["Parameters"])
136137
@d.dedent
138+
@extract_table_if_spatialdata
137139
def nhood_enrichment(
138140
adata: AnnData | SpatialData,
139141
cluster_key: str,
@@ -171,8 +173,6 @@ def nhood_enrichment(
171173
- :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['zscore']`` - the enrichment z-score.
172174
- :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['count']`` - the enrichment count.
173175
"""
174-
if isinstance(adata, SpatialData):
175-
adata = adata.table
176176
connectivity_key = Key.obsp.spatial_conn(connectivity_key)
177177
_assert_categorical_obs(adata, cluster_key)
178178
_assert_connectivity_key(adata, connectivity_key)
@@ -230,6 +230,7 @@ def nhood_enrichment(
230230

231231
@d.dedent
232232
@inject_docs(c=Centrality)
233+
@extract_table_if_spatialdata
233234
def centrality_scores(
234235
adata: AnnData | SpatialData,
235236
cluster_key: str,
@@ -268,8 +269,6 @@ def centrality_scores(
268269
- :attr:`anndata.AnnData.uns` ``['{{cluster_key}}_centrality_scores']`` - the centrality scores,
269270
as mentioned above.
270271
"""
271-
if isinstance(adata, SpatialData):
272-
adata = adata.table
273272
connectivity_key = Key.obsp.spatial_conn(connectivity_key)
274273
_assert_categorical_obs(adata, cluster_key)
275274
_assert_connectivity_key(adata, connectivity_key)
@@ -326,6 +325,7 @@ def centrality_scores(
326325

327326

328327
@d.dedent
328+
@extract_table_if_spatialdata
329329
def interaction_matrix(
330330
adata: AnnData | SpatialData,
331331
cluster_key: str,
@@ -356,8 +356,6 @@ def interaction_matrix(
356356
357357
- :attr:`anndata.AnnData.uns` ``['{cluster_key}_interactions']`` - the interaction matrix.
358358
"""
359-
if isinstance(adata, SpatialData):
360-
adata = adata.table
361359
connectivity_key = Key.obsp.spatial_conn(connectivity_key)
362360
_assert_categorical_obs(adata, cluster_key)
363361
_assert_connectivity_key(adata, connectivity_key)

src/squidpy/gr/_ppatterns.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_assert_positive,
3131
_assert_spatial_basis,
3232
_save_data,
33+
extract_table_if_spatialdata,
3334
)
3435

3536
__all__ = ["spatial_autocorr", "co_occurrence"]
@@ -45,6 +46,7 @@
4546

4647
@d.dedent
4748
@inject_docs(key=Key.obsp.spatial_conn(), sp=SpatialAutocorr)
49+
@extract_table_if_spatialdata
4850
def spatial_autocorr(
4951
adata: AnnData | SpatialData,
5052
connectivity_key: str = Key.obsp.spatial_conn(),
@@ -128,8 +130,6 @@ def spatial_autocorr(
128130
- :attr:`anndata.AnnData.uns` ``['moranI']`` - the above mentioned dataframe, if ``mode = {sp.MORAN.s!r}``.
129131
- :attr:`anndata.AnnData.uns` ``['gearyC']`` - the above mentioned dataframe, if ``mode = {sp.GEARY.s!r}``.
130132
"""
131-
if isinstance(adata, SpatialData):
132-
adata = adata.table
133133
_assert_connectivity_key(adata, connectivity_key)
134134

135135
def extract_X(adata: AnnData, genes: str | Sequence[str] | None) -> tuple[NDArrayA | spmatrix, Sequence[Any]]:
@@ -342,6 +342,7 @@ def _co_occurrence_helper(v_x: NDArrayA, v_y: NDArrayA, v_radium: NDArrayA, labs
342342

343343

344344
@d.dedent
345+
@extract_table_if_spatialdata
345346
def co_occurrence(
346347
adata: AnnData | SpatialData,
347348
cluster_key: str,
@@ -381,9 +382,6 @@ def co_occurrence(
381382
- :attr:`anndata.AnnData.uns` ``['{cluster_key}_co_occurrence']['interval']`` - the distance thresholds
382383
computed at ``interval``.
383384
"""
384-
385-
if isinstance(adata, SpatialData):
386-
adata = adata.table
387385
_assert_categorical_obs(adata, key=cluster_key)
388386
_assert_spatial_basis(adata, key=spatial_key)
389387

src/squidpy/gr/_ripley.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@
1919
from squidpy._constants._pkg_constants import Key
2020
from squidpy._docs import d, inject_docs
2121
from squidpy._utils import NDArrayA
22-
from squidpy.gr._utils import _assert_categorical_obs, _assert_spatial_basis, _save_data
22+
from squidpy.gr._utils import _assert_categorical_obs, _assert_spatial_basis, _save_data, extract_table_if_spatialdata
2323

2424
__all__ = ["ripley"]
2525

2626

2727
@d.dedent
2828
@inject_docs(key=Key.obsm.spatial, rp=RipleyStat)
29+
@extract_table_if_spatialdata
2930
def ripley(
3031
adata: AnnData | SpatialData,
3132
cluster_key: str,
@@ -104,8 +105,6 @@ def ripley(
104105
`Wikipedia <https://en.wikipedia.org/wiki/Spatial_descriptive_statistics#Ripley's_K_and_L_functions>`_
105106
or :cite:`Baddeley2015-lm`.
106107
"""
107-
if isinstance(adata, SpatialData):
108-
adata = adata.table
109108
_assert_categorical_obs(adata, key=cluster_key)
110109
_assert_spatial_basis(adata, key=spatial_key)
111110
coordinates = adata.obsm[spatial_key]

src/squidpy/gr/_sepal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@
2121
_assert_spatial_basis,
2222
_extract_expression,
2323
_save_data,
24+
extract_table_if_spatialdata,
2425
)
2526

2627
__all__ = ["sepal"]
2728

2829

2930
@d.dedent
3031
@inject_docs(key=Key.obsp.spatial_conn())
32+
@extract_table_if_spatialdata
3133
def sepal(
3234
adata: AnnData | SpatialData,
3335
max_neighs: Literal[4, 6],
@@ -93,8 +95,6 @@ def sepal(
9395
If some genes in :attr:`anndata.AnnData.uns` ``['sepal_score']`` are `NaN`,
9496
consider re-running the function with increased ``n_iter``.
9597
"""
96-
if isinstance(adata, SpatialData):
97-
adata = adata.table
9898
_assert_connectivity_key(adata, connectivity_key)
9999
_assert_spatial_basis(adata, key=spatial_key)
100100
if max_neighs not in (4, 6):

src/squidpy/gr/_utils.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import Hashable, Iterable, Sequence
5+
import functools
6+
import inspect
7+
from collections.abc import Callable, Hashable, Iterable, Sequence
68
from contextlib import contextmanager
79
from typing import Any
810

@@ -14,12 +16,81 @@
1416
from pandas.api.types import infer_dtype
1517
from scanpy import logging as logg
1618
from scipy.sparse import csc_matrix, csr_matrix, spmatrix
19+
from spatialdata import SpatialData
1720

1821
from squidpy._compat import ArrayView, SparseCSCView, SparseCSRView
1922
from squidpy._docs import d
2023
from squidpy._utils import NDArrayA, _unique_order_preserving
2124

2225

26+
_TABLE_KEY_DOC = """ table_key
27+
Key in :attr:`spatialdata.SpatialData.tables` where the table is stored.
28+
Only used if ``adata`` is a :class:`spatialdata.SpatialData`."""
29+
30+
31+
def extract_table_if_spatialdata(fn: Callable[..., Any]) -> Callable[..., Any]:
32+
"""Decorator that resolves a :class:`~spatialdata.SpatialData` to an :class:`~anndata.AnnData`.
33+
34+
Adds a ``table_key`` parameter (default ``"table"``) to the wrapped
35+
function's signature **and** appends its documentation to the docstring.
36+
When the first positional argument (``adata``) is a
37+
:class:`~spatialdata.SpatialData`, the table is looked up via
38+
``adata.tables[table_key]`` and passed through in its place.
39+
"""
40+
sig = inspect.signature(fn)
41+
42+
table_param = inspect.Parameter(
43+
"table_key",
44+
inspect.Parameter.KEYWORD_ONLY,
45+
default="table",
46+
)
47+
params = list(sig.parameters.values())
48+
kw_only_start = next(
49+
(i for i, p in enumerate(params) if p.kind == inspect.Parameter.KEYWORD_ONLY),
50+
len(params),
51+
)
52+
var_kw = [i for i, p in enumerate(params) if p.kind == inspect.Parameter.VAR_KEYWORD]
53+
insert_pos = var_kw[0] if var_kw else kw_only_start
54+
params.insert(insert_pos, table_param)
55+
new_sig = sig.replace(parameters=params)
56+
57+
@functools.wraps(fn)
58+
def wrapper(*args: Any, **kwargs: Any) -> Any:
59+
bound = new_sig.bind(*args, **kwargs)
60+
bound.apply_defaults()
61+
table_key: str = bound.arguments.pop("table_key")
62+
63+
adata = bound.arguments.get("adata")
64+
if isinstance(adata, SpatialData):
65+
if table_key not in adata.tables:
66+
raise ValueError(
67+
f"Table {table_key!r} not found in SpatialData. "
68+
f"Available tables: {list(adata.tables.keys())}"
69+
)
70+
bound.arguments["adata"] = adata.tables[table_key]
71+
72+
return fn(*bound.args, **bound.kwargs)
73+
74+
wrapper.__signature__ = new_sig # type: ignore[attr-defined]
75+
76+
if wrapper.__doc__ is not None:
77+
# Insert table_key docs before the "Returns" / "Notes" / "References"
78+
# section, or append at the end of Parameters if none found.
79+
doc = wrapper.__doc__
80+
for marker in ("Returns\n", "Notes\n", "References\n"):
81+
idx = doc.find(marker)
82+
if idx != -1:
83+
# Back up past the " -------\n" underline to the section header
84+
header_start = doc.rfind("\n", 0, idx - 1)
85+
doc = doc[:header_start] + "\n" + _TABLE_KEY_DOC + "\n" + doc[header_start:]
86+
break
87+
else:
88+
doc = doc.rstrip() + "\n" + _TABLE_KEY_DOC + "\n"
89+
wrapper.__doc__ = doc
90+
91+
return wrapper
92+
93+
2394
def _check_tuple_needles(
2495
needles: Sequence[tuple[Any, Any]],
2596
haystack: Sequence[Any],

src/squidpy/tl/_sliding_window.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
from spatialdata import SpatialData
1010

1111
from squidpy._docs import d
12-
from squidpy.gr._utils import _save_data
12+
from squidpy.gr._utils import _save_data, extract_table_if_spatialdata
1313

1414
__all__ = ["sliding_window"]
1515

1616

1717
@d.dedent
18+
@extract_table_if_spatialdata
1819
def sliding_window(
1920
adata: AnnData | SpatialData,
2021
library_key: str | None = None,
@@ -55,9 +56,6 @@ def sliding_window(
5556
if overlap < 0:
5657
raise ValueError("Overlap must be non-negative.")
5758

58-
if isinstance(adata, SpatialData):
59-
adata = adata.table
60-
6159
# we don't want to modify the original adata in case of copy=True
6260
if copy:
6361
adata = adata.copy()

0 commit comments

Comments
 (0)