Skip to content

Commit d8533e2

Browse files
authored
Add plot_embedding (#49)
1 parent 208e49b commit d8533e2

2 files changed

Lines changed: 60 additions & 0 deletions

File tree

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ._embedding import plot_embedding
2+
3+
__all__ = ["plot_embedding"]
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from typing import Literal
2+
3+
from matplotlib.figure import Figure
4+
5+
import scanpy as sc
6+
import scanpy.logging as logg
7+
from anndata import AnnData
8+
9+
10+
def plot_embedding(
11+
adata: AnnData,
12+
figsize: tuple[int, int] = (6, 6),
13+
aspect: list[Literal["auto", "equal"] | float] | Literal["auto", "equal"] | float = "auto",
14+
return_fig: bool = False,
15+
**kwargs,
16+
) -> Figure | None:
17+
"""Plot embedding and set figure size and aspect ratio.
18+
19+
Parameters
20+
----------
21+
adata
22+
Annotated data matrix.
23+
figsize
24+
Figure width and height in inches.
25+
aspect
26+
Aspect ratio of the Axes scaling, i.e. y/x-scale. Possible values:
27+
* "auto": fill the position rectangle with data.
28+
* "equal": same as aspect=1, i.e. same scaling for x and y.
29+
* float: The displayed size of 1 unit in y-data coordinates will be aspect times the displayed size of 1 unit in
30+
x-data coordinates; e.g. for aspect=2 a square in data coordinates will be rendered with a height of twice
31+
its width.
32+
return_fig
33+
Flag to return the matplotlib figure.
34+
kwargs
35+
Keyword arguments passed to Scanpy's `pl.embedding` function.
36+
37+
38+
Returns
39+
-------
40+
Is `return_fig==True` the Matplotlib figure object.
41+
"""
42+
fig = sc.pl.embedding(adata, return_fig=True, **kwargs)
43+
44+
fig.set_size_inches(*figsize)
45+
axes = fig.get_axes()
46+
if len(axes) > 0 and isinstance(aspect, (str, float)):
47+
aspect = [aspect] * len(axes)
48+
elif len(axes) != len(aspect):
49+
logg.warning("The aspect list is shorter than the number of panels. Using `aspect='auto'` for all panels.")
50+
aspect["auto"] * len(axes)
51+
52+
for ax_id, ax in enumerate(axes):
53+
ax.collections[0].set_rasterized(True)
54+
ax.set_aspect(aspect[ax_id])
55+
56+
if return_fig:
57+
return fig

0 commit comments

Comments
 (0)