diff --git a/dynamo/movie/__init__.py b/dynamo/movie/__init__.py index ff6c2f633..ee227bb7c 100644 --- a/dynamo/movie/__init__.py +++ b/dynamo/movie/__init__.py @@ -1,4 +1,4 @@ """Mapping Vector Field of Single Cells """ -from .fate import StreamFuncAnim, animate_fates +from .fate import PyvistaAnim, StreamFuncAnim, StreamFuncAnim3D, animate_fates diff --git a/dynamo/movie/fate.py b/dynamo/movie/fate.py index 8c5f908c4..814b654f7 100755 --- a/dynamo/movie/fate.py +++ b/dynamo/movie/fate.py @@ -1,20 +1,30 @@ import warnings from typing import Optional, Union +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + import matplotlib import numpy as np from anndata import AnnData from scipy.integrate import odeint from ..dynamo_logger import main_info, main_tqdm, main_warning -from ..plot.topography import topography +from ..plot.topography import topography, topography_3D from ..vectorfield.scVectorField import SvcVectorField from .utils import remove_particles -class StreamFuncAnim: - """Animating cell fate commitment prediction via reconstructed vector field function.""" +class BaseAnim: + """Base class for animating cell fate commitment prediction via reconstructed vector field function. + This class creates necessary components to produce an animation that describes the exact speed of a set of cells + at each time point, its movement in gene expression and the long range trajectory predicted by the reconstructed + vector field. Thus, it provides intuitive visual understanding of the RNA velocity, speed, acceleration, and cell + fate commitment in action. + """ def __init__( self, adata: AnnData, @@ -24,108 +34,36 @@ def __init__( n_steps: int = 100, cell_states: Union[int, list, None] = None, color: str = "ntr", - fig: Optional[matplotlib.figure.Figure] = None, - ax: matplotlib.axes.Axes = None, logspace: bool = False, max_time: Optional[float] = None, - frame_color=None, ): - """Animating cell fate commitment prediction via reconstructed vector field function. - - This class creates necessary components to produce an animation that describes the exact speed of a set of cells - at each time point, its movement in gene expression and the long range trajectory predicted by the reconstructed - vector field. Thus it provides intuitive visual understanding of the RNA velocity, speed, acceleration, and cell - fate commitment in action. - - This function is originally inspired by https://tonysyu.github.io/animating-particles-in-a-flow.html and relies on - animation module from matplotlib. Note that you may need to install `imagemagick` in order to properly show or save - the animation. See for example, http://louistiao.me/posts/notebooks/save-matplotlib-animations-as-gifs/ for more - details. - - Parameters - ---------- - adata: :class:`~anndata.AnnData` - AnnData object that already went through the fate prediction. - basis: `str` or None (default: `umap`) - The embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the reconstructed - trajectory will be projected back to high dimensional space via the `inverse_transform` function. - space. - fps_basis: `str` or None (default: `None`) - The basis that will be used for identifying or retrieving fixed points. Note that if `fps_basis` is - different from `basis`, the nearest cells of the fixed point from the `fps_basis` will be found and used to - visualize the position of the fixed point on `basis` embedding. - dims: `list` or `None` (default: `None') - The dimensions of low embedding space where cells will be drawn and it should corresponds to the space + """Construct a class that can be used to animate cell fate commitment prediction via reconstructed vector field + function. + + Args: + adata: annData object that already went through the fate prediction. + basis: the embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the + reconstructed trajectory will be projected back to high dimensional space via the `inverse_transform` + function space. + fp_basis: the basis that will be used for identifying or retrieving fixed points. Note that if `fps_basis` + is different from `basis`, the nearest cells of the fixed point from the `fps_basis` will be found and + used to visualize the position of the fixed point on `basis` embedding. + dims: the dimensions of low embedding space where cells will be drawn, and it should correspond to the space fate prediction take place. - n_steps: `int` (default: `100`) - The number of times steps (frames) fate prediction will take. - cell_states: `int`, `list` or `None` (default: `None`) - The number of cells state that will be randomly selected (if `int`), the indices of the cells states (if - `list`) or all cell states which fate prediction executed (if `None`) - fig: `matplotlib.figure.Figure` or None (default: `None`) - The figure that will contain both the background and animated components. - ax: `matplotlib.Axis` (optional, default `None`) - The matplotlib axes object that will be used as background plot of the vector field animation. If `ax` - is None, `topography(adata, basis=basis, color=color, ax=ax, save_show_or_return='return')` will be used - to create an axes. - logspace: `bool` (default: `False`) - Whether or to sample time points linearly on log space. If not, the sorted unique set of all time points - from all cell states' fate prediction will be used and then evenly sampled up to `n_steps` time points. - - Returns - ------- + n_steps: the number of times steps (frames) fate prediction will take. + cell_states: the number of cells state that will be randomly selected (if `int`), the indices of the cells + states (if `list`) or all cell states which fate prediction executed (if `None`) + color: the key of the data that will be used to color the embedding. + logspace: `whether or to sample time points linearly on log space. If not, the sorted unique set of all-time + points from all cell states' fate prediction will be used and then evenly sampled up to `n_steps` time + points. + max_time: the maximum time that will be used to scale the time vector. + + Returns: A class that contains .fig attribute and .update, .init_background that can be used to produce an animation of the prediction of cell fate commitment. - - Examples 1 - ---------- - >>> from matplotlib import animation - >>> progenitor = adata.obs_names[adata.obs.clusters == 'cluster_1'] - >>> fate_progenitor = progenitor - >>> info_genes = adata.var_names[adata.var.use_for_transition] - >>> dyn.pd.fate(adata, basis='umap', init_cells=fate_progenitor, interpolation_num=100, direction='forward', - ... inverse_transform=False, average=False) - >>> instance = dyn.mv.StreamFuncAnim(adata=adata, fig=None, ax=None) - >>> anim = animation.FuncAnimation(instance.fig, instance.update, init_func=instance.init_background, - ... frames=np.arange(100), interval=100, blit=True) - >>> from IPython.core.display import display, HTML - >>> HTML(anim.to_jshtml()) # embedding to jupyter notebook. - >>> anim.save('fate_ani.gif',writer="imagemagick") # save as gif file. - - Examples 2 - ---------- - >>> from matplotlib import animation - >>> progenitor = adata.obs_names[adata.obs.clusters == 'cluster_1'] - >>> fate_progenitor = progenitor - >>> info_genes = adata.var_names[adata.var.use_for_transition] - >>> dyn.pd.fate(adata, basis='umap', init_cells=fate_progenitor, interpolation_num=100, direction='forward', - ... inverse_transform=False, average=False) - >>> fig, ax = plt.subplots() - >>> ax = dyn.pl.topography(adata_old, color='time', ax=ax, save_show_or_return='return', color_key_cmap='viridis') - >>> ax.set_xlim(xlim) - >>> ax.set_ylim(ylim) - >>> instance = dyn.mv.StreamFuncAnim(adata=adata, fig=fig, ax=ax) - >>> anim = animation.FuncAnimation(fig, instance.update, init_func=instance.init_background, - ... frames=np.arange(100), interval=100, blit=True) - >>> from IPython.core.display import display, HTML - >>> HTML(anim.to_jshtml()) # embedding to jupyter notebook. - >>> anim.save('fate_ani.gif',writer="imagemagick") # save as gif file. - - Examples 3 - ---------- - >>> from matplotlib import animation - >>> progenitor = adata.obs_names[adata.obs.clusters == 'cluster_1'] - >>> fate_progenitor = progenitor - >>> info_genes = adata.var_names[adata.var.use_for_transition] - >>> dyn.pd.fate(adata, basis='umap', init_cells=fate_progenitor, interpolation_num=100, direction='forward', - ... inverse_transform=False, average=False) - >>> dyn.mv.animate_fates(adata) - - See also:: :func:`animate_fates` """ - import matplotlib.pyplot as plt - self.adata = adata self.basis = basis self.fp_basis = basis if fp_basis is None else fp_basis @@ -185,10 +123,125 @@ def __init__( M = M + 0.01 * np.abs(M - m) self.xlim = [m[0], M[0]] self.ylim = [m[1], M[1]] + if X_data.shape[1] == 3: + self.zlim = [m[2], M[2]] # self.ax.set_aspect("equal") self.color = color - self.frame_color = frame_color + + +class StreamFuncAnim(BaseAnim): + """The class for animating cell fate commitment prediction with matplotlib. + + This function is originally inspired by https://tonysyu.github.io/animating-particles-in-a-flow.html and relies on + animation module from matplotlib. Note that you may need to install `imagemagick` in order to properly show or save + the animation. See for example, http://louistiao.me/posts/notebooks/save-matplotlib-animations-as-gifs/ for more + details. + + Examples 1 + ---------- + >>> from matplotlib import animation + >>> progenitor = adata.obs_names[adata.obs.clusters == 'cluster_1'] + >>> fate_progenitor = progenitor + >>> info_genes = adata.var_names[adata.var.use_for_transition] + >>> dyn.pd.fate(adata, basis='umap', init_cells=fate_progenitor, interpolation_num=100, direction='forward', + ... inverse_transform=False, average=False) + >>> instance = dyn.mv.StreamFuncAnim(adata=adata, fig=None, ax=None) + >>> anim = animation.FuncAnimation(instance.fig, instance.update, init_func=instance.init_background, + ... frames=np.arange(100), interval=100, blit=True) + >>> from IPython.core.display import display, HTML + >>> HTML(anim.to_jshtml()) # embedding to jupyter notebook. + >>> anim.save('fate_ani.gif',writer="imagemagick") # save as gif file. + + Examples 2 + ---------- + >>> from matplotlib import animation + >>> progenitor = adata.obs_names[adata.obs.clusters == 'cluster_1'] + >>> fate_progenitor = progenitor + >>> info_genes = adata.var_names[adata.var.use_for_transition] + >>> dyn.pd.fate(adata, basis='umap', init_cells=fate_progenitor, interpolation_num=100, direction='forward', + ... inverse_transform=False, average=False) + >>> fig, ax = plt.subplots() + >>> ax = dyn.pl.topography(adata_old, color='time', ax=ax, save_show_or_return='return', color_key_cmap='viridis') + >>> ax.set_xlim(xlim) + >>> ax.set_ylim(ylim) + >>> instance = dyn.mv.StreamFuncAnim(adata=adata, fig=fig, ax=ax) + >>> anim = animation.FuncAnimation(fig, instance.update, init_func=instance.init_background, + ... frames=np.arange(100), interval=100, blit=True) + >>> from IPython.core.display import display, HTML + >>> HTML(anim.to_jshtml()) # embedding to jupyter notebook. + >>> anim.save('fate_ani.gif',writer="imagemagick") # save as gif file. + + Examples 3 + ---------- + >>> from matplotlib import animation + >>> progenitor = adata.obs_names[adata.obs.clusters == 'cluster_1'] + >>> fate_progenitor = progenitor + >>> info_genes = adata.var_names[adata.var.use_for_transition] + >>> dyn.pd.fate(adata, basis='umap', init_cells=fate_progenitor, interpolation_num=100, direction='forward', + ... inverse_transform=False, average=False) + >>> dyn.mv.animate_fates(adata) + + See also:: :func:`animate_fates` + """ + + def __init__( + self, + adata: AnnData, + basis: str = "umap", + fp_basis: Union[str, None] = None, + dims: Optional[list] = None, + n_steps: int = 100, + cell_states: Union[int, list, None] = None, + color: str = "ntr", + fig: Optional[matplotlib.figure.Figure] = None, + ax: matplotlib.axes.Axes = None, + logspace: bool = False, + max_time: Optional[float] = None, + ): + """Construct the StreamFuncAnim class. + + Args: + adata: annData object that already went through the fate prediction. + basis: the embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the + reconstructed trajectory will be projected back to high dimensional space via the `inverse_transform` + function space. + fp_basis: the basis that will be used for identifying or retrieving fixed points. Note that if `fps_basis` + is different from `basis`, the nearest cells of the fixed point from the `fps_basis` will be found and + used to visualize the position of the fixed point on `basis` embedding. + dims: the dimensions of low embedding space where cells will be drawn, and it should correspond to the space + fate prediction take place. + n_steps: the number of times steps (frames) fate prediction will take. + cell_states: the number of cells state that will be randomly selected (if `int`), the indices of the cells + states (if `list`) or all cell states which fate prediction executed (if `None`) + color: the key of the data that will be used to color the embedding. + fig: the figure that will contain both the background and animated components. + ax: the matplotlib axes object that will be used as background plot of the vector field animation. If `ax` + is None, `topography(adata, basis=basis, color=color, ax=ax, save_show_or_return='return')` will be used + to create an axes. + logspace: `whether or to sample time points linearly on log space. If not, the sorted unique set of all-time + points from all cell states' fate prediction will be used and then evenly sampled up to `n_steps` time + points. + max_time: the maximum time that will be used to scale the time vector. + + Returns: + A class that contains .fig attribute and .update, .init_background that can be used to produce an animation + of the prediction of cell fate commitment. + """ + + import matplotlib.pyplot as plt + + super().__init__( + adata=adata, + basis=basis, + fp_basis=fp_basis, + dims=dims, + n_steps=n_steps, + cell_states=cell_states, + color=color, + logspace=logspace, + max_time=max_time, + ) # Animation objects must create `fig` and `ax` attributes. if ax is None or fig is None: @@ -205,13 +258,14 @@ def __init__( self.fig = fig self.ax = ax - (self.ln,) = self.ax.plot([], [], "ro") + (self.ln,) = self.ax.plot([], [], "ro", zs=[]) if dims is not None and len(dims) == 3 else self.ax.plot([], [], "ro") def init_background(self): + """Initialize background of the animation.""" return (self.ln,) def update(self, frame): - """Update locations of "particles" in flow on each frame frame.""" + """Update locations of "particles" in flow on each frame.""" init_states = self.init_states time_vec = self.time_vec @@ -249,6 +303,47 @@ def update(self, frame): return (self.ln,) # return line so that blit works properly +class StreamFuncAnim3D(StreamFuncAnim): + """The class of 3D animation instance for matplotlib FuncAnimation function.""" + def update(self, frame): + """The function to call at each frame. Update the position of the line object in the animation.""" + init_states = self.init_states + time_vec = self.time_vec + + pts = [i.tolist() for i in init_states] + + if frame == 0: + x, y, z = init_states.T + + for line in self.ax.get_lines(): + line.remove() + + (self.ln,) = self.ax.plot(x, y, z, "ro", zorder=20) + return (self.ln,) # return line so that blit works properly + else: + pts = [self.displace(cur_pts, time_vec[frame])[1].tolist() for cur_pts in pts] + pts = np.asarray(pts) + + pts = np.asarray(pts) + + pts = remove_particles(pts, self.xlim, self.ylim, self.zlim) + + x, y, z = np.asarray(pts).transpose() + + for line in self.ax.get_lines(): + line.remove() + + (self.ln,) = self.ax.plot(x, y, z, "ro", zorder=20) + + if self.time_scaler is not None: + vf_time = (time_vec[frame] - time_vec[frame - 1]) * self.time_scaler + self.ax.set_title("current vector field time is: {:12.2f}".format(vf_time)) + + # anim.event_source.interval = (time_vec[frame] - time_vec[frame - 1]) / 100 + + return (self.ln,) # return line so that blit works properly + + def animate_fates( adata, basis="umap", @@ -260,7 +355,6 @@ def animate_fates( ax=None, logspace=False, max_time=None, - frame_color=None, interval=100, blit=True, save_show_or_return="show", @@ -271,7 +365,7 @@ def animate_fates( This class creates necessary components to produce an animation that describes the exact speed of a set of cells at each time point, its movement in gene expression and the long range trajectory predicted by the reconstructed - vector field. Thus it provides intuitive visual understanding of the RNA velocity, speed, acceleration, and cell + vector field. Thus, it provides intuitive visual understanding of the RNA velocity, speed, acceleration, and cell fate commitment in action. This function is originally inspired by https://tonysyu.github.io/animating-particles-in-a-flow.html and relies on @@ -279,49 +373,37 @@ def animate_fates( the animation. See for example, http://louistiao.me/posts/notebooks/save-matplotlib-animations-as-gifs/ for more details. - Parameters - ---------- - adata: :class:`~anndata.AnnData` - AnnData object that already went through the fate prediction. - basis: `str` or None (default: `None`) - The embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the reconstructed - trajectory will be projected back to high dimensional space via the `inverse_transform` function. - space. - dims: `list` or `None` (default: `None') - The dimensions of low embedding space where cells will be drawn and it should corresponds to the space + Args: + adata: annData object that already went through the fate prediction. + basis: the embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the + reconstructed trajectory will be projected back to high dimensional space via the `inverse_transform` + function space. + dims: the dimensions of low embedding space where cells will be drawn, and it should correspond to the space fate prediction take place. - n_steps: `int` (default: `100`) - The number of times steps (frames) fate prediction will take. - cell_states: `int`, `list` or `None` (default: `None`) - The number of cells state that will be randomly selected (if `int`), the indices of the cells states (if - `list`) or all cell states which fate prediction executed (if `None`) - fig: `matplotlib.figure.Figure` or None (default: `None`) - The figure that will contain both the background and animated components. - ax: `matplotlib.Axis` (optional, default `None`) - The matplotlib axes object that will be used as background plot of the vector field animation. If `ax` - is None, `topography(adata, basis=basis, color=color, ax=ax, save_show_or_return='return')` will be used - to create an axes. - logspace: `bool` (default: `False`) - Whether or to sample time points linearly on log space. If not, the sorted unique set of all time points - from all cell states' fate prediction will be used and then evenly sampled up to `n_steps` time points. - interval: `float` (default: `200`) - Delay between frames in milliseconds. - blit: `bool` (default: `False`) - Whether blitting is used to optimize drawing. Note: when using blitting, any animated artists will be drawn + n_steps: the number of times steps (frames) fate prediction will take. + cell_states: the number of cells state that will be randomly selected (if `int`), the indices of the cells + states (if `list`) or all cell states which fate prediction executed (if `None`) + color: the key of the data that will be used to color the embedding. + fig: the figure that will contain both the background and animated components. + ax: the matplotlib axes object that will be used as background plot of the vector field animation. If `ax` + is None, `topography(adata, basis=basis, color=color, ax=ax, save_show_or_return='return')` will be used + to create an axes. + logspace: `whether or to sample time points linearly on log space. If not, the sorted unique set of all-time + points from all cell states' fate prediction will be used and then evenly sampled up to `n_steps` time + points. + max_time: the maximum time that will be used to scale the time vector. + interval: delay between frames in milliseconds. + blit: whether blitting is used to optimize drawing. Note: when using blitting, any animated artists will be drawn according to their zorder; however, they will be drawn on top of any previous artists, regardless of their zorder. - save_show_or_return: `str` {'save', 'show', 'return'} (default: `save`) - Whether to save, show or return the figure. By default a gif will be used. - save_kwargs: `dict` (default: `{}`) - A dictionary that will passed to the anim.save. By default it is an empty dictionary and the save_fig function - will use the {"filename": 'fate_ani.gif', "writer": "imagemagick"} as its parameters. Otherwise you can - provide a dictionary that properly modify those keys according to your needs. see + save_show_or_return: whether to save, show or return the figure. By default, a gif will be used. + save_kwargs: a dictionary that will be passed to the anim.save. By default, it is an empty dictionary and the + save_fig function will use the {"filename": 'fate_ani.gif', "writer": "imagemagick"} as its parameters. + Otherwise, you can provide a dictionary that properly modify those keys according to your needs. see https://matplotlib.org/api/_as_gen/matplotlib.animation.Animation.save.html for more details. - kwargs: - Additional arguments passed to animation.FuncAnimation. + kwargs: additional arguments passed to animation.FuncAnimation. - Returns - ------- + Returns: Nothing but produce an animation that will be embedded to jupyter notebook or saved to disk. Examples 1 @@ -350,7 +432,6 @@ def animate_fates( ax=ax, logspace=logspace, max_time=max_time, - frame_color=frame_color, ) anim = animation.FuncAnimation( @@ -372,3 +453,239 @@ def animate_fates( HTML(anim.to_jshtml()) # embedding to jupyter notebook. else: anim + + +class PyvistaAnim(BaseAnim): + """The class for animating cell fate commitment prediction with pyvista.""" + def __init__( + self, + adata: AnnData, + basis: str = "umap", + fp_basis: Union[str, None] = None, + dims: Optional[list] = None, + n_steps: int = 15, + cell_states: Union[int, list, None] = None, + color: str = "ntr", + point_size: float = 15, + pl=None, + logspace: bool = False, + max_time: Optional[float] = None, + filename: str = "fate_animation.gif", + ): + """Construct the PyvistaAnim class. + + Args: + adata: annData object that already went through the fate prediction. + basis: the embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the + reconstructed trajectory will be projected back to high dimensional space via the `inverse_transform` + function space. + fp_basis: the basis that will be used for identifying or retrieving fixed points. Note that if `fps_basis` + is different from `basis`, the nearest cells of the fixed point from the `fps_basis` will be found and + used to visualize the position of the fixed point on `basis` embedding. + dims: the dimensions of low embedding space where cells will be drawn, and it should correspond to the space + fate prediction take place. + n_steps: the number of times steps (frames) fate prediction will take. + cell_states: the number of cells state that will be randomly selected (if `int`), the indices of the cells + states (if `list`) or all cell states which fate prediction executed (if `None`) + color: the key of the data that will be used to color the embedding. + point_size: the size of the points that will be used to draw the cells. + pl: the pyvista plotter object that will be used to draw the cells. If `pl` is None, `topography_3D(adata, + basis=basis, fps_basis=fp_basis, color=color, ax=pl, save_show_or_return='return')` will be used to + create a plotter. + logspace: `whether or to sample time points linearly on log space. If not, the sorted unique set of all-time + points from all cell states' fate prediction will be used and then evenly sampled up to `n_steps` time + points. + max_time: the maximum time that will be used to scale the time vector. + filename: the name of the gif file that will be saved to disk. + + Returns: + A class that contains .animate, that can be used to produce a gif of the prediction of cell fate + commitment. + """ + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + + super().__init__( + adata=adata, + basis=basis, + fp_basis=fp_basis, + dims=dims, + n_steps=n_steps, + cell_states=cell_states, + color=color, + logspace=logspace, + max_time=max_time, + ) + + self.filename = filename + + if pl is None: + self.pl = topography_3D( + self.adata, + basis=self.basis, + fps_basis=self.fp_basis, + color=self.color, + ax=self.pl, + save_show_or_return="return", + ) + else: + self.pl = pl + + self.n_steps = n_steps + self.point_size = point_size + + def animate(self): + """Animate the cell fate commitment prediction.""" + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + + pts = [i.tolist() for i in self.init_states] + + self.pl.open_gif(self.filename) + + pts = [self.displace(cur_pts, self.time_vec[0])[1].tolist() for cur_pts in pts] + pts = np.asarray(pts) + pts = remove_particles(pts, self.xlim, self.ylim, self.zlim) + + mesh = pv.PolyData(pts) + self.pl.add_mesh(mesh, color="red", render_points_as_spheres=True, point_size=self.point_size) + + for frame in range(1, self.n_steps): + pts = [self.displace(cur_pts, self.time_vec[frame])[1].tolist() for cur_pts in pts] + pts = np.asarray(pts) + # pts = remove_particles(pts, self.xlim, self.ylim, self.zlim) + mesh.points = pv.PolyData(pts).points + + self.pl.write_frame() + + self.pl.close() + + +class PlotlyAnim(BaseAnim): + """The class for animating cell fate commitment prediction with plotly.""" + def __init__( + self, + adata: AnnData, + basis: str = "umap", + fp_basis: Union[str, None] = None, + dims: Optional[list] = None, + n_steps: int = 15, + cell_states: Union[int, list, None] = None, + color: str = "ntr", + pl=None, + logspace: bool = False, + max_time: Optional[float] = None, + ): + """Construct the PlotlyAnim class. + + Args: + adata: annData object that already went through the fate prediction. + basis: the embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the + reconstructed trajectory will be projected back to high dimensional space via the `inverse_transform` + function space. + fp_basis: the basis that will be used for identifying or retrieving fixed points. Note that if `fps_basis` + is different from `basis`, the nearest cells of the fixed point from the `fps_basis` will be found and + used to visualize the position of the fixed point on `basis` embedding. + dims: the dimensions of low embedding space where cells will be drawn, and it should correspond to the space + fate prediction take place. + n_steps: the number of times steps (frames) fate prediction will take. + cell_states: the number of cells state that will be randomly selected (if `int`), the indices of the cells + states (if `list`) or all cell states which fate prediction executed (if `None`) + color: the key of the data that will be used to color the embedding. + pl: the plotly figure object that will be used to draw the cells. If `pl` is None, `topography_3D(adata, + basis=basis, fps_basis=fp_basis, color=color, ax=pl, save_show_or_return='return')` will be used to + create a plotter. + logspace: `whether or to sample time points linearly on log space. If not, the sorted unique set of all-time + points from all cell states' fate prediction will be used and then evenly sampled up to `n_steps` time + points. + max_time: the maximum time that will be used to scale the time vector. + + Returns: + A class that contains .animate, that can be used to produce a gif of the prediction of cell fate + commitment. + """ + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + + super().__init__( + adata=adata, + basis=basis, + fp_basis=fp_basis, + dims=dims, + n_steps=n_steps, + cell_states=cell_states, + color=color, + logspace=logspace, + max_time=max_time, + ) + + if pl is None: + self.pl = topography_3D( + self.adata, + basis=self.basis, + fps_basis=self.fp_basis, + color=self.color, + plot_method="plotly", + ax=self.pl, + save_show_or_return="return", + ) + else: + self.pl = pl + + self.n_steps = n_steps + self.pts_history = [] + + def calculate_pts_history(self): + """Calculate the history of the cell states.""" + pts = [i.tolist() for i in self.init_states] + + self.pts_history.append(pts) + + for frame in range(0, self.n_steps): + pts = [self.displace(cur_pts, self.time_vec[frame])[1].tolist() for cur_pts in pts] + pts = np.asarray(pts) + pts = remove_particles(pts, self.xlim, self.ylim, self.zlim) + self.pts_history.append(np.asarray(pts)) + + def animate(self): + """Animate the cell fate commitment prediction.""" + try: + import plotly.graph_objects as go + except ImportError: + raise ImportError("Please install plotly first.") + + if len(self.pts_history) == 0: + self.calculate_pts_history() + + fig = go.Figure( + data=self.pl, + layout=go.Layout(title="Moving Frenet Frame Along a Planar Curve", + updatemenus=[dict(type="buttons", + buttons=[dict(label="Play", + method="animate", + args=[None])])]), + frames=[ + go.Frame( + data=[ + go.Scatter3d( + x=self.pts_history[k][:, 0], + y=self.pts_history[k][:, 1], + z=self.pts_history[k][:, 2], + mode="markers", + marker=dict( + color="red", + size=20, + ), + ) + ] + ) for k in range(1, self.n_steps) + ] + ) + + fig.show() diff --git a/dynamo/movie/utils.py b/dynamo/movie/utils.py index dfb0018b9..b7f55bd22 100644 --- a/dynamo/movie/utils.py +++ b/dynamo/movie/utils.py @@ -1,10 +1,30 @@ -from typing import Union +from typing import Optional, Union +import numpy as np -def remove_particles(pts: list, xlim: Union[tuple, list], ylim: Union[tuple, list]): + +def remove_particles( + pts: np.ndarray, + xlim: Union[tuple, list], + ylim: Union[tuple, list], + zlim: Optional[Union[tuple, list]] = None, +) -> np.ndarray: + """Remove particles that fall outside specified coordinate ranges. + + Args: + pts: an array of points. + xlim: X-coordinate limits specified as a tuple or list of two values: (min_x, max_x). + ylim: Y-coordinate limits specified as a tuple or list of two values: (min_y, max_y). + zlim: Z-coordinate limits specified as a tuple or list of two values: (min_z, max_z). If not provided (default), + only 2D filtering based on xlim and ylim is performed. + + Returns: + An array of points that fall within the specified coordinate ranges. + """ if len(pts) == 0: return [] outside_xlim = (pts[:, 0] < xlim[0]) | (pts[:, 0] > xlim[1]) outside_ylim = (pts[:, 1] < ylim[0]) | (pts[:, 1] > ylim[1]) - keep = ~(outside_xlim | outside_ylim) + outside_zlim = np.full(outside_xlim.shape, False) if zlim is None else (pts[:, 2] < ylim[0]) | (pts[:, 2] > ylim[1]) + keep = ~(outside_xlim | outside_ylim | outside_zlim) return pts[keep] diff --git a/dynamo/plot/__init__.py b/dynamo/plot/__init__.py index f7f1c81f9..ce6f2ce29 100755 --- a/dynamo/plot/__init__.py +++ b/dynamo/plot/__init__.py @@ -31,7 +31,7 @@ variance_explained, ) from .pseudotime import plot_dim_reduced_direct_graph -from .scatters import scatters +from .scatters import scatters, scatters_interactive from .scPotential import show_landscape from .sctransform import sctransform_plot_fit, plot_residual_var from .scVectorField import ( # , plot_LIC_gray @@ -61,6 +61,7 @@ plot_separatrix, plot_traj, topography, + topography_3D, ) # from .theme import points @@ -81,6 +82,7 @@ "quiver_autoscaler", "save_fig", "scatters", + "scatters_interactive", "basic_stats", "show_fraction", "feature_genes", @@ -116,6 +118,7 @@ "plot_separatrix", "plot_traj", "topography", + "topography_3D", "speed", "acceleration", "curl", diff --git a/dynamo/plot/scVectorField.py b/dynamo/plot/scVectorField.py index 164d27f81..0d7460af2 100755 --- a/dynamo/plot/scVectorField.py +++ b/dynamo/plot/scVectorField.py @@ -27,12 +27,15 @@ from ..tools.utils import update_dict from ..vectorfield.topography import VectorField from ..vectorfield.utils import vecfld_from_adata -from .scatters import docstrings, scatters +from .scatters import docstrings, scatters, scatters_interactive from .utils import ( _get_adata_color_vec, default_quiver_args, quiver_autoscaler, + retrieve_plot_save_path, save_fig, + save_plotly_figure, + save_pyvista_plotter, set_arrow_alpha, set_stream_line_alpha, ) @@ -61,6 +64,7 @@ def cell_wise_vectors_3d( V: Union[np.ndarray, spmatrix] = None, color: Union[str, List[str]] = None, layer: str = "X", + plot_method: Literal["pv", "matplotlib"] = "pv", background: Optional[str] = "white", ncols: int = 4, figsize: Tuple[float] = (6, 4), @@ -71,10 +75,11 @@ def cell_wise_vectors_3d( save_show_or_return: str = "show", save_kwargs: Dict[str, Any] = {}, quiver_3d_kwargs: Dict[str, Any] = { - "zorder": 3, - "length": 2, - "linewidth": 5, - "arrow_length_ratio": 5, + "linewidth": 1, + "edgecolors": "white", + "alpha": 1, + "length": 8, + "arrow_length_ratio": 1, "norm": cm.colors.Normalize(), "cmap": cm.PRGn, }, @@ -84,8 +89,35 @@ def cell_wise_vectors_3d( elev: Optional[float] = None, azim: Optional[float] = None, alpha: Optional[float] = None, - show_magnitude: bool = False, + show_magnitude: bool = True, titles: Optional[List[str]] = None, + highlights: Optional[list] = None, + labels: Optional[list] = None, + values: Optional[list] = None, + theme: Optional[ + Literal[ + "blue", + "red", + "green", + "inferno", + "fire", + "viridis", + "darkblue", + "darkred", + "darkgreen", + ] + ] = None, + plotly_color: str = "Reds", + cmap: Optional[str] = None, + color_key: Union[Dict[str, str], List[str], None] = None, + color_key_cmap: Optional[str] = None, + pointsize: Optional[float] = None, + use_smoothed: bool = True, + sort: Literal["raw", "abs", "neg"] = "raw", + aggregate: Optional[str] = None, + show_arrowed_spines: bool = False, + frontier: bool = False, + s_kwargs_dict: Dict[str, Any] = {}, **cell_wise_kwargs, ) -> np.ndarray: """Plot the velocity or acceleration vector of each cell. @@ -105,6 +137,7 @@ def cell_wise_vectors_3d( V: the velocity array. If None, the array would be determined by `vkey` provided. Defaults to None. color: any column names or gene expression, etc. that will be used for coloring cells. Defaults to "ntr". layer: the layer of data to use for the scatter plot. Defaults to "X". + plot_method: the method to plot 3D vectors. Options include `pv` (pyvista) and `matplotlib`. background: the background color of the figure. Defaults to "white". ncols: the number of sub-plot columns. Defaults to 4. figsize: the size of each sub-plot panel. Defaults to (6, 4). @@ -132,6 +165,47 @@ def cell_wise_vectors_3d( alpha: the transparency of the colors. Defaults to None. show_magnitude: whether to show original values or normalize the data. Defaults to False. titles: the titles of the subplots. Defaults to None. + highlights: the color group that will be highlighted. If highligts is a list of lists, each list is relate to + each color element. Defaults to None. + labels: an array of labels (assumed integer or categorical), one for each data sample. This will be used for + coloring the points in the plot according to their label. Note that this option is mutually exclusive to the + `values` option. Defaults to None. + values: an array of values (assumed float or continuous), one for each sample. This will be used for coloring + the points in the plot according to a colorscale associated to the total range of values. Note that this + option is mutually exclusive to the `labels` option. Defaults to None. + theme: A color theme to use for plotting. A small set of predefined themes are provided which have relatively + good aesthetics. Available themes are: {'blue', 'red', 'green', 'inferno', 'fire', 'viridis', 'darkblue', + 'darkred', 'darkgreen'}. Defaults to None. + plotly_color: the color of the Plotly Cone plot. It must be an array containing arrays mapping a normalized + value to a rgb, rgba, hex, hsl, hsv, or named color string. + cmap: The name of a matplotlib colormap to use for coloring or shading points. If no labels or values are passed + this will be used for shading points according to density (largely only of relevance for very large + datasets). If values are passed this will be used for shading according the value. Note that if theme is + passed then this value will be overridden by the corresponding option of the theme. Defaults to None. + color_key: the method to assign colors to categoricals. This can either be an explicit dict mapping labels to + colors (as strings of form '#RRGGBB'), or an array like object providing one color for each distinct + category being provided in `labels`. Either way this mapping will be used to color points according to the + label. Note that if theme is passed then this value will be overridden by the corresponding option of the + theme. Defaults to None. + color_key_cmap: the name of a matplotlib colormap to use for categorical coloring. If an explicit `color_key` is + not given a color mapping for categories can be generated from the label list and selecting a matching list + of colors from the given colormap. Note that if theme is passed then this value will be overridden by the + corresponding option of the theme. Defaults to None. + pointsize: the scale of the point size. Actual point cell size is calculated as + `500.0 / np.sqrt(adata.shape[0]) * pointsize`. Defaults to None. + use_smoothed: whether to use smoothed values (i.e. M_s / M_u instead of spliced / unspliced, etc.). Defaults to + True. + sort: the method to reorder data so that high values points will be on top of background points. Can be one of + {'raw', 'abs', 'neg'}, i.e. sorted by raw data, sort by absolute values or sort by negative values. Defaults + to "raw". + aggregate: the column in adata.obs that will be used to aggregate data points. Defaults to None. + show_arrowed_spines: whether to show a pair of arrowed spines representing the basis of the scatter is currently + using. Defaults to False. + frontier: whether to add the frontier. Scatter plots can be enhanced by using transparency (alpha) in order to + show area of high density and multiple scatter plots can be used to delineate a frontier. See matplotlib + tips & tricks cheatsheet (https://github.com/matplotlib/cheatsheets). Originally inspired by figures from + scEU-seq paper: https://science.sciencemag.org/content/367/6482/1151. Defaults to False. + s_kwargs_dict: any other kwargs that will be passed to `dynamo.pl.scatters`. Defaults to {}. Raises: ValueError: invalid `x`, `y`, or `z`. @@ -248,65 +322,195 @@ def add_axis_label(ax, labels): nrows += 1 ncols = min(ncols, len(color)) - figure, axes = plt.subplots(nrows, ncols, figsize=figsize, subplot_kw=dict(projection="3d")) - axes = np.array(axes) - axes_flatten = axes.flatten() - - for i in range(len(color)): - ax = axes_flatten[i] - ax.set_title(color[i]) - norm = quiver_3d_kwargs["norm"] - cmap = quiver_3d_kwargs["cmap"] - color_vec = _get_adata_color_vec(adata, layer=layer, col=color[i]) - assert len(color_vec) > 0, "color vector or data vector size is 0" - - # convet categorical string data colors to labels - if type(color_vec[0]) is str: - unique_vals, color_vec = np.unique(color_vec, return_inverse=True) - - color_vec = cmap(norm(color_vec)) - - # TODO due to matplotlib quiver3 impl, we need to add colors for arrow head segments - # TODO if matplotlib changes its detailed impl, we may not need the following line - color_vec = list(color_vec) + [element for element in list(color_vec) for _ in range(2)] - # color_vec = matplotlib.colors.to_rgba(color_vec, alpha=alpha) - main_debug("color vec len: " + str(len(color_vec))) - ax.view_init(elev=elev, azim=azim) - ax.quiver( - x0, - x1, - x2, - v0, - v1, - v2, - color=color_vec, - # facecolors=color_vec, - **quiver_3d_kwargs, + if plot_method == "pv": + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + + pl, colors_list = scatters_interactive( + adata=adata, + basis=basis, + x=x, + y=y, + z=z, + color=color, + layer=layer, + labels=labels, + values=values, + cmap=cmap, + theme=theme, + background=background, + color_key=color_key, + color_key_cmap=color_key_cmap, + use_smoothed=use_smoothed, + save_show_or_return="return", + render_points_as_spheres=True, ) - ax.set_title(titles[i]) - ax.set_facecolor(background) - add_axis_label(ax, axis_labels) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "cell_wise_vectors_3d", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) + point_cloud = pv.PolyData(np.column_stack((x0.values, x1.values, x2.values))) + point_cloud['vectors'] = np.column_stack((v0.values, v1.values, v2.values)) - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False + r, c = pl.shape[0], pl.shape[1] + subplot_indices = [[i, j] for i in range(r) for j in range(c)] + cur_subplot = 0 - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.show() - if save_show_or_return in ["return", "all"]: - return axes + for i in range(len(color)): + point_cloud.point_data["colors"] = np.stack(colors_list[i]) + + arrows = point_cloud.glyph( + orient='vectors', + factor=3.5, + ) + + if r * c != 1: + pl.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) + cur_subplot += 1 + + pl.add_mesh(arrows, scalars="colors", preference='point', rgb=True) + + return save_pyvista_plotter( + pl=pl, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) + + elif plot_method == "plotly": + try: + import plotly.graph_objects as go + except ImportError: + raise ImportError("Please install plotly first.") + + pl, colors_list = scatters_interactive( + adata=adata, + basis=basis, + x=x, + y=y, + z=z, + color=color, + layer=layer, + plot_method="plotly", + labels=labels, + values=values, + cmap=cmap, + theme=theme, + background=background, + color_key=color_key, + color_key_cmap=color_key_cmap, + use_smoothed=use_smoothed, + save_show_or_return="return", + opacity=0.5, + ) + + r, c = pl._get_subplot_rows_columns() + subplot_indices = [[i, j] for i in range(list(r)[-1]) for j in range(list(c)[-1])] + cur_subplot = 0 + + for i in range(len(color)): + # colors = [[index, "rgb({},{},{})".format(int(row[0] * 255), int(row[1] * 255), int(row[2] * 255))] for index, row in enumerate(colors_list[i])] + + pl.add_trace( + go.Cone( + x=x0.values, + y=x1.values, + z=x2.values, + u=v0.values, + v=v1.values, + w=v2.values, + colorscale=plotly_color, + # colorscale=colors, + sizemode="absolute", + sizeref=1, + ), + row=subplot_indices[cur_subplot][0] + 1, col=subplot_indices[cur_subplot][1] + 1, + ) # TODO: implement customized color for individual cone + + cur_subplot += 1 + + return save_plotly_figure( + pl=pl, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) + + else: + axes_list, color_list, _ = scatters( + adata=adata, + basis=basis, + x=x, + y=y, + z=z, + color=color, + layer=layer, + highlights=highlights, + labels=labels, + values=values, + theme=theme, + cmap=cmap, + color_key=color_key, + color_key_cmap=color_key_cmap, + background=background, + ncols=ncols, + pointsize=pointsize, + figsize=figsize, + show_legend=None, + use_smoothed=use_smoothed, + aggregate=aggregate, + show_arrowed_spines=show_arrowed_spines, + ax=ax, + sort=sort, + save_show_or_return="return", + frontier=frontier, + projection="3d", + **s_kwargs_dict, + return_all=True, + ) + + if type(axes_list) != list: + axes_list = [axes_list] + color_list = [color_list] + + for i in range(len(color)): + ax = axes_list[i] + ax.set_title(color[i]) + cmap_3d = [element for element in color_list[i]] + [element for element in color_list[i] for _ in range(2)] + main_debug("color vec len: " + str(len(cmap_3d))) + ax.view_init(elev=elev, azim=azim) + ax.quiver( + x0, + x1, + x2, + v0, + v1, + v2, + color=cmap_3d, + # facecolors=color_vec, + **quiver_3d_kwargs, + ) + ax.set_title(titles[i]) + ax.set_facecolor(background) + add_axis_label(ax, axis_labels) + + if save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": "cell_wise_vectors_3d", + "dpi": None, + "ext": "pdf", + "transparent": True, + "close": True, + "verbose": True, + } + s_kwargs = update_dict(s_kwargs, save_kwargs) + + if save_show_or_return in ["both", "all"]: + s_kwargs["close"] = False + + save_fig(**s_kwargs) + if save_show_or_return in ["show", "both", "all"]: + plt.show() + if save_show_or_return in ["return", "all"]: + return axes_list def grid_vectors_3d(): @@ -754,6 +958,7 @@ def cell_wise_vectors( df = pd.DataFrame({"x": X[:, 0], "y": X[:, 1], "u": V[:, 0], "v": V[:, 1]}) elif projection == "3d": df = pd.DataFrame({"x": X[:, 0], "y": X[:, 1], "z": X[:, 2], "u": V[:, 0], "v": V[:, 1], "w": V[:, 2]}) + show_legend = None else: raise NotImplementedError("Projection method %s is not implemented" % projection) @@ -799,7 +1004,14 @@ def cell_wise_vectors( "zorder": 10, } quiver_kwargs = update_dict(quiver_kwargs, cell_wise_kwargs) - quiver_3d_kwargs = {"arrow_length_ratio": scale} + quiver_3d_kwargs = { + "linewidth": 1, + "edgecolors": "white", + "alpha": 1, + "length": 8, + "arrow_length_ratio": scale, + + } axes_list, color_list, _ = scatters( adata=adata, @@ -857,6 +1069,7 @@ def cell_wise_vectors( **quiver_kwargs, ) elif projection == "3d": + cmap_3d = [element for element in color_list[i]] + [element for element in color_list[i] for _ in range(2)] ax.quiver( x0, x1, @@ -864,7 +1077,7 @@ def cell_wise_vectors( v0, v1, v2, - # color=color_list[i], + color=cmap_3d, # facecolors=color_list[i], **quiver_3d_kwargs, ) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 5e0cf339a..e00b520d2 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -13,9 +13,10 @@ import numpy as np import pandas as pd from anndata import AnnData -from matplotlib import patches +from matplotlib import patches, rcParams from matplotlib.axes import Axes from matplotlib.lines import Line2D +from matplotlib.colors import rgb2hex, to_hex from pandas.api.types import is_categorical_dtype from ..configuration import _themes, reset_rcParams @@ -30,13 +31,17 @@ _matplotlib_points, _select_font_color, arrowed_spines, + calculate_colors, deaxis_all, despline_all, is_cell_anno_column, is_gene_name, is_layer_keys, is_list_of_lists, + retrieve_plot_save_path, save_fig, + save_plotly_figure, + save_pyvista_plotter, ) docstrings = DocstringProcessor() @@ -264,8 +269,6 @@ def scatters( """ import matplotlib.pyplot as plt - from matplotlib import rcParams - from matplotlib.colors import rgb2hex, to_hex # 2d is not a projection in matplotlib, default is None (rectilinear) if projection == "2d": @@ -912,3 +915,451 @@ def _plot_basis_layer(cur_b, cur_l): return (axes_list, color_list, font_color) if total_panels > 1 else (ax, color_out, font_color) else: return axes_list if total_panels > 1 else ax + + +def map_to_points( + _adata: AnnData, + axis_x: str, + axis_y: str, + axis_z: str, + basis_key: str, + cur_c: str, + cur_b: str, + cur_l_smoothed: str, +) -> Tuple[pd.DataFrame, str]: + """A helper function to map the given axis to corresponding coordinates in current embedding space. + + Args: + _adata: an AnnData object. + axis_x: the column index of the low dimensional embedding for the x-axis in current space. + axis_y: the column index of the low dimensional embedding for the y-axis in current space. + axis_z: the column index of the low dimensional embedding for the z-axis in current space. + basis_key: the basis key constructed by current basis and layer. + cur_c: the current key to color the data. + cur_b: the current basis key representing the reduced dimension. + cur_l_smoothed: the smoothed layer of data to use. + + Returns: + The 3D DataFrame with coordinates of each sample and the title of the plot. + """ + gene_title = [] + anno_title = [] + + def _map_cur_axis(cur: str) -> Tuple[np.ndarray, str]: + """A helper function to map an axis. + + Args: + cur: the current axis to map. + + Returns: + The coordinates and the column names. + """ + nonlocal gene_title, anno_title + + if is_gene_name(_adata, cur): + points_df_data = (_adata.obs_vector(k=cur, layer=None) + if cur_l_smoothed == "X" + else _adata.obs_vector(k=cur, layer=cur_l_smoothed)) + points_column = cur + " (" + cur_l_smoothed + ")" + gene_title.append(cur) + elif is_cell_anno_column(_adata, cur): + points_df_data = _adata.obs_vector(cur) + points_column = cur + anno_title.append(cur) + elif is_layer_keys(_adata, cur): + points_df_data = _adata[:, cur_b].layers[cur] + points_column = flatten(points_df_data) + else: + raise ValueError("Make sure your `x`, `y` and `z` are integers, gene names, column names in .obs, etc.") + + return points_df_data, points_column + + if type(axis_x) is int and type(axis_y) is int and type(axis_z): + x_col_name = cur_b + "_0" + y_col_name = cur_b + "_1" + z_col_name = cur_b + "_2" + + points = pd.DataFrame( + { + x_col_name: _adata.obsm[basis_key][:, axis_x], + y_col_name: _adata.obsm[basis_key][:, axis_y], + z_col_name: _adata.obsm[basis_key][:, axis_z], + } + ) + points.columns = [x_col_name, y_col_name, z_col_name] + + cur_title = cur_c + + return points, cur_title + elif type(axis_x) in [anndata._core.views.ArrayView, np.ndarray] and type(axis_y) in [ + anndata._core.views.ArrayView, + np.ndarray, + ]: + points = pd.DataFrame({"x": flatten(axis_x), "y": flatten(axis_y), "x": flatten(axis_z)}) + points.columns = ["x", "y", "z"] + else: + x_points_df_data, x_points_column = _map_cur_axis(axis_x) + y_points_df_data, y_points_column = _map_cur_axis(axis_y) + z_points_df_data, z_points_column = _map_cur_axis(axis_z) + points = pd.DataFrame({ + axis_x: x_points_df_data, + axis_y: y_points_df_data, + axis_z: z_points_df_data, + }) + points.columns = [x_points_column, y_points_column, z_points_column] + + if len(gene_title) != 0: + cur_title = " VS ".join(gene_title) + elif len(anno_title) == 3: + cur_title = " VS ".join(anno_title) + else: + cur_title = cur_b + + return points, cur_title + + +def scatters_interactive( + adata: AnnData, + basis: str = "umap", + x: Union[int, str] = 0, + y: Union[int, str] = 1, + z: Union[int, str] = 2, + color: str = "ntr", + layer: str = "X", + plot_method: str = "pv", + labels: Optional[list] = None, + values: Optional[list] = None, + cmap: Optional[str] = None, + theme: Optional[str] = None, + background: Optional[str] = None, + color_key: Union[Dict[str, str], List[str], None] = None, + color_key_cmap: Optional[str] = None, + use_smoothed: bool = True, + sym_c: bool = False, + smooth: bool = False, + save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", + save_kwargs: Dict[str, Any] = {}, + **kwargs, +): + """Plot an embedding as points with Pyvista. Currently only 3D input is supported. For 2D data, `scatters` is a + better alternative. + + The function will use the colors from matplotlib to keep consistence with other plotting functions. + + Args: + adata: an AnnData object. + basis: the reduced dimension stored in adata.obsm. The specific basis key will be constructed in the following + priority if exits: 1) specific layer input + basis 2) X_ + basis 3) basis. E.g. if basis is PCA, `scatters` + is going to look for 1) if specific layer is spliced, `spliced_pca` 2) `X_pca` (dynamo convention) 3) `pca`. + Defaults to "umap". + x: the column index of the low dimensional embedding for the x-axis. Defaults to 0. + y: the column index of the low dimensional embedding for the y-axis. Defaults to 1. + z: the column index of the low dimensional embedding for the z-axis. Defaults to 2. + color: any column names or gene expression, etc. that will be used for coloring cells. Defaults to "ntr". + layer: the layer of data to use for the scatter plot. Defaults to "X". + labels: an array of labels (assumed integer or categorical), one for each data sample. This will be used for + coloring the points in the plot according to their label. Note that this option is mutually exclusive to the + `values` option. Defaults to None. + values: an array of values (assumed float or continuous), one for each sample. This will be used for coloring + the points in the plot according to a colorscale associated to the total range of values. Note that this + option is mutually exclusive to the `labels` option. Defaults to None. + theme: A color theme to use for plotting. A small set of predefined themes are provided which have relatively + good aesthetics. Defaults to None. + cmap: The name of a matplotlib colormap to use for coloring or shading points. If no labels or values are passed + this will be used for shading points according to density (largely only of relevance for very large + datasets). If values are passed this will be used for shading according the value. Note that if theme is + passed then this value will be overridden by the corresponding option of the theme. Defaults to None. + background: the color of the background. Usually this will be either 'white' or 'black', but any color name will + work. Ideally one wants to match this appropriately to the colors being used for points etc. This is one of + the things that themes handle for you. Note that if theme is passed then this value will be overridden by + the corresponding option of the theme. Defaults to None. + color_key: the method to assign colors to categoricals. This can either be an explicit dict mapping labels to + colors (as strings of form '#RRGGBB'), or an array like object providing one color for each distinct + category being provided in `labels`. Either way this mapping will be used to color points according to the + label. Note that if theme is passed then this value will be overridden by the corresponding option of the + theme. Defaults to None. + color_key_cmap: the name of a matplotlib colormap to use for categorical coloring. If an explicit `color_key` is + not given a color mapping for categories can be generated from the label list and selecting a matching list + of colors from the given colormap. Note that if theme is passed then this value will be overridden by the + corresponding option of the theme. Defaults to None. + use_smoothed: whether to use smoothed values (i.e. M_s / M_u instead of spliced / unspliced, etc.). Defaults to + True. + sym_c: whether do you want to make the limits of continuous color to be symmetric, normally this should be used + for plotting velocity, jacobian, curl, divergence or other types of data with both positive or negative + values. Defaults to False. + smooth: whether do you want to further smooth data and how much smoothing do you want. If it is `False`, no + smoothing will be applied. If `True`, smoothing based on one-step diffusion of connectivity matrix + (`.uns['moment_cnn']`) will be applied. If a number larger than 1, smoothing will be based on `smooth` steps + of diffusion. + save_show_or_return: whether to save, show or return the figure. If "both", it will save and plot the figure at + the same time. If "all", the figure will be saved, displayed and the associated axis and other object will + be return. Defaults to "show". + save_kwargs: A dictionary that will be passed to the saving function. By default, it is an empty dictionary + and the saving function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + "title": PyVista Export, "raster": True, "painter": True} as its parameters. Otherwise, you can provide a + dictionary that properly modify those keys according to your needs. Defaults to {}. + **kwargs: any other kwargs that would be passed to `Plotter.add_points()`. + + Returns: + If `save_show_or_return` is `save`, `show` or `both`, the function will return nothing but show or save the + figure. If `save_show_or_return` is `return`, the function will return the axis object(s) that contains the + figure. + """ + + if plot_method == "pv": + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + elif plot_method == "plotly": + try: + import plotly.express as px + import plotly.graph_objects as go + from plotly.subplots import make_subplots + except ImportError: + raise ImportError("Please install plotly first.") + else: + raise NotImplementedError("Current plot method not supported.") + + if type(x) in [int, str]: + x = [x] + if type(y) in [int, str]: + y = [y] + if type(z) in [int, str]: + z = [z] + + # make x, y, z lists of list, where each list corresponds to one coordinate set + if ( + type(x) in [anndata._core.views.ArrayView, np.ndarray] + and type(y) in [anndata._core.views.ArrayView, np.ndarray] + and type(z) in [anndata._core.views.ArrayView, np.ndarray] + and len(x) == adata.n_obs + and len(y) == adata.n_obs + and len(z) == adata.n_obs + ): + x, y, z = [x], [y], [z] + + elif hasattr(x, "__len__") and hasattr(y, "__len__") and hasattr(z, "__len__"): + x, y, z = list(x), list(y), list(z) + + assert len(x) == len(y) and len(x) == len(z), "bug: x, y, z does not have the same shape." + + if use_smoothed: + mapper = get_mapper() + + # check color, layer, basis -> convert to list + if type(color) is str: + color = [color] + if type(layer) is str: + layer = [layer] + if type(basis) is str: + basis = [basis] + + n_c, n_l, n_b, n_x, n_y, n_z = ( + 1 if color is None else len(color), + 1 if layer is None else len(layer), + 1 if basis is None else len(basis), + 1 if x is None else 1 if type(x) in [anndata._core.views.ArrayView, np.ndarray] else len(x), + 1 if y is None else 1 if type(y) in [anndata._core.views.ArrayView, np.ndarray] else len(y), + 1 if z is None else 1 if type(z) in [anndata._core.views.ArrayView, np.ndarray] else len(z), + ) + + total_panels, ncols = ( + n_c * n_l * n_b * n_x * n_y * n_z, + max([n_c, n_l, n_b, n_x, n_y, n_z]), + ) + + nrow, ncol = int(np.ceil(total_panels / ncols)), ncols + subplot_indices = [[i, j] for i in range(nrow) for j in range(ncol)] + cur_subplot = 0 + colors_list = [] + + if total_panels == 1: + pl = pv.Plotter() if plot_method == "pv" else make_subplots(rows=1, cols=1, specs=[[{"type": "scatter3d"}]]) + else: + pl = ( + pv.Plotter(shape=(nrow, ncol)) + if plot_method == "pv" + else + make_subplots(rows=nrow, cols=ncol, specs=[[{"type": "scatter3d"} for _ in range(ncol)] for _ in range(nrow)]) + ) + + def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: + """A helper function for plotting a specific basis/layer data + + Args: + cur_b: current basis + cur_l: current layer + """ + nonlocal background, adata, cmap, cur_subplot, sym_c + + if cur_l in ["acceleration", "curvature", "divergence", "velocity_S", "velocity_T"]: + cur_l_smoothed = cur_l + cmap, sym_c = "bwr", True # TODO maybe use other divergent color map in the future + else: + if use_smoothed: + cur_l_smoothed = cur_l if cur_l.startswith("M_") | cur_l.startswith("velocity") else mapper[cur_l] + if cur_l.startswith("velocity"): + cmap, sym_c = "bwr", True + + if cur_l + "_" + cur_b in adata.obsm.keys(): + prefix = cur_l + "_" + elif ("X_" + cur_b) in adata.obsm.keys(): + prefix = "X_" + elif cur_b in adata.obsm.keys(): + # special case for spatial for compatibility with other packages + prefix = "" + else: + raise ValueError("Please check if basis=%s exists in adata.obsm" % basis) + + basis_key = prefix + cur_b + main_info("plotting with basis key=%s" % basis_key, indent_level=2) + + for cur_c in color: + main_debug("coloring scatter of cur_c: %s" % str(cur_c)) + + _color = _get_adata_color_vec(adata, cur_l, cur_c) + + # select data rows based on stack color thresholding + is_numeric_color = np.issubdtype(_color.dtype, np.number) + if not is_numeric_color: + main_info( + "skip filtering %s by stack threshold when stacking color because it is not a numeric type" + % (cur_c), + indent_level=2, + ) + _labels, _values = None, None + + for cur_x, cur_y, cur_z in zip(x, y, z): # here x / y are arrays + main_debug("handling coordinates, cur_x: %s, cur_y: %s, cur_z: %s" % (cur_x, cur_y, cur_z)) + + points, cur_title = map_to_points( + adata, + axis_x=cur_x, + axis_y=cur_y, + axis_z=cur_z, + basis_key=basis_key, + cur_c=cur_c, + cur_b=cur_b, + cur_l_smoothed=cur_l_smoothed, + ) + + # https://stackoverflow.com/questions/4187185/how-can-i-check-if-my-python-object-is-a-number + # answer from Boris. + is_not_continuous = not isinstance(_color[0], Number) or _color.dtype.name == "category" + + if is_not_continuous: + _labels = np.asarray(_color) if is_categorical_dtype(_color) else _color + if theme is None: + if background in ["#ffffff", "black"]: + _theme_ = "glasbey_dark" + else: + _theme_ = "glasbey_white" + else: + _theme_ = theme + else: + _values = _color + if theme is None: + if background in ["#ffffff", "black"]: + _theme_ = "inferno" if cur_l != "velocity" else "div_blue_black_red" + else: + _theme_ = "viridis" if not cur_l.startswith("velocity") else "div_blue_red" + else: + _theme_ = theme + + _cmap = _themes[_theme_]["cmap"] if cmap is None else cmap + + _color_key_cmap = _themes[_theme_]["color_key_cmap"] if color_key_cmap is None else color_key_cmap + background = _themes[_theme_]["background"] if background is None else background + + if labels is not None and values is not None: + raise ValueError("Conflicting options; only one of labels or values should be set") + + if labels is not None or values is not None: + _labels = labels.copy() + _values = values.copy() + main_info("`Color` will be ignored because labels/values is provided.") + + if smooth and not is_not_continuous: + main_debug("smooth and not continuous") + knn = adata.obsp["moments_con"] + _values = ( + calc_1nd_moment(_values, knn)[0] + if smooth in [1, True] + else calc_1nd_moment(_values, knn**smooth)[0] + ) + + colors, color_type, _ = calculate_colors( + points.values, + labels=_labels, + values=_values, + cmap=_cmap, + color_key=color_key, + color_key_cmap=_color_key_cmap, + background=background, + sym_c=sym_c, + ) + + colors_list.append(colors) + + if plot_method == "pv": + if total_panels > 1: + pl.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) + + pvdataset = pv.PolyData(points.values) + pvdataset.point_data["colors"] = np.stack(colors) + pl.add_points(pvdataset, scalars="colors", preference='point', rgb=True, cmap=_cmap, **kwargs) + + if color_type == "labels": + type_color_dict = {cell_type: cell_color for cell_type, cell_color in zip(_labels, colors)} + type_color_pair = [[k, v] for k, v in type_color_dict.items()] + pl.add_legend(labels=type_color_pair) + else: + pl.add_scalar_bar() # TODO: fix the bug that scalar bar only works in the first plot + + pl.add_text(cur_title) + pl.add_axes(xlabel=points.columns[0], ylabel=points.columns[1], zlabel=points.columns[2]) + elif plot_method == "plotly": + + pl.add_trace( + go.Scatter3d( + x=points.iloc[:, 0], + y=points.iloc[:, 1], + z=points.iloc[:, 2], + mode="markers", + marker=dict( + color=colors, + ), + text=_labels if color_type == "labels" else _values, + **kwargs, + ), + row=subplot_indices[cur_subplot][0] + 1, col=subplot_indices[cur_subplot][1] + 1, + ) + + pl.update_layout( + scene=dict( + xaxis_title=points.columns[0], + yaxis_title=points.columns[1], + zaxis_title=points.columns[2] + ), + ) + + cur_subplot += 1 + + for cur_b in basis: + for cur_l in layer: + main_debug("Plotting basis:%s, layer: %s" % (str(basis), str(layer))) + main_debug("colors: %s" % (str(color))) + _plot_basis_layer_pv(cur_b, cur_l) + + return save_pyvista_plotter( + pl=pl, + colors_list=colors_list, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) if plot_method == "pv" else save_plotly_figure( + pl=pl, + colors_list=colors_list, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) diff --git a/dynamo/plot/streamtube.py b/dynamo/plot/streamtube.py index a95e818a9..3e8d01527 100644 --- a/dynamo/plot/streamtube.py +++ b/dynamo/plot/streamtube.py @@ -33,7 +33,7 @@ def plot_3d_streamtube( save_show_or_return: Literal["save", "show", "return"] = "show", save_kwargs: Dict[str, Any] = {}, ): - """Plot a interative 3d streamtube plot via plotly. + """Plot an interative 3d streamtube plot via plotly. A streamtube is a tubular region surrounded by streamlines that form a closed loop. It's a continuous version of a streamtube plot (3D quiver plot) and can provide insight into flow data from natural systems. The color of tubes is @@ -89,10 +89,7 @@ def plot_3d_streamtube( else: _background = background - if is_gene_name(adata, color): - color_val = adata.obs_vector(k=color, layer=None) if layer == "X" else adata.obs_vector(k=color, layer=layer) - elif is_cell_anno_column(adata, color): - color_val = adata.obs_vector + color_val = adata.obs_vector(k=color, layer=None) if layer == "X" else adata.obs_vector(k=color, layer=layer) is_not_continous = not isinstance(color_val[0], Number) or color_val.dtype.name == "category" @@ -133,30 +130,41 @@ def plot_3d_streamtube( mapper = cm.ScalarMappable(norm=norm, cmap=_cmap) colors = _to_hex(mapper.to_rgba(values)) + if adata.obsm["X_" + basis].shape[1] < 3: + raise ValueError("Current basis has dimensions less than 3!") + + if "VecFld_" + basis not in adata.uns.keys(): + raise KeyError("Corresponding vector field not found! Please run VectorField() with current basis.") + X = adata.obsm["X_" + basis][:, dims] - grid_kwargs_dict = { - "density": None, - "smooth": None, - "n_neighbors": None, - "min_mass": None, - "autoscale": False, - "adjust_for_stream": True, - "V_threshold": None, - } - - X_grid, p_mass, neighs, weight = prepare_velocity_grid_data( - X, - [60, 60, 60], - density=grid_kwargs_dict["density"], - smooth=grid_kwargs_dict["smooth"], - n_neighbors=grid_kwargs_dict["n_neighbors"], - ) - from ..vectorfield.utils import vecfld_from_adata + if "grid" in adata.uns["VecFld_" + basis].keys() and "grid_V" in adata.uns["VecFld_" + basis].keys(): + X_grid = adata.uns["VecFld_" + basis]["grid"] + velocity_grid = adata.uns["VecFld_" + basis]["grid_V"] + else: + grid_kwargs_dict = { + "density": None, + "smooth": None, + "n_neighbors": None, + "min_mass": None, + "autoscale": False, + "adjust_for_stream": True, + "V_threshold": None, + } + + X_grid, p_mass, neighs, weight = prepare_velocity_grid_data( + X, + [60, 60, 60], + density=grid_kwargs_dict["density"], + smooth=grid_kwargs_dict["smooth"], + n_neighbors=grid_kwargs_dict["n_neighbors"], + ) + + from ..vectorfield.utils import vecfld_from_adata - VecFld, func = vecfld_from_adata(adata, basis="umap") + VecFld, func = vecfld_from_adata(adata, basis=basis) - velocity_grid = func(X_grid) + velocity_grid = func(X_grid) fig = go.Figure( data=go.Streamtube( @@ -167,14 +175,12 @@ def plot_3d_streamtube( v=velocity_grid[:, 1], w=velocity_grid[:, 2], starts=dict( - x=adata[labels == init_group, :].obsm["X_umap"][:125, 0], - y=adata[labels == init_group, :].obsm["X_umap"][:125, 1], - z=adata[labels == init_group, :].obsm["X_umap"][:125, 2], + x=adata[labels == init_group, :].obsm["X_" + basis][:125, 0], + y=adata[labels == init_group, :].obsm["X_" + basis][:125, 1], + z=adata[labels == init_group, :].obsm["X_" + basis][:125, 2], ), - sizeref=3000, colorscale="Portland", showscale=False, - maxdisplayed=3000, ) ) diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index 3b60c753d..c47e6ba79 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -26,13 +26,16 @@ from ..vectorfield.topography import topography as _topology # , compute_separatrices from ..vectorfield.utils import vecfld_from_adata from ..vectorfield.vector_calculus import curl, divergence -from .scatters import docstrings, scatters +from .scatters import docstrings, scatters, scatters_interactive from .utils import ( _plot_traj, _select_font_color, default_quiver_args, quiver_autoscaler, + retrieve_plot_save_path, save_fig, + save_plotly_figure, + save_pyvista_plotter, set_arrow_alpha, set_stream_line_alpha, ) @@ -422,10 +425,11 @@ def plot_fixed_points( background: Optional[str] = None, save_show_or_return: Literal["save", "show", "return"] = "return", save_kwargs: Dict[str, Any] = {}, + plot_method: Literal["pv", "matplotlib"] = "matplotlib", ax: Optional[Axes] = None, **kwargs, ) -> Optional[Axes]: - """Plot fixed points stored in the VectorField2D class. + """Plot fixed points stored in the VectorField class. Args: vecfld: an instance of the vector_field class. @@ -444,7 +448,9 @@ def plot_fixed_points( and the save_fig function will use the {"path": None, "prefix": 'plot_fixed_points', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. - ax: the matplotlib axes used for plotting. Default is to use the current axis. Defaults to None. + plot_method: the method to plot 3D points. Options include `pv` (pyvista) and `matplotlib`. + ax: the matplotlib axes or pyvista plotter used for plotting. Default is to use the current axis. Defaults to + None. Returns: None would be returned by default. If `save_show_or_return` is set to be 'return', the Axes of the generated @@ -510,61 +516,144 @@ def plot_fixed_points( vecfld_dict["confidence"], ) - if ax is None: - ax = plt.gca() - cm = matplotlib.cm.get_cmap(_cmap) if type(_cmap) is str else _cmap - for i in range(len(Xss)): - cur_ftype = ftype[i] - marker_ = markers.MarkerStyle(marker=marker, fillstyle=filltype[int(cur_ftype + 1)]) - ax.scatter( - *Xss[i], - marker=marker_, - s=markersize, - c=c if confidence is None else np.array(cm(confidence[i])).reshape(1, -1), - edgecolor=_select_font_color(_background), - linewidths=1, - cmap=_cmap, - vmin=0, - zorder=5, - ) - txt = ax.text( - *Xss[i], - repr(i), - c=("black" if cur_ftype == -1 else "blue" if cur_ftype == 0 else "red"), - horizontalalignment="center", - verticalalignment="center", - zorder=6, - weight="bold", + colors = [c if confidence is None else np.array(cm(confidence[i])) for i in range(len(confidence))] + text_colors = ["black" if cur_ftype == -1 else "blue" if cur_ftype == 0 else "red" for cur_ftype in ftype] + + if plot_method == "pv": + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + + emitting_indices = [index for index, color in enumerate(text_colors) if color == "red"] + unstable_indices = [index for index, color in enumerate(text_colors) if color == "blue"] + absorbing_indices = [index for index, color in enumerate(text_colors) if color == "black"] + fps_type_indices = [emitting_indices, unstable_indices, absorbing_indices] + + r, c = ax.shape[0], ax.shape[1] + subplot_indices = [[i, j] for i in range(r) for j in range(c)] + cur_subplot = 0 + + for i in range(r * c): + + if r * c != 1: + ax.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) + cur_subplot += 1 + + for indices in fps_type_indices: + points = pv.PolyData(Xss[indices]) + points.point_data["colors"] = np.array(colors)[indices] + points["Labels"] = [str(idx) for idx in indices] + + ax.add_points(points, render_points_as_spheres=True, rgba=True, point_size=15) + ax.add_point_labels( + points, + "Labels", + text_color=text_colors[indices[0]], + font_size=24, + shape_opacity=0, + show_points=False, + always_visible=True, + ) + + return save_pyvista_plotter( + pl=ax, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, ) - txt.set_path_effects( - [ - PathEffects.Stroke(linewidth=1.5, foreground=_background, alpha=0.8), - PathEffects.Normal(), - ] + elif plot_method == "plotly": + try: + import plotly.graph_objects as go + except ImportError: + raise ImportError("Please install plotly first.") + + r, c = ax._get_subplot_rows_columns() + r, c = list(r)[-1], list(c)[-1] + subplot_indices = [[i, j] for i in range(r) for j in range(c)] + cur_subplot = 0 + + for i in range(r * c): + ax.add_trace( + go.Scatter3d( + x=Xss[:, 0], + y=Xss[:, 1], + z=Xss[:, 2], + mode="markers+text", + marker=dict( + color=colors, + size=15, + ), + text=[str(i) for i in range(len(Xss))], + textfont=dict( + color=text_colors, + size=15, + ), + **kwargs, + ), + row=subplot_indices[cur_subplot][0] + 1, col=subplot_indices[cur_subplot][1] + 1, + ) + + return save_plotly_figure( + pl=ax, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, ) + else: + if ax is None: + ax = plt.gca() + + for i in range(len(Xss)): + cur_ftype = ftype[i] + marker_ = markers.MarkerStyle(marker=marker, fillstyle=filltype[int(cur_ftype + 1)]) + ax.scatter( + *Xss[i], + marker=marker_, + s=markersize, + c=c if confidence is None else np.array(cm(confidence[i])).reshape(1, -1), + edgecolor=_select_font_color(_background), + linewidths=1, + cmap=_cmap, + vmin=0, + zorder=5, + ) # TODO: Figure out the user warning that no data for colormapping provided via 'c'. + txt = ax.text( + *Xss[i], + repr(i), + c=("black" if cur_ftype == -1 else "blue" if cur_ftype == 0 else "red"), + horizontalalignment="center", + verticalalignment="center", + zorder=6, + weight="bold", + ) + txt.set_path_effects( + [ + PathEffects.Stroke(linewidth=1.5, foreground=_background, alpha=0.8), + PathEffects.Normal(), + ] + ) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "plot_fixed_points", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) + if save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": "plot_fixed_points", + "dpi": None, + "ext": "pdf", + "transparent": True, + "close": True, + "verbose": True, + } + s_kwargs = update_dict(s_kwargs, save_kwargs) - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False + if save_show_or_return in ["both", "all"]: + s_kwargs["close"] = False - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + save_fig(**s_kwargs) + if save_show_or_return in ["show", "both", "all"]: + plt.tight_layout() + plt.show() + if save_show_or_return in ["return", "all"]: + return ax def plot_traj( @@ -1359,3 +1448,472 @@ def topography( plt.show() if save_show_or_return in ["return", "all"]: return axes_list if len(axes_list) > 1 else axes_list[0] + + +# TODO: Implement more `terms` like streamline and trajectory for 3D topography +@docstrings.with_indent(4) +def topography_3D( + adata: AnnData, + basis: str = "umap", + fps_basis: str = "umap", + x: int = 0, + y: int = 1, + z: int = 2, + color: str = "ntr", + layer: str = "X", + plot_method: Literal["pv", "matplotlib"] = "matplotlib", + highlights: Optional[list] = None, + labels: Optional[list] = None, + values: Optional[list] = None, + theme: Optional[ + Literal[ + "blue", + "red", + "green", + "inferno", + "fire", + "viridis", + "darkblue", + "darkred", + "darkgreen", + ] + ] = None, + cmap: Optional[str] = None, + color_key: Union[Dict[str, str], List[str], None] = None, + color_key_cmap: Optional[str] = None, + alpha: Optional[float] = None, + background: Optional[str] = "white", + ncols: int = 4, + pointsize: Optional[float] = None, + figsize: Tuple[float, float] = (6, 4), + show_legend: str = True, + use_smoothed: bool = True, + xlim: np.ndarray = None, + ylim: np.ndarray = None, + zlim: np.ndarray = None, + t: Optional[npt.ArrayLike] = None, + terms: List[str] = ["fixed_points"], + init_cells: List[int] = None, + init_states: np.ndarray = None, + quiver_source: Literal["raw", "reconstructed"] = "raw", + approx: bool = False, + markersize: float = 200, + marker_cmap: Optional[str] = None, + save_show_or_return: Literal["save", "show", "return"] = "show", + save_kwargs: Dict[str, Any] = {}, + aggregate: Optional[str] = None, + show_arrowed_spines: bool = False, + ax: Optional[Axes] = None, + sort: Literal["raw", "abs", "neg"] = "raw", + frontier: bool = False, + s_kwargs_dict: Dict[str, Any] = {}, + n: int = 25, +) -> Union[Axes, List[Axes], None]: + """Plot the topography of the reconstructed vector field in 3D space. + + Args: + adata: AnnData object that contains the reconstructed vector field. + basis: The embedding data space that will be used to plot the topography. Defaults to `umap`. + fps_basis: The embedding data space that will be used to plot the fixed points. Defaults to `umap`. + x: The index of the first dimension of the embedding data space that will be used to plot the topography. + Defaults to 0. + y: The index of the second dimension of the embedding data space that will be used to plot the topography. + Defaults to 1. + z: The index of the third dimension of the embedding data space that will be used to plot the topography. + Defaults to 2. + color: The color of the topography. Defaults to `ntr`. + layer: The layer of the data that will be used to plot the topography. Defaults to `X`. + plot_method: The method that will be used to plot the topography. Defaults to `matplotlib`. + highlights: The list of gene names that will be used to highlight the gene expression on the topography. + Defaults to None. + labels: The list of gene names that will be used to label the gene expression on the topography. Defaults to + None. + values: The list of gene names that will be used to color the gene expression on the topography. Defaults to + None. + theme: The color theme that will be used to plot the topography. Defaults to None. + cmap: The name of a matplotlib colormap that will be used to color the topography. Defaults to None. + color_key: The color dictionary that will be used to color the topography. Defaults to None. + color_key_cmap: The name of a matplotlib colormap that will be used to color the color key. Defaults to None. + alpha: The transparency of the topography. Defaults to None. + background: The background color of the topography. Defaults to `white`. + ncols: The number of columns for the figure. Defaults to 4. + pointsize: The scale of the point size. Actual point cell size is calculated as + `500.0 / np.sqrt(adata.shape[0]) * pointsize`. Defaults to None. + figsize: The width and height of a figure. Defaults to (6, 4). + show_legend: Whether to display a legend of the labels. Defaults to `on data`. + use_smoothed: Whether to use smoothed values (i.e. M_s / M_u instead of spliced / unspliced, etc.). Defaults to + True. + xlim: The range of x-coordinate. Defaults to None. + ylim: The range of y-coordinate. Defaults to None. + zlim: The range of z-coordinate. Defaults to None. + t: The length of the time period from which to predict cell state forward or backward over time. This is used + by the odeint function. Defaults to None. + terms: A list of plotting items to include in the final topography figure. ('streamline', 'nullcline', + 'fixed_points', 'separatrix', 'trajectory', 'quiver') are all the items that we can support. Defaults to + ["streamline", "fixed_points"]. + init_cells: cell name or indices of the initial cell states for the historical or future cell state prediction + with numerical integration. If the names in init_cells are not find in the adata.obs_name, it will be + treated as cell indices and must be integers. Defaults to None. + init_states: the initial cell states for the historical or future cell state prediction with numerical + integration. It can be either a one-dimensional array or N x 2 dimension array. The `init_state` will be + replaced to that defined by init_cells if init_cells are not None. Defaults to None. + quiver_source: the data source that will be used to draw the quiver plot. If `init_cells` is provided, this will + set to be the projected RNA velocity before vector field reconstruction automatically. If `init_cells` is + not provided, this will set to be the velocity vectors calculated from the reconstructed vector field + function automatically. If quiver_source is `reconstructed`, the velocity vectors calculated from the + reconstructed vector field function will be used. Defaults to "raw". + approx: whether to use streamplot to draw the integration line from the init_state. Defaults to False. + markersize: the size of the marker. Defaults to 200. + marker_cmap: the name of a matplotlib colormap to use for coloring or shading the confidence of fixed points. If + None, the default color map will set to be viridis (inferno) when the background is white (black). Defaults + to None. + save_show_or_return: Whether to save, show or return the figure. Defaults to `show`. + save_kwargs: A dictionary that will be passed to the save_fig function. By default, it is an empty dictionary + and the save_fig function will use the {"path": None, "prefix": 'topography', "dpi": None, "ext": 'pdf', + "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a + dictionary that properly modify those keys according to your needs. Defaults to {}. + aggregate: The column in adata.obs that will be used to aggregate data points. Defaults to None. + show_arrowed_spines: Whether to show a pair of arrowed spines representing the basis of the scatter is currently + using. Defaults to False. + ax: The axis on which to make the plot. Defaults to None. + sort: The method to reorder data so that high values points will be on top of background points. Can be one of + {'raw', 'abs', 'neg'}, i.e. sorted by raw data, sort by absolute values or sort by negative values. Defaults + to "raw". Defaults to "raw". + frontier: whether to add the frontier. Scatter plots can be enhanced by using transparency (alpha) in order to + show area of high density and multiple scatter plots can be used to delineate a frontier. See matplotlib + tips & tricks cheatsheet (https://github.com/matplotlib/cheatsheets). Originally inspired by figures from + scEU-seq paper: https://science.sciencemag.org/content/367/6482/1151. Defaults to False. + s_kwargs_dict: The dictionary of the scatter arguments. Defaults to {}. + n: Number of samples for calculating the fixed points. + + Returns: + None would be returned by default. If `save_show_or_return` is set to be 'return', the Axes of the generated + subplots would be returned. + """ + + logger = LoggerManager.gen_logger("dynamo-topography-plot") + logger.log_time() + + from matplotlib import rcParams + from matplotlib.colors import to_hex + + if type(color) == str: + color = [color] + + if alpha is None: + alpha = 0.8 if plot_method in ["pv", "plotly"] else 0.1 + + if background is None: + _background = rcParams.get("figure.facecolor") + _background = to_hex(_background) if type(_background) is tuple else _background + else: + _background = background + + terms = list(terms) if type(terms) is tuple else [terms] if type(terms) is str else terms + if approx: + if "streamline" not in terms: + terms.append("streamline") + if "trajectory" in terms: + terms = list(set(terms).difference("trajectory")) + + if init_cells is not None or init_states is not None: + terms.extend("trajectory") + + uns_key = "VecFld" if basis == "X" else "VecFld_" + basis + fps_uns_key = "VecFld" if fps_basis == "X" else "VecFld_" + fps_basis + + if uns_key not in adata.uns.keys(): + + if "velocity_" + basis not in adata.obsm_keys(): + logger.info( + f"velocity_{basis} is computed yet. " f"Projecting the velocity vector to {basis} basis now ...", + indent_level=1, + ) + cell_velocities(adata, basis=basis) + + logger.info( + f"Vector field for {basis} is not constructed. Constructing it now ...", + indent_level=1, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + if basis == fps_basis: + logger.info( + f"`basis` and `fps_basis` are all {basis}. Will also map topography ...", + indent_level=2, + ) + VectorField(adata, basis, map_topography=True, n=n) + else: + VectorField(adata, basis) + if fps_uns_key not in adata.uns.keys(): + if "velocity_" + basis not in adata.obsm_keys(): + logger.info( + f"velocity_{basis} is computed yet. " f"Projecting the velocity vector to {basis} basis now ...", + indent_level=1, + ) + cell_velocities(adata, basis=basis) + + logger.info( + f"Vector field for {fps_basis} is not constructed. " f"Constructing it and mapping its topography now ...", + indent_level=1, + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + VectorField(adata, fps_basis, map_topography=True, n=n) + + vecfld_dict, vecfld = vecfld_from_adata(adata, basis) + + fps_vecfld_dict, fps_vecfld = vecfld_from_adata(adata, fps_basis) + + # need to use "X_basis" to plot on the scatter point space + if "Xss" not in fps_vecfld_dict: + # if topology is not mapped for this basis, calculate it now. + logger.info( + f"Vector field for {fps_basis} is but its topography is not mapped. " f"Mapping topography now ...", + indent_level=1, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + _topology(adata, fps_basis, VecFld=None, n=n) + else: + if fps_vecfld_dict["Xss"].size > 0 and fps_vecfld_dict["Xss"].shape[1] > 3: + fps_vecfld_dict["X_basis"], fps_vecfld_dict["Xss"] = ( + vecfld_dict["X"][:, :3], + vecfld_dict["X"][fps_vecfld_dict["fp_ind"], :3], + ) + + xlim, ylim, zlim = ( + adata.uns[fps_uns_key]["xlim"] if xlim is None else xlim, + adata.uns[fps_uns_key]["ylim"] if ylim is None else ylim, + adata.uns[fps_uns_key]["zlim"] if zlim is None else zlim, + ) + + if xlim is None or ylim is None or zlim is None: + X_basis = vecfld_dict["X"][:, :3] + min_, max_ = X_basis.min(0), X_basis.max(0) + + xlim = [ + min_[0] - (max_[0] - min_[0]) * 0.1, + max_[0] + (max_[0] - min_[0]) * 0.1, + ] + ylim = [ + min_[1] - (max_[1] - min_[1]) * 0.1, + max_[1] + (max_[1] - min_[1]) * 0.1, + ] + zlim = [ + min_[2] - (max_[2] - min_[2]) * 0.1, + max_[2] + (max_[2] - min_[2]) * 0.1, + ] + + + if init_cells is not None: + if init_states is None: + intersect_cell_names = list(set(init_cells).intersection(adata.obs_names)) + _init_states = ( + adata.obsm["X_" + basis][init_cells, :] + if len(intersect_cell_names) == 0 + else adata[intersect_cell_names].obsm["X_" + basis].copy() + ) + V = ( + adata.obsm["velocity_" + basis][init_cells, :] + if len(intersect_cell_names) == 0 + else adata[intersect_cell_names].obsm["velocity_" + basis].copy() + ) + + init_states = _init_states + + if quiver_source == "reconstructed" or (init_states is not None and init_cells is None): + from ..tools.utils import vector_field_function + + V = vector_field_function(init_states, vecfld_dict, [0, 1]) + + if plot_method == "pv": + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + + pl, colors_list = scatters_interactive( + adata=adata, + basis=basis, + x=x, + y=y, + z=z, + color=color, + layer=layer, + labels=labels, + values=values, + cmap=cmap, + theme=theme, + background=background, + color_key=color_key, + color_key_cmap=color_key_cmap, + use_smoothed=use_smoothed, + save_show_or_return="return", + # style='points_gaussian', + opacity=alpha, + ) + + if "fixed_points" in terms: + pl = plot_fixed_points( + fps_vecfld, + fps_vecfld_dict, + background=_background, + ax=pl, + markersize=markersize, + cmap=marker_cmap, + plot_method="pv", + ) + + return save_pyvista_plotter( + pl=pl, + colors_list=colors_list, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) + elif plot_method == "plotly": + try: + import plotly.graph_objects as go + except ImportError: + raise ImportError("Please install plotly first.") + + pl, colors_list = scatters_interactive( + adata=adata, + basis=basis, + x=x, + y=y, + z=z, + color=color, + layer=layer, + plot_method="plotly", + labels=labels, + values=values, + cmap=cmap, + theme=theme, + background=background, + color_key=color_key, + color_key_cmap=color_key_cmap, + use_smoothed=use_smoothed, + save_show_or_return="return", + opacity=alpha, + ) + + if "fixed_points" in terms: + pl = plot_fixed_points( + fps_vecfld, + fps_vecfld_dict, + background=_background, + ax=pl, + markersize=markersize, + cmap=marker_cmap, + plot_method="plotly", + ) + + return save_plotly_figure( + pl=pl, + colors_list=colors_list, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) + else: + # plt.figure(facecolor=_background) + axes_list, color_list, font_color = scatters( + adata=adata, + basis=basis, + x=x, + y=y, + z=z, + color=color, + layer=layer, + highlights=highlights, + labels=labels, + values=values, + theme=theme, + cmap=cmap, + color_key=color_key, + color_key_cmap=color_key_cmap, + alpha=alpha, + background=_background, + ncols=ncols, + pointsize=pointsize, + figsize=figsize, + show_legend=show_legend, + use_smoothed=use_smoothed, + aggregate=aggregate, + show_arrowed_spines=show_arrowed_spines, + ax=ax, + sort=sort, + save_show_or_return="return", + frontier=frontier, + projection="3d", + **s_kwargs_dict, + return_all=True, + ) + + if type(axes_list) != list: + axes_list, color_list, font_color = ( + [axes_list], + [color_list], + [font_color], + ) + for i in range(len(axes_list)): + # ax = axes_list[i] + + axes_list[i].set_xlabel(basis + "_1") + axes_list[i].set_ylabel(basis + "_2") + axes_list[i].set_zlabel(basis + "_3") + # axes_list[i].set_aspect("equal") + + # Build the plot + axes_list[i].set_xlim(xlim) + axes_list[i].set_ylim(ylim) + axes_list[i].set_zlim(zlim) + + axes_list[i].set_facecolor(background) + + if t is None: + if vecfld_dict["grid_V"] is None: + max_t = np.max((np.diff(xlim), np.diff(ylim))) / np.min(np.abs(vecfld_dict["V"][:, :2])) + else: + max_t = np.max((np.diff(xlim), np.diff(ylim))) / np.min(np.abs(vecfld_dict["grid_V"])) + + t = np.linspace(0, max_t, 10 ** (np.min((int(np.log10(max_t)), 8)))) + + if "fixed_points" in terms: + axes_list[i] = plot_fixed_points( + fps_vecfld, + fps_vecfld_dict, + background=_background, + ax=axes_list[i], + markersize=markersize, + cmap=marker_cmap, + ) + + if save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": "topography", + "dpi": None, + "ext": "pdf", + "transparent": True, + "close": True, + "verbose": True, + } + s_kwargs = update_dict(s_kwargs, save_kwargs) + + if save_show_or_return in ["both", "all"]: + s_kwargs["close"] = False + + save_fig(**s_kwargs) + if save_show_or_return in ["show", "both", "all"]: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + plt.tight_layout() + + plt.show() + if save_show_or_return in ["return", "all"]: + return axes_list if len(axes_list) > 1 else axes_list[0] \ No newline at end of file diff --git a/dynamo/plot/utils.py b/dynamo/plot/utils.py index 7076ea611..ca1ec1358 100755 --- a/dynamo/plot/utils.py +++ b/dynamo/plot/utils.py @@ -4,7 +4,7 @@ # import matplotlib.tri as tri import warnings -from typing import Optional +from typing import Dict, List, Optional, Tuple from warnings import warn import matplotlib @@ -19,7 +19,7 @@ from ..configuration import _themes from ..dynamo_logger import main_debug -from ..tools.utils import integrate_vf # integrate_vf_ivp +from ..tools.utils import integrate_vf, update_dict # integrate_vf_ivp # --------------------------------------------------------------------------------------------------- @@ -76,6 +76,214 @@ def _get_adata_color_vec(adata, layer, col): return np.array(_color).flatten() +def calculate_colors( + points, + ax=None, + labels=None, + values=None, + highlights=None, + cmap="Blues", + color_key=None, + color_key_cmap="Spectral", + background="white", + width=7, + height=5, + vmin=2, + vmax=98, + sort="raw", + sym_c=False, + projection=None, # default in matplotlib + **kwargs, +): + import matplotlib.pyplot as plt + from matplotlib.ticker import MaxNLocator + + dpi = plt.rcParams["figure.dpi"] + width, height = width * dpi, height * dpi + rasterized = kwargs["rasterized"] if "rasterized" in kwargs.keys() else None + # """Use matplotlib to plot points""" + # point_size = 500.0 / np.sqrt(points.shape[0]) + + legend_elements = None + + if ax is None: + dpi = plt.rcParams["figure.dpi"] + fig = plt.figure(figsize=(width / dpi, height / dpi)) + ax = fig.add_subplot(111, projection=projection) + + ax.set_facecolor(background) + + # Color by labels + if labels is not None: + main_debug("labels are not None, drawing by labels") + color_type = "labels" + + if labels.shape[0] != points.shape[0]: + raise ValueError( + "Labels must have a label for " + "each sample (size mismatch: {} {})".format(labels.shape[0], points.shape[0]) + ) + if color_key is None: + main_debug("color_key is None") + cmap = copy.copy(matplotlib.cm.get_cmap(color_key_cmap)) + cmap.set_bad("lightgray") + colors = None + + if highlights is None: + unique_labels = np.unique(labels) + num_labels = unique_labels.shape[0] + color_key = plt.get_cmap(color_key_cmap)(np.linspace(0, 1, num_labels)) + else: + if type(highlights) is str: + highlights = [highlights] + highlights.append("other") + unique_labels = np.array(highlights) + num_labels = unique_labels.shape[0] + color_key = _to_hex(plt.get_cmap(color_key_cmap)(np.linspace(0, 1, num_labels))) + color_key[-1] = "#bdbdbd" # lightgray hex code https://www.color-hex.com/color/d3d3d3 + + labels[[i not in highlights[:-1] for i in labels]] = "other" + points = pd.DataFrame(points) + points["label"] = pd.Categorical(labels) + + # reorder data so that highlighting points will be on top of background points + highlight_ids, background_ids = ( + points["label"] != "other", + points["label"] == "other", + ) + # reorder_data = points.copy(deep=True) + # ( + # reorder_data.loc[:(sum(background_ids) - 1), :], + # reorder_data.loc[sum(background_ids):, :], + # ) = (points.loc[background_ids, :].values, points.loc[highlight_ids, :].values) + points = pd.concat( + ( + points.loc[background_ids, :], + points.loc[highlight_ids, :], + ) + ).values + labels = points[:, 2] + + # WARNING: do not change the following line to "elif" during refactor + # This if-else branch is not logically parallel to the previous one. The following branch sets `colors`. + if isinstance(color_key, dict): + main_debug("color_key is a dict") + colors = pd.Series(labels).map(color_key).values + unique_labels = np.unique(labels) + legend_elements = [ + # Patch(facecolor=color_key[k], label=k) for k in unique_labels + Line2D( + [0], + [0], + marker="o", + color=color_key[k], + label=k, + linestyle="None", + ) + for k in unique_labels + ] + else: + main_debug("color_key is not None and not a dict") + unique_labels = np.unique(labels) + if len(color_key) < unique_labels.shape[0]: + raise ValueError("Color key must have enough colors for the number of labels") + + new_color_key = {k: color_key[i] for i, k in enumerate(unique_labels)} + legend_elements = [ + # Patch(facecolor=color_key[i], label=k) + Line2D( + [0], + [0], + marker="o", + color=color_key[i], + label=k, + linestyle="None", + ) + for i, k in enumerate(unique_labels) + ] + colors = pd.Series(labels).map(new_color_key) + + # Color by values + elif values is not None: + main_debug("drawing points by values") + color_type = "values" + cmap_ = copy.copy(matplotlib.cm.get_cmap(cmap)) + cmap_.set_bad("lightgray") + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + matplotlib.cm.register_cmap(name=cmap_.name, cmap=cmap_, override_builtin=True) + + if values.shape[0] != points.shape[0]: + raise ValueError( + "Values must have a value for " + "each sample (size mismatch: {} {})".format(values.shape[0], points.shape[0]) + ) + # reorder data so that high values points will be on top of background points + sorted_id = ( + np.argsort(abs(values)) if sort == "abs" else np.argsort(-values) if sort == "neg" else np.argsort(values) + ) + values, points = values[sorted_id], points[sorted_id, :] + + # if there are very few cells have expression, set the vmin/vmax only based on positive values to + # get rid of outliers + if np.nanmin(values) == 0: + n_pos_cells = sum(values > 0) + if 0 < n_pos_cells / len(values) < 0.02: + vmin = 0 if n_pos_cells == 1 else np.percentile(values[values > 0], 2) + vmax = np.nanmax(values) if n_pos_cells == 1 else np.percentile(values[values > 0], 98) + if vmin + vmax in [1, 100]: + vmin += 1e-12 + vmax += 1e-12 + + # if None: min/max from data + # if positive and sum up to 1, take fraction + # if positive and sum up to 100, take percentage + # otherwise take the data + _vmin = ( + np.nanmin(values) + if vmin is None + else np.nanpercentile(values, vmin * 100) + if (vmin + vmax == 1 and 0 <= vmin < vmax) + else np.nanpercentile(values, vmin) + if (vmin + vmax == 100 and 0 <= vmin < vmax) + else vmin + ) + _vmax = ( + np.nanmax(values) + if vmax is None + else np.nanpercentile(values, vmax * 100) + if (vmin + vmax == 1 and 0 <= vmin < vmax) + else np.nanpercentile(values, vmax) + if (vmin + vmax == 100 and 0 <= vmin < vmax) + else vmax + ) + + if sym_c and _vmin < 0 and _vmax > 0: + bounds = np.nanmax([np.abs(_vmin), _vmax]) + bounds = bounds * np.array([-1, 1]) + _vmin, _vmax = bounds + + + if "norm" in kwargs: + norm = kwargs["norm"] + else: + norm = matplotlib.colors.Normalize(vmin=_vmin, vmax=_vmax) + + mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap) + mappable.set_array(values) + + cmap = matplotlib.cm.get_cmap(cmap) + colors = cmap(values) + # No color (just pick the midpoint of the cmap) + else: + main_debug("drawing points without color passed in args, using midpoint of the cmap") + color_type = "midpoint" + colors = plt.get_cmap(cmap)(0.5) + + return (colors, color_type, None) if color_type != "labels" else (colors.values, color_type, legend_elements) + + # --------------------------------------------------------------------------------------------------- # plotting utilities that borrowed from umap # link: https://github.com/lmcinnes/umap/blob/7e051d8f3c4adca90ca81eb45f6a9d1372c076cf/umap/plot.py @@ -216,7 +424,9 @@ def _matplotlib_points( if ax is None: dpi = plt.rcParams["figure.dpi"] fig = plt.figure(figsize=(width / dpi, height / dpi)) - ax = fig.add_subplot(111, projection=projection) + ax = fig.add_subplot( + 111, projection=projection, computed_zorder=False, + ) if projection == "3d" else fig.add_subplot(111, projection=projection) ax.set_facecolor(background) @@ -610,10 +820,12 @@ def _matplotlib_points( for i in unique_labels: if i == "other": continue - color_cnt = np.nanmedian(points[np.where(labels == i)[0], :2].astype("float"), 0) + if projection == "3d": + color_cnt = np.nanmedian(points[np.where(labels == i)[0], :3].astype("float"), 0) + else: + color_cnt = np.nanmedian(points[np.where(labels == i)[0], :2].astype("float"), 0) txt = ax.text( - color_cnt[0], - color_cnt[1], + *color_cnt, str(i), color=_select_font_color(font_color), zorder=1000, @@ -1447,6 +1659,141 @@ def save_fig( print("Done") +def retrieve_plot_save_path( + path: Optional[str] = None, + prefix: Optional[str] = None, + ext: str = "pdf", +) -> str: + """Retrieve the path to save dynamo plots. + + Args: + path: the path (and filename, without the extension) to save_fig the figure to. Defaults to None. + prefix: the prefix added to the figure name. This will be automatically set accordingly to the plotting function + used. Defaults to None. + ext: the file extension. This must be supported by the active matplotlib or pyvista backend. Most backends + support 'png', 'pdf', 'ps', 'eps', and 'svg'. Defaults to "pdf". + Returns: + The saving path. + """ + prefix = os.path.normpath(prefix) + if path is None: + path = os.getcwd() + "/" + + # Extract the directory and filename from the given path + directory = os.path.split(path)[0] + filename = os.path.split(path)[1] + if directory == "": + directory = "." + if filename == "": + filename = "dyn_savefig" + + # If the directory does not exist, create it + if not os.path.exists(directory): + os.makedirs(directory) + + # The final path to save_fig to + savepath = ( + os.path.join(directory, filename + "." + ext) + if prefix is None + else os.path.join(directory, prefix + "_" + filename + "." + ext) + ) + + return savepath + + +def save_pyvista_plotter( + pl, + colors_list: Optional[List] = None, + save_show_or_return: str = "show", + save_kwargs: Optional[Dict] = None, +) -> Optional[Tuple]: + """Save, show or return the pyvista.Plotter. + + Args: + pl: target plotter object. + colors_list: corresponding the list of colors mapping. + save_show_or_return: whether to save, show or return the figure. If "both", it will save and plot the figure at + the same time. If "all", the figure will be saved, displayed and the associated axis and other object will + be return. Defaults to "show". + save_kwargs: A dictionary that will be passed to the saving function. By default, it is an empty dictionary + and the saving function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + "title": PyVista Export, "raster": True, "painter": True} as its parameters. Otherwise, you can provide a + dictionary that properly modify those keys according to your needs. Defaults to {}. + + Returns: + If `save_show_or_return` is `return` or `all`, the plotter object and list of color mapping will be returned. + """ + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + + main_debug("show, return or save...") + if save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": "scatters_pv", + "ext": "pdf", + "title": 'PyVista Export', + "raster": True, + "painter": True, + } + + s_kwargs = update_dict(s_kwargs, save_kwargs) + + saving_path = retrieve_plot_save_path(path=s_kwargs["path"], prefix=s_kwargs["prefix"], ext=s_kwargs["ext"]) + pl.save_graphic(saving_path, title=s_kwargs["title"], raster=s_kwargs["raster"], painter=s_kwargs["painter"]) + + if save_show_or_return in ["show", "both", "all"]: + pl.show() + + if save_show_or_return in ["return", "all"]: + return (pl, colors_list) if colors_list else pl + + +def save_plotly_figure( + pl, + colors_list: Optional[List] = None, + save_show_or_return: str = "show", + save_kwargs: Optional[Dict] = None, +) -> Optional[Tuple]: + """Save, show or return the plotly figure. + + Args: + pl: target plotly object. + colors_list: corresponding the list of colors mapping. + save_show_or_return: whether to save, show or return the figure. If "both", it will save and plot the figure at + the same time. If "all", the figure will be saved, displayed and the associated axis and other object will + be return. Defaults to "show". + save_kwargs: A dictionary that will be passed to the saving function. By default, it is an empty dictionary + and the saving function will use the {"path": None, "prefix": 'scatter', "ext": 'html'} as its parameters. + Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults + to {}. + + Returns: + If `save_show_or_return` is `return` or `all`, the figure and list of color mapping will be returned. + """ + + main_debug("show, return or save...") + if save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": "scatters_plotly", + "ext": "html", + } + + s_kwargs = update_dict(s_kwargs, save_kwargs) + + saving_path = retrieve_plot_save_path(path=s_kwargs["path"], prefix=s_kwargs["prefix"], ext=s_kwargs["ext"]) + pl.write_html(saving_path) + + if save_show_or_return in ["show", "both", "all"]: + pl.show() + + if save_show_or_return in ["return", "all"]: + return (pl, colors_list) if colors_list else pl + + # --------------------------------------------------------------------------------------------------- def alpha_shape(x, y, alpha): # Start Using SHAPELY diff --git a/dynamo/vectorfield/topography.py b/dynamo/vectorfield/topography.py index 4e837934c..49513b85b 100755 --- a/dynamo/vectorfield/topography.py +++ b/dynamo/vectorfield/topography.py @@ -2,7 +2,7 @@ import datetime import os import warnings -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import anndata import numpy as np @@ -603,6 +603,105 @@ def output_to_dict(self, dict_vf): return dict_vf +class VectorField3D(VectorField2D): + """A class that represents a 3D vector field, which is a type of mathematical object that assigns a 3D vector to + each point in a 3D space. + + The class is derived from the VectorField2D class. This vector field can be defined using a function that returns + the vector at each point, or by separate functions for the x and y components of the vector. Nullclines calculation + are not supported for 3D vector space because of the computational complexity. + """ + def __init__( + self, + func: Callable, + func_vx: Optional[Callable] = None, + func_vy: Optional[Callable] = None, + func_vz: Optional[Callable] = None, + X_data: Optional[np.ndarray] = None, + ) -> None: + """Initialize the VectorField3D object. + + Args: + func: a function that takes an (n, 3) array of coordinates and returns an (n, 3) array of vectors + func_vx: a function that takes an (n, 3) array of coordinates and returns an (n,) array of x components of + the vectors, Defaults to None. + func_vy: a function that takes an (n, 3) array of coordinates and returns an (n,) array of y components of + the vectors, Defaults to None. + func_vz: a function that takes an (n, 3) array of coordinates and returns an (n,) array of z components of + the vectors, Defaults to None. + X_data: Defaults to None. + """ + super().__init__(func, func_vx, func_vy, X_data) + + def func_dim(x, func, dim): + y = func(x) + if y.ndim == 1: + y = y[dim] + else: + y = y[:, dim].flatten() + return y + + if func_vz is None: + self.fz = lambda x: func_dim(x, self.func, 2) + else: + self.fz = func_vz + + self.NCz = None + + def find_fixed_points_by_sampling( + self, + n: int, + x_range: Tuple[float, float], + y_range: Tuple[float, float], + z_range: Tuple[float, float], + lhs: Optional[bool] = True, + tol_redundant: float = 1e-4, + ) -> None: + """Find fixed points by sampling the vector field within a specified range of coordinates. + + Args: + n: the number of samples to take. + x_range: a tuple of two floats specifying the range of x coordinates to sample. + y_range: a tuple of two floats specifying the range of y coordinates to sample. + z_range: a tuple of two floats specifying the range of z coordinates to sample. + lhs: whether to use Latin Hypercube Sampling to generate the samples. Defaults to `True`. + tol_redundant: the tolerance for removing redundant fixed points. Defaults to 1e-4. + """ + if lhs: + from ..tools.sampling import lhsclassic + + X0 = lhsclassic(n, 3) + else: + X0 = np.random.rand(n, 3) + X0[:, 0] = X0[:, 0] * (x_range[1] - x_range[0]) + x_range[0] + X0[:, 1] = X0[:, 1] * (y_range[1] - y_range[0]) + y_range[0] + X0[:, 2] = X0[:, 2] * (z_range[1] - z_range[0]) + z_range[0] + X, J, _ = find_fixed_points( + X0, + self.func, + domain=[x_range, y_range, z_range], + tol_redundant=tol_redundant, + ) + if X is None: + raise ValueError(f"No fixed points found. Try to increase the number of samples n.") + self.Xss.add_fixed_points(X, J, tol_redundant) + + + def output_to_dict(self, dict_vf) -> Dict: + """Output the vector field as a dictionary. + + Returns: + A dictionary containing nullclines, fixed points, confidence and jacobians. + """ + dict_vf["NCx"] = self.NCx + dict_vf["NCy"] = self.NCy + dict_vf["NCz"] = self.NCz + dict_vf["Xss"] = self.Xss.get_X() + dict_vf["confidence"] = self.get_Xss_confidence() + dict_vf["J"] = self.Xss.get_J() + return dict_vf + + def util_topology( adata: AnnData, basis: str, @@ -620,17 +719,24 @@ def util_topology( basis: A string specifying the reduced dimension embedding to use for the computation. dims: A tuple of two integers specifying the dimensions of X to consider. func: A vector-valued function taking in coordinates and returning the vector field. - VecFld: `VecFldDict` TypedDict storing information about the vector field and SparseVFC-related parameters and computations. - X: an alternative to providing an `AnnData` object. Provide an np.ndarray from which `dims` are accessed, Defaults to None. + VecFld: `VecFldDict` TypedDict storing information about the vector field and SparseVFC-related parameters and + computations. + X: an alternative to providing an `AnnData` object. Provide a np.ndarray from which `dims` are accessed, + Defaults to None. n: An optional integer specifying the number of points to use for computing fixed points. Defaults to 25. Returns: A tuple consisting of the following elements: - - X_basis: an array of shape (n, 2) where n is the number of points in X. This is the subset of X consisting of the first two dimensions specified by dims. If X is not provided, X_basis is taken from the obsm attribute of adata using the key "X_" + basis. - - xlim, ylim: a tuple of floats specifying the limits of the x and y axes, respectively. These are computed based on the minimum and maximum values of X_basis. + - X_basis: an array of shape (n, 2) where n is the number of points in X. This is the subset of X consisting + of the first two dimensions specified by dims. If X is not provided, X_basis is taken from the obsm + attribute of adata using the key "X_" + basis. + - xlim, ylim, zlim: a tuple of floats specifying the limits of the x, y and z axes, respectively. These are + computed based on the minimum and maximum values of X_basis. - confidence: an array of shape (n, ) containing the confidence scores of the fixed points. - - NCx, NCy: arrays of shape (n, ) containing the x and y coordinates of the nullclines (lines where the derivative of the system is zero), respectively. - - Xss: an array of shape (n, k) where k is the number of dimensions of the system, containing the fixed points. + - NCx, NCy: arrays of shape (n, ) containing the x and y coordinates of the nullclines (lines where the + derivative of the system is zero), respectively. + - Xss: an array of shape (n, k) where k is the number of dimensions of the system, containing the fixed + points. - ftype: an array of shape (n, ) containing the types of fixed points (attractor, repeller, or saddle). - an array of shape (n, ) containing the indices of the fixed points in the original data. """ @@ -648,6 +754,7 @@ def util_topology( min_[1] - (max_[1] - min_[1]) * 0.1, max_[1] + (max_[1] - min_[1]) * 0.1, ] + zlim = None vecfld = VectorField2D(func, X_data=X_basis) vecfld.find_fixed_points_by_sampling(n, xlim, ylim) @@ -655,11 +762,35 @@ def util_topology( vecfld.compute_nullclines(xlim, ylim, find_new_fixed_points=True) NCx, NCy = vecfld.NCx, vecfld.NCy + Xss, ftype = vecfld.get_fixed_points(get_types=True) + confidence = vecfld.get_Xss_confidence() + elif X_basis.shape[1] == 3: + fp_ind = None + min_, max_ = X_basis.min(0), X_basis.max(0) + + xlim = [ + min_[0] - (max_[0] - min_[0]) * 0.1, + max_[0] + (max_[0] - min_[0]) * 0.1, + ] + ylim = [ + min_[1] - (max_[1] - min_[1]) * 0.1, + max_[1] + (max_[1] - min_[1]) * 0.1, + ] + zlim = [ + min_[2] - (max_[2] - min_[2]) * 0.1, + max_[2] + (max_[2] - min_[2]) * 0.1, + ] + + vecfld = VectorField3D(func, X_data=X_basis) + vecfld.find_fixed_points_by_sampling(n, xlim, ylim, zlim) + + NCx, NCy = None, None + Xss, ftype = vecfld.get_fixed_points(get_types=True) confidence = vecfld.get_Xss_confidence() else: fp_ind = None - xlim, ylim, confidence, NCx, NCy = None, None, None, None, None + xlim, ylim, zlim, confidence, NCx, NCy = None, None, None, None, None, None vecfld = BaseVectorField( X=VecFld["X"][VecFld["valid_ind"], :], V=VecFld["Y"][VecFld["valid_ind"], :], @@ -671,7 +802,7 @@ def util_topology( fp_ind = nearest_neighbors(Xss, vecfld.data["X"], 1).flatten() Xss = vecfld.data["X"][fp_ind] - return X_basis, xlim, ylim, confidence, NCx, NCy, Xss, ftype, fp_ind + return X_basis, xlim, ylim, zlim, confidence, NCx, NCy, Xss, ftype, fp_ind def topography( @@ -684,12 +815,13 @@ def topography( VecFld: Optional[VecFldDict] = None, **kwargs, ) -> AnnData: - """Map the topography of the single cell vector field in (first) two dimensions. + """Map the topography of the single cell vector field in (first) two or three dimensions. Args: adata: an AnnData object. basis: The reduced dimension embedding of cells to visualize. - layer: Which layer of the data will be used for vector field function reconstruction. This will be used in conjunction with X. + layer: Which layer of the data will be used for vector field function reconstruction. This will be used in + conjunction with X. X: Original data. Not used dims: The dimensions that will be used for vector field reconstruction. n: Number of samples for calculating the fixed points. @@ -699,7 +831,8 @@ def topography( Returns: `AnnData` object that is updated with the `VecFld` or 'VecFld_' + basis dictionary in the `uns` attribute. - The `VecFld2D` key stores an instance of the VectorField2D class which presumably has fixed points, nullcline, separatrix, computed and stored. + The `VecFld2D` key stores an instance of the VectorField2D class which presumably has fixed points, nullcline, + separatrix, computed and stored. """ if VecFld is None: @@ -722,6 +855,7 @@ def func(x): X_basis, xlim, ylim, + zlim, confidence, NCx, NCy, @@ -743,6 +877,7 @@ def func(x): { "xlim": xlim, "ylim": ylim, + "zlim": zlim, "X_data": X_basis, "Xss": Xss, "ftype": ftype, @@ -756,6 +891,7 @@ def func(x): adata.uns[vf_key] = { "xlim": xlim, "ylim": ylim, + "zlim": zlim, "X_data": X_basis, "Xss": Xss, "ftype": ftype,