|
| 1 | +from typing import Literal |
| 2 | + |
| 3 | +from matplotlib.figure import Figure |
| 4 | + |
| 5 | +import scanpy as sc |
| 6 | +import scanpy.logging as logg |
| 7 | +from anndata import AnnData |
| 8 | + |
| 9 | + |
| 10 | +def plot_embedding( |
| 11 | + adata: AnnData, |
| 12 | + figsize: tuple[int, int] = (6, 6), |
| 13 | + aspect: list[Literal["auto", "equal"] | float] | Literal["auto", "equal"] | float = "auto", |
| 14 | + return_fig: bool = False, |
| 15 | + **kwargs, |
| 16 | +) -> Figure | None: |
| 17 | + """Plot embedding and set figure size and aspect ratio. |
| 18 | +
|
| 19 | + Parameters |
| 20 | + ---------- |
| 21 | + adata |
| 22 | + Annotated data matrix. |
| 23 | + figsize |
| 24 | + Figure width and height in inches. |
| 25 | + aspect |
| 26 | + Aspect ratio of the Axes scaling, i.e. y/x-scale. Possible values: |
| 27 | + * "auto": fill the position rectangle with data. |
| 28 | + * "equal": same as aspect=1, i.e. same scaling for x and y. |
| 29 | + * float: The displayed size of 1 unit in y-data coordinates will be aspect times the displayed size of 1 unit in |
| 30 | + x-data coordinates; e.g. for aspect=2 a square in data coordinates will be rendered with a height of twice |
| 31 | + its width. |
| 32 | + return_fig |
| 33 | + Flag to return the matplotlib figure. |
| 34 | + kwargs |
| 35 | + Keyword arguments passed to Scanpy's `pl.embedding` function. |
| 36 | +
|
| 37 | +
|
| 38 | + Returns |
| 39 | + ------- |
| 40 | + Is `return_fig==True` the Matplotlib figure object. |
| 41 | + """ |
| 42 | + fig = sc.pl.embedding(adata, return_fig=True, **kwargs) |
| 43 | + |
| 44 | + fig.set_size_inches(*figsize) |
| 45 | + axes = fig.get_axes() |
| 46 | + if len(axes) > 0 and isinstance(aspect, (str, float)): |
| 47 | + aspect = [aspect] * len(axes) |
| 48 | + elif len(axes) != len(aspect): |
| 49 | + logg.warning("The aspect list is shorter than the number of panels. Using `aspect='auto'` for all panels.") |
| 50 | + aspect["auto"] * len(axes) |
| 51 | + |
| 52 | + for ax_id, ax in enumerate(axes): |
| 53 | + ax.collections[0].set_rasterized(True) |
| 54 | + ax.set_aspect(aspect[ax_id]) |
| 55 | + |
| 56 | + if return_fig: |
| 57 | + return fig |
0 commit comments