From a2bec28c760f0e8b1e55111b42f268f1098994b8 Mon Sep 17 00:00:00 2001 From: Sichao25 Date: Tue, 1 Aug 2023 18:27:36 -0400 Subject: [PATCH 01/62] enable 3d option of cellwise vec --- dynamo/plot/scVectorField.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dynamo/plot/scVectorField.py b/dynamo/plot/scVectorField.py index 164d27f81..2e4159736 100755 --- a/dynamo/plot/scVectorField.py +++ b/dynamo/plot/scVectorField.py @@ -754,6 +754,7 @@ def cell_wise_vectors( df = pd.DataFrame({"x": X[:, 0], "y": X[:, 1], "u": V[:, 0], "v": V[:, 1]}) elif projection == "3d": df = pd.DataFrame({"x": X[:, 0], "y": X[:, 1], "z": X[:, 2], "u": V[:, 0], "v": V[:, 1], "w": V[:, 2]}) + show_legend = None else: raise NotImplementedError("Projection method %s is not implemented" % projection) @@ -857,6 +858,7 @@ def cell_wise_vectors( **quiver_kwargs, ) elif projection == "3d": + cmap_3d = [element for element in color_list[i]] + [element for element in color_list[i] for _ in range(2)] ax.quiver( x0, x1, @@ -864,7 +866,7 @@ def cell_wise_vectors( v0, v1, v2, - # color=color_list[i], + color=cmap_3d, # facecolors=color_list[i], **quiver_3d_kwargs, ) From 3ed40615b1202bf0e84f45a8f32f8dd7eb18d7e2 Mon Sep 17 00:00:00 2001 From: Sichao25 Date: Tue, 1 Aug 2023 22:47:41 -0400 Subject: [PATCH 02/62] debug streamtube basis --- dynamo/plot/streamtube.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/dynamo/plot/streamtube.py b/dynamo/plot/streamtube.py index a95e818a9..22c9aeb8d 100644 --- a/dynamo/plot/streamtube.py +++ b/dynamo/plot/streamtube.py @@ -89,10 +89,7 @@ def plot_3d_streamtube( else: _background = background - if is_gene_name(adata, color): - color_val = adata.obs_vector(k=color, layer=None) if layer == "X" else adata.obs_vector(k=color, layer=layer) - elif is_cell_anno_column(adata, color): - color_val = adata.obs_vector + color_val = adata.obs_vector(k=color, layer=None) if layer == "X" else adata.obs_vector(k=color, layer=layer) is_not_continous = not isinstance(color_val[0], Number) or color_val.dtype.name == "category" @@ -154,7 +151,7 @@ def plot_3d_streamtube( from ..vectorfield.utils import vecfld_from_adata - VecFld, func = vecfld_from_adata(adata, basis="umap") + VecFld, func = vecfld_from_adata(adata, basis=basis) velocity_grid = func(X_grid) @@ -167,9 +164,9 @@ def plot_3d_streamtube( v=velocity_grid[:, 1], w=velocity_grid[:, 2], starts=dict( - x=adata[labels == init_group, :].obsm["X_umap"][:125, 0], - y=adata[labels == init_group, :].obsm["X_umap"][:125, 1], - z=adata[labels == init_group, :].obsm["X_umap"][:125, 2], + x=adata[labels == init_group, :].obsm["X_" + basis][:125, 0], + y=adata[labels == init_group, :].obsm["X_" + basis][:125, 1], + z=adata[labels == init_group, :].obsm["X_" + basis][:125, 2], ), sizeref=3000, colorscale="Portland", From a0af0d1917b14a6a9a76f283a4880531b785e4f3 Mon Sep 17 00:00:00 2001 From: Sichao25 Date: Tue, 8 Aug 2023 20:08:34 -0400 Subject: [PATCH 03/62] update 3D cell vector color map --- dynamo/plot/scVectorField.py | 108 ++++++++++++++++++++++++++--------- 1 file changed, 80 insertions(+), 28 deletions(-) diff --git a/dynamo/plot/scVectorField.py b/dynamo/plot/scVectorField.py index 2e4159736..f930c22e7 100755 --- a/dynamo/plot/scVectorField.py +++ b/dynamo/plot/scVectorField.py @@ -71,10 +71,11 @@ def cell_wise_vectors_3d( save_show_or_return: str = "show", save_kwargs: Dict[str, Any] = {}, quiver_3d_kwargs: Dict[str, Any] = { - "zorder": 3, - "length": 2, - "linewidth": 5, - "arrow_length_ratio": 5, + "linewidth": 1, + "edgecolors": "white", + "alpha": 1, + "length": 8, + "arrow_length_ratio": 1, "norm": cm.colors.Normalize(), "cmap": cm.PRGn, }, @@ -84,8 +85,34 @@ def cell_wise_vectors_3d( elev: Optional[float] = None, azim: Optional[float] = None, alpha: Optional[float] = None, - show_magnitude: bool = False, + show_magnitude: bool = True, titles: Optional[List[str]] = None, + highlights: Optional[list] = None, + labels: Optional[list] = None, + values: Optional[list] = None, + theme: Optional[ + Literal[ + "blue", + "red", + "green", + "inferno", + "fire", + "viridis", + "darkblue", + "darkred", + "darkgreen", + ] + ] = None, + cmap: Optional[str] = None, + color_key: Union[Dict[str, str], List[str], None] = None, + color_key_cmap: Optional[str] = None, + pointsize: Optional[float] = None, + use_smoothed: bool = True, + sort: Literal["raw", "abs", "neg"] = "raw", + aggregate: Optional[str] = None, + show_arrowed_spines: bool = False, + frontier: bool = False, + s_kwargs_dict: Dict[str, Any] = {}, **cell_wise_kwargs, ) -> np.ndarray: """Plot the velocity or acceleration vector of each cell. @@ -248,29 +275,47 @@ def add_axis_label(ax, labels): nrows += 1 ncols = min(ncols, len(color)) - figure, axes = plt.subplots(nrows, ncols, figsize=figsize, subplot_kw=dict(projection="3d")) - axes = np.array(axes) - axes_flatten = axes.flatten() + axes_list, color_list, _ = scatters( + adata=adata, + basis=basis, + x=x, + y=y, + z=z, + color=color, + layer=layer, + highlights=highlights, + labels=labels, + values=values, + theme=theme, + cmap=cmap, + color_key=color_key, + color_key_cmap=color_key_cmap, + background=background, + ncols=ncols, + pointsize=pointsize, + figsize=figsize, + show_legend=None, + use_smoothed=use_smoothed, + aggregate=aggregate, + show_arrowed_spines=show_arrowed_spines, + ax=ax, + sort=sort, + save_show_or_return="return", + frontier=frontier, + projection="3d", + **s_kwargs_dict, + return_all=True, + ) + + if type(axes_list) != list: + axes_list = [axes_list] + color_list = [color_list] for i in range(len(color)): - ax = axes_flatten[i] + ax = axes_list[i] ax.set_title(color[i]) - norm = quiver_3d_kwargs["norm"] - cmap = quiver_3d_kwargs["cmap"] - color_vec = _get_adata_color_vec(adata, layer=layer, col=color[i]) - assert len(color_vec) > 0, "color vector or data vector size is 0" - - # convet categorical string data colors to labels - if type(color_vec[0]) is str: - unique_vals, color_vec = np.unique(color_vec, return_inverse=True) - - color_vec = cmap(norm(color_vec)) - - # TODO due to matplotlib quiver3 impl, we need to add colors for arrow head segments - # TODO if matplotlib changes its detailed impl, we may not need the following line - color_vec = list(color_vec) + [element for element in list(color_vec) for _ in range(2)] - # color_vec = matplotlib.colors.to_rgba(color_vec, alpha=alpha) - main_debug("color vec len: " + str(len(color_vec))) + cmap_3d = [element for element in color_list[i]] + [element for element in color_list[i] for _ in range(2)] + main_debug("color vec len: " + str(len(cmap_3d))) ax.view_init(elev=elev, azim=azim) ax.quiver( x0, @@ -279,7 +324,7 @@ def add_axis_label(ax, labels): v0, v1, v2, - color=color_vec, + color=cmap_3d, # facecolors=color_vec, **quiver_3d_kwargs, ) @@ -306,7 +351,7 @@ def add_axis_label(ax, labels): if save_show_or_return in ["show", "both", "all"]: plt.show() if save_show_or_return in ["return", "all"]: - return axes + return axes_list def grid_vectors_3d(): @@ -800,7 +845,14 @@ def cell_wise_vectors( "zorder": 10, } quiver_kwargs = update_dict(quiver_kwargs, cell_wise_kwargs) - quiver_3d_kwargs = {"arrow_length_ratio": scale} + quiver_3d_kwargs = { + "linewidth": 1, + "edgecolors": "white", + "alpha": 1, + "length": 8, + "arrow_length_ratio": scale, + + } axes_list, color_list, _ = scatters( adata=adata, From ad0bcc3d6a3b34771a2cdefb563c9f014f56026b Mon Sep 17 00:00:00 2001 From: Sichao25 Date: Tue, 8 Aug 2023 20:17:45 -0400 Subject: [PATCH 04/62] add missing params docstr --- dynamo/plot/scVectorField.py | 39 ++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/dynamo/plot/scVectorField.py b/dynamo/plot/scVectorField.py index f930c22e7..77b744384 100755 --- a/dynamo/plot/scVectorField.py +++ b/dynamo/plot/scVectorField.py @@ -159,6 +159,45 @@ def cell_wise_vectors_3d( alpha: the transparency of the colors. Defaults to None. show_magnitude: whether to show original values or normalize the data. Defaults to False. titles: the titles of the subplots. Defaults to None. + highlights: the color group that will be highlighted. If highligts is a list of lists, each list is relate to + each color element. Defaults to None. + labels: an array of labels (assumed integer or categorical), one for each data sample. This will be used for + coloring the points in the plot according to their label. Note that this option is mutually exclusive to the + `values` option. Defaults to None. + values: an array of values (assumed float or continuous), one for each sample. This will be used for coloring + the points in the plot according to a colorscale associated to the total range of values. Note that this + option is mutually exclusive to the `labels` option. Defaults to None. + theme: A color theme to use for plotting. A small set of predefined themes are provided which have relatively + good aesthetics. Available themes are: {'blue', 'red', 'green', 'inferno', 'fire', 'viridis', 'darkblue', + 'darkred', 'darkgreen'}. Defaults to None. + cmap: The name of a matplotlib colormap to use for coloring or shading points. If no labels or values are passed + this will be used for shading points according to density (largely only of relevance for very large + datasets). If values are passed this will be used for shading according the value. Note that if theme is + passed then this value will be overridden by the corresponding option of the theme. Defaults to None. + color_key: the method to assign colors to categoricals. This can either be an explicit dict mapping labels to + colors (as strings of form '#RRGGBB'), or an array like object providing one color for each distinct + category being provided in `labels`. Either way this mapping will be used to color points according to the + label. Note that if theme is passed then this value will be overridden by the corresponding option of the + theme. Defaults to None. + color_key_cmap: the name of a matplotlib colormap to use for categorical coloring. If an explicit `color_key` is + not given a color mapping for categories can be generated from the label list and selecting a matching list + of colors from the given colormap. Note that if theme is passed then this value will be overridden by the + corresponding option of the theme. Defaults to None. + pointsize: the scale of the point size. Actual point cell size is calculated as + `500.0 / np.sqrt(adata.shape[0]) * pointsize`. Defaults to None. + use_smoothed: whether to use smoothed values (i.e. M_s / M_u instead of spliced / unspliced, etc.). Defaults to + True. + sort: the method to reorder data so that high values points will be on top of background points. Can be one of + {'raw', 'abs', 'neg'}, i.e. sorted by raw data, sort by absolute values or sort by negative values. Defaults + to "raw". + aggregate: the column in adata.obs that will be used to aggregate data points. Defaults to None. + show_arrowed_spines: whether to show a pair of arrowed spines representing the basis of the scatter is currently + using. Defaults to False. + frontier: whether to add the frontier. Scatter plots can be enhanced by using transparency (alpha) in order to + show area of high density and multiple scatter plots can be used to delineate a frontier. See matplotlib + tips & tricks cheatsheet (https://github.com/matplotlib/cheatsheets). Originally inspired by figures from + scEU-seq paper: https://science.sciencemag.org/content/367/6482/1151. Defaults to False. + s_kwargs_dict: any other kwargs that will be passed to `dynamo.pl.scatters`. Defaults to {}. Raises: ValueError: invalid `x`, `y`, or `z`. From 2130c6f3c13d6962b7fa81a61598e8cc90f7bb30 Mon Sep 17 00:00:00 2001 From: sichao Date: Wed, 16 Aug 2023 19:06:01 -0400 Subject: [PATCH 05/62] create 3D vf class --- dynamo/vectorfield/topography.py | 38 ++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/dynamo/vectorfield/topography.py b/dynamo/vectorfield/topography.py index ee388e967..ad11dc9bf 100755 --- a/dynamo/vectorfield/topography.py +++ b/dynamo/vectorfield/topography.py @@ -610,6 +610,44 @@ def output_to_dict(self, dict_vf): return dict_vf +class VectorField3D: + def __init__( + self, + func: Callable, + func_vx: Optional[Callable] = None, + func_vy: Optional[Callable] = None, + func_vz: Optional[Callable] = None, + X_data: Optional[np.ndarray] = None, + ): + self.func = func + + def func_dim(x, func, dim): + y = func(x) + if y.ndim == 1: + y = y[dim] + else: + y = y[:, dim].flatten() + return y + + if func_vx is None: + self.fx = lambda x: func_dim(x, self.func, 0) + else: + self.fx = func_vx + if func_vy is None: + self.fy = lambda x: func_dim(x, self.func, 1) + else: + self.fy = func_vy + if func_vz is None: + self.fy = lambda x: func_dim(x, self.func, 2) + else: + self.fy = func_vz + self.Xss = FixedPoints() + self.X_data = X_data + self.NCx = None + self.NCy = None + self.NCz = None + + def util_topology( adata: AnnData, basis: str, From d5349a9d44f93db06bdabd7e2a375ce88e67494b Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 17 Aug 2023 16:36:45 -0400 Subject: [PATCH 06/62] 3d nullcline WIP --- dynamo/vectorfield/topography.py | 125 +++++++++++++++++++++++++++---- 1 file changed, 109 insertions(+), 16 deletions(-) diff --git a/dynamo/vectorfield/topography.py b/dynamo/vectorfield/topography.py index ad11dc9bf..a36723318 100755 --- a/dynamo/vectorfield/topography.py +++ b/dynamo/vectorfield/topography.py @@ -178,6 +178,45 @@ def compute_nullclines_2d( return NCx, NCy +def compute_nullclines_3d( + X0: Union[List, np.ndarray], + fdx: Callable, + fdy: Callable, + fdz: Callable, + x_range: List, + y_range: List, + z_range: List, + s_max: Optional[float] = None, + ds: Optional[float] = None, +) -> Tuple[List]: + if s_max is None: + s_max = 5 * ((x_range[1] - x_range[0]) + (y_range[1] - y_range[0]) + (z_range[1] - z_range[0])) + if ds is None: + ds = s_max / 1e3 + + NCx = [] + NCy = [] + NCz = [] + for x0 in X0: + # initialize tangent predictor + theta = np.random.rand() * 2 * np.pi + phi = np.random.rand() * 2 * np.pi + r = ds * 2 + v0 = [r * np.sin(theta) * np.cos(phi), r * np.sin(theta) * np.sin(phi), r * np.cos(theta)] + v0 /= np.linalg.norm(v0) + # nullcline continuation + NCx.append(continuation(x0, fdx, s_max, ds, v0=v0)) + NCx.append(continuation(x0, fdx, s_max, ds, v0=-v0)) + NCy.append(continuation(x0, fdy, s_max, ds, v0=v0)) + NCy.append(continuation(x0, fdy, s_max, ds, v0=-v0)) + NCz.append(continuation(x0, fdz, s_max, ds, v0=v0)) + NCz.append(continuation(x0, fdz, s_max, ds, v0=-v0)) + NCx = clip_curves(NCx, [x_range, y_range, z_range], ds * 10) + NCy = clip_curves(NCy, [x_range, y_range, z_range], ds * 10) + NCz = clip_curves(NCz, [x_range, y_range, z_range], ds * 10) + return NCx, NCy, NCz + + def compute_separatrices( Xss: np.ndarray, Js: np.ndarray, @@ -610,7 +649,7 @@ def output_to_dict(self, dict_vf): return dict_vf -class VectorField3D: +class VectorField3D(VectorField2D): def __init__( self, func: Callable, @@ -619,7 +658,7 @@ def __init__( func_vz: Optional[Callable] = None, X_data: Optional[np.ndarray] = None, ): - self.func = func + super().__init__(func, func_vx, func_vy, X_data) def func_dim(x, func, dim): y = func(x) @@ -629,24 +668,78 @@ def func_dim(x, func, dim): y = y[:, dim].flatten() return y - if func_vx is None: - self.fx = lambda x: func_dim(x, self.func, 0) - else: - self.fx = func_vx - if func_vy is None: - self.fy = lambda x: func_dim(x, self.func, 1) - else: - self.fy = func_vy if func_vz is None: - self.fy = lambda x: func_dim(x, self.func, 2) + self.fz = lambda x: func_dim(x, self.func, 2) else: - self.fy = func_vz - self.Xss = FixedPoints() - self.X_data = X_data - self.NCx = None - self.NCy = None + self.fz = func_vz + self.NCz = None + def find_fixed_points_by_sampling( + self, + n: int, + x_range: Tuple[float, float], + y_range: Tuple[float, float], + z_range: Tuple[float, float], + lhs: Optional[bool] = True, + tol_redundant: float = 1e-4, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + if lhs: + from ..tools.sampling import lhsclassic + + X0 = lhsclassic(n, 3) + else: + X0 = np.random.rand(n, 3) + X0[:, 0] = X0[:, 0] * (x_range[1] - x_range[0]) + x_range[0] + X0[:, 1] = X0[:, 1] * (y_range[1] - y_range[0]) + y_range[0] + X0[:, 2] = X0[:, 2] * (z_range[1] - z_range[0]) + z_range[0] + X, J, _ = find_fixed_points( + X0, + self.func, + domain=[x_range, y_range, z_range], + tol_redundant=tol_redundant, + ) + if X is None: + raise ValueError(f"No fixed points found. Try to increase the number of samples n.") + self.Xss.add_fixed_points(X, J, tol_redundant) + + def compute_nullclines( + self, + x_range: Tuple[float, float], + y_range: Tuple[float, float], + z_range: Tuple[float, float], + find_new_fixed_points: Optional[bool] = False, + tol_redundant: Optional[float] = 1e-4, + ): + # compute arguments + s_max = 5 * ((x_range[1] - x_range[0]) + (y_range[1] - y_range[0])) + ds = s_max / 1e3 + self.NCx, self.NCy, self.NCz = compute_nullclines_3d( + self.Xss.get_X(), + self.fx, + self.fy, + self.fz, + x_range, + y_range, + z_range, + s_max=s_max, + ds=ds, + ) + # if find_new_fixed_points: + # sample_interval = ds * 10 + # X, J = find_fixed_points_nullcline_3d(self.func, self.NCx, self.NCy, self.NCz, sample_interval, tol_redundant) + # outside = is_outside(X, [x_range, y_range]) + # self.Xss.add_fixed_points(X[~outside], J[~outside], tol_redundant) + + def output_to_dict(self, dict_vf): + dict_vf["NCx"] = self.NCx + dict_vf["NCy"] = self.NCy + dict_vf["NCz"] = self.NCz + dict_vf["Xss"] = self.Xss.get_X() + dict_vf["confidence"] = self.get_Xss_confidence() + dict_vf["J"] = self.Xss.get_J() + return dict_vf + def util_topology( adata: AnnData, From a33f5a08fac1c3cfb23da3150136e38c112531d1 Mon Sep 17 00:00:00 2001 From: sichao Date: Mon, 21 Aug 2023 11:30:59 -0400 Subject: [PATCH 07/62] add edge case exception --- dynamo/plot/streamtube.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dynamo/plot/streamtube.py b/dynamo/plot/streamtube.py index 22c9aeb8d..2b59312b7 100644 --- a/dynamo/plot/streamtube.py +++ b/dynamo/plot/streamtube.py @@ -130,6 +130,9 @@ def plot_3d_streamtube( mapper = cm.ScalarMappable(norm=norm, cmap=_cmap) colors = _to_hex(mapper.to_rgba(values)) + if adata.obsm["X_" + basis].shape[1] < 3: + raise ValueError("Current basis has dimensions less than 3!") + X = adata.obsm["X_" + basis][:, dims] grid_kwargs_dict = { "density": None, From 3425947876a0fdb58c44935a764acb985f975d25 Mon Sep 17 00:00:00 2001 From: sichao Date: Mon, 21 Aug 2023 15:01:13 -0400 Subject: [PATCH 08/62] add if statement to read grid data direclt --- dynamo/plot/streamtube.py | 48 +++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/dynamo/plot/streamtube.py b/dynamo/plot/streamtube.py index 2b59312b7..535a1728b 100644 --- a/dynamo/plot/streamtube.py +++ b/dynamo/plot/streamtube.py @@ -133,30 +133,38 @@ def plot_3d_streamtube( if adata.obsm["X_" + basis].shape[1] < 3: raise ValueError("Current basis has dimensions less than 3!") + if "VecFld_" + basis not in adata.uns.keys(): + raise KeyError("Corresponding vector field not found! Please run VectorField() with current basis.") + X = adata.obsm["X_" + basis][:, dims] - grid_kwargs_dict = { - "density": None, - "smooth": None, - "n_neighbors": None, - "min_mass": None, - "autoscale": False, - "adjust_for_stream": True, - "V_threshold": None, - } - - X_grid, p_mass, neighs, weight = prepare_velocity_grid_data( - X, - [60, 60, 60], - density=grid_kwargs_dict["density"], - smooth=grid_kwargs_dict["smooth"], - n_neighbors=grid_kwargs_dict["n_neighbors"], - ) - from ..vectorfield.utils import vecfld_from_adata + if "grid" in adata.uns["VecFld_" + basis].keys() and "grid_V" in adata.uns["VecFld_" + basis].keys(): + X_grid = adata.uns["VecFld_pca3"]["grid"] + velocity_grid = adata.uns["VecFld_pca3"]["grid_V"] + else: + grid_kwargs_dict = { + "density": None, + "smooth": None, + "n_neighbors": None, + "min_mass": None, + "autoscale": False, + "adjust_for_stream": True, + "V_threshold": None, + } + + X_grid, p_mass, neighs, weight = prepare_velocity_grid_data( + X, + [60, 60, 60], + density=grid_kwargs_dict["density"], + smooth=grid_kwargs_dict["smooth"], + n_neighbors=grid_kwargs_dict["n_neighbors"], + ) + + from ..vectorfield.utils import vecfld_from_adata - VecFld, func = vecfld_from_adata(adata, basis=basis) + VecFld, func = vecfld_from_adata(adata, basis=basis) - velocity_grid = func(X_grid) + velocity_grid = func(X_grid) fig = go.Figure( data=go.Streamtube( From a21e5e4c9a1fee6c16d668802c707a7c7ac755f1 Mon Sep 17 00:00:00 2001 From: sichao Date: Wed, 30 Aug 2023 14:11:41 -0400 Subject: [PATCH 09/62] add 3D option to vf topography --- dynamo/vectorfield/topography.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/dynamo/vectorfield/topography.py b/dynamo/vectorfield/topography.py index a36723318..d0ec9efd2 100755 --- a/dynamo/vectorfield/topography.py +++ b/dynamo/vectorfield/topography.py @@ -793,6 +793,30 @@ def util_topology( vecfld.compute_nullclines(xlim, ylim, find_new_fixed_points=True) NCx, NCy = vecfld.NCx, vecfld.NCy + Xss, ftype = vecfld.get_fixed_points(get_types=True) + confidence = vecfld.get_Xss_confidence() + elif X_basis.shape[1] == 3: + fp_ind = None + min_, max_ = X_basis.min(0), X_basis.max(0) + + xlim = [ + min_[0] - (max_[0] - min_[0]) * 0.1, + max_[0] + (max_[0] - min_[0]) * 0.1, + ] + ylim = [ + min_[1] - (max_[1] - min_[1]) * 0.1, + max_[1] + (max_[1] - min_[1]) * 0.1, + ] + zlim = [ + min_[2] - (max_[2] - min_[2]) * 0.1, + max_[2] + (max_[2] - min_[2]) * 0.1, + ] + + vecfld = VectorField3D(func, X_data=X_basis) + vecfld.find_fixed_points_by_sampling(n, xlim, ylim, zlim) + + NCx, NCy = None, None + Xss, ftype = vecfld.get_fixed_points(get_types=True) confidence = vecfld.get_Xss_confidence() else: From 828fd12218e914de8ac2cf4e4de859ed35d46d59 Mon Sep 17 00:00:00 2001 From: sichao Date: Wed, 30 Aug 2023 16:46:07 -0400 Subject: [PATCH 10/62] save zlim in 3d topography --- dynamo/vectorfield/topography.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dynamo/vectorfield/topography.py b/dynamo/vectorfield/topography.py index d0ec9efd2..e3cbbcf97 100755 --- a/dynamo/vectorfield/topography.py +++ b/dynamo/vectorfield/topography.py @@ -786,6 +786,7 @@ def util_topology( min_[1] - (max_[1] - min_[1]) * 0.1, max_[1] + (max_[1] - min_[1]) * 0.1, ] + zlim = None vecfld = VectorField2D(func, X_data=X_basis) vecfld.find_fixed_points_by_sampling(n, xlim, ylim) @@ -821,7 +822,7 @@ def util_topology( confidence = vecfld.get_Xss_confidence() else: fp_ind = None - xlim, ylim, confidence, NCx, NCy = None, None, None, None, None + xlim, ylim, zlim, confidence, NCx, NCy = None, None, None, None, None, None vecfld = BaseVectorField( X=VecFld["X"][VecFld["valid_ind"], :], V=VecFld["Y"][VecFld["valid_ind"], :], @@ -833,7 +834,7 @@ def util_topology( fp_ind = nearest_neighbors(Xss, vecfld.data["X"], 1).flatten() Xss = vecfld.data["X"][fp_ind] - return X_basis, xlim, ylim, confidence, NCx, NCy, Xss, ftype, fp_ind + return X_basis, xlim, ylim, zlim, confidence, NCx, NCy, Xss, ftype, fp_ind def topography( @@ -884,6 +885,7 @@ def func(x): X_basis, xlim, ylim, + zlim, confidence, NCx, NCy, @@ -905,6 +907,7 @@ def func(x): { "xlim": xlim, "ylim": ylim, + "zlim": zlim, "X_data": X_basis, "Xss": Xss, "ftype": ftype, @@ -918,6 +921,7 @@ def func(x): adata.uns[vf_key] = { "xlim": xlim, "ylim": ylim, + "zlim": zlim, "X_data": X_basis, "Xss": Xss, "ftype": ftype, From 9aef5e1c3a6db755f433d63ffea3949a1e929a04 Mon Sep 17 00:00:00 2001 From: sichao Date: Wed, 30 Aug 2023 16:46:34 -0400 Subject: [PATCH 11/62] create 3d topography graph --- dynamo/plot/__init__.py | 2 + dynamo/plot/topography.py | 403 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 405 insertions(+) diff --git a/dynamo/plot/__init__.py b/dynamo/plot/__init__.py index 978f543cf..32ea82f2a 100755 --- a/dynamo/plot/__init__.py +++ b/dynamo/plot/__init__.py @@ -60,6 +60,7 @@ plot_separatrix, plot_traj, topography, + topography_3D, ) # from .theme import points @@ -115,6 +116,7 @@ "plot_separatrix", "plot_traj", "topography", + "topography_3D", "speed", "acceleration", "curl", diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index 3b60c753d..34e7cf6be 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -1359,3 +1359,406 @@ def topography( plt.show() if save_show_or_return in ["return", "all"]: return axes_list if len(axes_list) > 1 else axes_list[0] + + +@docstrings.with_indent(4) +def topography_3D( + adata: AnnData, + basis: str = "umap", + fps_basis: str = "umap", + x: int = 0, + y: int = 1, + color: str = "ntr", + layer: str = "X", + highlights: Optional[list] = None, + labels: Optional[list] = None, + values: Optional[list] = None, + theme: Optional[ + Literal[ + "blue", + "red", + "green", + "inferno", + "fire", + "viridis", + "darkblue", + "darkred", + "darkgreen", + ] + ] = None, + cmap: Optional[str] = None, + color_key: Union[Dict[str, str], List[str], None] = None, + color_key_cmap: Optional[str] = None, + background: Optional[str] = "white", + ncols: int = 4, + pointsize: Optional[float] = None, + figsize: Tuple[float, float] = (6, 4), + show_legend: str = "on data", + use_smoothed: bool = True, + xlim: np.ndarray = None, + ylim: np.ndarray = None, + zlim: np.ndarray = None, + t: Optional[npt.ArrayLike] = None, + terms: List[str] = ["fixed_points"], + init_cells: List[int] = None, + init_states: np.ndarray = None, + quiver_source: Literal["raw", "reconstructed"] = "raw", + fate: Literal["history", "future", "both"] = "both", + approx: bool = False, + quiver_size: Optional[float] = None, + quiver_length: Optional[float] = None, + density: float = 1, + linewidth: float = 1, + streamline_color: Optional[str] = None, + streamline_alpha: float = 0.4, + color_start_points: Optional[str] = None, + markersize: float = 200, + marker_cmap: Optional[str] = None, + save_show_or_return: Literal["save", "show", "return"] = "show", + save_kwargs: Dict[str, Any] = {}, + aggregate: Optional[str] = None, + show_arrowed_spines: bool = False, + ax: Optional[Axes] = None, + sort: Literal["raw", "abs", "neg"] = "raw", + frontier: bool = False, + s_kwargs_dict: Dict[str, Any] = {}, + q_kwargs_dict: Dict[str, Any] = {}, + n: int = 25, + **streamline_kwargs_dict, +) -> Union[Axes, List[Axes], None]: + + from ..external.hodge import ddhodge + + logger = LoggerManager.gen_logger("dynamo-topography-plot") + logger.log_time() + + from matplotlib import rcParams + from matplotlib.colors import to_hex + + if type(color) == str: + color = [color] + + if background is None: + _background = rcParams.get("figure.facecolor") + _background = to_hex(_background) if type(_background) is tuple else _background + else: + _background = background + + terms = list(terms) if type(terms) is tuple else [terms] if type(terms) is str else terms + if approx: + if "streamline" not in terms: + terms.append("streamline") + if "trajectory" in terms: + terms = list(set(terms).difference("trajectory")) + + if init_cells is not None or init_states is not None: + terms.extend("trajectory") + + uns_key = "VecFld" if basis == "X" else "VecFld_" + basis + fps_uns_key = "VecFld" if fps_basis == "X" else "VecFld_" + fps_basis + + if uns_key not in adata.uns.keys(): + + if "velocity_" + basis not in adata.obsm_keys(): + logger.info( + f"velocity_{basis} is computed yet. " f"Projecting the velocity vector to {basis} basis now ...", + indent_level=1, + ) + cell_velocities(adata, basis=basis) + + logger.info( + f"Vector field for {basis} is not constructed. Constructing it now ...", + indent_level=1, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + if basis == fps_basis: + logger.info( + f"`basis` and `fps_basis` are all {basis}. Will also map topography ...", + indent_level=2, + ) + VectorField(adata, basis, map_topography=True, n=n) + else: + VectorField(adata, basis) + if fps_uns_key not in adata.uns.keys(): + if "velocity_" + basis not in adata.obsm_keys(): + logger.info( + f"velocity_{basis} is computed yet. " f"Projecting the velocity vector to {basis} basis now ...", + indent_level=1, + ) + cell_velocities(adata, basis=basis) + + logger.info( + f"Vector field for {fps_basis} is not constructed. " f"Constructing it and mapping its topography now ...", + indent_level=1, + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + VectorField(adata, fps_basis, map_topography=True, n=n) + # elif "VecFld2D" not in adata.uns[uns_key].keys(): + # with warnings.catch_warnings(): + # warnings.simplefilter("ignore") + # + # _topology(adata, basis, VecFld=None) + # elif "VecFld2D" in adata.uns[uns_key].keys() and type(adata.uns[uns_key]["VecFld2D"]) == str: + # with warnings.catch_warnings(): + # warnings.simplefilter("ignore") + # + # _topology(adata, basis, VecFld=None) + + vecfld_dict, vecfld = vecfld_from_adata(adata, basis) + + fps_vecfld_dict, fps_vecfld = vecfld_from_adata(adata, fps_basis) + + # need to use "X_basis" to plot on the scatter point space + if "Xss" not in fps_vecfld_dict: + # if topology is not mapped for this basis, calculate it now. + logger.info( + f"Vector field for {fps_basis} is but its topography is not mapped. " f"Mapping topography now ...", + indent_level=1, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + _topology(adata, fps_basis, VecFld=None, n=n) + else: + if fps_vecfld_dict["Xss"].size > 0 and fps_vecfld_dict["Xss"].shape[1] > 3: + fps_vecfld_dict["X_basis"], fps_vecfld_dict["Xss"] = ( + vecfld_dict["X"][:, :3], + vecfld_dict["X"][fps_vecfld_dict["fp_ind"], :3], + ) + + xlim, ylim, zlim = ( + adata.uns[fps_uns_key]["xlim"] if xlim is None else xlim, + adata.uns[fps_uns_key]["ylim"] if ylim is None else ylim, + adata.uns[fps_uns_key]["zlim"] if zlim is None else zlim, + ) + + if xlim is None or ylim is None or zlim is None: + X_basis = vecfld_dict["X"][:, :3] + min_, max_ = X_basis.min(0), X_basis.max(0) + + xlim = [ + min_[0] - (max_[0] - min_[0]) * 0.1, + max_[0] + (max_[0] - min_[0]) * 0.1, + ] + ylim = [ + min_[1] - (max_[1] - min_[1]) * 0.1, + max_[1] + (max_[1] - min_[1]) * 0.1, + ] + zlim = [ + min_[2] - (max_[2] - min_[2]) * 0.1, + max_[2] + (max_[2] - min_[2]) * 0.1, + ] + + + if init_cells is not None: + if init_states is None: + intersect_cell_names = list(set(init_cells).intersection(adata.obs_names)) + _init_states = ( + adata.obsm["X_" + basis][init_cells, :] + if len(intersect_cell_names) == 0 + else adata[intersect_cell_names].obsm["X_" + basis].copy() + ) + V = ( + adata.obsm["velocity_" + basis][init_cells, :] + if len(intersect_cell_names) == 0 + else adata[intersect_cell_names].obsm["velocity_" + basis].copy() + ) + + init_states = _init_states + + if quiver_source == "reconstructed" or (init_states is not None and init_cells is None): + from ..tools.utils import vector_field_function + + V = vector_field_function(init_states, vecfld_dict, [0, 1]) + + # plt.figure(facecolor=_background) + axes_list, color_list, font_color = scatters( + adata=adata, + basis=basis, + x=x, + y=y, + color=color, + layer=layer, + highlights=highlights, + labels=labels, + values=values, + theme=theme, + cmap=cmap, + color_key=color_key, + color_key_cmap=color_key_cmap, + background=_background, + ncols=ncols, + pointsize=pointsize, + figsize=figsize, + show_legend=show_legend, + use_smoothed=use_smoothed, + aggregate=aggregate, + show_arrowed_spines=show_arrowed_spines, + ax=ax, + sort=sort, + save_show_or_return="return", + frontier=frontier, + projection="3d", + **s_kwargs_dict, + return_all=True, + ) + + if type(axes_list) != list: + axes_list, color_list, font_color = ( + [axes_list], + [color_list], + [font_color], + ) + for i in range(len(axes_list)): + # ax = axes_list[i] + + axes_list[i].set_xlabel(basis + "_1") + axes_list[i].set_ylabel(basis + "_2") + axes_list[i].set_zlabel(basis + "_3") + # axes_list[i].set_aspect("equal") + + # Build the plot + axes_list[i].set_xlim(xlim) + axes_list[i].set_ylim(ylim) + axes_list[i].set_zlim(zlim) + + axes_list[i].set_facecolor(background) + + if t is None: + if vecfld_dict["grid_V"] is None: + max_t = np.max((np.diff(xlim), np.diff(ylim))) / np.min(np.abs(vecfld_dict["V"][:, :2])) + else: + max_t = np.max((np.diff(xlim), np.diff(ylim))) / np.min(np.abs(vecfld_dict["grid_V"])) + + t = np.linspace(0, max_t, 10 ** (np.min((int(np.log10(max_t)), 8)))) + + integration_direction = ( + "both" if fate == "both" else "forward" if fate == "future" else "backward" if fate == "history" else "both" + ) + + if "streamline" in terms: + if approx: + axes_list[i] = plot_flow_field( + vecfld, + xlim, + ylim, + background=_background, + start_points=init_states, + integration_direction=integration_direction, + density=density, + linewidth=linewidth, + streamline_color=streamline_color, + streamline_alpha=streamline_alpha, + color_start_points=color_start_points, + ax=axes_list[i], + **streamline_kwargs_dict, + ) + else: + axes_list[i] = plot_flow_field( + vecfld, + xlim, + ylim, + background=_background, + density=density, + linewidth=linewidth, + streamline_color=streamline_color, + streamline_alpha=streamline_alpha, + color_start_points=color_start_points, + ax=axes_list[i], + **streamline_kwargs_dict, + ) + + if "fixed_points" in terms: + axes_list[i] = plot_fixed_points( + fps_vecfld, + fps_vecfld_dict, + background=_background, + ax=axes_list[i], + markersize=markersize, + cmap=marker_cmap, + ) + + if "separatrices" in terms: + axes_list[i] = plot_separatrix(vecfld, xlim, ylim, t=t, background=_background, ax=axes_list[i]) + + if init_states is not None and "trajectory" in terms: + if not approx: + axes_list[i] = plot_traj( + vecfld.func, + init_states, + t, + background=_background, + integration_direction=integration_direction, + ax=axes_list[i], + ) + + # show quivers for the init_states cells + if init_states is not None and "quiver" in terms: + X = init_states + V /= 3 * quiver_autoscaler(X, V) + + df = pd.DataFrame({"x": X[:, 0], "y": X[:, 1], "u": V[:, 0], "v": V[:, 1]}) + + if quiver_size is None: + quiver_size = 1 + if _background in ["#ffffff", "black"]: + edgecolors = "white" + else: + edgecolors = "black" + + head_w, head_l, ax_l, scale = default_quiver_args(quiver_size, quiver_length) # + quiver_kwargs = { + "angles": "xy", + "scale": scale, + "scale_units": "xy", + "width": 0.0005, + "headwidth": head_w, + "headlength": head_l, + "headaxislength": ax_l, + "minshaft": 1, + "minlength": 1, + "pivot": "tail", + "linewidth": 0.1, + "edgecolors": edgecolors, + "alpha": 1, + "zorder": 7, + } + quiver_kwargs = update_dict(quiver_kwargs, q_kwargs_dict) + # axes_list[i].quiver(X_grid[:, 0], X_grid[:, 1], V_grid[:, 0], V_grid[:, 1], **quiver_kwargs) + axes_list[i].quiver( + df.iloc[:, 0], + df.iloc[:, 1], + df.iloc[:, 2], + df.iloc[:, 3], + **quiver_kwargs, + ) # color='red', facecolors='gray' + + if save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": "topography", + "dpi": None, + "ext": "pdf", + "transparent": True, + "close": True, + "verbose": True, + } + s_kwargs = update_dict(s_kwargs, save_kwargs) + + if save_show_or_return in ["both", "all"]: + s_kwargs["close"] = False + + save_fig(**s_kwargs) + if save_show_or_return in ["show", "both", "all"]: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + plt.tight_layout() + + plt.show() + if save_show_or_return in ["return", "all"]: + return axes_list if len(axes_list) > 1 else axes_list[0] \ No newline at end of file From 8ae98086388813c3a6c02a4ada596b8dad1d9d7e Mon Sep 17 00:00:00 2001 From: sichao Date: Wed, 30 Aug 2023 17:10:40 -0400 Subject: [PATCH 12/62] update docstring for the vf topography --- dynamo/vectorfield/topography.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/dynamo/vectorfield/topography.py b/dynamo/vectorfield/topography.py index e3cbbcf97..49ab400ff 100755 --- a/dynamo/vectorfield/topography.py +++ b/dynamo/vectorfield/topography.py @@ -758,17 +758,24 @@ def util_topology( basis: A string specifying the reduced dimension embedding to use for the computation. dims: A tuple of two integers specifying the dimensions of X to consider. func: A vector-valued function taking in coordinates and returning the vector field. - VecFld: `VecFldDict` TypedDict storing information about the vector field and SparseVFC-related parameters and computations. - X: an alternative to providing an `AnnData` object. Provide an np.ndarray from which `dims` are accessed, Defaults to None. + VecFld: `VecFldDict` TypedDict storing information about the vector field and SparseVFC-related parameters and + computations. + X: an alternative to providing an `AnnData` object. Provide a np.ndarray from which `dims` are accessed, + Defaults to None. n: An optional integer specifying the number of points to use for computing fixed points. Defaults to 25. Returns: A tuple consisting of the following elements: - - X_basis: an array of shape (n, 2) where n is the number of points in X. This is the subset of X consisting of the first two dimensions specified by dims. If X is not provided, X_basis is taken from the obsm attribute of adata using the key "X_" + basis. - - xlim, ylim: a tuple of floats specifying the limits of the x and y axes, respectively. These are computed based on the minimum and maximum values of X_basis. + - X_basis: an array of shape (n, 2) where n is the number of points in X. This is the subset of X consisting + of the first two dimensions specified by dims. If X is not provided, X_basis is taken from the obsm + attribute of adata using the key "X_" + basis. + - xlim, ylim, zlim: a tuple of floats specifying the limits of the x, y and z axes, respectively. These are + computed based on the minimum and maximum values of X_basis. - confidence: an array of shape (n, ) containing the confidence scores of the fixed points. - - NCx, NCy: arrays of shape (n, ) containing the x and y coordinates of the nullclines (lines where the derivative of the system is zero), respectively. - - Xss: an array of shape (n, k) where k is the number of dimensions of the system, containing the fixed points. + - NCx, NCy: arrays of shape (n, ) containing the x and y coordinates of the nullclines (lines where the + derivative of the system is zero), respectively. + - Xss: an array of shape (n, k) where k is the number of dimensions of the system, containing the fixed + points. - ftype: an array of shape (n, ) containing the types of fixed points (attractor, repeller, or saddle). - an array of shape (n, ) containing the indices of the fixed points in the original data. """ @@ -847,12 +854,13 @@ def topography( VecFld: Optional[VecFldDict] = None, **kwargs, ) -> AnnData: - """Map the topography of the single cell vector field in (first) two dimensions. + """Map the topography of the single cell vector field in (first) two or three dimensions. Args: adata: an AnnData object. basis: The reduced dimension embedding of cells to visualize. - layer: Which layer of the data will be used for vector field function reconstruction. This will be used in conjunction with X. + layer: Which layer of the data will be used for vector field function reconstruction. This will be used in + conjunction with X. X: Original data. Not used dims: The dimensions that will be used for vector field reconstruction. n: Number of samples for calculating the fixed points. @@ -862,7 +870,8 @@ def topography( Returns: `AnnData` object that is updated with the `VecFld` or 'VecFld_' + basis dictionary in the `uns` attribute. - The `VecFld2D` key stores an instance of the VectorField2D class which presumably has fixed points, nullcline, separatrix, computed and stored. + The `VecFld2D` key stores an instance of the VectorField2D class which presumably has fixed points, nullcline, + separatrix, computed and stored. """ if VecFld is None: From 194c0b885368524c8ac7f40ab28ce9862f546102 Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 31 Aug 2023 16:08:05 -0400 Subject: [PATCH 13/62] add 3D option to remove_particles --- dynamo/movie/utils.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/dynamo/movie/utils.py b/dynamo/movie/utils.py index dfb0018b9..7fc2597ff 100644 --- a/dynamo/movie/utils.py +++ b/dynamo/movie/utils.py @@ -1,10 +1,18 @@ -from typing import Union +from typing import Optional, Union +import numpy as np -def remove_particles(pts: list, xlim: Union[tuple, list], ylim: Union[tuple, list]): + +def remove_particles( + pts: list, + xlim: Union[tuple, list], + ylim: Union[tuple, list], + zlim: Optional[Union[tuple, list]] = None, +): if len(pts) == 0: return [] outside_xlim = (pts[:, 0] < xlim[0]) | (pts[:, 0] > xlim[1]) outside_ylim = (pts[:, 1] < ylim[0]) | (pts[:, 1] > ylim[1]) - keep = ~(outside_xlim | outside_ylim) + outside_zlim = np.full(outside_xlim.shape, False) if zlim is None else (pts[:, 2] < ylim[0]) | (pts[:, 2] > ylim[1]) + keep = ~(outside_xlim | outside_ylim | outside_zlim) return pts[keep] From 7f444759bf2c66e3a17bcc9b03c231a51265fd19 Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 31 Aug 2023 16:08:43 -0400 Subject: [PATCH 14/62] crete StreamFuncAnim --- dynamo/movie/__init__.py | 2 +- dynamo/movie/fate.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/dynamo/movie/__init__.py b/dynamo/movie/__init__.py index ff6c2f633..4f795f84a 100644 --- a/dynamo/movie/__init__.py +++ b/dynamo/movie/__init__.py @@ -1,4 +1,4 @@ """Mapping Vector Field of Single Cells """ -from .fate import StreamFuncAnim, animate_fates +from .fate import StreamFuncAnim, StreamFuncAnim3D, animate_fates diff --git a/dynamo/movie/fate.py b/dynamo/movie/fate.py index 8c5f908c4..aee7dc026 100755 --- a/dynamo/movie/fate.py +++ b/dynamo/movie/fate.py @@ -249,6 +249,45 @@ def update(self, frame): return (self.ln,) # return line so that blit works properly +class StreamFuncAnim3D(StreamFuncAnim): + def update(self, frame): + init_states = self.init_states + time_vec = self.time_vec + + pts = [i.tolist() for i in init_states] + + if frame == 0: + x, y, z = init_states.T + + for line in self.ax.get_lines(): + line.remove() + + (self.ln,) = self.ax.plot(x, y, z, "ro", zorder=20) + return (self.ln,) # return line so that blit works properly + else: + pts = [self.displace(cur_pts, time_vec[frame])[1].tolist() for cur_pts in pts] + pts = np.asarray(pts) + + pts = np.asarray(pts) + + pts = remove_particles(pts, self.xlim, self.ylim, self.zlim) + + x, y, z = np.asarray(pts).transpose() + + for line in self.ax.get_lines(): + line.remove() + + (self.ln,) = self.ax.plot(x, y, z, "ro", zorder=20) + + if self.time_scaler is not None: + vf_time = (time_vec[frame] - time_vec[frame - 1]) * self.time_scaler + self.ax.set_title("current vector field time is: {:12.2f}".format(vf_time)) + + # anim.event_source.interval = (time_vec[frame] - time_vec[frame - 1]) / 100 + + return (self.ln,) # return line so that blit works properly + + def animate_fates( adata, basis="umap", From c1e40fc9d12948d99421686f83f83a26e57e5c57 Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 31 Aug 2023 16:56:12 -0400 Subject: [PATCH 15/62] add 3D option when initializing StreamFuncAnim --- dynamo/movie/fate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dynamo/movie/fate.py b/dynamo/movie/fate.py index aee7dc026..0a4b814c1 100755 --- a/dynamo/movie/fate.py +++ b/dynamo/movie/fate.py @@ -185,6 +185,8 @@ def __init__( M = M + 0.01 * np.abs(M - m) self.xlim = [m[0], M[0]] self.ylim = [m[1], M[1]] + if X_data.shape[1] == 3: + self.zlim = [m[2], M[2]] # self.ax.set_aspect("equal") self.color = color @@ -205,7 +207,7 @@ def __init__( self.fig = fig self.ax = ax - (self.ln,) = self.ax.plot([], [], "ro") + (self.ln,) = self.ax.plot([], [], "ro", zs=[]) if X_data.shape[1] == 3 else self.ax.plot([], [], "ro") def init_background(self): return (self.ln,) From ecf12f424e291d4eb58cbeae86f257bf928838b0 Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 31 Aug 2023 16:56:50 -0400 Subject: [PATCH 16/62] create docstring for remove_particles --- dynamo/movie/utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/dynamo/movie/utils.py b/dynamo/movie/utils.py index 7fc2597ff..6d982943d 100644 --- a/dynamo/movie/utils.py +++ b/dynamo/movie/utils.py @@ -9,6 +9,18 @@ def remove_particles( ylim: Union[tuple, list], zlim: Optional[Union[tuple, list]] = None, ): + """Remove particles that fall outside specified coordinate ranges. + + Args: + pts: an array of points. + xlim: X-coordinate limits specified as a tuple or list of two values: (min_x, max_x). + ylim: Y-coordinate limits specified as a tuple or list of two values: (min_y, max_y). + zlim: Z-coordinate limits specified as a tuple or list of two values: (min_z, max_z). If not provided (default), + only 2D filtering based on xlim and ylim is performed. + + Returns: + An array of points that fall within the specified coordinate ranges. + """ if len(pts) == 0: return [] outside_xlim = (pts[:, 0] < xlim[0]) | (pts[:, 0] > xlim[1]) From b5049684fef435a14afc6215a4d9bee30661e547 Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 31 Aug 2023 16:57:33 -0400 Subject: [PATCH 17/62] update typing of remove_particles --- dynamo/movie/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dynamo/movie/utils.py b/dynamo/movie/utils.py index 6d982943d..b7f55bd22 100644 --- a/dynamo/movie/utils.py +++ b/dynamo/movie/utils.py @@ -4,11 +4,11 @@ def remove_particles( - pts: list, + pts: np.ndarray, xlim: Union[tuple, list], ylim: Union[tuple, list], zlim: Optional[Union[tuple, list]] = None, -): +) -> np.ndarray: """Remove particles that fall outside specified coordinate ranges. Args: From 466a64302ee3fbb13b17e5dcab4ffe0da19409f3 Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 31 Aug 2023 17:08:53 -0400 Subject: [PATCH 18/62] create docstr for StreamFuncAnim3D --- dynamo/movie/fate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dynamo/movie/fate.py b/dynamo/movie/fate.py index 0a4b814c1..a40592993 100755 --- a/dynamo/movie/fate.py +++ b/dynamo/movie/fate.py @@ -252,7 +252,9 @@ def update(self, frame): class StreamFuncAnim3D(StreamFuncAnim): + """The class of 3D animation instance for matplotlib FuncAnimation function.""" def update(self, frame): + """The function to call at each frame. Update the position of the line object in the animation.""" init_states = self.init_states time_vec = self.time_vec From 347b8ac81480bafd0e57b27cedbdd5ac9d5f2427 Mon Sep 17 00:00:00 2001 From: sichao Date: Fri, 1 Sep 2023 10:34:53 -0400 Subject: [PATCH 19/62] delete compute_nullclines_3d --- dynamo/vectorfield/topography.py | 39 -------------------------------- 1 file changed, 39 deletions(-) diff --git a/dynamo/vectorfield/topography.py b/dynamo/vectorfield/topography.py index 49ab400ff..c4f24ee83 100755 --- a/dynamo/vectorfield/topography.py +++ b/dynamo/vectorfield/topography.py @@ -178,45 +178,6 @@ def compute_nullclines_2d( return NCx, NCy -def compute_nullclines_3d( - X0: Union[List, np.ndarray], - fdx: Callable, - fdy: Callable, - fdz: Callable, - x_range: List, - y_range: List, - z_range: List, - s_max: Optional[float] = None, - ds: Optional[float] = None, -) -> Tuple[List]: - if s_max is None: - s_max = 5 * ((x_range[1] - x_range[0]) + (y_range[1] - y_range[0]) + (z_range[1] - z_range[0])) - if ds is None: - ds = s_max / 1e3 - - NCx = [] - NCy = [] - NCz = [] - for x0 in X0: - # initialize tangent predictor - theta = np.random.rand() * 2 * np.pi - phi = np.random.rand() * 2 * np.pi - r = ds * 2 - v0 = [r * np.sin(theta) * np.cos(phi), r * np.sin(theta) * np.sin(phi), r * np.cos(theta)] - v0 /= np.linalg.norm(v0) - # nullcline continuation - NCx.append(continuation(x0, fdx, s_max, ds, v0=v0)) - NCx.append(continuation(x0, fdx, s_max, ds, v0=-v0)) - NCy.append(continuation(x0, fdy, s_max, ds, v0=v0)) - NCy.append(continuation(x0, fdy, s_max, ds, v0=-v0)) - NCz.append(continuation(x0, fdz, s_max, ds, v0=v0)) - NCz.append(continuation(x0, fdz, s_max, ds, v0=-v0)) - NCx = clip_curves(NCx, [x_range, y_range, z_range], ds * 10) - NCy = clip_curves(NCy, [x_range, y_range, z_range], ds * 10) - NCz = clip_curves(NCz, [x_range, y_range, z_range], ds * 10) - return NCx, NCy, NCz - - def compute_separatrices( Xss: np.ndarray, Js: np.ndarray, From da4eb754a6b87b3e7d199ba640652e9ddc21a8e0 Mon Sep 17 00:00:00 2001 From: sichao Date: Fri, 1 Sep 2023 10:49:40 -0400 Subject: [PATCH 20/62] create docstr and update typing for VectorField3D --- dynamo/vectorfield/topography.py | 42 +++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/dynamo/vectorfield/topography.py b/dynamo/vectorfield/topography.py index c4f24ee83..e65841258 100755 --- a/dynamo/vectorfield/topography.py +++ b/dynamo/vectorfield/topography.py @@ -2,7 +2,7 @@ import datetime import os import warnings -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import anndata import numpy as np @@ -611,6 +611,13 @@ def output_to_dict(self, dict_vf): class VectorField3D(VectorField2D): + """A class that represents a 3D vector field, which is a type of mathematical object that assigns a 3D vector to + each point in a 3D space. + + The class is derived from the VectorField2D class. This vector field can be defined using a function that returns + the vector at each point, or by separate functions for the x and y components of the vector. Nullclines calculation + are not supported for 3D vector space because of the computational complexity. + """ def __init__( self, func: Callable, @@ -618,7 +625,19 @@ def __init__( func_vy: Optional[Callable] = None, func_vz: Optional[Callable] = None, X_data: Optional[np.ndarray] = None, - ): + ) -> None: + """Initialize the VectorField3D object. + + Args: + func: a function that takes an (n, 3) array of coordinates and returns an (n, 3) array of vectors + func_vx: a function that takes an (n, 3) array of coordinates and returns an (n,) array of x components of + the vectors, Defaults to None. + func_vy: a function that takes an (n, 3) array of coordinates and returns an (n,) array of y components of + the vectors, Defaults to None. + func_vz: a function that takes an (n, 3) array of coordinates and returns an (n,) array of z components of + the vectors, Defaults to None. + X_data: Defaults to None. + """ super().__init__(func, func_vx, func_vy, X_data) def func_dim(x, func, dim): @@ -644,7 +663,17 @@ def find_fixed_points_by_sampling( z_range: Tuple[float, float], lhs: Optional[bool] = True, tol_redundant: float = 1e-4, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> None: + """Find fixed points by sampling the vector field within a specified range of coordinates. + + Args: + n: the number of samples to take. + x_range: a tuple of two floats specifying the range of x coordinates to sample. + y_range: a tuple of two floats specifying the range of y coordinates to sample. + z_range: a tuple of two floats specifying the range of z coordinates to sample. + lhs: whether to use Latin Hypercube Sampling to generate the samples. Defaults to `True`. + tol_redundant: the tolerance for removing redundant fixed points. Defaults to 1e-4. + """ if lhs: from ..tools.sampling import lhsclassic @@ -692,7 +721,12 @@ def compute_nullclines( # outside = is_outside(X, [x_range, y_range]) # self.Xss.add_fixed_points(X[~outside], J[~outside], tol_redundant) - def output_to_dict(self, dict_vf): + def output_to_dict(self, dict_vf) -> Dict: + """Output the vector field as a dictionary. + + Returns: + A dictionary containing nullclines, fixed points, confidence and jacobians. + """ dict_vf["NCx"] = self.NCx dict_vf["NCy"] = self.NCy dict_vf["NCz"] = self.NCz From 5af50f75c90b3b68cab3992b6e19f44ab9d9e589 Mon Sep 17 00:00:00 2001 From: sichao Date: Fri, 1 Sep 2023 10:50:25 -0400 Subject: [PATCH 21/62] pass compute_nullclines in VectorField3D --- dynamo/vectorfield/topography.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/dynamo/vectorfield/topography.py b/dynamo/vectorfield/topography.py index e65841258..47d1dcde7 100755 --- a/dynamo/vectorfield/topography.py +++ b/dynamo/vectorfield/topography.py @@ -701,20 +701,20 @@ def compute_nullclines( find_new_fixed_points: Optional[bool] = False, tol_redundant: Optional[float] = 1e-4, ): - # compute arguments - s_max = 5 * ((x_range[1] - x_range[0]) + (y_range[1] - y_range[0])) - ds = s_max / 1e3 - self.NCx, self.NCy, self.NCz = compute_nullclines_3d( - self.Xss.get_X(), - self.fx, - self.fy, - self.fz, - x_range, - y_range, - z_range, - s_max=s_max, - ds=ds, - ) + pass + # s_max = 5 * ((x_range[1] - x_range[0]) + (y_range[1] - y_range[0])) + # ds = s_max / 1e3 + # self.NCx, self.NCy, self.NCz = compute_nullclines_3d( + # self.Xss.get_X(), + # self.fx, + # self.fy, + # self.fz, + # x_range, + # y_range, + # z_range, + # s_max=s_max, + # ds=ds, + # ) # if find_new_fixed_points: # sample_interval = ds * 10 # X, J = find_fixed_points_nullcline_3d(self.func, self.NCx, self.NCy, self.NCz, sample_interval, tol_redundant) From 1a6e72a11db9dd1a66e38e5e3e2314f2736c4a13 Mon Sep 17 00:00:00 2001 From: sichao Date: Wed, 20 Sep 2023 11:42:51 -0400 Subject: [PATCH 22/62] debug 3d scatters --- dynamo/plot/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dynamo/plot/utils.py b/dynamo/plot/utils.py index 3c430aa5b..6c3065635 100755 --- a/dynamo/plot/utils.py +++ b/dynamo/plot/utils.py @@ -593,10 +593,12 @@ def _matplotlib_points( for i in unique_labels: if i == "other": continue - color_cnt = np.nanmedian(points[np.where(labels == i)[0], :2].astype("float"), 0) + if projection == "3d": + color_cnt = np.nanmedian(points[np.where(labels == i)[0], :3].astype("float"), 0) + else: + color_cnt = np.nanmedian(points[np.where(labels == i)[0], :2].astype("float"), 0) txt = ax.text( - color_cnt[0], - color_cnt[1], + *color_cnt, str(i), color=_select_font_color(font_color), zorder=1000, From d2a221bede030ef8900bfe7b2e7d14f7b4bf1584 Mon Sep 17 00:00:00 2001 From: sichao Date: Wed, 20 Sep 2023 14:40:12 -0400 Subject: [PATCH 23/62] create pyvista 3d scatters --- dynamo/plot/__init__.py | 3 +- dynamo/plot/scatters.py | 271 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 270 insertions(+), 4 deletions(-) diff --git a/dynamo/plot/__init__.py b/dynamo/plot/__init__.py index 32ea82f2a..f8b620ca3 100755 --- a/dynamo/plot/__init__.py +++ b/dynamo/plot/__init__.py @@ -30,7 +30,7 @@ show_fraction, variance_explained, ) -from .scatters import scatters +from .scatters import scatters, scatters_pv from .scPotential import show_landscape from .sctransform import sctransform_plot_fit, plot_residual_var from .scVectorField import ( # , plot_LIC_gray @@ -81,6 +81,7 @@ "quiver_autoscaler", "save_fig", "scatters", + "scatters_pv", "basic_stats", "show_fraction", "feature_genes", diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 85f0d57d8..4c814b08f 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -13,9 +13,10 @@ import numpy as np import pandas as pd from anndata import AnnData -from matplotlib import patches +from matplotlib import patches, rcParams from matplotlib.axes import Axes from matplotlib.lines import Line2D +from matplotlib.colors import rgb2hex, to_hex from pandas.api.types import is_categorical_dtype from ..configuration import _themes, reset_rcParams @@ -264,8 +265,6 @@ def scatters( """ import matplotlib.pyplot as plt - from matplotlib import rcParams - from matplotlib.colors import rgb2hex, to_hex # 2d is not a projection in matplotlib, default is None (rectilinear) if projection == "2d": @@ -908,3 +907,269 @@ def _plot_basis_layer(cur_b, cur_l): return (axes_list, color_list, font_color) if total_panels > 1 else (ax, color_out, font_color) else: return axes_list if total_panels > 1 else ax + + +def scatters_pv( + adata: AnnData, + basis: str = "umap", + x: int = 0, + y: int = 1, + z: int = 2, + color: str = "ntr", + layer: str = "X", + highlights: Optional[list] = None, + labels: Optional[list] = None, + values: Optional[list] = None, + cmap: Optional[str] = None, + theme: Optional[str] = None, + background: Optional[str] = None, + color_key_cmap: Optional[str] = None, + use_smoothed: bool = True, + smooth: bool = False, + save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", + save_kwargs: Dict[str, Any] = {}, + **kwargs, +): + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + + if background is None: + _background = rcParams.get("figure.facecolor") + _background = to_hex(_background) if type(_background) is tuple else _background + # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background) + else: + _background = background + + if type(x) in [int, str]: + x = [x] + if type(y) in [int, str]: + y = [y] + if type(z) in [int, str]: + z = [z] + + # make x, y, z lists of list, where each list corresponds to one coordinate set + if ( + type(x) in [anndata._core.views.ArrayView, np.ndarray] + and type(y) in [anndata._core.views.ArrayView, np.ndarray] + and type(z) in [anndata._core.views.ArrayView, np.ndarray] + and len(x) == adata.n_obs + and len(y) == adata.n_obs + and len(z) == adata.n_obs + ): + x, y, z = [x], [y], [z] + + elif hasattr(x, "__len__") and hasattr(y, "__len__") and hasattr(z, "__len__"): + x, y, z = list(x), list(y), list(z) + + assert len(x) == len(y) and len(x) == len(z), "bug: x, y, z does not have the same shape." + + if use_smoothed: + mapper = get_mapper() + + # check color, layer, basis -> convert to list + if type(color) is str: + color = [color] + if type(layer) is str: + layer = [layer] + if type(basis) is str: + basis = [basis] + + pl = pv.Plotter() + + def _plot_basis_layer_pv(cur_b, cur_l): + nonlocal _background, adata, cmap, x, y, z, labels, values + + if cur_l in ["acceleration", "curvature", "divergence", "velocity_S", "velocity_T"]: + cur_l_smoothed = cur_l + cmap, sym_c = "bwr", True # TODO maybe use other divergent color map in the future + else: + if use_smoothed: + cur_l_smoothed = cur_l if cur_l.startswith("M_") | cur_l.startswith("velocity") else mapper[cur_l] + if cur_l.startswith("velocity"): + cmap, sym_c = "bwr", True + + if cur_l + "_" + cur_b in adata.obsm.keys(): + prefix = cur_l + "_" + elif ("X_" + cur_b) in adata.obsm.keys(): + prefix = "X_" + elif cur_b in adata.obsm.keys(): + # special case for spatial for compatibility with other packages + prefix = "" + else: + raise ValueError("Please check if basis=%s exists in adata.obsm" % basis) + + basis_key = prefix + cur_b + main_info("plotting with basis key=%s" % basis_key, indent_level=2) + + for cur_c in color: + main_debug("coloring scatter of cur_c: %s" % str(cur_c)) + cur_title = cur_c + + _color = _get_adata_color_vec(adata, cur_l, cur_c) + + # select data rows based on stack color thresholding + is_numeric_color = np.issubdtype(_color.dtype, np.number) + if not is_numeric_color: + main_info( + "skip filtering %s by stack threshold when stacking color because it is not a numeric type" + % (cur_c), + indent_level=2, + ) + _values = values + _adata = adata + + for cur_x, cur_y, cur_z in zip(x, y, z): # here x / y are arrays + main_debug("handling coordinates, cur_x: %s, cur_y: %s, cur_z: %s" % (cur_x, cur_y, cur_z)) + if type(cur_x) is int and type(cur_y) is int and type(cur_z): + x_col_name = cur_b + "_0" + y_col_name = cur_b + "_1" + z_col_name = cur_b + "_2" + + points = pd.DataFrame( + { + x_col_name: _adata.obsm[basis_key][:, cur_x], + y_col_name: _adata.obsm[basis_key][:, cur_y], + z_col_name: _adata.obsm[basis_key][:, cur_z], + } + ) + points.columns = [x_col_name, y_col_name, z_col_name] + + elif is_gene_name(_adata, cur_x) and is_gene_name(_adata, cur_y) and is_gene_name(_adata, cur_z): + points = pd.DataFrame( + { + cur_x: _adata.obs_vector(k=cur_x, layer=None) + if cur_l_smoothed == "X" + else _adata.obs_vector(k=cur_x, layer=cur_l_smoothed), + cur_y: _adata.obs_vector(k=cur_y, layer=None) + if cur_l_smoothed == "X" + else _adata.obs_vector(k=cur_y, layer=cur_l_smoothed), + cur_z: _adata.obs_vector(k=cur_z, layer=None) + if cur_l_smoothed == "X" + else _adata.obs_vector(k=cur_z, layer=cur_l_smoothed), + } + ) + # points = points.loc[(points > 0).sum(1) > 1, :] + points.columns = [ + cur_x + " (" + cur_l_smoothed + ")", + cur_y + " (" + cur_l_smoothed + ")", + cur_z + " (" + cur_l_smoothed + ")", + ] + cur_title = cur_x + " VS " + cur_y + " VS " + cur_z + elif is_cell_anno_column(_adata, cur_x) and is_cell_anno_column(_adata, cur_y) and is_cell_anno_column(_adata, cur_z): + points = pd.DataFrame( + { + cur_x: _adata.obs_vector(cur_x), + cur_y: _adata.obs_vector(cur_y), + cur_z: _adata.obs_vector(cur_z), + } + ) + points.columns = [cur_x, cur_y, cur_z] + cur_title = cur_x + " VS " + cur_y + " VS " + cur_z + elif is_cell_anno_column(_adata, cur_x) and is_gene_name(_adata, cur_y): + points = pd.DataFrame( + { + cur_x: _adata.obs_vector(cur_x), + cur_y: _adata.obs_vector(k=cur_y, layer=None) + if cur_l_smoothed == "X" + else _adata.obs_vector(k=cur_y, layer=cur_l_smoothed), + } + ) + # points = points.loc[points.iloc[:, 1] > 0, :] + points.columns = [ + cur_x, + cur_y + " (" + cur_l_smoothed + ")", + ] + cur_title = cur_y + elif is_gene_name(_adata, cur_x) and is_cell_anno_column(_adata, cur_y): + points = pd.DataFrame( + { + cur_x: _adata.obs_vector(k=cur_x, layer=None) + if cur_l_smoothed == "X" + else _adata.obs_vector(k=cur_x, layer=cur_l_smoothed), + cur_y: _adata.obs_vector(cur_y), + } + ) + # points = points.loc[points.iloc[:, 0] > 0, :] + points.columns = [ + cur_x + " (" + cur_l_smoothed + ")", + cur_y, + ] + cur_title = cur_x + elif is_layer_keys(_adata, cur_x) and is_layer_keys(_adata, cur_y): + cur_x_, cur_y_ = ( + _adata[:, cur_b].layers[cur_x], + _adata[:, cur_b].layers[cur_y], + ) + points = pd.DataFrame({cur_x: flatten(cur_x_), cur_y: flatten(cur_y_)}) + # points = points.loc[points.iloc[:, 0] > 0, :] + points.columns = [cur_x, cur_y] + cur_title = cur_b + elif type(cur_x) in [anndata._core.views.ArrayView, np.ndarray] and type(cur_y) in [ + anndata._core.views.ArrayView, + np.ndarray, + ]: + points = pd.DataFrame({"x": flatten(cur_x), "y": flatten(cur_y)}) + points.columns = ["x", "y"] + cur_title = cur_b + else: + raise ValueError("Make sure your `x` and `y` are integers, gene names, column names in .obs, etc.") + + # https://stackoverflow.com/questions/4187185/how-can-i-check-if-my-python-object-is-a-number + # answer from Boris. + is_not_continuous = not isinstance(_color[0], Number) or _color.dtype.name == "category" + + if is_not_continuous: + labels = np.asarray(_color) if is_categorical_dtype(_color) else _color + if theme is None: + if _background in ["#ffffff", "black"]: + _theme_ = "glasbey_dark" + else: + _theme_ = "glasbey_white" + else: + _theme_ = theme + else: + _values = _color + if theme is None: + if _background in ["#ffffff", "black"]: + _theme_ = "inferno" if cur_l != "velocity" else "div_blue_black_red" + else: + _theme_ = "viridis" if not cur_l.startswith("velocity") else "div_blue_red" + else: + _theme_ = theme + + _cmap = _themes[_theme_]["cmap"] if cmap is None else cmap + + _color_key_cmap = _themes[_theme_]["color_key_cmap"] if color_key_cmap is None else color_key_cmap + _background = _themes[_theme_]["background"] if _background is None else _background + + if labels is not None and values is not None: + raise ValueError("Conflicting options; only one of labels or values should be set") + + if smooth and not is_not_continuous: + main_debug("smooth and not continuous") + knn = _adata.obsp["moments_con"] + values = ( + calc_1nd_moment(values, knn)[0] + if smooth in [1, True] + else calc_1nd_moment(values, knn**smooth)[0] + ) + + pvdataset = pv.PolyData(points.values) + pl.add_points(pvdataset.points, color='red') + # pl.show() + + for cur_b in basis: + for cur_l in layer: + main_debug("Plotting basis:%s, layer: %s" % (str(basis), str(layer))) + main_debug("colors: %s" % (str(color))) + _plot_basis_layer_pv(cur_b, cur_l) + + main_debug("show, return or save...") + if save_show_or_return in ["save", "both", "all"]: + pl.save_graphic(**save_kwargs) + if save_show_or_return in ["show", "both", "all"]: + pl.show() + if save_show_or_return in ["return", "all"]: + return pl \ No newline at end of file From c057c18640f410d76f96e8e0c30f591a3a2b8f4a Mon Sep 17 00:00:00 2001 From: sichao Date: Wed, 20 Sep 2023 16:12:01 -0400 Subject: [PATCH 24/62] enable colors selection in pv --- dynamo/plot/scatters.py | 33 ++++--- dynamo/plot/utils.py | 208 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 228 insertions(+), 13 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 4c814b08f..e0c723ce4 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -31,6 +31,7 @@ _matplotlib_points, _select_font_color, arrowed_spines, + calculate_colors, deaxis_all, despline_all, is_cell_anno_column, @@ -923,8 +924,10 @@ def scatters_pv( cmap: Optional[str] = None, theme: Optional[str] = None, background: Optional[str] = None, + color_key: Union[Dict[str, str], List[str], None] = None, color_key_cmap: Optional[str] = None, use_smoothed: bool = True, + sym_c: bool = False, smooth: bool = False, save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Dict[str, Any] = {}, @@ -935,13 +938,6 @@ def scatters_pv( except ImportError: raise ImportError("Please install pyvista first.") - if background is None: - _background = rcParams.get("figure.facecolor") - _background = to_hex(_background) if type(_background) is tuple else _background - # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background) - else: - _background = background - if type(x) in [int, str]: x = [x] if type(y) in [int, str]: @@ -979,7 +975,7 @@ def scatters_pv( pl = pv.Plotter() def _plot_basis_layer_pv(cur_b, cur_l): - nonlocal _background, adata, cmap, x, y, z, labels, values + nonlocal background, adata, cmap, x, y, z, labels, sym_c, values if cur_l in ["acceleration", "curvature", "divergence", "velocity_S", "velocity_T"]: cur_l_smoothed = cur_l @@ -1123,7 +1119,7 @@ def _plot_basis_layer_pv(cur_b, cur_l): if is_not_continuous: labels = np.asarray(_color) if is_categorical_dtype(_color) else _color if theme is None: - if _background in ["#ffffff", "black"]: + if background in ["#ffffff", "black"]: _theme_ = "glasbey_dark" else: _theme_ = "glasbey_white" @@ -1132,7 +1128,7 @@ def _plot_basis_layer_pv(cur_b, cur_l): else: _values = _color if theme is None: - if _background in ["#ffffff", "black"]: + if background in ["#ffffff", "black"]: _theme_ = "inferno" if cur_l != "velocity" else "div_blue_black_red" else: _theme_ = "viridis" if not cur_l.startswith("velocity") else "div_blue_red" @@ -1142,7 +1138,7 @@ def _plot_basis_layer_pv(cur_b, cur_l): _cmap = _themes[_theme_]["cmap"] if cmap is None else cmap _color_key_cmap = _themes[_theme_]["color_key_cmap"] if color_key_cmap is None else color_key_cmap - _background = _themes[_theme_]["background"] if _background is None else _background + background = _themes[_theme_]["background"] if background is None else background if labels is not None and values is not None: raise ValueError("Conflicting options; only one of labels or values should be set") @@ -1156,9 +1152,20 @@ def _plot_basis_layer_pv(cur_b, cur_l): else calc_1nd_moment(values, knn**smooth)[0] ) + colors, _, _ = calculate_colors( + points, + labels=labels, + values=_values, + cmap=_cmap, + color_key=color_key, + color_key_cmap=_color_key_cmap, + background=background, + sym_c=sym_c, + ) + pvdataset = pv.PolyData(points.values) - pl.add_points(pvdataset.points, color='red') - # pl.show() + pl.add_points(pvdataset.points) + for cur_b in basis: for cur_l in layer: diff --git a/dynamo/plot/utils.py b/dynamo/plot/utils.py index 6c3065635..26d073398 100755 --- a/dynamo/plot/utils.py +++ b/dynamo/plot/utils.py @@ -59,6 +59,214 @@ def _get_adata_color_vec(adata, layer, col): return np.array(_color).flatten() +def calculate_colors( + points, + ax=None, + labels=None, + values=None, + highlights=None, + cmap="Blues", + color_key=None, + color_key_cmap="Spectral", + background="white", + width=7, + height=5, + vmin=2, + vmax=98, + sort="raw", + sym_c=False, + projection=None, # default in matplotlib + **kwargs, +): + import matplotlib.pyplot as plt + from matplotlib.ticker import MaxNLocator + + dpi = plt.rcParams["figure.dpi"] + width, height = width * dpi, height * dpi + rasterized = kwargs["rasterized"] if "rasterized" in kwargs.keys() else None + # """Use matplotlib to plot points""" + # point_size = 500.0 / np.sqrt(points.shape[0]) + + legend_elements = None + + if ax is None: + dpi = plt.rcParams["figure.dpi"] + fig = plt.figure(figsize=(width / dpi, height / dpi)) + ax = fig.add_subplot(111, projection=projection) + + ax.set_facecolor(background) + + # Color by labels + if labels is not None: + main_debug("labels are not None, drawing by labels") + color_type = "labels" + + if labels.shape[0] != points.shape[0]: + raise ValueError( + "Labels must have a label for " + "each sample (size mismatch: {} {})".format(labels.shape[0], points.shape[0]) + ) + if color_key is None: + main_debug("color_key is None") + cmap = copy.copy(matplotlib.cm.get_cmap(color_key_cmap)) + cmap.set_bad("lightgray") + colors = None + + if highlights is None: + unique_labels = np.unique(labels) + num_labels = unique_labels.shape[0] + color_key = plt.get_cmap(color_key_cmap)(np.linspace(0, 1, num_labels)) + else: + if type(highlights) is str: + highlights = [highlights] + highlights.append("other") + unique_labels = np.array(highlights) + num_labels = unique_labels.shape[0] + color_key = _to_hex(plt.get_cmap(color_key_cmap)(np.linspace(0, 1, num_labels))) + color_key[-1] = "#bdbdbd" # lightgray hex code https://www.color-hex.com/color/d3d3d3 + + labels[[i not in highlights[:-1] for i in labels]] = "other" + points = pd.DataFrame(points) + points["label"] = pd.Categorical(labels) + + # reorder data so that highlighting points will be on top of background points + highlight_ids, background_ids = ( + points["label"] != "other", + points["label"] == "other", + ) + # reorder_data = points.copy(deep=True) + # ( + # reorder_data.loc[:(sum(background_ids) - 1), :], + # reorder_data.loc[sum(background_ids):, :], + # ) = (points.loc[background_ids, :].values, points.loc[highlight_ids, :].values) + points = pd.concat( + ( + points.loc[background_ids, :], + points.loc[highlight_ids, :], + ) + ).values + labels = points[:, 2] + + # WARNING: do not change the following line to "elif" during refactor + # This if-else branch is not logically parallel to the previous one. The following branch sets `colors`. + if isinstance(color_key, dict): + main_debug("color_key is a dict") + colors = pd.Series(labels).map(color_key).values + unique_labels = np.unique(labels) + legend_elements = [ + # Patch(facecolor=color_key[k], label=k) for k in unique_labels + Line2D( + [0], + [0], + marker="o", + color=color_key[k], + label=k, + linestyle="None", + ) + for k in unique_labels + ] + else: + main_debug("color_key is not None and not a dict") + unique_labels = np.unique(labels) + if len(color_key) < unique_labels.shape[0]: + raise ValueError("Color key must have enough colors for the number of labels") + + new_color_key = {k: color_key[i] for i, k in enumerate(unique_labels)} + legend_elements = [ + # Patch(facecolor=color_key[i], label=k) + Line2D( + [0], + [0], + marker="o", + color=color_key[i], + label=k, + linestyle="None", + ) + for i, k in enumerate(unique_labels) + ] + colors = pd.Series(labels).map(new_color_key) + + # Color by values + elif values is not None: + main_debug("drawing points by values") + color_type = "values" + cmap_ = copy.copy(matplotlib.cm.get_cmap(cmap)) + cmap_.set_bad("lightgray") + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + matplotlib.cm.register_cmap(name=cmap_.name, cmap=cmap_, override_builtin=True) + + if values.shape[0] != points.shape[0]: + raise ValueError( + "Values must have a value for " + "each sample (size mismatch: {} {})".format(values.shape[0], points.shape[0]) + ) + # reorder data so that high values points will be on top of background points + sorted_id = ( + np.argsort(abs(values)) if sort == "abs" else np.argsort(-values) if sort == "neg" else np.argsort(values) + ) + values, points = values[sorted_id], points[sorted_id, :] + + # if there are very few cells have expression, set the vmin/vmax only based on positive values to + # get rid of outliers + if np.nanmin(values) == 0: + n_pos_cells = sum(values > 0) + if 0 < n_pos_cells / len(values) < 0.02: + vmin = 0 if n_pos_cells == 1 else np.percentile(values[values > 0], 2) + vmax = np.nanmax(values) if n_pos_cells == 1 else np.percentile(values[values > 0], 98) + if vmin + vmax in [1, 100]: + vmin += 1e-12 + vmax += 1e-12 + + # if None: min/max from data + # if positive and sum up to 1, take fraction + # if positive and sum up to 100, take percentage + # otherwise take the data + _vmin = ( + np.nanmin(values) + if vmin is None + else np.nanpercentile(values, vmin * 100) + if (vmin + vmax == 1 and 0 <= vmin < vmax) + else np.nanpercentile(values, vmin) + if (vmin + vmax == 100 and 0 <= vmin < vmax) + else vmin + ) + _vmax = ( + np.nanmax(values) + if vmax is None + else np.nanpercentile(values, vmax * 100) + if (vmin + vmax == 1 and 0 <= vmin < vmax) + else np.nanpercentile(values, vmax) + if (vmin + vmax == 100 and 0 <= vmin < vmax) + else vmax + ) + + if sym_c and _vmin < 0 and _vmax > 0: + bounds = np.nanmax([np.abs(_vmin), _vmax]) + bounds = bounds * np.array([-1, 1]) + _vmin, _vmax = bounds + + + if "norm" in kwargs: + norm = kwargs["norm"] + else: + norm = matplotlib.colors.Normalize(vmin=_vmin, vmax=_vmax) + + mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap) + mappable.set_array(values) + + cmap = matplotlib.cm.get_cmap(cmap) + colors = cmap(values) + # No color (just pick the midpoint of the cmap) + else: + main_debug("drawing points without color passed in args, using midpoint of the cmap") + color_type = "midpoint" + colors = plt.get_cmap(cmap)(0.5) + + return (colors, color_type, None) if color_type != "labels" else (colors, color_type, legend_elements) + + # --------------------------------------------------------------------------------------------------- # plotting utilities that borrowed from umap # link: https://github.com/lmcinnes/umap/blob/7e051d8f3c4adca90ca81eb45f6a9d1372c076cf/umap/plot.py From 03e3142a539733f5611938c4733f3972278faf92 Mon Sep 17 00:00:00 2001 From: sichao Date: Wed, 20 Sep 2023 17:23:42 -0400 Subject: [PATCH 25/62] debug colors selection --- dynamo/plot/scatters.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index e0c723ce4..7b0331e63 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -1164,7 +1164,8 @@ def _plot_basis_layer_pv(cur_b, cur_l): ) pvdataset = pv.PolyData(points.values) - pl.add_points(pvdataset.points) + pvdataset.point_data["colors"] = np.stack(colors.values) + pl.add_points(pvdataset, scalars="colors", preference='point', rgb=True) for cur_b in basis: From 199b0585b4db38f9b9210f27a6a7cc76cbe95b7e Mon Sep 17 00:00:00 2001 From: sichao Date: Wed, 20 Sep 2023 17:40:45 -0400 Subject: [PATCH 26/62] optimize pyvista saving --- dynamo/plot/scatters.py | 19 +++++++++++++++++-- dynamo/plot/utils.py | 42 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 7b0331e63..158babf86 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -38,6 +38,7 @@ is_gene_name, is_layer_keys, is_list_of_lists, + retrieve_plot_save_path, save_fig, ) @@ -1176,8 +1177,22 @@ def _plot_basis_layer_pv(cur_b, cur_l): main_debug("show, return or save...") if save_show_or_return in ["save", "both", "all"]: - pl.save_graphic(**save_kwargs) + s_kwargs = { + "path": None, + "prefix": "scatters_pv", + "ext": "pdf", + "title": 'PyVista Export', + "raster": True, + "painter": True, + } + + s_kwargs = update_dict(s_kwargs, save_kwargs) + + saving_path = retrieve_plot_save_path(path=s_kwargs["path"], prefix=s_kwargs["prefix"], ext=s_kwargs["ext"]) + pl.save_graphic(saving_path, title=s_kwargs["title"], raster=s_kwargs["raster"], painter=s_kwargs["painter"]) + if save_show_or_return in ["show", "both", "all"]: pl.show() + if save_show_or_return in ["return", "all"]: - return pl \ No newline at end of file + return pl diff --git a/dynamo/plot/utils.py b/dynamo/plot/utils.py index 26d073398..9bf20c233 100755 --- a/dynamo/plot/utils.py +++ b/dynamo/plot/utils.py @@ -1640,6 +1640,48 @@ def save_fig( print("Done") +def retrieve_plot_save_path( + path: Optional[str] = None, + prefix: Optional[str] = None, + ext: str = "pdf", +) -> str: + """Retrieve the path to save dynamo plots. + + Args: + path: the path (and filename, without the extension) to save_fig the figure to. Defaults to None. + prefix: the prefix added to the figure name. This will be automatically set accordingly to the plotting function + used. Defaults to None. + ext: the file extension. This must be supported by the active matplotlib or pyvista backend. Most backends + support 'png', 'pdf', 'ps', 'eps', and 'svg'. Defaults to "pdf". + Returns: + The saving path. + """ + prefix = os.path.normpath(prefix) + if path is None: + path = os.getcwd() + "/" + + # Extract the directory and filename from the given path + directory = os.path.split(path)[0] + filename = os.path.split(path)[1] + if directory == "": + directory = "." + if filename == "": + filename = "dyn_savefig" + + # If the directory does not exist, create it + if not os.path.exists(directory): + os.makedirs(directory) + + # The final path to save_fig to + savepath = ( + os.path.join(directory, filename + "." + ext) + if prefix is None + else os.path.join(directory, prefix + "_" + filename + "." + ext) + ) + + return savepath + + # --------------------------------------------------------------------------------------------------- def alpha_shape(x, y, alpha): # Start Using SHAPELY From 3cc392f16f1cdf028c6b872b823e16b467125ebb Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 21 Sep 2023 11:56:13 -0400 Subject: [PATCH 27/62] implement multiple input options for x, y, z --- dynamo/plot/scatters.py | 173 ++++++++++++++++++---------------------- 1 file changed, 78 insertions(+), 95 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 158babf86..2cfff3206 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -911,12 +911,78 @@ def _plot_basis_layer(cur_b, cur_l): return axes_list if total_panels > 1 else ax +def map_to_points(_adata: AnnData, axis_x: str, axis_y: str, axis_z: str, basis_key: str, cur_b: str, cur_l_smoothed: str): + + gene_title = [] + anno_title = [] + + def _map_cur_axis(cur): + nonlocal gene_title, anno_title + + if is_gene_name(_adata, cur): + points_df_data = (_adata.obs_vector(k=cur, layer=None) + if cur_l_smoothed == "X" + else _adata.obs_vector(k=cur, layer=cur_l_smoothed)) + points_column = cur + " (" + cur_l_smoothed + ")" + gene_title.append(cur) + elif is_cell_anno_column(_adata, cur): + points_df_data = _adata.obs_vector(cur) + points_column = cur + anno_title.append(cur) + elif is_layer_keys(_adata, cur): + points_df_data = _adata[:, cur_b].layers[cur] + points_column = flatten(points_df_data) + else: + raise ValueError("Make sure your `x`, `y` and `z` are integers, gene names, column names in .obs, etc.") + + return points_df_data, points_column + + if type(axis_x) is int and type(axis_y) is int and type(axis_z): + x_col_name = cur_b + "_0" + y_col_name = cur_b + "_1" + z_col_name = cur_b + "_2" + + points = pd.DataFrame( + { + x_col_name: _adata.obsm[basis_key][:, axis_x], + y_col_name: _adata.obsm[basis_key][:, axis_y], + z_col_name: _adata.obsm[basis_key][:, axis_z], + } + ) + points.columns = [x_col_name, y_col_name, z_col_name] + elif type(axis_x) in [anndata._core.views.ArrayView, np.ndarray] and type(axis_y) in [ + anndata._core.views.ArrayView, + np.ndarray, + ]: + points = pd.DataFrame({"x": flatten(axis_x), "y": flatten(axis_y), "x": flatten(axis_z)}) + points.columns = ["x", "y", "z"] + else: + x_points_df_data, x_points_column = _map_cur_axis(axis_x) + y_points_df_data, y_points_column = _map_cur_axis(axis_y) + z_points_df_data, z_points_column = _map_cur_axis(axis_z) + points = pd.DataFrame({ + axis_x: x_points_df_data, + axis_y: y_points_df_data, + axis_z: z_points_df_data, + }) + points.columns = [x_points_column, y_points_column, z_points_column] + + if len(gene_title) != 0: + cur_title = " VS ".join(gene_title) + elif len(anno_title) == 3: + cur_title = " VS ".join(anno_title) + else: + cur_title = cur_b + + return points, cur_title + + def scatters_pv( adata: AnnData, basis: str = "umap", - x: int = 0, - y: int = 1, - z: int = 2, + x: Union[int, str] = 0, + y: Union[int, str] = 1, + z: Union[int, str] = 2, color: str = "ntr", layer: str = "X", highlights: Optional[list] = None, @@ -1019,99 +1085,16 @@ def _plot_basis_layer_pv(cur_b, cur_l): for cur_x, cur_y, cur_z in zip(x, y, z): # here x / y are arrays main_debug("handling coordinates, cur_x: %s, cur_y: %s, cur_z: %s" % (cur_x, cur_y, cur_z)) - if type(cur_x) is int and type(cur_y) is int and type(cur_z): - x_col_name = cur_b + "_0" - y_col_name = cur_b + "_1" - z_col_name = cur_b + "_2" - - points = pd.DataFrame( - { - x_col_name: _adata.obsm[basis_key][:, cur_x], - y_col_name: _adata.obsm[basis_key][:, cur_y], - z_col_name: _adata.obsm[basis_key][:, cur_z], - } - ) - points.columns = [x_col_name, y_col_name, z_col_name] - elif is_gene_name(_adata, cur_x) and is_gene_name(_adata, cur_y) and is_gene_name(_adata, cur_z): - points = pd.DataFrame( - { - cur_x: _adata.obs_vector(k=cur_x, layer=None) - if cur_l_smoothed == "X" - else _adata.obs_vector(k=cur_x, layer=cur_l_smoothed), - cur_y: _adata.obs_vector(k=cur_y, layer=None) - if cur_l_smoothed == "X" - else _adata.obs_vector(k=cur_y, layer=cur_l_smoothed), - cur_z: _adata.obs_vector(k=cur_z, layer=None) - if cur_l_smoothed == "X" - else _adata.obs_vector(k=cur_z, layer=cur_l_smoothed), - } - ) - # points = points.loc[(points > 0).sum(1) > 1, :] - points.columns = [ - cur_x + " (" + cur_l_smoothed + ")", - cur_y + " (" + cur_l_smoothed + ")", - cur_z + " (" + cur_l_smoothed + ")", - ] - cur_title = cur_x + " VS " + cur_y + " VS " + cur_z - elif is_cell_anno_column(_adata, cur_x) and is_cell_anno_column(_adata, cur_y) and is_cell_anno_column(_adata, cur_z): - points = pd.DataFrame( - { - cur_x: _adata.obs_vector(cur_x), - cur_y: _adata.obs_vector(cur_y), - cur_z: _adata.obs_vector(cur_z), - } - ) - points.columns = [cur_x, cur_y, cur_z] - cur_title = cur_x + " VS " + cur_y + " VS " + cur_z - elif is_cell_anno_column(_adata, cur_x) and is_gene_name(_adata, cur_y): - points = pd.DataFrame( - { - cur_x: _adata.obs_vector(cur_x), - cur_y: _adata.obs_vector(k=cur_y, layer=None) - if cur_l_smoothed == "X" - else _adata.obs_vector(k=cur_y, layer=cur_l_smoothed), - } - ) - # points = points.loc[points.iloc[:, 1] > 0, :] - points.columns = [ - cur_x, - cur_y + " (" + cur_l_smoothed + ")", - ] - cur_title = cur_y - elif is_gene_name(_adata, cur_x) and is_cell_anno_column(_adata, cur_y): - points = pd.DataFrame( - { - cur_x: _adata.obs_vector(k=cur_x, layer=None) - if cur_l_smoothed == "X" - else _adata.obs_vector(k=cur_x, layer=cur_l_smoothed), - cur_y: _adata.obs_vector(cur_y), - } - ) - # points = points.loc[points.iloc[:, 0] > 0, :] - points.columns = [ - cur_x + " (" + cur_l_smoothed + ")", - cur_y, - ] - cur_title = cur_x - elif is_layer_keys(_adata, cur_x) and is_layer_keys(_adata, cur_y): - cur_x_, cur_y_ = ( - _adata[:, cur_b].layers[cur_x], - _adata[:, cur_b].layers[cur_y], - ) - points = pd.DataFrame({cur_x: flatten(cur_x_), cur_y: flatten(cur_y_)}) - # points = points.loc[points.iloc[:, 0] > 0, :] - points.columns = [cur_x, cur_y] - cur_title = cur_b - elif type(cur_x) in [anndata._core.views.ArrayView, np.ndarray] and type(cur_y) in [ - anndata._core.views.ArrayView, - np.ndarray, - ]: - points = pd.DataFrame({"x": flatten(cur_x), "y": flatten(cur_y)}) - points.columns = ["x", "y"] - cur_title = cur_b - else: - raise ValueError("Make sure your `x` and `y` are integers, gene names, column names in .obs, etc.") + points, cur_title = map_to_points( + _adata, + axis_x=cur_x, + axis_y=cur_y, + axis_z=cur_z, + basis_key=basis_key, + cur_b=cur_b, + cur_l_smoothed=cur_l_smoothed, + ) # https://stackoverflow.com/questions/4187185/how-can-i-check-if-my-python-object-is-a-number # answer from Boris. From 4e08d76cb20f28847117f2c8e2043508486b315f Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 21 Sep 2023 13:34:41 -0400 Subject: [PATCH 28/62] add pv scatters legend, title, axes --- dynamo/plot/scatters.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 2cfff3206..7da0b428c 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -1151,6 +1151,12 @@ def _plot_basis_layer_pv(cur_b, cur_l): pvdataset.point_data["colors"] = np.stack(colors.values) pl.add_points(pvdataset, scalars="colors", preference='point', rgb=True) + type_color_dict = {cell_type: cell_color for cell_type, cell_color in zip(labels, colors.values)} + type_color_pair = [[k, v] for k, v in type_color_dict.items()] + pl.add_legend(labels=type_color_pair) + + pl.add_text(cur_title) + pl.add_axes(xlabel=points.columns[0], ylabel=points.columns[1], zlabel=points.columns[2]) for cur_b in basis: for cur_l in layer: From 6c029c874db978f3589b9710d09fd5e291887fb0 Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 21 Sep 2023 14:11:49 -0400 Subject: [PATCH 29/62] add docstr and typing for scatters_pv --- dynamo/plot/scatters.py | 104 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 100 insertions(+), 4 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 7da0b428c..9c1af63ff 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -911,12 +911,41 @@ def _plot_basis_layer(cur_b, cur_l): return axes_list if total_panels > 1 else ax -def map_to_points(_adata: AnnData, axis_x: str, axis_y: str, axis_z: str, basis_key: str, cur_b: str, cur_l_smoothed: str): +def map_to_points( + _adata: AnnData, + axis_x: str, + axis_y: str, + axis_z: str, + basis_key: str, + cur_b: str, + cur_l_smoothed: str, +) -> Tuple[pd.DataFrame, str]: + """A helper function to map the given axis to corresponding coordinates in current embedding space. + Args: + _adata: an AnnData object. + axis_x: the column index of the low dimensional embedding for the x-axis in current space. + axis_y: the column index of the low dimensional embedding for the y-axis in current space. + axis_z: the column index of the low dimensional embedding for the z-axis in current space. + basis_key: the basis key constructed by current basis and layer. + cur_b: the current basis key representing the reduced dimension. + cur_l_smoothed: the smoothed layer of data to use. + + Returns: + The 3D DataFrame with coordinates of each sample and the title of the plot. + """ gene_title = [] anno_title = [] - def _map_cur_axis(cur): + def _map_cur_axis(cur: str) -> Tuple[np.ndarray, str]: + """A helper function to map an axis. + + Args: + cur: the current axis to map. + + Returns: + The coordinates and the column names. + """ nonlocal gene_title, anno_title if is_gene_name(_adata, cur): @@ -1000,6 +1029,67 @@ def scatters_pv( save_kwargs: Dict[str, Any] = {}, **kwargs, ): + """Plot an embedding as points with Pyvista. Currently only 3D input is supported. For 2D data, `scatters` is a + better alternative. + + The function will use the colors from matplotlib to keep consistence with other plotting functions. + + Args: + adata: an AnnData object. + basis: the reduced dimension stored in adata.obsm. The specific basis key will be constructed in the following + priority if exits: 1) specific layer input + basis 2) X_ + basis 3) basis. E.g. if basis is PCA, `scatters` + is going to look for 1) if specific layer is spliced, `spliced_pca` 2) `X_pca` (dynamo convention) 3) `pca`. + Defaults to "umap". + x: the column index of the low dimensional embedding for the x-axis. Defaults to 0. + y: the column index of the low dimensional embedding for the y-axis. Defaults to 1. + z: the column index of the low dimensional embedding for the z-axis. Defaults to 2. + color: any column names or gene expression, etc. that will be used for coloring cells. Defaults to "ntr". + layer: the layer of data to use for the scatter plot. Defaults to "X". + highlights: the color group that will be highlighted. If highlights is a list of lists, each list is relate to + each color element. Defaults to None. + labels: an array of labels (assumed integer or categorical), one for each data sample. This will be used for + coloring the points in the plot according to their label. Note that this option is mutually exclusive to the + `values` option. Defaults to None. + values: an array of values (assumed float or continuous), one for each sample. This will be used for coloring + the points in the plot according to a colorscale associated to the total range of values. Note that this + option is mutually exclusive to the `labels` option. Defaults to None. + theme: A color theme to use for plotting. A small set of predefined themes are provided which have relatively + good aesthetics. Defaults to None. + cmap: The name of a matplotlib colormap to use for coloring or shading points. If no labels or values are passed + this will be used for shading points according to density (largely only of relevance for very large + datasets). If values are passed this will be used for shading according the value. Note that if theme is + passed then this value will be overridden by the corresponding option of the theme. Defaults to None. + background: the color of the background. Usually this will be either 'white' or 'black', but any color name will + work. Ideally one wants to match this appropriately to the colors being used for points etc. This is one of + the things that themes handle for you. Note that if theme is passed then this value will be overridden by + the corresponding option of the theme. Defaults to None. + color_key: the method to assign colors to categoricals. This can either be an explicit dict mapping labels to + colors (as strings of form '#RRGGBB'), or an array like object providing one color for each distinct + category being provided in `labels`. Either way this mapping will be used to color points according to the + label. Note that if theme is passed then this value will be overridden by the corresponding option of the + theme. Defaults to None. + color_key_cmap: the name of a matplotlib colormap to use for categorical coloring. If an explicit `color_key` is + not given a color mapping for categories can be generated from the label list and selecting a matching list + of colors from the given colormap. Note that if theme is passed then this value will be overridden by the + corresponding option of the theme. Defaults to None. + use_smoothed: whether to use smoothed values (i.e. M_s / M_u instead of spliced / unspliced, etc.). Defaults to + True. + sym_c: whether do you want to make the limits of continuous color to be symmetric, normally this should be used + for plotting velocity, jacobian, curl, divergence or other types of data with both positive or negative + values. Defaults to False. + smooth: whether do you want to further smooth data and how much smoothing do you want. If it is `False`, no + smoothing will be applied. If `True`, smoothing based on one-step diffusion of connectivity matrix + (`.uns['moment_cnn']`) will be applied. If a number larger than 1, smoothing will be based on `smooth` steps + of diffusion. + save_show_or_return: whether to save, show or return the figure. If "both", it will save and plot the figure at + the same time. If "all", the figure will be saved, displayed and the associated axis and other object will + be return. Defaults to "show". + save_kwargs: A dictionary that will be passed to the saving function. By default, it is an empty dictionary + and the saving function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + "title": PyVista Export, "raster": True, "painter": True} as its parameters. Otherwise, you can provide a + dictionary that properly modify those keys according to your needs. Defaults to {}. + **kwargs: any other kwargs that would be passed to `Plotter.add_points()`. + """ try: import pyvista as pv except ImportError: @@ -1041,7 +1131,13 @@ def scatters_pv( pl = pv.Plotter() - def _plot_basis_layer_pv(cur_b, cur_l): + def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: + """A helper function for plotting a specific basis/layer data + + Args: + cur_b: current basis + cur_l: current layer + """ nonlocal background, adata, cmap, x, y, z, labels, sym_c, values if cur_l in ["acceleration", "curvature", "divergence", "velocity_S", "velocity_T"]: @@ -1149,7 +1245,7 @@ def _plot_basis_layer_pv(cur_b, cur_l): pvdataset = pv.PolyData(points.values) pvdataset.point_data["colors"] = np.stack(colors.values) - pl.add_points(pvdataset, scalars="colors", preference='point', rgb=True) + pl.add_points(pvdataset, scalars="colors", preference='point', rgb=True, **kwargs) type_color_dict = {cell_type: cell_color for cell_type, cell_color in zip(labels, colors.values)} type_color_pair = [[k, v] for k, v in type_color_dict.items()] From ffdadd59425212d3ea48fb77cefc476a416ce082 Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 21 Sep 2023 15:36:43 -0400 Subject: [PATCH 30/62] implement subplots --- dynamo/plot/scatters.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 9c1af63ff..0928976b4 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -1129,7 +1129,28 @@ def scatters_pv( if type(basis) is str: basis = [basis] - pl = pv.Plotter() + n_c, n_l, n_b, n_x, n_y, n_z = ( + 1 if color is None else len(color), + 1 if layer is None else len(layer), + 1 if basis is None else len(basis), + 1 if x is None else 1 if type(x) in [anndata._core.views.ArrayView, np.ndarray] else len(x), + 1 if y is None else 1 if type(y) in [anndata._core.views.ArrayView, np.ndarray] else len(y), + 1 if z is None else 1 if type(z) in [anndata._core.views.ArrayView, np.ndarray] else len(z), + ) + + total_panels, ncols = ( + n_c * n_l * n_b * n_x * n_y * n_z, + max([n_c, n_l, n_b, n_x, n_y, n_z]), + ) + + nrow, ncol = int(np.ceil(total_panels / ncols)), ncols + subplot_indices = [[i, j] for i in range(nrow) for j in range(ncol)] + cur_subplot = 0 + + if total_panels == 1: + pl = pv.Plotter() + else: + pl = pv.Plotter(shape=(nrow, ncol)) def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: """A helper function for plotting a specific basis/layer data @@ -1138,7 +1159,7 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: cur_b: current basis cur_l: current layer """ - nonlocal background, adata, cmap, x, y, z, labels, sym_c, values + nonlocal background, adata, cmap, cur_subplot, x, y, z, labels, sym_c, values if cur_l in ["acceleration", "curvature", "divergence", "velocity_S", "velocity_T"]: cur_l_smoothed = cur_l @@ -1243,6 +1264,9 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: sym_c=sym_c, ) + if total_panels > 1: + pl.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) + pvdataset = pv.PolyData(points.values) pvdataset.point_data["colors"] = np.stack(colors.values) pl.add_points(pvdataset, scalars="colors", preference='point', rgb=True, **kwargs) @@ -1254,6 +1278,8 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: pl.add_text(cur_title) pl.add_axes(xlabel=points.columns[0], ylabel=points.columns[1], zlabel=points.columns[2]) + cur_subplot += 1 + for cur_b in basis: for cur_l in layer: main_debug("Plotting basis:%s, layer: %s" % (str(basis), str(layer))) From 0a0d1adc412eab209f1d3d03d79e3b00c602afca Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 21 Sep 2023 15:40:19 -0400 Subject: [PATCH 31/62] debug title for integer axes input --- dynamo/plot/scatters.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 0928976b4..a9d4f855d 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -917,6 +917,7 @@ def map_to_points( axis_y: str, axis_z: str, basis_key: str, + cur_c: str, cur_b: str, cur_l_smoothed: str, ) -> Tuple[pd.DataFrame, str]: @@ -928,6 +929,7 @@ def map_to_points( axis_y: the column index of the low dimensional embedding for the y-axis in current space. axis_z: the column index of the low dimensional embedding for the z-axis in current space. basis_key: the basis key constructed by current basis and layer. + cur_c: the current key to color the data. cur_b: the current basis key representing the reduced dimension. cur_l_smoothed: the smoothed layer of data to use. @@ -979,6 +981,10 @@ def _map_cur_axis(cur: str) -> Tuple[np.ndarray, str]: } ) points.columns = [x_col_name, y_col_name, z_col_name] + + cur_title = cur_c + + return points, cur_title elif type(axis_x) in [anndata._core.views.ArrayView, np.ndarray] and type(axis_y) in [ anndata._core.views.ArrayView, np.ndarray, @@ -1185,7 +1191,6 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: for cur_c in color: main_debug("coloring scatter of cur_c: %s" % str(cur_c)) - cur_title = cur_c _color = _get_adata_color_vec(adata, cur_l, cur_c) @@ -1209,6 +1214,7 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: axis_y=cur_y, axis_z=cur_z, basis_key=basis_key, + cur_c=cur_c, cur_b=cur_b, cur_l_smoothed=cur_l_smoothed, ) From 62adedfa8ed55b857f2d1a80ec1723cec1e2a871 Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 21 Sep 2023 16:34:56 -0400 Subject: [PATCH 32/62] create pv 3d vectors plot --- dynamo/plot/scVectorField.py | 206 +++++++++++++++++++++++------------ 1 file changed, 134 insertions(+), 72 deletions(-) diff --git a/dynamo/plot/scVectorField.py b/dynamo/plot/scVectorField.py index 77b744384..6c7760cbc 100755 --- a/dynamo/plot/scVectorField.py +++ b/dynamo/plot/scVectorField.py @@ -27,11 +27,12 @@ from ..tools.utils import update_dict from ..vectorfield.topography import VectorField from ..vectorfield.utils import vecfld_from_adata -from .scatters import docstrings, scatters +from .scatters import docstrings, scatters, scatters_pv from .utils import ( _get_adata_color_vec, default_quiver_args, quiver_autoscaler, + retrieve_plot_save_path, save_fig, set_arrow_alpha, set_stream_line_alpha, @@ -61,6 +62,7 @@ def cell_wise_vectors_3d( V: Union[np.ndarray, spmatrix] = None, color: Union[str, List[str]] = None, layer: str = "X", + plot_method: str = "pv", background: Optional[str] = "white", ncols: int = 4, figsize: Tuple[float] = (6, 4), @@ -211,6 +213,12 @@ def cell_wise_vectors_3d( from matplotlib import rcParams from matplotlib.colors import to_hex + if plot_method == "pv": + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + def add_axis_label(ax, labels): ax.set_xlabel(labels[0]) ax.set_ylabel(labels[1]) @@ -314,83 +322,137 @@ def add_axis_label(ax, labels): nrows += 1 ncols = min(ncols, len(color)) - axes_list, color_list, _ = scatters( - adata=adata, - basis=basis, - x=x, - y=y, - z=z, - color=color, - layer=layer, - highlights=highlights, - labels=labels, - values=values, - theme=theme, - cmap=cmap, - color_key=color_key, - color_key_cmap=color_key_cmap, - background=background, - ncols=ncols, - pointsize=pointsize, - figsize=figsize, - show_legend=None, - use_smoothed=use_smoothed, - aggregate=aggregate, - show_arrowed_spines=show_arrowed_spines, - ax=ax, - sort=sort, - save_show_or_return="return", - frontier=frontier, - projection="3d", - **s_kwargs_dict, - return_all=True, - ) + if plot_method == "pv": + pl = scatters_pv( + adata=adata, + basis=basis, + x=x, + y=y, + z=z, + color=color, + layer=layer, + highlights=highlights, + labels=labels, + values=values, + cmap=cmap, + theme=theme, + background=background, + color_key=color_key, + color_key_cmap=color_key_cmap, + use_smoothed=use_smoothed, + save_show_or_return="return", + render_points_as_spheres=True, + ) + point_cloud = pv.PolyData(np.column_stack((x0.values, x1.values, x2.values))) + point_cloud['vectors'] = np.column_stack((v0.values, v1.values, v2.values)) - if type(axes_list) != list: - axes_list = [axes_list] - color_list = [color_list] + arrows = point_cloud.glyph( + orient='vectors', + factor=3.5, + ) + pl.add_mesh(arrows, color='lightblue') + + main_debug("show, return or save...") + if save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": "scatters_pv", + "ext": "pdf", + "title": 'PyVista Export', + "raster": True, + "painter": True, + } - for i in range(len(color)): - ax = axes_list[i] - ax.set_title(color[i]) - cmap_3d = [element for element in color_list[i]] + [element for element in color_list[i] for _ in range(2)] - main_debug("color vec len: " + str(len(cmap_3d))) - ax.view_init(elev=elev, azim=azim) - ax.quiver( - x0, - x1, - x2, - v0, - v1, - v2, - color=cmap_3d, - # facecolors=color_vec, - **quiver_3d_kwargs, + s_kwargs = update_dict(s_kwargs, save_kwargs) + + saving_path = retrieve_plot_save_path(path=s_kwargs["path"], prefix=s_kwargs["prefix"], ext=s_kwargs["ext"]) + pl.save_graphic(saving_path, title=s_kwargs["title"], raster=s_kwargs["raster"], + painter=s_kwargs["painter"]) + + if save_show_or_return in ["show", "both", "all"]: + pl.show() + + if save_show_or_return in ["return", "all"]: + return pl + + else: + axes_list, color_list, _ = scatters( + adata=adata, + basis=basis, + x=x, + y=y, + z=z, + color=color, + layer=layer, + highlights=highlights, + labels=labels, + values=values, + theme=theme, + cmap=cmap, + color_key=color_key, + color_key_cmap=color_key_cmap, + background=background, + ncols=ncols, + pointsize=pointsize, + figsize=figsize, + show_legend=None, + use_smoothed=use_smoothed, + aggregate=aggregate, + show_arrowed_spines=show_arrowed_spines, + ax=ax, + sort=sort, + save_show_or_return="return", + frontier=frontier, + projection="3d", + **s_kwargs_dict, + return_all=True, ) - ax.set_title(titles[i]) - ax.set_facecolor(background) - add_axis_label(ax, axis_labels) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "cell_wise_vectors_3d", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) + if type(axes_list) != list: + axes_list = [axes_list] + color_list = [color_list] - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False + for i in range(len(color)): + ax = axes_list[i] + ax.set_title(color[i]) + cmap_3d = [element for element in color_list[i]] + [element for element in color_list[i] for _ in range(2)] + main_debug("color vec len: " + str(len(cmap_3d))) + ax.view_init(elev=elev, azim=azim) + ax.quiver( + x0, + x1, + x2, + v0, + v1, + v2, + color=cmap_3d, + # facecolors=color_vec, + **quiver_3d_kwargs, + ) + ax.set_title(titles[i]) + ax.set_facecolor(background) + add_axis_label(ax, axis_labels) + + if save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": "cell_wise_vectors_3d", + "dpi": None, + "ext": "pdf", + "transparent": True, + "close": True, + "verbose": True, + } + s_kwargs = update_dict(s_kwargs, save_kwargs) - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.show() - if save_show_or_return in ["return", "all"]: - return axes_list + if save_show_or_return in ["both", "all"]: + s_kwargs["close"] = False + + save_fig(**s_kwargs) + if save_show_or_return in ["show", "both", "all"]: + plt.show() + if save_show_or_return in ["return", "all"]: + return axes_list def grid_vectors_3d(): From 1df5adb7300f74a437d5e9f0977981589d9d6d20 Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 21 Sep 2023 16:46:59 -0400 Subject: [PATCH 33/62] color the pyvista 3d vector --- dynamo/plot/scVectorField.py | 5 +++-- dynamo/plot/scatters.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/dynamo/plot/scVectorField.py b/dynamo/plot/scVectorField.py index 6c7760cbc..0b11773a7 100755 --- a/dynamo/plot/scVectorField.py +++ b/dynamo/plot/scVectorField.py @@ -323,7 +323,7 @@ def add_axis_label(ax, labels): ncols = min(ncols, len(color)) if plot_method == "pv": - pl = scatters_pv( + pl, colors_list = scatters_pv( adata=adata, basis=basis, x=x, @@ -345,12 +345,13 @@ def add_axis_label(ax, labels): ) point_cloud = pv.PolyData(np.column_stack((x0.values, x1.values, x2.values))) point_cloud['vectors'] = np.column_stack((v0.values, v1.values, v2.values)) + point_cloud.point_data["colors"] = np.stack(colors_list[0].values) arrows = point_cloud.glyph( orient='vectors', factor=3.5, ) - pl.add_mesh(arrows, color='lightblue') + pl.add_mesh(arrows, scalars="colors", preference='point', rgb=True) main_debug("show, return or save...") if save_show_or_return in ["save", "both", "all"]: diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index a9d4f855d..068962b23 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -1152,6 +1152,7 @@ def scatters_pv( nrow, ncol = int(np.ceil(total_panels / ncols)), ncols subplot_indices = [[i, j] for i in range(nrow) for j in range(ncol)] cur_subplot = 0 + colors_list = [] if total_panels == 1: pl = pv.Plotter() @@ -1269,6 +1270,7 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: background=background, sym_c=sym_c, ) + colors_list.append(colors) if total_panels > 1: pl.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) @@ -1312,4 +1314,4 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: pl.show() if save_show_or_return in ["return", "all"]: - return pl + return pl, colors_list From dbcaf0e8a63512efa45cdfc0c35164e93d81c128 Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 21 Sep 2023 18:18:31 -0400 Subject: [PATCH 34/62] debug scatters_pv color with multiple colors --- dynamo/plot/scatters.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 068962b23..5c0c4ab7e 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -1203,7 +1203,7 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: % (cur_c), indent_level=2, ) - _values = values + _labels, _values = None, None _adata = adata for cur_x, cur_y, cur_z in zip(x, y, z): # here x / y are arrays @@ -1225,7 +1225,7 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: is_not_continuous = not isinstance(_color[0], Number) or _color.dtype.name == "category" if is_not_continuous: - labels = np.asarray(_color) if is_categorical_dtype(_color) else _color + _labels = np.asarray(_color) if is_categorical_dtype(_color) else _color if theme is None: if background in ["#ffffff", "black"]: _theme_ = "glasbey_dark" @@ -1251,6 +1251,11 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: if labels is not None and values is not None: raise ValueError("Conflicting options; only one of labels or values should be set") + if labels is not None or values is not None: + _labels = labels + _values = values + main_info("`Color` will be ignored because labels/values is provided.") + if smooth and not is_not_continuous: main_debug("smooth and not continuous") knn = _adata.obsp["moments_con"] @@ -1259,10 +1264,11 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: if smooth in [1, True] else calc_1nd_moment(values, knn**smooth)[0] ) + _values = values - colors, _, _ = calculate_colors( - points, - labels=labels, + colors, color_type, _ = calculate_colors( + points.values, + labels=_labels, values=_values, cmap=_cmap, color_key=color_key, @@ -1276,12 +1282,18 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: pl.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) pvdataset = pv.PolyData(points.values) - pvdataset.point_data["colors"] = np.stack(colors.values) - pl.add_points(pvdataset, scalars="colors", preference='point', rgb=True, **kwargs) - type_color_dict = {cell_type: cell_color for cell_type, cell_color in zip(labels, colors.values)} - type_color_pair = [[k, v] for k, v in type_color_dict.items()] - pl.add_legend(labels=type_color_pair) + if color_type == "labels": + type_color_dict = {cell_type: cell_color for cell_type, cell_color in zip(_labels, colors.values)} + type_color_pair = [[k, v] for k, v in type_color_dict.items()] + pl.add_legend(labels=type_color_pair) + colors_list.append(colors.values) + pvdataset.point_data["colors"] = np.stack(colors.values) + else: + colors_list.append(colors) + pvdataset.point_data["colors"] = np.stack(colors) + + pl.add_points(pvdataset, scalars="colors", preference='point', rgb=True, **kwargs) pl.add_text(cur_title) pl.add_axes(xlabel=points.columns[0], ylabel=points.columns[1], zlabel=points.columns[2]) From 389435da60b28df537f63f1a9441b862c23bb7c2 Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 21 Sep 2023 18:36:10 -0400 Subject: [PATCH 35/62] implement multiple plots in 3d pv vecters --- dynamo/plot/scVectorField.py | 24 ++++++++++++++++++------ dynamo/plot/scatters.py | 1 - 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/dynamo/plot/scVectorField.py b/dynamo/plot/scVectorField.py index 0b11773a7..e3a95e7ee 100755 --- a/dynamo/plot/scVectorField.py +++ b/dynamo/plot/scVectorField.py @@ -343,15 +343,27 @@ def add_axis_label(ax, labels): save_show_or_return="return", render_points_as_spheres=True, ) + point_cloud = pv.PolyData(np.column_stack((x0.values, x1.values, x2.values))) point_cloud['vectors'] = np.column_stack((v0.values, v1.values, v2.values)) - point_cloud.point_data["colors"] = np.stack(colors_list[0].values) - arrows = point_cloud.glyph( - orient='vectors', - factor=3.5, - ) - pl.add_mesh(arrows, scalars="colors", preference='point', rgb=True) + r, c = pl.shape[0], pl.shape[1] + subplot_indices = [[i, j] for i in range(r) for j in range(c)] + cur_subplot = 0 + + for i in range(len(color)): + point_cloud.point_data["colors"] = np.stack(colors_list[i]) + + arrows = point_cloud.glyph( + orient='vectors', + factor=3.5, + ) + + if r * c != 1: + pl.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) + cur_subplot += 1 + + pl.add_mesh(arrows, scalars="colors", preference='point', rgb=True) main_debug("show, return or save...") if save_show_or_return in ["save", "both", "all"]: diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 5c0c4ab7e..09b43f875 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -1276,7 +1276,6 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: background=background, sym_c=sym_c, ) - colors_list.append(colors) if total_panels > 1: pl.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) From aaa16123ac49c2838cf4fea857853cb68233026d Mon Sep 17 00:00:00 2001 From: sichao Date: Fri, 22 Sep 2023 11:44:14 -0400 Subject: [PATCH 36/62] optimize scatters pv color --- dynamo/plot/scatters.py | 30 +++++++++++++----------------- dynamo/plot/utils.py | 2 +- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 09b43f875..8e7c97756 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -1166,7 +1166,7 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: cur_b: current basis cur_l: current layer """ - nonlocal background, adata, cmap, cur_subplot, x, y, z, labels, sym_c, values + nonlocal background, adata, cmap, cur_subplot, sym_c if cur_l in ["acceleration", "curvature", "divergence", "velocity_S", "velocity_T"]: cur_l_smoothed = cur_l @@ -1204,13 +1204,12 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: indent_level=2, ) _labels, _values = None, None - _adata = adata for cur_x, cur_y, cur_z in zip(x, y, z): # here x / y are arrays main_debug("handling coordinates, cur_x: %s, cur_y: %s, cur_z: %s" % (cur_x, cur_y, cur_z)) points, cur_title = map_to_points( - _adata, + adata, axis_x=cur_x, axis_y=cur_y, axis_z=cur_z, @@ -1252,19 +1251,18 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: raise ValueError("Conflicting options; only one of labels or values should be set") if labels is not None or values is not None: - _labels = labels - _values = values + _labels = labels.copy() + _values = values.copy() main_info("`Color` will be ignored because labels/values is provided.") if smooth and not is_not_continuous: main_debug("smooth and not continuous") - knn = _adata.obsp["moments_con"] - values = ( - calc_1nd_moment(values, knn)[0] + knn = adata.obsp["moments_con"] + _values = ( + calc_1nd_moment(_values, knn)[0] if smooth in [1, True] - else calc_1nd_moment(values, knn**smooth)[0] + else calc_1nd_moment(_values, knn**smooth)[0] ) - _values = values colors, color_type, _ = calculate_colors( points.values, @@ -1280,19 +1278,17 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: if total_panels > 1: pl.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) + colors_list.append(colors) pvdataset = pv.PolyData(points.values) + pvdataset.point_data["colors"] = np.stack(colors) + pl.add_points(pvdataset, scalars="colors", preference='point', rgb=True, cmap=_cmap, **kwargs) if color_type == "labels": - type_color_dict = {cell_type: cell_color for cell_type, cell_color in zip(_labels, colors.values)} + type_color_dict = {cell_type: cell_color for cell_type, cell_color in zip(_labels, colors)} type_color_pair = [[k, v] for k, v in type_color_dict.items()] pl.add_legend(labels=type_color_pair) - colors_list.append(colors.values) - pvdataset.point_data["colors"] = np.stack(colors.values) else: - colors_list.append(colors) - pvdataset.point_data["colors"] = np.stack(colors) - - pl.add_points(pvdataset, scalars="colors", preference='point', rgb=True, **kwargs) + pl.add_scalar_bar() pl.add_text(cur_title) pl.add_axes(xlabel=points.columns[0], ylabel=points.columns[1], zlabel=points.columns[2]) diff --git a/dynamo/plot/utils.py b/dynamo/plot/utils.py index 9bf20c233..cc688a572 100755 --- a/dynamo/plot/utils.py +++ b/dynamo/plot/utils.py @@ -264,7 +264,7 @@ def calculate_colors( color_type = "midpoint" colors = plt.get_cmap(cmap)(0.5) - return (colors, color_type, None) if color_type != "labels" else (colors, color_type, legend_elements) + return (colors, color_type, None) if color_type != "labels" else (colors.values, color_type, legend_elements) # --------------------------------------------------------------------------------------------------- From 610b7c9c92d74ca7437d64c387b4251699bbcfe8 Mon Sep 17 00:00:00 2001 From: sichao Date: Fri, 22 Sep 2023 14:05:05 -0400 Subject: [PATCH 37/62] create todo --- dynamo/plot/scatters.py | 2 +- dynamo/plot/topography.py | 104 +------------------------------------- 2 files changed, 3 insertions(+), 103 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 8e7c97756..6ac665ff8 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -1288,7 +1288,7 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: type_color_pair = [[k, v] for k, v in type_color_dict.items()] pl.add_legend(labels=type_color_pair) else: - pl.add_scalar_bar() + pl.add_scalar_bar() # TODO: fix the bug that scalar bar only works in the first plot pl.add_text(cur_title) pl.add_axes(xlabel=points.columns[0], ylabel=points.columns[1], zlabel=points.columns[2]) diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index 34e7cf6be..41652a95e 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -527,7 +527,7 @@ def plot_fixed_points( cmap=_cmap, vmin=0, zorder=5, - ) + ) # TODO: Figure out the user warning that no data for colormapping provided via 'c'. txt = ax.text( *Xss[i], repr(i), @@ -1361,6 +1361,7 @@ def topography( return axes_list if len(axes_list) > 1 else axes_list[0] +# TODO: Implement more `terms` like streamline and trajectory for 3D topography @docstrings.with_indent(4) def topography_3D( adata: AnnData, @@ -1498,16 +1499,6 @@ def topography_3D( warnings.simplefilter("ignore") VectorField(adata, fps_basis, map_topography=True, n=n) - # elif "VecFld2D" not in adata.uns[uns_key].keys(): - # with warnings.catch_warnings(): - # warnings.simplefilter("ignore") - # - # _topology(adata, basis, VecFld=None) - # elif "VecFld2D" in adata.uns[uns_key].keys() and type(adata.uns[uns_key]["VecFld2D"]) == str: - # with warnings.catch_warnings(): - # warnings.simplefilter("ignore") - # - # _topology(adata, basis, VecFld=None) vecfld_dict, vecfld = vecfld_from_adata(adata, basis) @@ -1636,42 +1627,6 @@ def topography_3D( t = np.linspace(0, max_t, 10 ** (np.min((int(np.log10(max_t)), 8)))) - integration_direction = ( - "both" if fate == "both" else "forward" if fate == "future" else "backward" if fate == "history" else "both" - ) - - if "streamline" in terms: - if approx: - axes_list[i] = plot_flow_field( - vecfld, - xlim, - ylim, - background=_background, - start_points=init_states, - integration_direction=integration_direction, - density=density, - linewidth=linewidth, - streamline_color=streamline_color, - streamline_alpha=streamline_alpha, - color_start_points=color_start_points, - ax=axes_list[i], - **streamline_kwargs_dict, - ) - else: - axes_list[i] = plot_flow_field( - vecfld, - xlim, - ylim, - background=_background, - density=density, - linewidth=linewidth, - streamline_color=streamline_color, - streamline_alpha=streamline_alpha, - color_start_points=color_start_points, - ax=axes_list[i], - **streamline_kwargs_dict, - ) - if "fixed_points" in terms: axes_list[i] = plot_fixed_points( fps_vecfld, @@ -1682,61 +1637,6 @@ def topography_3D( cmap=marker_cmap, ) - if "separatrices" in terms: - axes_list[i] = plot_separatrix(vecfld, xlim, ylim, t=t, background=_background, ax=axes_list[i]) - - if init_states is not None and "trajectory" in terms: - if not approx: - axes_list[i] = plot_traj( - vecfld.func, - init_states, - t, - background=_background, - integration_direction=integration_direction, - ax=axes_list[i], - ) - - # show quivers for the init_states cells - if init_states is not None and "quiver" in terms: - X = init_states - V /= 3 * quiver_autoscaler(X, V) - - df = pd.DataFrame({"x": X[:, 0], "y": X[:, 1], "u": V[:, 0], "v": V[:, 1]}) - - if quiver_size is None: - quiver_size = 1 - if _background in ["#ffffff", "black"]: - edgecolors = "white" - else: - edgecolors = "black" - - head_w, head_l, ax_l, scale = default_quiver_args(quiver_size, quiver_length) # - quiver_kwargs = { - "angles": "xy", - "scale": scale, - "scale_units": "xy", - "width": 0.0005, - "headwidth": head_w, - "headlength": head_l, - "headaxislength": ax_l, - "minshaft": 1, - "minlength": 1, - "pivot": "tail", - "linewidth": 0.1, - "edgecolors": edgecolors, - "alpha": 1, - "zorder": 7, - } - quiver_kwargs = update_dict(quiver_kwargs, q_kwargs_dict) - # axes_list[i].quiver(X_grid[:, 0], X_grid[:, 1], V_grid[:, 0], V_grid[:, 1], **quiver_kwargs) - axes_list[i].quiver( - df.iloc[:, 0], - df.iloc[:, 1], - df.iloc[:, 2], - df.iloc[:, 3], - **quiver_kwargs, - ) # color='red', facecolors='gray' - if save_show_or_return in ["save", "both", "all"]: s_kwargs = { "path": None, From 3a63363bb378f72ca842bf6bfdb26cee2596cac2 Mon Sep 17 00:00:00 2001 From: sichao Date: Fri, 22 Sep 2023 15:19:04 -0400 Subject: [PATCH 38/62] create pyvista option in 3d topography --- dynamo/plot/topography.py | 225 ++++++++++++++++++++++++-------------- 1 file changed, 141 insertions(+), 84 deletions(-) diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index 41652a95e..87fe08c8a 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -26,12 +26,13 @@ from ..vectorfield.topography import topography as _topology # , compute_separatrices from ..vectorfield.utils import vecfld_from_adata from ..vectorfield.vector_calculus import curl, divergence -from .scatters import docstrings, scatters +from .scatters import docstrings, scatters, scatters_pv from .utils import ( _plot_traj, _select_font_color, default_quiver_args, quiver_autoscaler, + retrieve_plot_save_path, save_fig, set_arrow_alpha, set_stream_line_alpha, @@ -1369,8 +1370,10 @@ def topography_3D( fps_basis: str = "umap", x: int = 0, y: int = 1, + z: int = 2, color: str = "ntr", layer: str = "X", + plot_method: str = "matplotlib", highlights: Optional[list] = None, labels: Optional[list] = None, values: Optional[list] = None, @@ -1566,99 +1569,153 @@ def topography_3D( V = vector_field_function(init_states, vecfld_dict, [0, 1]) - # plt.figure(facecolor=_background) - axes_list, color_list, font_color = scatters( - adata=adata, - basis=basis, - x=x, - y=y, - color=color, - layer=layer, - highlights=highlights, - labels=labels, - values=values, - theme=theme, - cmap=cmap, - color_key=color_key, - color_key_cmap=color_key_cmap, - background=_background, - ncols=ncols, - pointsize=pointsize, - figsize=figsize, - show_legend=show_legend, - use_smoothed=use_smoothed, - aggregate=aggregate, - show_arrowed_spines=show_arrowed_spines, - ax=ax, - sort=sort, - save_show_or_return="return", - frontier=frontier, - projection="3d", - **s_kwargs_dict, - return_all=True, - ) - - if type(axes_list) != list: - axes_list, color_list, font_color = ( - [axes_list], - [color_list], - [font_color], + if plot_method == "pv": + pl, colors_list = scatters_pv( + adata=adata, + basis=basis, + x=x, + y=y, + z=z, + color=color, + layer=layer, + highlights=highlights, + labels=labels, + values=values, + cmap=cmap, + theme=theme, + background=background, + color_key=color_key, + color_key_cmap=color_key_cmap, + use_smoothed=use_smoothed, + save_show_or_return="return", + style='points_gaussian', + opacity=0.5, ) - for i in range(len(axes_list)): - # ax = axes_list[i] - axes_list[i].set_xlabel(basis + "_1") - axes_list[i].set_ylabel(basis + "_2") - axes_list[i].set_zlabel(basis + "_3") - # axes_list[i].set_aspect("equal") + r, c = pl.shape[0], pl.shape[1] + subplot_indices = [[i, j] for i in range(r) for j in range(c)] + cur_subplot = 0 + + for i in range(len(color)): + if r * c != 1: + pl.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) + cur_subplot += 1 + + if save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": "scatters_pv", + "ext": "pdf", + "title": 'PyVista Export', + "raster": True, + "painter": True, + } - # Build the plot - axes_list[i].set_xlim(xlim) - axes_list[i].set_ylim(ylim) - axes_list[i].set_zlim(zlim) + s_kwargs = update_dict(s_kwargs, save_kwargs) - axes_list[i].set_facecolor(background) + saving_path = retrieve_plot_save_path(path=s_kwargs["path"], prefix=s_kwargs["prefix"], ext=s_kwargs["ext"]) + pl.save_graphic(saving_path, title=s_kwargs["title"], raster=s_kwargs["raster"], painter=s_kwargs["painter"]) - if t is None: - if vecfld_dict["grid_V"] is None: - max_t = np.max((np.diff(xlim), np.diff(ylim))) / np.min(np.abs(vecfld_dict["V"][:, :2])) - else: - max_t = np.max((np.diff(xlim), np.diff(ylim))) / np.min(np.abs(vecfld_dict["grid_V"])) + if save_show_or_return in ["show", "both", "all"]: + pl.show() - t = np.linspace(0, max_t, 10 ** (np.min((int(np.log10(max_t)), 8)))) + if save_show_or_return in ["return", "all"]: + return pl, colors_list + else: + # plt.figure(facecolor=_background) + axes_list, color_list, font_color = scatters( + adata=adata, + basis=basis, + x=x, + y=y, + z=z, + color=color, + layer=layer, + highlights=highlights, + labels=labels, + values=values, + theme=theme, + cmap=cmap, + color_key=color_key, + color_key_cmap=color_key_cmap, + background=_background, + ncols=ncols, + pointsize=pointsize, + figsize=figsize, + show_legend=show_legend, + use_smoothed=use_smoothed, + aggregate=aggregate, + show_arrowed_spines=show_arrowed_spines, + ax=ax, + sort=sort, + save_show_or_return="return", + frontier=frontier, + projection="3d", + **s_kwargs_dict, + return_all=True, + ) - if "fixed_points" in terms: - axes_list[i] = plot_fixed_points( - fps_vecfld, - fps_vecfld_dict, - background=_background, - ax=axes_list[i], - markersize=markersize, - cmap=marker_cmap, + if type(axes_list) != list: + axes_list, color_list, font_color = ( + [axes_list], + [color_list], + [font_color], ) + for i in range(len(axes_list)): + # ax = axes_list[i] - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "topography", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) + axes_list[i].set_xlabel(basis + "_1") + axes_list[i].set_ylabel(basis + "_2") + axes_list[i].set_zlabel(basis + "_3") + # axes_list[i].set_aspect("equal") - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False + # Build the plot + axes_list[i].set_xlim(xlim) + axes_list[i].set_ylim(ylim) + axes_list[i].set_zlim(zlim) - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") + axes_list[i].set_facecolor(background) - plt.tight_layout() + if t is None: + if vecfld_dict["grid_V"] is None: + max_t = np.max((np.diff(xlim), np.diff(ylim))) / np.min(np.abs(vecfld_dict["V"][:, :2])) + else: + max_t = np.max((np.diff(xlim), np.diff(ylim))) / np.min(np.abs(vecfld_dict["grid_V"])) - plt.show() - if save_show_or_return in ["return", "all"]: - return axes_list if len(axes_list) > 1 else axes_list[0] \ No newline at end of file + t = np.linspace(0, max_t, 10 ** (np.min((int(np.log10(max_t)), 8)))) + + if "fixed_points" in terms: + axes_list[i] = plot_fixed_points( + fps_vecfld, + fps_vecfld_dict, + background=_background, + ax=axes_list[i], + markersize=markersize, + cmap=marker_cmap, + ) + + if save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": "topography", + "dpi": None, + "ext": "pdf", + "transparent": True, + "close": True, + "verbose": True, + } + s_kwargs = update_dict(s_kwargs, save_kwargs) + + if save_show_or_return in ["both", "all"]: + s_kwargs["close"] = False + + save_fig(**s_kwargs) + if save_show_or_return in ["show", "both", "all"]: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + plt.tight_layout() + + plt.show() + if save_show_or_return in ["return", "all"]: + return axes_list if len(axes_list) > 1 else axes_list[0] \ No newline at end of file From 8519c3efd44850422c91b9a6c5d5c51c87c4758a Mon Sep 17 00:00:00 2001 From: sichao Date: Fri, 22 Sep 2023 16:14:47 -0400 Subject: [PATCH 39/62] implement pyvista fixed points plot in topography --- dynamo/plot/topography.py | 180 +++++++++++++++++++++++++------------- 1 file changed, 119 insertions(+), 61 deletions(-) diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index 87fe08c8a..5828d555f 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -423,6 +423,7 @@ def plot_fixed_points( background: Optional[str] = None, save_show_or_return: Literal["save", "show", "return"] = "return", save_kwargs: Dict[str, Any] = {}, + plot_method: str = "matplotlib", ax: Optional[Axes] = None, **kwargs, ) -> Optional[Axes]: @@ -457,6 +458,12 @@ def plot_fixed_points( from matplotlib import markers, rcParams from matplotlib.colors import to_hex + if plot_method == "pv": + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + if background is None: _background = rcParams.get("figure.facecolor") _background = to_hex(_background) if type(_background) is tuple else _background @@ -511,61 +518,104 @@ def plot_fixed_points( vecfld_dict["confidence"], ) - if ax is None: - ax = plt.gca() - cm = matplotlib.cm.get_cmap(_cmap) if type(_cmap) is str else _cmap - for i in range(len(Xss)): - cur_ftype = ftype[i] - marker_ = markers.MarkerStyle(marker=marker, fillstyle=filltype[int(cur_ftype + 1)]) - ax.scatter( - *Xss[i], - marker=marker_, - s=markersize, - c=c if confidence is None else np.array(cm(confidence[i])).reshape(1, -1), - edgecolor=_select_font_color(_background), - linewidths=1, - cmap=_cmap, - vmin=0, - zorder=5, - ) # TODO: Figure out the user warning that no data for colormapping provided via 'c'. - txt = ax.text( - *Xss[i], - repr(i), - c=("black" if cur_ftype == -1 else "blue" if cur_ftype == 0 else "red"), - horizontalalignment="center", - verticalalignment="center", - zorder=6, - weight="bold", - ) - txt.set_path_effects( - [ - PathEffects.Stroke(linewidth=1.5, foreground=_background, alpha=0.8), - PathEffects.Normal(), - ] - ) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "plot_fixed_points", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) + if plot_method == "pv": - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False + points = pv.PolyData(Xss) + colors = [c if confidence is None else np.array(cm(confidence[i])).reshape(1, -1) for i in range(len(confidence))] + points.point_data["colors"] = colors - save_fig(**s_kwargs) - if save_show_or_return in ["show", "both", "all"]: - plt.tight_layout() - plt.show() - if save_show_or_return in ["return", "all"]: - return ax + points["Labels"] = [str(i) for i in range(points.n_points)] + + r, c = ax.shape[0], ax.shape[1] + subplot_indices = [[i, j] for i in range(r) for j in range(c)] + cur_subplot = 0 + + for i in range(r * c): + + if r * c != 1: + ax.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) + cur_subplot += 1 + + ax.add_points(points, render_points_as_spheres=True, rgba=True, point_size=15) + ax.add_point_labels(points, "Labels", font_size=36, show_points=False) # TODO: only work for the first plot + + if save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": "scatters_pv", + "ext": "pdf", + "title": 'PyVista Export', + "raster": True, + "painter": True, + } + + s_kwargs = update_dict(s_kwargs, save_kwargs) + + saving_path = retrieve_plot_save_path(path=s_kwargs["path"], prefix=s_kwargs["prefix"], ext=s_kwargs["ext"]) + ax.save_graphic(saving_path, title=s_kwargs["title"], raster=s_kwargs["raster"], painter=s_kwargs["painter"]) + + if save_show_or_return in ["show", "both", "all"]: + ax.show() + + if save_show_or_return in ["return", "all"]: + return ax + else: + if ax is None: + ax = plt.gca() + + for i in range(len(Xss)): + cur_ftype = ftype[i] + marker_ = markers.MarkerStyle(marker=marker, fillstyle=filltype[int(cur_ftype + 1)]) + ax.scatter( + *Xss[i], + marker=marker_, + s=markersize, + c=c if confidence is None else np.array(cm(confidence[i])).reshape(1, -1), + edgecolor=_select_font_color(_background), + linewidths=1, + cmap=_cmap, + vmin=0, + zorder=5, + ) # TODO: Figure out the user warning that no data for colormapping provided via 'c'. + txt = ax.text( + *Xss[i], + repr(i), + c=("black" if cur_ftype == -1 else "blue" if cur_ftype == 0 else "red"), + horizontalalignment="center", + verticalalignment="center", + zorder=6, + weight="bold", + ) + txt.set_path_effects( + [ + PathEffects.Stroke(linewidth=1.5, foreground=_background, alpha=0.8), + PathEffects.Normal(), + ] + ) + + if save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": "plot_fixed_points", + "dpi": None, + "ext": "pdf", + "transparent": True, + "close": True, + "verbose": True, + } + s_kwargs = update_dict(s_kwargs, save_kwargs) + + if save_show_or_return in ["both", "all"]: + s_kwargs["close"] = False + + save_fig(**s_kwargs) + if save_show_or_return in ["show", "both", "all"]: + plt.tight_layout() + plt.show() + if save_show_or_return in ["return", "all"]: + return ax def plot_traj( @@ -1439,6 +1489,12 @@ def topography_3D( from matplotlib import rcParams from matplotlib.colors import to_hex + if plot_method == "pv": + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + if type(color) == str: color = [color] @@ -1588,18 +1644,20 @@ def topography_3D( color_key_cmap=color_key_cmap, use_smoothed=use_smoothed, save_show_or_return="return", - style='points_gaussian', - opacity=0.5, + # style='points_gaussian', + opacity=0.8, ) - r, c = pl.shape[0], pl.shape[1] - subplot_indices = [[i, j] for i in range(r) for j in range(c)] - cur_subplot = 0 - - for i in range(len(color)): - if r * c != 1: - pl.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) - cur_subplot += 1 + if "fixed_points" in terms: + pl = plot_fixed_points( + fps_vecfld, + fps_vecfld_dict, + background=_background, + ax=pl, + markersize=markersize, + cmap=marker_cmap, + plot_method="pv", + ) if save_show_or_return in ["save", "both", "all"]: s_kwargs = { From 5a7264db2e679c357076cc9145b00419e80c5fd5 Mon Sep 17 00:00:00 2001 From: sichao Date: Fri, 22 Sep 2023 16:36:01 -0400 Subject: [PATCH 40/62] update docstr --- dynamo/plot/scVectorField.py | 3 ++- dynamo/plot/topography.py | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/dynamo/plot/scVectorField.py b/dynamo/plot/scVectorField.py index e3a95e7ee..64f3ee3d9 100755 --- a/dynamo/plot/scVectorField.py +++ b/dynamo/plot/scVectorField.py @@ -62,7 +62,7 @@ def cell_wise_vectors_3d( V: Union[np.ndarray, spmatrix] = None, color: Union[str, List[str]] = None, layer: str = "X", - plot_method: str = "pv", + plot_method: Literal["pv", "matplotlib"] = "pv", background: Optional[str] = "white", ncols: int = 4, figsize: Tuple[float] = (6, 4), @@ -134,6 +134,7 @@ def cell_wise_vectors_3d( V: the velocity array. If None, the array would be determined by `vkey` provided. Defaults to None. color: any column names or gene expression, etc. that will be used for coloring cells. Defaults to "ntr". layer: the layer of data to use for the scatter plot. Defaults to "X". + plot_method: the method to plot 3D vectors. Options include `pv` (pyvista) and `matplotlib`. background: the background color of the figure. Defaults to "white". ncols: the number of sub-plot columns. Defaults to 4. figsize: the size of each sub-plot panel. Defaults to (6, 4). diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index 5828d555f..19d9a2f03 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -423,7 +423,7 @@ def plot_fixed_points( background: Optional[str] = None, save_show_or_return: Literal["save", "show", "return"] = "return", save_kwargs: Dict[str, Any] = {}, - plot_method: str = "matplotlib", + plot_method: Literal["pv", "matplotlib"] = "matplotlib", ax: Optional[Axes] = None, **kwargs, ) -> Optional[Axes]: @@ -446,7 +446,9 @@ def plot_fixed_points( and the save_fig function will use the {"path": None, "prefix": 'plot_fixed_points', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. - ax: the matplotlib axes used for plotting. Default is to use the current axis. Defaults to None. + plot_method: the method to plot 3D points. Options include `pv` (pyvista) and `matplotlib`. + ax: the matplotlib axes or pyvista plotter used for plotting. Default is to use the current axis. Defaults to + None. Returns: None would be returned by default. If `save_show_or_return` is set to be 'return', the Axes of the generated @@ -1423,7 +1425,7 @@ def topography_3D( z: int = 2, color: str = "ntr", layer: str = "X", - plot_method: str = "matplotlib", + plot_method: Literal["pv", "matplotlib"] = "matplotlib", highlights: Optional[list] = None, labels: Optional[list] = None, values: Optional[list] = None, From 3260a350b9a546cf5e09bb2b6678509ad8ac6b7b Mon Sep 17 00:00:00 2001 From: sichao Date: Fri, 22 Sep 2023 17:54:09 -0400 Subject: [PATCH 41/62] create helper class for saving pyvista plotter --- dynamo/plot/scVectorField.py | 28 ++++--------------- dynamo/plot/scatters.py | 28 +++++-------------- dynamo/plot/topography.py | 52 ++++++++-------------------------- dynamo/plot/utils.py | 54 ++++++++++++++++++++++++++++++++++-- 4 files changed, 77 insertions(+), 85 deletions(-) diff --git a/dynamo/plot/scVectorField.py b/dynamo/plot/scVectorField.py index 64f3ee3d9..2bde45040 100755 --- a/dynamo/plot/scVectorField.py +++ b/dynamo/plot/scVectorField.py @@ -34,6 +34,7 @@ quiver_autoscaler, retrieve_plot_save_path, save_fig, + save_pyvista_plotter, set_arrow_alpha, set_stream_line_alpha, ) @@ -366,28 +367,11 @@ def add_axis_label(ax, labels): pl.add_mesh(arrows, scalars="colors", preference='point', rgb=True) - main_debug("show, return or save...") - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "scatters_pv", - "ext": "pdf", - "title": 'PyVista Export', - "raster": True, - "painter": True, - } - - s_kwargs = update_dict(s_kwargs, save_kwargs) - - saving_path = retrieve_plot_save_path(path=s_kwargs["path"], prefix=s_kwargs["prefix"], ext=s_kwargs["ext"]) - pl.save_graphic(saving_path, title=s_kwargs["title"], raster=s_kwargs["raster"], - painter=s_kwargs["painter"]) - - if save_show_or_return in ["show", "both", "all"]: - pl.show() - - if save_show_or_return in ["return", "all"]: - return pl + return save_pyvista_plotter( + pl=pl, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) else: axes_list, color_list, _ = scatters( diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 6ac665ff8..e9dde5368 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -40,6 +40,7 @@ is_list_of_lists, retrieve_plot_save_path, save_fig, + save_pyvista_plotter, ) docstrings = DocstringProcessor() @@ -1301,24 +1302,9 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: main_debug("colors: %s" % (str(color))) _plot_basis_layer_pv(cur_b, cur_l) - main_debug("show, return or save...") - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "scatters_pv", - "ext": "pdf", - "title": 'PyVista Export', - "raster": True, - "painter": True, - } - - s_kwargs = update_dict(s_kwargs, save_kwargs) - - saving_path = retrieve_plot_save_path(path=s_kwargs["path"], prefix=s_kwargs["prefix"], ext=s_kwargs["ext"]) - pl.save_graphic(saving_path, title=s_kwargs["title"], raster=s_kwargs["raster"], painter=s_kwargs["painter"]) - - if save_show_or_return in ["show", "both", "all"]: - pl.show() - - if save_show_or_return in ["return", "all"]: - return pl, colors_list + return save_pyvista_plotter( + pl=pl, + colors_list=colors_list, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index 19d9a2f03..70dc9ee72 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -34,6 +34,7 @@ quiver_autoscaler, retrieve_plot_save_path, save_fig, + save_pyvista_plotter, set_arrow_alpha, set_stream_line_alpha, ) @@ -543,26 +544,11 @@ def plot_fixed_points( ax.add_points(points, render_points_as_spheres=True, rgba=True, point_size=15) ax.add_point_labels(points, "Labels", font_size=36, show_points=False) # TODO: only work for the first plot - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "scatters_pv", - "ext": "pdf", - "title": 'PyVista Export', - "raster": True, - "painter": True, - } - - s_kwargs = update_dict(s_kwargs, save_kwargs) - - saving_path = retrieve_plot_save_path(path=s_kwargs["path"], prefix=s_kwargs["prefix"], ext=s_kwargs["ext"]) - ax.save_graphic(saving_path, title=s_kwargs["title"], raster=s_kwargs["raster"], painter=s_kwargs["painter"]) - - if save_show_or_return in ["show", "both", "all"]: - ax.show() - - if save_show_or_return in ["return", "all"]: - return ax + return save_pyvista_plotter( + pl=ax, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) else: if ax is None: ax = plt.gca() @@ -1661,26 +1647,12 @@ def topography_3D( plot_method="pv", ) - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "scatters_pv", - "ext": "pdf", - "title": 'PyVista Export', - "raster": True, - "painter": True, - } - - s_kwargs = update_dict(s_kwargs, save_kwargs) - - saving_path = retrieve_plot_save_path(path=s_kwargs["path"], prefix=s_kwargs["prefix"], ext=s_kwargs["ext"]) - pl.save_graphic(saving_path, title=s_kwargs["title"], raster=s_kwargs["raster"], painter=s_kwargs["painter"]) - - if save_show_or_return in ["show", "both", "all"]: - pl.show() - - if save_show_or_return in ["return", "all"]: - return pl, colors_list + return save_pyvista_plotter( + pl=pl, + colors_list=colors_list, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) else: # plt.figure(facecolor=_background) axes_list, color_list, font_color = scatters( diff --git a/dynamo/plot/utils.py b/dynamo/plot/utils.py index cc688a572..82bdbe10f 100755 --- a/dynamo/plot/utils.py +++ b/dynamo/plot/utils.py @@ -4,7 +4,7 @@ # import matplotlib.tri as tri import warnings -from typing import Optional +from typing import Dict, List, Optional, Tuple from warnings import warn import matplotlib @@ -19,7 +19,7 @@ from ..configuration import _themes from ..dynamo_logger import main_debug -from ..tools.utils import integrate_vf # integrate_vf_ivp +from ..tools.utils import integrate_vf, update_dict # integrate_vf_ivp # --------------------------------------------------------------------------------------------------- @@ -1682,6 +1682,56 @@ def retrieve_plot_save_path( return savepath +def save_pyvista_plotter( + pl, + colors_list: Optional[List] = None, + save_show_or_return: str = "show", + save_kwargs: Optional[Dict] = None, +) -> Optional[Tuple]: + """Save, show or return the pyvista.Plotter. + + Args: + pl: target plotter object. + colors_list: corresponding the list of colors mapping. + save_show_or_return: whether to save, show or return the figure. If "both", it will save and plot the figure at + the same time. If "all", the figure will be saved, displayed and the associated axis and other object will + be return. Defaults to "show". + save_kwargs: A dictionary that will be passed to the saving function. By default, it is an empty dictionary + and the saving function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + "title": PyVista Export, "raster": True, "painter": True} as its parameters. Otherwise, you can provide a + dictionary that properly modify those keys according to your needs. Defaults to {}. + + Returns: + If `save_show_or_return` is `return` or `all`, the plotter object and list of color mapping will be returned. + """ + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + + main_debug("show, return or save...") + if save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": "scatters_pv", + "ext": "pdf", + "title": 'PyVista Export', + "raster": True, + "painter": True, + } + + s_kwargs = update_dict(s_kwargs, save_kwargs) + + saving_path = retrieve_plot_save_path(path=s_kwargs["path"], prefix=s_kwargs["prefix"], ext=s_kwargs["ext"]) + pl.save_graphic(saving_path, title=s_kwargs["title"], raster=s_kwargs["raster"], painter=s_kwargs["painter"]) + + if save_show_or_return in ["show", "both", "all"]: + pl.show() + + if save_show_or_return in ["return", "all"]: + return (pl, colors_list) if colors_list else pl + + # --------------------------------------------------------------------------------------------------- def alpha_shape(x, y, alpha): # Start Using SHAPELY From b7c3b29ab962f14e3ac81fc68a1da7346140bf5d Mon Sep 17 00:00:00 2001 From: sichao Date: Mon, 25 Sep 2023 15:06:04 -0400 Subject: [PATCH 42/62] create BaseAnim class --- dynamo/movie/fate.py | 208 ++++++++++++++++++++++++++++++++----------- 1 file changed, 156 insertions(+), 52 deletions(-) diff --git a/dynamo/movie/fate.py b/dynamo/movie/fate.py index a40592993..249a0b6ff 100755 --- a/dynamo/movie/fate.py +++ b/dynamo/movie/fate.py @@ -1,6 +1,11 @@ import warnings from typing import Optional, Union +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + import matplotlib import numpy as np from anndata import AnnData @@ -12,9 +17,7 @@ from .utils import remove_particles -class StreamFuncAnim: - """Animating cell fate commitment prediction via reconstructed vector field function.""" - +class BaseAnim: def __init__( self, adata: AnnData, @@ -76,56 +79,8 @@ def __init__( ------- A class that contains .fig attribute and .update, .init_background that can be used to produce an animation of the prediction of cell fate commitment. - - Examples 1 - ---------- - >>> from matplotlib import animation - >>> progenitor = adata.obs_names[adata.obs.clusters == 'cluster_1'] - >>> fate_progenitor = progenitor - >>> info_genes = adata.var_names[adata.var.use_for_transition] - >>> dyn.pd.fate(adata, basis='umap', init_cells=fate_progenitor, interpolation_num=100, direction='forward', - ... inverse_transform=False, average=False) - >>> instance = dyn.mv.StreamFuncAnim(adata=adata, fig=None, ax=None) - >>> anim = animation.FuncAnimation(instance.fig, instance.update, init_func=instance.init_background, - ... frames=np.arange(100), interval=100, blit=True) - >>> from IPython.core.display import display, HTML - >>> HTML(anim.to_jshtml()) # embedding to jupyter notebook. - >>> anim.save('fate_ani.gif',writer="imagemagick") # save as gif file. - - Examples 2 - ---------- - >>> from matplotlib import animation - >>> progenitor = adata.obs_names[adata.obs.clusters == 'cluster_1'] - >>> fate_progenitor = progenitor - >>> info_genes = adata.var_names[adata.var.use_for_transition] - >>> dyn.pd.fate(adata, basis='umap', init_cells=fate_progenitor, interpolation_num=100, direction='forward', - ... inverse_transform=False, average=False) - >>> fig, ax = plt.subplots() - >>> ax = dyn.pl.topography(adata_old, color='time', ax=ax, save_show_or_return='return', color_key_cmap='viridis') - >>> ax.set_xlim(xlim) - >>> ax.set_ylim(ylim) - >>> instance = dyn.mv.StreamFuncAnim(adata=adata, fig=fig, ax=ax) - >>> anim = animation.FuncAnimation(fig, instance.update, init_func=instance.init_background, - ... frames=np.arange(100), interval=100, blit=True) - >>> from IPython.core.display import display, HTML - >>> HTML(anim.to_jshtml()) # embedding to jupyter notebook. - >>> anim.save('fate_ani.gif',writer="imagemagick") # save as gif file. - - Examples 3 - ---------- - >>> from matplotlib import animation - >>> progenitor = adata.obs_names[adata.obs.clusters == 'cluster_1'] - >>> fate_progenitor = progenitor - >>> info_genes = adata.var_names[adata.var.use_for_transition] - >>> dyn.pd.fate(adata, basis='umap', init_cells=fate_progenitor, interpolation_num=100, direction='forward', - ... inverse_transform=False, average=False) - >>> dyn.mv.animate_fates(adata) - - See also:: :func:`animate_fates` """ - import matplotlib.pyplot as plt - self.adata = adata self.basis = basis self.fp_basis = basis if fp_basis is None else fp_basis @@ -192,6 +147,136 @@ def __init__( self.color = color self.frame_color = frame_color + +class StreamFuncAnim(BaseAnim): + """Animating cell fate commitment prediction via reconstructed vector field function.""" + + def __init__( + self, + adata: AnnData, + basis: str = "umap", + fp_basis: Union[str, None] = None, + dims: Optional[list] = None, + n_steps: int = 100, + cell_states: Union[int, list, None] = None, + color: str = "ntr", + fig: Optional[matplotlib.figure.Figure] = None, + ax: matplotlib.axes.Axes = None, + logspace: bool = False, + max_time: Optional[float] = None, + frame_color=None, + ): + """Animating cell fate commitment prediction via reconstructed vector field function. + + This class creates necessary components to produce an animation that describes the exact speed of a set of cells + at each time point, its movement in gene expression and the long range trajectory predicted by the reconstructed + vector field. Thus it provides intuitive visual understanding of the RNA velocity, speed, acceleration, and cell + fate commitment in action. + + This function is originally inspired by https://tonysyu.github.io/animating-particles-in-a-flow.html and relies on + animation module from matplotlib. Note that you may need to install `imagemagick` in order to properly show or save + the animation. See for example, http://louistiao.me/posts/notebooks/save-matplotlib-animations-as-gifs/ for more + details. + + Parameters + ---------- + adata: :class:`~anndata.AnnData` + AnnData object that already went through the fate prediction. + basis: `str` or None (default: `umap`) + The embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the reconstructed + trajectory will be projected back to high dimensional space via the `inverse_transform` function. + space. + fps_basis: `str` or None (default: `None`) + The basis that will be used for identifying or retrieving fixed points. Note that if `fps_basis` is + different from `basis`, the nearest cells of the fixed point from the `fps_basis` will be found and used to + visualize the position of the fixed point on `basis` embedding. + dims: `list` or `None` (default: `None') + The dimensions of low embedding space where cells will be drawn and it should corresponds to the space + fate prediction take place. + n_steps: `int` (default: `100`) + The number of times steps (frames) fate prediction will take. + cell_states: `int`, `list` or `None` (default: `None`) + The number of cells state that will be randomly selected (if `int`), the indices of the cells states (if + `list`) or all cell states which fate prediction executed (if `None`) + fig: `matplotlib.figure.Figure` or None (default: `None`) + The figure that will contain both the background and animated components. + ax: `matplotlib.Axis` (optional, default `None`) + The matplotlib axes object that will be used as background plot of the vector field animation. If `ax` + is None, `topography(adata, basis=basis, color=color, ax=ax, save_show_or_return='return')` will be used + to create an axes. + logspace: `bool` (default: `False`) + Whether or to sample time points linearly on log space. If not, the sorted unique set of all time points + from all cell states' fate prediction will be used and then evenly sampled up to `n_steps` time points. + + Returns + ------- + A class that contains .fig attribute and .update, .init_background that can be used to produce an animation + of the prediction of cell fate commitment. + + Examples 1 + ---------- + >>> from matplotlib import animation + >>> progenitor = adata.obs_names[adata.obs.clusters == 'cluster_1'] + >>> fate_progenitor = progenitor + >>> info_genes = adata.var_names[adata.var.use_for_transition] + >>> dyn.pd.fate(adata, basis='umap', init_cells=fate_progenitor, interpolation_num=100, direction='forward', + ... inverse_transform=False, average=False) + >>> instance = dyn.mv.StreamFuncAnim(adata=adata, fig=None, ax=None) + >>> anim = animation.FuncAnimation(instance.fig, instance.update, init_func=instance.init_background, + ... frames=np.arange(100), interval=100, blit=True) + >>> from IPython.core.display import display, HTML + >>> HTML(anim.to_jshtml()) # embedding to jupyter notebook. + >>> anim.save('fate_ani.gif',writer="imagemagick") # save as gif file. + + Examples 2 + ---------- + >>> from matplotlib import animation + >>> progenitor = adata.obs_names[adata.obs.clusters == 'cluster_1'] + >>> fate_progenitor = progenitor + >>> info_genes = adata.var_names[adata.var.use_for_transition] + >>> dyn.pd.fate(adata, basis='umap', init_cells=fate_progenitor, interpolation_num=100, direction='forward', + ... inverse_transform=False, average=False) + >>> fig, ax = plt.subplots() + >>> ax = dyn.pl.topography(adata_old, color='time', ax=ax, save_show_or_return='return', color_key_cmap='viridis') + >>> ax.set_xlim(xlim) + >>> ax.set_ylim(ylim) + >>> instance = dyn.mv.StreamFuncAnim(adata=adata, fig=fig, ax=ax) + >>> anim = animation.FuncAnimation(fig, instance.update, init_func=instance.init_background, + ... frames=np.arange(100), interval=100, blit=True) + >>> from IPython.core.display import display, HTML + >>> HTML(anim.to_jshtml()) # embedding to jupyter notebook. + >>> anim.save('fate_ani.gif',writer="imagemagick") # save as gif file. + + Examples 3 + ---------- + >>> from matplotlib import animation + >>> progenitor = adata.obs_names[adata.obs.clusters == 'cluster_1'] + >>> fate_progenitor = progenitor + >>> info_genes = adata.var_names[adata.var.use_for_transition] + >>> dyn.pd.fate(adata, basis='umap', init_cells=fate_progenitor, interpolation_num=100, direction='forward', + ... inverse_transform=False, average=False) + >>> dyn.mv.animate_fates(adata) + + See also:: :func:`animate_fates` + """ + + import matplotlib.pyplot as plt + + super().__init__( + adata=adata, + basis=basis, + fp_basis=fp_basis, + dims=dims, + n_steps=n_steps, + cell_states=cell_states, + color=color, + fig=fig, + ax=ax, + logspace=logspace, + max_time=max_time, + frame_color=frame_color, + ) + # Animation objects must create `fig` and `ax` attributes. if ax is None or fig is None: self.fig, self.ax = plt.subplots() @@ -207,7 +292,7 @@ def __init__( self.fig = fig self.ax = ax - (self.ln,) = self.ax.plot([], [], "ro", zs=[]) if X_data.shape[1] == 3 else self.ax.plot([], [], "ro") + (self.ln,) = self.ax.plot([], [], "ro", zs=[]) if len(dims) == 3 else self.ax.plot([], [], "ro") def init_background(self): return (self.ln,) @@ -415,3 +500,22 @@ def animate_fates( HTML(anim.to_jshtml()) # embedding to jupyter notebook. else: anim + + +def animate_fates_pv( + adata, + basis="umap", + dims=None, + n_steps=100, + cell_states=None, + color="ntr", + logspace=False, + max_time=None, + frame_color=None, + interval=100, + blit=True, + save_show_or_return="show", + save_kwargs={}, + **kwargs, +): + pass From ac48efea5964a1a9469d29243eccc1c15bb5f305 Mon Sep 17 00:00:00 2001 From: sichao Date: Mon, 25 Sep 2023 15:48:37 -0400 Subject: [PATCH 43/62] create PyvistaAnim class --- dynamo/movie/fate.py | 70 ++++++++++++++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 22 deletions(-) diff --git a/dynamo/movie/fate.py b/dynamo/movie/fate.py index 249a0b6ff..8a6df84c6 100755 --- a/dynamo/movie/fate.py +++ b/dynamo/movie/fate.py @@ -12,7 +12,7 @@ from scipy.integrate import odeint from ..dynamo_logger import main_info, main_tqdm, main_warning -from ..plot.topography import topography +from ..plot.topography import topography, topography_3D from ..vectorfield.scVectorField import SvcVectorField from .utils import remove_particles @@ -27,8 +27,6 @@ def __init__( n_steps: int = 100, cell_states: Union[int, list, None] = None, color: str = "ntr", - fig: Optional[matplotlib.figure.Figure] = None, - ax: matplotlib.axes.Axes = None, logspace: bool = False, max_time: Optional[float] = None, frame_color=None, @@ -270,8 +268,6 @@ def __init__( n_steps=n_steps, cell_states=cell_states, color=color, - fig=fig, - ax=ax, logspace=logspace, max_time=max_time, frame_color=frame_color, @@ -502,20 +498,50 @@ def animate_fates( anim -def animate_fates_pv( - adata, - basis="umap", - dims=None, - n_steps=100, - cell_states=None, - color="ntr", - logspace=False, - max_time=None, - frame_color=None, - interval=100, - blit=True, - save_show_or_return="show", - save_kwargs={}, - **kwargs, -): - pass +class PyvistaAnim(BaseAnim): + def __init__( + self, + adata: AnnData, + basis: str = "umap", + fp_basis: Union[str, None] = None, + dims: Optional[list] = None, + n_steps: int = 100, + cell_states: Union[int, list, None] = None, + color: str = "ntr", + pl=None, + logspace: bool = False, + max_time: Optional[float] = None, + frame_color=None, + filename: str = "fate_ani.gif", + ): + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + + super().__init__( + adata=adata, + basis=basis, + fp_basis=fp_basis, + dims=dims, + n_steps=n_steps, + cell_states=cell_states, + color=color, + logspace=logspace, + max_time=max_time, + frame_color=frame_color, + ) + + self.filename = filename + + if pl is None: + self.pl = topography_3D( + self.adata, + basis=self.basis, + fps_basis=self.fp_basis, + color=self.color, + ax=self.pl, + save_show_or_return="return", + ) + else: + self.pl = pl From f94d70f246e74efaa176478f3ba892c1cce6f6f2 Mon Sep 17 00:00:00 2001 From: sichao Date: Tue, 26 Sep 2023 11:51:54 -0400 Subject: [PATCH 44/62] pyvista animate WIP --- dynamo/movie/__init__.py | 2 +- dynamo/movie/fate.py | 29 +++++++++++++++++++++++++++-- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/dynamo/movie/__init__.py b/dynamo/movie/__init__.py index 4f795f84a..ee227bb7c 100644 --- a/dynamo/movie/__init__.py +++ b/dynamo/movie/__init__.py @@ -1,4 +1,4 @@ """Mapping Vector Field of Single Cells """ -from .fate import StreamFuncAnim, StreamFuncAnim3D, animate_fates +from .fate import PyvistaAnim, StreamFuncAnim, StreamFuncAnim3D, animate_fates diff --git a/dynamo/movie/fate.py b/dynamo/movie/fate.py index 8a6df84c6..d7017d40d 100755 --- a/dynamo/movie/fate.py +++ b/dynamo/movie/fate.py @@ -505,14 +505,14 @@ def __init__( basis: str = "umap", fp_basis: Union[str, None] = None, dims: Optional[list] = None, - n_steps: int = 100, + n_steps: int = 15, cell_states: Union[int, list, None] = None, color: str = "ntr", pl=None, logspace: bool = False, max_time: Optional[float] = None, frame_color=None, - filename: str = "fate_ani.gif", + filename: str = "fate_animation.gif", ): try: import pyvista as pv @@ -545,3 +545,28 @@ def __init__( ) else: self.pl = pl + + self.n_steps = n_steps + + def animate(self): + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + + pts = [i.tolist() for i in self.init_states] + + self.pl.open_gif(self.filename) + + mesh = pv.PolyData(self.init_states) + self.pl.add_mesh(mesh, color="red", render_points_as_spheres=True) + + for frame in range(0, self.n_steps): + pts = [self.displace(cur_pts, self.time_vec[frame])[1].tolist() for cur_pts in pts] + pts = np.asarray(pts) + pts = remove_particles(pts, self.xlim, self.ylim, self.zlim) + mesh.points = pv.PolyData(np.asarray(pts)).points + + self.pl.write_frame() # TODO: debug this + + self.pl.close() From 604308207e746f9355d16dc129c7b88fc6baaaff Mon Sep 17 00:00:00 2001 From: sichao Date: Tue, 26 Sep 2023 17:32:25 -0400 Subject: [PATCH 45/62] debug streamtube 3d --- dynamo/plot/streamtube.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dynamo/plot/streamtube.py b/dynamo/plot/streamtube.py index 535a1728b..7c7ba7171 100644 --- a/dynamo/plot/streamtube.py +++ b/dynamo/plot/streamtube.py @@ -139,8 +139,8 @@ def plot_3d_streamtube( X = adata.obsm["X_" + basis][:, dims] if "grid" in adata.uns["VecFld_" + basis].keys() and "grid_V" in adata.uns["VecFld_" + basis].keys(): - X_grid = adata.uns["VecFld_pca3"]["grid"] - velocity_grid = adata.uns["VecFld_pca3"]["grid_V"] + X_grid = adata.uns["VecFld_" + basis]["grid"] + velocity_grid = adata.uns["VecFld_" + basis]["grid_V"] else: grid_kwargs_dict = { "density": None, From 3d8badbe6509c0a74fc6d549ce8d8f90dd5c387e Mon Sep 17 00:00:00 2001 From: sichao Date: Tue, 26 Sep 2023 17:47:56 -0400 Subject: [PATCH 46/62] tune streamtube parameters --- dynamo/plot/streamtube.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dynamo/plot/streamtube.py b/dynamo/plot/streamtube.py index 7c7ba7171..fcacf3c58 100644 --- a/dynamo/plot/streamtube.py +++ b/dynamo/plot/streamtube.py @@ -179,10 +179,8 @@ def plot_3d_streamtube( y=adata[labels == init_group, :].obsm["X_" + basis][:125, 1], z=adata[labels == init_group, :].obsm["X_" + basis][:125, 2], ), - sizeref=3000, colorscale="Portland", showscale=False, - maxdisplayed=3000, ) ) From c030fd680e6106cdfe598cfc33dcfcb357d2a455 Mon Sep 17 00:00:00 2001 From: sichao Date: Wed, 27 Sep 2023 16:40:18 -0400 Subject: [PATCH 47/62] create plotly saving function --- dynamo/plot/utils.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/dynamo/plot/utils.py b/dynamo/plot/utils.py index 82bdbe10f..3b18a2045 100755 --- a/dynamo/plot/utils.py +++ b/dynamo/plot/utils.py @@ -1732,6 +1732,49 @@ def save_pyvista_plotter( return (pl, colors_list) if colors_list else pl +def save_plotly_figure( + pl, + colors_list: Optional[List] = None, + save_show_or_return: str = "show", + save_kwargs: Optional[Dict] = None, +) -> Optional[Tuple]: + """Save, show or return the plotly figure. + + Args: + pl: target plotly object. + colors_list: corresponding the list of colors mapping. + save_show_or_return: whether to save, show or return the figure. If "both", it will save and plot the figure at + the same time. If "all", the figure will be saved, displayed and the associated axis and other object will + be return. Defaults to "show". + save_kwargs: A dictionary that will be passed to the saving function. By default, it is an empty dictionary + and the saving function will use the {"path": None, "prefix": 'scatter', "ext": 'html'} as its parameters. + Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults + to {}. + + Returns: + If `save_show_or_return` is `return` or `all`, the figure and list of color mapping will be returned. + """ + + main_debug("show, return or save...") + if save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": "scatters_plotly", + "ext": "html", + } + + s_kwargs = update_dict(s_kwargs, save_kwargs) + + saving_path = retrieve_plot_save_path(path=s_kwargs["path"], prefix=s_kwargs["prefix"], ext=s_kwargs["ext"]) + pl.write_html(saving_path) + + if save_show_or_return in ["show", "both", "all"]: + pl.show() + + if save_show_or_return in ["return", "all"]: + return (pl, colors_list) if colors_list else pl + + # --------------------------------------------------------------------------------------------------- def alpha_shape(x, y, alpha): # Start Using SHAPELY From 382e3e6d947924c6a0596d86f8d08406dce5a9d9 Mon Sep 17 00:00:00 2001 From: sichao Date: Wed, 27 Sep 2023 16:40:39 -0400 Subject: [PATCH 48/62] add plotly 3d scatters options --- dynamo/plot/scatters.py | 90 +++++++++++++++++++++++++++++++---------- 1 file changed, 69 insertions(+), 21 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index e9dde5368..b6d8f01d0 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -40,6 +40,7 @@ is_list_of_lists, retrieve_plot_save_path, save_fig, + save_plotly_figure, save_pyvista_plotter, ) @@ -1021,6 +1022,7 @@ def scatters_pv( z: Union[int, str] = 2, color: str = "ntr", layer: str = "X", + plot_method: str = "pv", highlights: Optional[list] = None, labels: Optional[list] = None, values: Optional[list] = None, @@ -1097,10 +1099,21 @@ def scatters_pv( dictionary that properly modify those keys according to your needs. Defaults to {}. **kwargs: any other kwargs that would be passed to `Plotter.add_points()`. """ - try: - import pyvista as pv - except ImportError: - raise ImportError("Please install pyvista first.") + + if plot_method == "pv": + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + elif plot_method == "plotly": + try: + import plotly.express as px + import plotly.graph_objects as go + from plotly.subplots import make_subplots + except ImportError: + raise ImportError("Please install plotly first.") + else: + raise NotImplementedError("Current plot method not supported.") if type(x) in [int, str]: x = [x] @@ -1156,9 +1169,14 @@ def scatters_pv( colors_list = [] if total_panels == 1: - pl = pv.Plotter() + pl = pv.Plotter() if plot_method == "pv" else make_subplots(rows=1, cols=1, specs=[[{"type": "scatter3d"}]]) else: - pl = pv.Plotter(shape=(nrow, ncol)) + pl = ( + pv.Plotter(shape=(nrow, ncol)) + if plot_method == "pv" + else + make_subplots(rows=nrow, cols=ncol, specs=[[{"type": "scatter3d"} for _ in range(ncol)] for _ in range(nrow)]) + ) def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: """A helper function for plotting a specific basis/layer data @@ -1276,23 +1294,48 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: sym_c=sym_c, ) - if total_panels > 1: - pl.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) - colors_list.append(colors) - pvdataset = pv.PolyData(points.values) - pvdataset.point_data["colors"] = np.stack(colors) - pl.add_points(pvdataset, scalars="colors", preference='point', rgb=True, cmap=_cmap, **kwargs) - - if color_type == "labels": - type_color_dict = {cell_type: cell_color for cell_type, cell_color in zip(_labels, colors)} - type_color_pair = [[k, v] for k, v in type_color_dict.items()] - pl.add_legend(labels=type_color_pair) - else: - pl.add_scalar_bar() # TODO: fix the bug that scalar bar only works in the first plot - pl.add_text(cur_title) - pl.add_axes(xlabel=points.columns[0], ylabel=points.columns[1], zlabel=points.columns[2]) + if plot_method == "pv": + if total_panels > 1: + pl.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) + + pvdataset = pv.PolyData(points.values) + pvdataset.point_data["colors"] = np.stack(colors) + pl.add_points(pvdataset, scalars="colors", preference='point', rgb=True, cmap=_cmap, **kwargs) + + if color_type == "labels": + type_color_dict = {cell_type: cell_color for cell_type, cell_color in zip(_labels, colors)} + type_color_pair = [[k, v] for k, v in type_color_dict.items()] + pl.add_legend(labels=type_color_pair) + else: + pl.add_scalar_bar() # TODO: fix the bug that scalar bar only works in the first plot + + pl.add_text(cur_title) + pl.add_axes(xlabel=points.columns[0], ylabel=points.columns[1], zlabel=points.columns[2]) + elif plot_method == "plotly": + + pl.add_trace( + go.Scatter3d( + x=points.iloc[:, 0], + y=points.iloc[:, 1], + z=points.iloc[:, 2], + mode="markers", + marker=dict( + color=colors, + ), + text=_labels if color_type == "labels" else _values, + ), + row=subplot_indices[cur_subplot][0] + 1, col=subplot_indices[cur_subplot][1] + 1, + ) + + pl.update_layout( + scene=dict( + xaxis_title=points.columns[0], + yaxis_title=points.columns[1], + zaxis_title=points.columns[2] + ), + ) cur_subplot += 1 @@ -1307,4 +1350,9 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: colors_list=colors_list, save_show_or_return=save_show_or_return, save_kwargs=save_kwargs, + ) if plot_method == "pv" else save_plotly_figure( + pl=pl, + colors_list=colors_list, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, ) From 0e0337bf4fd0c993e1ecb09242db3e2f7acf9219 Mon Sep 17 00:00:00 2001 From: sichao Date: Wed, 27 Sep 2023 16:41:30 -0400 Subject: [PATCH 49/62] rename scatters_pv to scatters_interactive --- dynamo/plot/__init__.py | 4 ++-- dynamo/plot/scVectorField.py | 4 ++-- dynamo/plot/scatters.py | 2 +- dynamo/plot/topography.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dynamo/plot/__init__.py b/dynamo/plot/__init__.py index f8b620ca3..937c57833 100755 --- a/dynamo/plot/__init__.py +++ b/dynamo/plot/__init__.py @@ -30,7 +30,7 @@ show_fraction, variance_explained, ) -from .scatters import scatters, scatters_pv +from .scatters import scatters, scatters_interactive from .scPotential import show_landscape from .sctransform import sctransform_plot_fit, plot_residual_var from .scVectorField import ( # , plot_LIC_gray @@ -81,7 +81,7 @@ "quiver_autoscaler", "save_fig", "scatters", - "scatters_pv", + "scatters_interactive", "basic_stats", "show_fraction", "feature_genes", diff --git a/dynamo/plot/scVectorField.py b/dynamo/plot/scVectorField.py index 2bde45040..83b386c6d 100755 --- a/dynamo/plot/scVectorField.py +++ b/dynamo/plot/scVectorField.py @@ -27,7 +27,7 @@ from ..tools.utils import update_dict from ..vectorfield.topography import VectorField from ..vectorfield.utils import vecfld_from_adata -from .scatters import docstrings, scatters, scatters_pv +from .scatters import docstrings, scatters, scatters_interactive from .utils import ( _get_adata_color_vec, default_quiver_args, @@ -325,7 +325,7 @@ def add_axis_label(ax, labels): ncols = min(ncols, len(color)) if plot_method == "pv": - pl, colors_list = scatters_pv( + pl, colors_list = scatters_interactive( adata=adata, basis=basis, x=x, diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index b6d8f01d0..595eabce4 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -1014,7 +1014,7 @@ def _map_cur_axis(cur: str) -> Tuple[np.ndarray, str]: return points, cur_title -def scatters_pv( +def scatters_interactive( adata: AnnData, basis: str = "umap", x: Union[int, str] = 0, diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index 70dc9ee72..734100acd 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -26,7 +26,7 @@ from ..vectorfield.topography import topography as _topology # , compute_separatrices from ..vectorfield.utils import vecfld_from_adata from ..vectorfield.vector_calculus import curl, divergence -from .scatters import docstrings, scatters, scatters_pv +from .scatters import docstrings, scatters, scatters_interactive from .utils import ( _plot_traj, _select_font_color, @@ -1614,7 +1614,7 @@ def topography_3D( V = vector_field_function(init_states, vecfld_dict, [0, 1]) if plot_method == "pv": - pl, colors_list = scatters_pv( + pl, colors_list = scatters_interactive( adata=adata, basis=basis, x=x, From 7a6d4ea48408bca55485418721c932ab56dac2b2 Mon Sep 17 00:00:00 2001 From: sichao Date: Wed, 27 Sep 2023 18:23:45 -0400 Subject: [PATCH 50/62] add plotly cone option --- dynamo/plot/scVectorField.py | 62 ++++++++++++++++++++++++++++++++++++ dynamo/plot/scatters.py | 1 + 2 files changed, 63 insertions(+) diff --git a/dynamo/plot/scVectorField.py b/dynamo/plot/scVectorField.py index 83b386c6d..2340035c9 100755 --- a/dynamo/plot/scVectorField.py +++ b/dynamo/plot/scVectorField.py @@ -34,6 +34,7 @@ quiver_autoscaler, retrieve_plot_save_path, save_fig, + save_plotly_figure, save_pyvista_plotter, set_arrow_alpha, set_stream_line_alpha, @@ -220,6 +221,13 @@ def cell_wise_vectors_3d( import pyvista as pv except ImportError: raise ImportError("Please install pyvista first.") + elif plot_method == "plotly": + try: + import plotly.graph_objects as go + except ImportError: + raise ImportError("Please install plotly first.") + else: + raise NotImplementedError("Current plot method not supported.") def add_axis_label(ax, labels): ax.set_xlabel(labels[0]) @@ -373,6 +381,60 @@ def add_axis_label(ax, labels): save_kwargs=save_kwargs, ) + elif plot_method == "plotly": + pl, colors_list = scatters_interactive( + adata=adata, + basis=basis, + x=x, + y=y, + z=z, + color=color, + layer=layer, + plot_method = "plotly", + highlights=highlights, + labels=labels, + values=values, + cmap=cmap, + theme=theme, + background=background, + color_key=color_key, + color_key_cmap=color_key_cmap, + use_smoothed=use_smoothed, + save_show_or_return="return", + opacity=0.5, + ) + + r, c = pl._get_subplot_rows_columns() + subplot_indices = [[i, j] for i in range(list(r)[-1]) for j in range(list(c)[-1])] + cur_subplot = 0 + + for i in range(len(color)): + # colors = [[index, "rgb({},{},{})".format(int(row[0] * 255), int(row[1] * 255), int(row[2] * 255))] for index, row in enumerate(colors_list[i])] + + pl.add_trace( + go.Cone( + x=x0.values, + y=x1.values, + z=x2.values, + u=v0.values, + v=v1.values, + w=v2.values, + colorscale='Blues', + # colorscale=colors, + sizemode="absolute", + sizeref=1, + ), + row=subplot_indices[cur_subplot][0] + 1, col=subplot_indices[cur_subplot][1] + 1, + ) + + cur_subplot += 1 + + return save_plotly_figure( + pl=pl, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) + else: axes_list, color_list, _ = scatters( adata=adata, diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 595eabce4..ca93259aa 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -1325,6 +1325,7 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: color=colors, ), text=_labels if color_type == "labels" else _values, + **kwargs, ), row=subplot_indices[cur_subplot][0] + 1, col=subplot_indices[cur_subplot][1] + 1, ) From 4dbba0b4386c93d097a9f8247bd94cefc0a0f48e Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 28 Sep 2023 15:16:08 -0400 Subject: [PATCH 51/62] implement plotly option in 3d topography --- dynamo/plot/topography.py | 81 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index 734100acd..185107eda 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -34,6 +34,7 @@ quiver_autoscaler, retrieve_plot_save_path, save_fig, + save_plotly_figure, save_pyvista_plotter, set_arrow_alpha, set_stream_line_alpha, @@ -522,11 +523,11 @@ def plot_fixed_points( ) cm = matplotlib.cm.get_cmap(_cmap) if type(_cmap) is str else _cmap + colors = [c if confidence is None else np.array(cm(confidence[i])) for i in range(len(confidence))] if plot_method == "pv": points = pv.PolyData(Xss) - colors = [c if confidence is None else np.array(cm(confidence[i])).reshape(1, -1) for i in range(len(confidence))] points.point_data["colors"] = colors points["Labels"] = [str(i) for i in range(points.n_points)] @@ -549,6 +550,39 @@ def plot_fixed_points( save_show_or_return=save_show_or_return, save_kwargs=save_kwargs, ) + elif plot_method == "plotly": + try: + import plotly.graph_objects as go + except ImportError: + raise ImportError("Please install plotly first.") + + r, c = ax._get_subplot_rows_columns() + r, c = list(r)[-1], list(c)[-1] + subplot_indices = [[i, j] for i in range(r) for j in range(c)] + cur_subplot = 0 + + for i in range(r * c): + ax.add_trace( + go.Scatter3d( + x=Xss[:, 0], + y=Xss[:, 1], + z=Xss[:, 2], + mode="markers+text", + marker=dict( + color=colors, + size=15, + ), + text=[str(i) for i in range(len(Xss))], + **kwargs, + ), + row=subplot_indices[cur_subplot][0] + 1, col=subplot_indices[cur_subplot][1] + 1, + ) + + return save_plotly_figure( + pl=ax, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) else: if ax is None: ax = plt.gca() @@ -1653,6 +1687,51 @@ def topography_3D( save_show_or_return=save_show_or_return, save_kwargs=save_kwargs, ) + elif plot_method == "plotly": + try: + import plotly.graph_objects as go + except ImportError: + raise ImportError("Please install plotly first.") + + pl, colors_list = scatters_interactive( + adata=adata, + basis=basis, + x=x, + y=y, + z=z, + color=color, + layer=layer, + plot_method="plotly", + highlights=highlights, + labels=labels, + values=values, + cmap=cmap, + theme=theme, + background=background, + color_key=color_key, + color_key_cmap=color_key_cmap, + use_smoothed=use_smoothed, + save_show_or_return="return", + opacity=0.8, + ) + + if "fixed_points" in terms: + pl = plot_fixed_points( + fps_vecfld, + fps_vecfld_dict, + background=_background, + ax=pl, + markersize=markersize, + cmap=marker_cmap, + plot_method="plotly", + ) + + return save_plotly_figure( + pl=pl, + colors_list=colors_list, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) else: # plt.figure(facecolor=_background) axes_list, color_list, font_color = scatters( From 79be07ae1a618c9dfd941e3736e3561a489d8691 Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 28 Sep 2023 15:26:07 -0400 Subject: [PATCH 52/62] reorganize import statements --- dynamo/plot/scVectorField.py | 23 ++++++++++------------- dynamo/plot/topography.py | 21 +++++++++------------ 2 files changed, 19 insertions(+), 25 deletions(-) diff --git a/dynamo/plot/scVectorField.py b/dynamo/plot/scVectorField.py index 2340035c9..86a84d7af 100755 --- a/dynamo/plot/scVectorField.py +++ b/dynamo/plot/scVectorField.py @@ -216,19 +216,6 @@ def cell_wise_vectors_3d( from matplotlib import rcParams from matplotlib.colors import to_hex - if plot_method == "pv": - try: - import pyvista as pv - except ImportError: - raise ImportError("Please install pyvista first.") - elif plot_method == "plotly": - try: - import plotly.graph_objects as go - except ImportError: - raise ImportError("Please install plotly first.") - else: - raise NotImplementedError("Current plot method not supported.") - def add_axis_label(ax, labels): ax.set_xlabel(labels[0]) ax.set_ylabel(labels[1]) @@ -333,6 +320,11 @@ def add_axis_label(ax, labels): ncols = min(ncols, len(color)) if plot_method == "pv": + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + pl, colors_list = scatters_interactive( adata=adata, basis=basis, @@ -382,6 +374,11 @@ def add_axis_label(ax, labels): ) elif plot_method == "plotly": + try: + import plotly.graph_objects as go + except ImportError: + raise ImportError("Please install plotly first.") + pl, colors_list = scatters_interactive( adata=adata, basis=basis, diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index 185107eda..fb6ebc41c 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -462,12 +462,6 @@ def plot_fixed_points( from matplotlib import markers, rcParams from matplotlib.colors import to_hex - if plot_method == "pv": - try: - import pyvista as pv - except ImportError: - raise ImportError("Please install pyvista first.") - if background is None: _background = rcParams.get("figure.facecolor") _background = to_hex(_background) if type(_background) is tuple else _background @@ -526,6 +520,10 @@ def plot_fixed_points( colors = [c if confidence is None else np.array(cm(confidence[i])) for i in range(len(confidence))] if plot_method == "pv": + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") points = pv.PolyData(Xss) points.point_data["colors"] = colors @@ -1511,12 +1509,6 @@ def topography_3D( from matplotlib import rcParams from matplotlib.colors import to_hex - if plot_method == "pv": - try: - import pyvista as pv - except ImportError: - raise ImportError("Please install pyvista first.") - if type(color) == str: color = [color] @@ -1648,6 +1640,11 @@ def topography_3D( V = vector_field_function(init_states, vecfld_dict, [0, 1]) if plot_method == "pv": + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + pl, colors_list = scatters_interactive( adata=adata, basis=basis, From 4407a3de8ed282b6ef0c0b20aac876e593418f02 Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 28 Sep 2023 17:34:23 -0400 Subject: [PATCH 53/62] create plotly animation --- dynamo/movie/__init__.py | 2 +- dynamo/movie/fate.py | 100 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+), 1 deletion(-) diff --git a/dynamo/movie/__init__.py b/dynamo/movie/__init__.py index ee227bb7c..731de9415 100644 --- a/dynamo/movie/__init__.py +++ b/dynamo/movie/__init__.py @@ -1,4 +1,4 @@ """Mapping Vector Field of Single Cells """ -from .fate import PyvistaAnim, StreamFuncAnim, StreamFuncAnim3D, animate_fates +from .fate import PlotlyAnim, PyvistaAnim, StreamFuncAnim, StreamFuncAnim3D, animate_fates diff --git a/dynamo/movie/fate.py b/dynamo/movie/fate.py index d7017d40d..a044a3e3b 100755 --- a/dynamo/movie/fate.py +++ b/dynamo/movie/fate.py @@ -570,3 +570,103 @@ def animate(self): self.pl.write_frame() # TODO: debug this self.pl.close() + + +class PlotlyAnim(BaseAnim): + def __init__( + self, + adata: AnnData, + basis: str = "umap", + fp_basis: Union[str, None] = None, + dims: Optional[list] = None, + n_steps: int = 15, + cell_states: Union[int, list, None] = None, + color: str = "ntr", + pl=None, + logspace: bool = False, + max_time: Optional[float] = None, + frame_color=None, + filename: str = "fate_animation.gif", + ): + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + + super().__init__( + adata=adata, + basis=basis, + fp_basis=fp_basis, + dims=dims, + n_steps=n_steps, + cell_states=cell_states, + color=color, + logspace=logspace, + max_time=max_time, + frame_color=frame_color, + ) + + self.filename = filename + + if pl is None: + self.pl = topography_3D( + self.adata, + basis=self.basis, + fps_basis=self.fp_basis, + color=self.color, + plot_method="plotly", + ax=self.pl, + save_show_or_return="return", + ) + else: + self.pl = pl + + self.n_steps = n_steps + self.pts_history = [] + + def calculate_pts_history(self): + pts = [i.tolist() for i in self.init_states] + + self.pts_history.append(pts) + + for frame in range(0, self.n_steps): + pts = [self.displace(cur_pts, self.time_vec[frame])[1].tolist() for cur_pts in pts] + pts = np.asarray(pts) + pts = remove_particles(pts, self.xlim, self.ylim, self.zlim) + self.pts_history.append(np.asarray(pts)) + + def animate(self): + try: + import plotly.graph_objects as go + except ImportError: + raise ImportError("Please install plotly first.") + + if len(self.pts_history) == 0: + self.calculate_pts_history() + + fig = go.Figure( + data=self.pl, + layout=go.Layout(title="Moving Frenet Frame Along a Planar Curve", + updatemenus=[dict(type="buttons", + buttons=[dict(label="Play", + method="animate", + args=[None])])]), + frames=[ + go.Frame( + data=[ + go.Scatter3d( + x=self.pts_history[k][:, 0], + y=self.pts_history[k][:, 1], + z=self.pts_history[k][:, 2], + mode="markers", + marker=dict( + color="red", + size=20, + ), + ) + ] + ) for k in range(1, self.n_steps) + ] + ) + + fig.show() From dfdbaa0a63c189d5525ff87e4d88a17955a801dc Mon Sep 17 00:00:00 2001 From: sichao Date: Mon, 16 Oct 2023 14:03:49 -0400 Subject: [PATCH 54/62] debug pyvista 3D animation --- dynamo/movie/fate.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/dynamo/movie/fate.py b/dynamo/movie/fate.py index a044a3e3b..f9ed14c38 100755 --- a/dynamo/movie/fate.py +++ b/dynamo/movie/fate.py @@ -508,6 +508,7 @@ def __init__( n_steps: int = 15, cell_states: Union[int, list, None] = None, color: str = "ntr", + point_size: float = 15, pl=None, logspace: bool = False, max_time: Optional[float] = None, @@ -547,6 +548,7 @@ def __init__( self.pl = pl self.n_steps = n_steps + self.point_size = point_size def animate(self): try: @@ -558,16 +560,20 @@ def animate(self): self.pl.open_gif(self.filename) - mesh = pv.PolyData(self.init_states) - self.pl.add_mesh(mesh, color="red", render_points_as_spheres=True) + pts = [self.displace(cur_pts, self.time_vec[0])[1].tolist() for cur_pts in pts] + pts = np.asarray(pts) + pts = remove_particles(pts, self.xlim, self.ylim, self.zlim) - for frame in range(0, self.n_steps): + mesh = pv.PolyData(pts) + self.pl.add_mesh(mesh, color="red", render_points_as_spheres=True, point_size=self.point_size) + + for frame in range(1, self.n_steps): pts = [self.displace(cur_pts, self.time_vec[frame])[1].tolist() for cur_pts in pts] pts = np.asarray(pts) - pts = remove_particles(pts, self.xlim, self.ylim, self.zlim) - mesh.points = pv.PolyData(np.asarray(pts)).points + # pts = remove_particles(pts, self.xlim, self.ylim, self.zlim) + mesh.points = pv.PolyData(pts).points - self.pl.write_frame() # TODO: debug this + self.pl.write_frame() self.pl.close() From c7119dc47bdd23b513cdc65c8aa052f68dff255f Mon Sep 17 00:00:00 2001 From: sichao Date: Mon, 16 Oct 2023 14:49:26 -0400 Subject: [PATCH 55/62] distinguish fps color for pyvista --- dynamo/plot/topography.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index fb6ebc41c..d9276a46f 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -525,10 +525,11 @@ def plot_fixed_points( except ImportError: raise ImportError("Please install pyvista first.") - points = pv.PolyData(Xss) - points.point_data["colors"] = colors - - points["Labels"] = [str(i) for i in range(points.n_points)] + text_colors = ["black" if cur_ftype == -1 else "blue" if cur_ftype == 0 else "red" for cur_ftype in ftype] + emitting_indices = [index for index, color in enumerate(text_colors) if color == "red"] + unstable_indices = [index for index, color in enumerate(text_colors) if color == "blue"] + absorbing_indices = [index for index, color in enumerate(text_colors) if color == "black"] + fps_type_indices = [emitting_indices, unstable_indices, absorbing_indices] r, c = ax.shape[0], ax.shape[1] subplot_indices = [[i, j] for i in range(r) for j in range(c)] @@ -540,8 +541,21 @@ def plot_fixed_points( ax.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) cur_subplot += 1 - ax.add_points(points, render_points_as_spheres=True, rgba=True, point_size=15) - ax.add_point_labels(points, "Labels", font_size=36, show_points=False) # TODO: only work for the first plot + for indices in fps_type_indices: + points = pv.PolyData(Xss[indices]) + points.point_data["colors"] = np.array(colors)[indices] + points["Labels"] = [str(idx) for idx in indices] + + ax.add_points(points, render_points_as_spheres=True, rgba=True, point_size=15) + ax.add_point_labels( + points, + "Labels", + text_color=text_colors[indices[0]], + font_size=24, + shape_opacity=0, + show_points=False, + always_visible=True, + ) return save_pyvista_plotter( pl=ax, From 587a73b28f697cf86863bf184de0931754af90db Mon Sep 17 00:00:00 2001 From: sichao Date: Mon, 16 Oct 2023 15:05:18 -0400 Subject: [PATCH 56/62] support cone color input and add todo --- dynamo/plot/scVectorField.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dynamo/plot/scVectorField.py b/dynamo/plot/scVectorField.py index 86a84d7af..94c26cfba 100755 --- a/dynamo/plot/scVectorField.py +++ b/dynamo/plot/scVectorField.py @@ -107,6 +107,7 @@ def cell_wise_vectors_3d( "darkgreen", ] ] = None, + plotly_color: str = "Reds", cmap: Optional[str] = None, color_key: Union[Dict[str, str], List[str], None] = None, color_key_cmap: Optional[str] = None, @@ -175,6 +176,8 @@ def cell_wise_vectors_3d( theme: A color theme to use for plotting. A small set of predefined themes are provided which have relatively good aesthetics. Available themes are: {'blue', 'red', 'green', 'inferno', 'fire', 'viridis', 'darkblue', 'darkred', 'darkgreen'}. Defaults to None. + plotly_color: the color of the Plotly Cone plot. It must be an array containing arrays mapping a normalized + value to a rgb, rgba, hex, hsl, hsv, or named color string. cmap: The name of a matplotlib colormap to use for coloring or shading points. If no labels or values are passed this will be used for shading points according to density (largely only of relevance for very large datasets). If values are passed this will be used for shading according the value. Note that if theme is @@ -387,7 +390,7 @@ def add_axis_label(ax, labels): z=z, color=color, layer=layer, - plot_method = "plotly", + plot_method="plotly", highlights=highlights, labels=labels, values=values, @@ -416,13 +419,13 @@ def add_axis_label(ax, labels): u=v0.values, v=v1.values, w=v2.values, - colorscale='Blues', + colorscale=plotly_color, # colorscale=colors, sizemode="absolute", sizeref=1, ), row=subplot_indices[cur_subplot][0] + 1, col=subplot_indices[cur_subplot][1] + 1, - ) + ) # TODO: implement customized color for individual cone cur_subplot += 1 From bd1492cd5a43c756548df9b77a94126fe1aca129 Mon Sep 17 00:00:00 2001 From: sichao Date: Mon, 16 Oct 2023 15:57:11 -0400 Subject: [PATCH 57/62] debug plotly topogrgphy --- dynamo/plot/topography.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index d9276a46f..a7b967cb9 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -518,6 +518,7 @@ def plot_fixed_points( cm = matplotlib.cm.get_cmap(_cmap) if type(_cmap) is str else _cmap colors = [c if confidence is None else np.array(cm(confidence[i])) for i in range(len(confidence))] + text_colors = ["black" if cur_ftype == -1 else "blue" if cur_ftype == 0 else "red" for cur_ftype in ftype] if plot_method == "pv": try: @@ -525,7 +526,6 @@ def plot_fixed_points( except ImportError: raise ImportError("Please install pyvista first.") - text_colors = ["black" if cur_ftype == -1 else "blue" if cur_ftype == 0 else "red" for cur_ftype in ftype] emitting_indices = [index for index, color in enumerate(text_colors) if color == "red"] unstable_indices = [index for index, color in enumerate(text_colors) if color == "blue"] absorbing_indices = [index for index, color in enumerate(text_colors) if color == "black"] @@ -585,6 +585,10 @@ def plot_fixed_points( size=15, ), text=[str(i) for i in range(len(Xss))], + textfont=dict( + color=text_colors, + size=15, + ), **kwargs, ), row=subplot_indices[cur_subplot][0] + 1, col=subplot_indices[cur_subplot][1] + 1, From 4114af786c8f045541fcb0b3be905efcf7862fea Mon Sep 17 00:00:00 2001 From: sichao Date: Wed, 18 Oct 2023 14:54:39 -0400 Subject: [PATCH 58/62] debug zorder in matplotlib 3d --- dynamo/plot/topography.py | 11 ++++++++--- dynamo/plot/utils.py | 4 +++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index a7b967cb9..725a22b1b 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -1481,11 +1481,12 @@ def topography_3D( cmap: Optional[str] = None, color_key: Union[Dict[str, str], List[str], None] = None, color_key_cmap: Optional[str] = None, + alpha: Optional[float] = None, background: Optional[str] = "white", ncols: int = 4, pointsize: Optional[float] = None, figsize: Tuple[float, float] = (6, 4), - show_legend: str = "on data", + show_legend: str = True, use_smoothed: bool = True, xlim: np.ndarray = None, ylim: np.ndarray = None, @@ -1530,6 +1531,9 @@ def topography_3D( if type(color) == str: color = [color] + if alpha is None: + alpha = 0.8 if plot_method in ["pv", "plotly"] else 0.1 + if background is None: _background = rcParams.get("figure.facecolor") _background = to_hex(_background) if type(_background) is tuple else _background @@ -1682,7 +1686,7 @@ def topography_3D( use_smoothed=use_smoothed, save_show_or_return="return", # style='points_gaussian', - opacity=0.8, + opacity=alpha, ) if "fixed_points" in terms: @@ -1727,7 +1731,7 @@ def topography_3D( color_key_cmap=color_key_cmap, use_smoothed=use_smoothed, save_show_or_return="return", - opacity=0.8, + opacity=alpha, ) if "fixed_points" in terms: @@ -1764,6 +1768,7 @@ def topography_3D( cmap=cmap, color_key=color_key, color_key_cmap=color_key_cmap, + alpha=alpha, background=_background, ncols=ncols, pointsize=pointsize, diff --git a/dynamo/plot/utils.py b/dynamo/plot/utils.py index 3b18a2045..962a3bfa6 100755 --- a/dynamo/plot/utils.py +++ b/dynamo/plot/utils.py @@ -407,7 +407,9 @@ def _matplotlib_points( if ax is None: dpi = plt.rcParams["figure.dpi"] fig = plt.figure(figsize=(width / dpi, height / dpi)) - ax = fig.add_subplot(111, projection=projection) + ax = fig.add_subplot( + 111, projection=projection, computed_zorder=False, + ) if projection == "3d" else fig.add_subplot(111, projection=projection) ax.set_facecolor(background) From 2c263d5ac691dd5b8c613669aed8b1065096406a Mon Sep 17 00:00:00 2001 From: sichao Date: Fri, 20 Oct 2023 17:55:24 -0400 Subject: [PATCH 59/62] debug 2D anmation and add docstr --- dynamo/movie/fate.py | 331 ++++++++++++++++++++++--------------------- 1 file changed, 172 insertions(+), 159 deletions(-) diff --git a/dynamo/movie/fate.py b/dynamo/movie/fate.py index f9ed14c38..814b654f7 100755 --- a/dynamo/movie/fate.py +++ b/dynamo/movie/fate.py @@ -18,6 +18,13 @@ class BaseAnim: + """Base class for animating cell fate commitment prediction via reconstructed vector field function. + + This class creates necessary components to produce an animation that describes the exact speed of a set of cells + at each time point, its movement in gene expression and the long range trajectory predicted by the reconstructed + vector field. Thus, it provides intuitive visual understanding of the RNA velocity, speed, acceleration, and cell + fate commitment in action. + """ def __init__( self, adata: AnnData, @@ -29,52 +36,30 @@ def __init__( color: str = "ntr", logspace: bool = False, max_time: Optional[float] = None, - frame_color=None, ): - """Animating cell fate commitment prediction via reconstructed vector field function. - - This class creates necessary components to produce an animation that describes the exact speed of a set of cells - at each time point, its movement in gene expression and the long range trajectory predicted by the reconstructed - vector field. Thus it provides intuitive visual understanding of the RNA velocity, speed, acceleration, and cell - fate commitment in action. - - This function is originally inspired by https://tonysyu.github.io/animating-particles-in-a-flow.html and relies on - animation module from matplotlib. Note that you may need to install `imagemagick` in order to properly show or save - the animation. See for example, http://louistiao.me/posts/notebooks/save-matplotlib-animations-as-gifs/ for more - details. - - Parameters - ---------- - adata: :class:`~anndata.AnnData` - AnnData object that already went through the fate prediction. - basis: `str` or None (default: `umap`) - The embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the reconstructed - trajectory will be projected back to high dimensional space via the `inverse_transform` function. - space. - fps_basis: `str` or None (default: `None`) - The basis that will be used for identifying or retrieving fixed points. Note that if `fps_basis` is - different from `basis`, the nearest cells of the fixed point from the `fps_basis` will be found and used to - visualize the position of the fixed point on `basis` embedding. - dims: `list` or `None` (default: `None') - The dimensions of low embedding space where cells will be drawn and it should corresponds to the space + """Construct a class that can be used to animate cell fate commitment prediction via reconstructed vector field + function. + + Args: + adata: annData object that already went through the fate prediction. + basis: the embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the + reconstructed trajectory will be projected back to high dimensional space via the `inverse_transform` + function space. + fp_basis: the basis that will be used for identifying or retrieving fixed points. Note that if `fps_basis` + is different from `basis`, the nearest cells of the fixed point from the `fps_basis` will be found and + used to visualize the position of the fixed point on `basis` embedding. + dims: the dimensions of low embedding space where cells will be drawn, and it should correspond to the space fate prediction take place. - n_steps: `int` (default: `100`) - The number of times steps (frames) fate prediction will take. - cell_states: `int`, `list` or `None` (default: `None`) - The number of cells state that will be randomly selected (if `int`), the indices of the cells states (if - `list`) or all cell states which fate prediction executed (if `None`) - fig: `matplotlib.figure.Figure` or None (default: `None`) - The figure that will contain both the background and animated components. - ax: `matplotlib.Axis` (optional, default `None`) - The matplotlib axes object that will be used as background plot of the vector field animation. If `ax` - is None, `topography(adata, basis=basis, color=color, ax=ax, save_show_or_return='return')` will be used - to create an axes. - logspace: `bool` (default: `False`) - Whether or to sample time points linearly on log space. If not, the sorted unique set of all time points - from all cell states' fate prediction will be used and then evenly sampled up to `n_steps` time points. - - Returns - ------- + n_steps: the number of times steps (frames) fate prediction will take. + cell_states: the number of cells state that will be randomly selected (if `int`), the indices of the cells + states (if `list`) or all cell states which fate prediction executed (if `None`) + color: the key of the data that will be used to color the embedding. + logspace: `whether or to sample time points linearly on log space. If not, the sorted unique set of all-time + points from all cell states' fate prediction will be used and then evenly sampled up to `n_steps` time + points. + max_time: the maximum time that will be used to scale the time vector. + + Returns: A class that contains .fig attribute and .update, .init_background that can be used to produce an animation of the prediction of cell fate commitment. """ @@ -143,75 +128,17 @@ def __init__( # self.ax.set_aspect("equal") self.color = color - self.frame_color = frame_color class StreamFuncAnim(BaseAnim): - """Animating cell fate commitment prediction via reconstructed vector field function.""" + """The class for animating cell fate commitment prediction with matplotlib. - def __init__( - self, - adata: AnnData, - basis: str = "umap", - fp_basis: Union[str, None] = None, - dims: Optional[list] = None, - n_steps: int = 100, - cell_states: Union[int, list, None] = None, - color: str = "ntr", - fig: Optional[matplotlib.figure.Figure] = None, - ax: matplotlib.axes.Axes = None, - logspace: bool = False, - max_time: Optional[float] = None, - frame_color=None, - ): - """Animating cell fate commitment prediction via reconstructed vector field function. - - This class creates necessary components to produce an animation that describes the exact speed of a set of cells - at each time point, its movement in gene expression and the long range trajectory predicted by the reconstructed - vector field. Thus it provides intuitive visual understanding of the RNA velocity, speed, acceleration, and cell - fate commitment in action. - - This function is originally inspired by https://tonysyu.github.io/animating-particles-in-a-flow.html and relies on - animation module from matplotlib. Note that you may need to install `imagemagick` in order to properly show or save - the animation. See for example, http://louistiao.me/posts/notebooks/save-matplotlib-animations-as-gifs/ for more - details. - - Parameters - ---------- - adata: :class:`~anndata.AnnData` - AnnData object that already went through the fate prediction. - basis: `str` or None (default: `umap`) - The embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the reconstructed - trajectory will be projected back to high dimensional space via the `inverse_transform` function. - space. - fps_basis: `str` or None (default: `None`) - The basis that will be used for identifying or retrieving fixed points. Note that if `fps_basis` is - different from `basis`, the nearest cells of the fixed point from the `fps_basis` will be found and used to - visualize the position of the fixed point on `basis` embedding. - dims: `list` or `None` (default: `None') - The dimensions of low embedding space where cells will be drawn and it should corresponds to the space - fate prediction take place. - n_steps: `int` (default: `100`) - The number of times steps (frames) fate prediction will take. - cell_states: `int`, `list` or `None` (default: `None`) - The number of cells state that will be randomly selected (if `int`), the indices of the cells states (if - `list`) or all cell states which fate prediction executed (if `None`) - fig: `matplotlib.figure.Figure` or None (default: `None`) - The figure that will contain both the background and animated components. - ax: `matplotlib.Axis` (optional, default `None`) - The matplotlib axes object that will be used as background plot of the vector field animation. If `ax` - is None, `topography(adata, basis=basis, color=color, ax=ax, save_show_or_return='return')` will be used - to create an axes. - logspace: `bool` (default: `False`) - Whether or to sample time points linearly on log space. If not, the sorted unique set of all time points - from all cell states' fate prediction will be used and then evenly sampled up to `n_steps` time points. - - Returns - ------- - A class that contains .fig attribute and .update, .init_background that can be used to produce an animation - of the prediction of cell fate commitment. + This function is originally inspired by https://tonysyu.github.io/animating-particles-in-a-flow.html and relies on + animation module from matplotlib. Note that you may need to install `imagemagick` in order to properly show or save + the animation. See for example, http://louistiao.me/posts/notebooks/save-matplotlib-animations-as-gifs/ for more + details. - Examples 1 + Examples 1 ---------- >>> from matplotlib import animation >>> progenitor = adata.obs_names[adata.obs.clusters == 'cluster_1'] @@ -256,6 +183,50 @@ def __init__( >>> dyn.mv.animate_fates(adata) See also:: :func:`animate_fates` + """ + + def __init__( + self, + adata: AnnData, + basis: str = "umap", + fp_basis: Union[str, None] = None, + dims: Optional[list] = None, + n_steps: int = 100, + cell_states: Union[int, list, None] = None, + color: str = "ntr", + fig: Optional[matplotlib.figure.Figure] = None, + ax: matplotlib.axes.Axes = None, + logspace: bool = False, + max_time: Optional[float] = None, + ): + """Construct the StreamFuncAnim class. + + Args: + adata: annData object that already went through the fate prediction. + basis: the embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the + reconstructed trajectory will be projected back to high dimensional space via the `inverse_transform` + function space. + fp_basis: the basis that will be used for identifying or retrieving fixed points. Note that if `fps_basis` + is different from `basis`, the nearest cells of the fixed point from the `fps_basis` will be found and + used to visualize the position of the fixed point on `basis` embedding. + dims: the dimensions of low embedding space where cells will be drawn, and it should correspond to the space + fate prediction take place. + n_steps: the number of times steps (frames) fate prediction will take. + cell_states: the number of cells state that will be randomly selected (if `int`), the indices of the cells + states (if `list`) or all cell states which fate prediction executed (if `None`) + color: the key of the data that will be used to color the embedding. + fig: the figure that will contain both the background and animated components. + ax: the matplotlib axes object that will be used as background plot of the vector field animation. If `ax` + is None, `topography(adata, basis=basis, color=color, ax=ax, save_show_or_return='return')` will be used + to create an axes. + logspace: `whether or to sample time points linearly on log space. If not, the sorted unique set of all-time + points from all cell states' fate prediction will be used and then evenly sampled up to `n_steps` time + points. + max_time: the maximum time that will be used to scale the time vector. + + Returns: + A class that contains .fig attribute and .update, .init_background that can be used to produce an animation + of the prediction of cell fate commitment. """ import matplotlib.pyplot as plt @@ -270,7 +241,6 @@ def __init__( color=color, logspace=logspace, max_time=max_time, - frame_color=frame_color, ) # Animation objects must create `fig` and `ax` attributes. @@ -288,13 +258,14 @@ def __init__( self.fig = fig self.ax = ax - (self.ln,) = self.ax.plot([], [], "ro", zs=[]) if len(dims) == 3 else self.ax.plot([], [], "ro") + (self.ln,) = self.ax.plot([], [], "ro", zs=[]) if dims is not None and len(dims) == 3 else self.ax.plot([], [], "ro") def init_background(self): + """Initialize background of the animation.""" return (self.ln,) def update(self, frame): - """Update locations of "particles" in flow on each frame frame.""" + """Update locations of "particles" in flow on each frame.""" init_states = self.init_states time_vec = self.time_vec @@ -384,7 +355,6 @@ def animate_fates( ax=None, logspace=False, max_time=None, - frame_color=None, interval=100, blit=True, save_show_or_return="show", @@ -395,7 +365,7 @@ def animate_fates( This class creates necessary components to produce an animation that describes the exact speed of a set of cells at each time point, its movement in gene expression and the long range trajectory predicted by the reconstructed - vector field. Thus it provides intuitive visual understanding of the RNA velocity, speed, acceleration, and cell + vector field. Thus, it provides intuitive visual understanding of the RNA velocity, speed, acceleration, and cell fate commitment in action. This function is originally inspired by https://tonysyu.github.io/animating-particles-in-a-flow.html and relies on @@ -403,49 +373,37 @@ def animate_fates( the animation. See for example, http://louistiao.me/posts/notebooks/save-matplotlib-animations-as-gifs/ for more details. - Parameters - ---------- - adata: :class:`~anndata.AnnData` - AnnData object that already went through the fate prediction. - basis: `str` or None (default: `None`) - The embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the reconstructed - trajectory will be projected back to high dimensional space via the `inverse_transform` function. - space. - dims: `list` or `None` (default: `None') - The dimensions of low embedding space where cells will be drawn and it should corresponds to the space + Args: + adata: annData object that already went through the fate prediction. + basis: the embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the + reconstructed trajectory will be projected back to high dimensional space via the `inverse_transform` + function space. + dims: the dimensions of low embedding space where cells will be drawn, and it should correspond to the space fate prediction take place. - n_steps: `int` (default: `100`) - The number of times steps (frames) fate prediction will take. - cell_states: `int`, `list` or `None` (default: `None`) - The number of cells state that will be randomly selected (if `int`), the indices of the cells states (if - `list`) or all cell states which fate prediction executed (if `None`) - fig: `matplotlib.figure.Figure` or None (default: `None`) - The figure that will contain both the background and animated components. - ax: `matplotlib.Axis` (optional, default `None`) - The matplotlib axes object that will be used as background plot of the vector field animation. If `ax` - is None, `topography(adata, basis=basis, color=color, ax=ax, save_show_or_return='return')` will be used - to create an axes. - logspace: `bool` (default: `False`) - Whether or to sample time points linearly on log space. If not, the sorted unique set of all time points - from all cell states' fate prediction will be used and then evenly sampled up to `n_steps` time points. - interval: `float` (default: `200`) - Delay between frames in milliseconds. - blit: `bool` (default: `False`) - Whether blitting is used to optimize drawing. Note: when using blitting, any animated artists will be drawn + n_steps: the number of times steps (frames) fate prediction will take. + cell_states: the number of cells state that will be randomly selected (if `int`), the indices of the cells + states (if `list`) or all cell states which fate prediction executed (if `None`) + color: the key of the data that will be used to color the embedding. + fig: the figure that will contain both the background and animated components. + ax: the matplotlib axes object that will be used as background plot of the vector field animation. If `ax` + is None, `topography(adata, basis=basis, color=color, ax=ax, save_show_or_return='return')` will be used + to create an axes. + logspace: `whether or to sample time points linearly on log space. If not, the sorted unique set of all-time + points from all cell states' fate prediction will be used and then evenly sampled up to `n_steps` time + points. + max_time: the maximum time that will be used to scale the time vector. + interval: delay between frames in milliseconds. + blit: whether blitting is used to optimize drawing. Note: when using blitting, any animated artists will be drawn according to their zorder; however, they will be drawn on top of any previous artists, regardless of their zorder. - save_show_or_return: `str` {'save', 'show', 'return'} (default: `save`) - Whether to save, show or return the figure. By default a gif will be used. - save_kwargs: `dict` (default: `{}`) - A dictionary that will passed to the anim.save. By default it is an empty dictionary and the save_fig function - will use the {"filename": 'fate_ani.gif', "writer": "imagemagick"} as its parameters. Otherwise you can - provide a dictionary that properly modify those keys according to your needs. see + save_show_or_return: whether to save, show or return the figure. By default, a gif will be used. + save_kwargs: a dictionary that will be passed to the anim.save. By default, it is an empty dictionary and the + save_fig function will use the {"filename": 'fate_ani.gif', "writer": "imagemagick"} as its parameters. + Otherwise, you can provide a dictionary that properly modify those keys according to your needs. see https://matplotlib.org/api/_as_gen/matplotlib.animation.Animation.save.html for more details. - kwargs: - Additional arguments passed to animation.FuncAnimation. + kwargs: additional arguments passed to animation.FuncAnimation. - Returns - ------- + Returns: Nothing but produce an animation that will be embedded to jupyter notebook or saved to disk. Examples 1 @@ -474,7 +432,6 @@ def animate_fates( ax=ax, logspace=logspace, max_time=max_time, - frame_color=frame_color, ) anim = animation.FuncAnimation( @@ -499,6 +456,7 @@ def animate_fates( class PyvistaAnim(BaseAnim): + """The class for animating cell fate commitment prediction with pyvista.""" def __init__( self, adata: AnnData, @@ -512,9 +470,38 @@ def __init__( pl=None, logspace: bool = False, max_time: Optional[float] = None, - frame_color=None, filename: str = "fate_animation.gif", ): + """Construct the PyvistaAnim class. + + Args: + adata: annData object that already went through the fate prediction. + basis: the embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the + reconstructed trajectory will be projected back to high dimensional space via the `inverse_transform` + function space. + fp_basis: the basis that will be used for identifying or retrieving fixed points. Note that if `fps_basis` + is different from `basis`, the nearest cells of the fixed point from the `fps_basis` will be found and + used to visualize the position of the fixed point on `basis` embedding. + dims: the dimensions of low embedding space where cells will be drawn, and it should correspond to the space + fate prediction take place. + n_steps: the number of times steps (frames) fate prediction will take. + cell_states: the number of cells state that will be randomly selected (if `int`), the indices of the cells + states (if `list`) or all cell states which fate prediction executed (if `None`) + color: the key of the data that will be used to color the embedding. + point_size: the size of the points that will be used to draw the cells. + pl: the pyvista plotter object that will be used to draw the cells. If `pl` is None, `topography_3D(adata, + basis=basis, fps_basis=fp_basis, color=color, ax=pl, save_show_or_return='return')` will be used to + create a plotter. + logspace: `whether or to sample time points linearly on log space. If not, the sorted unique set of all-time + points from all cell states' fate prediction will be used and then evenly sampled up to `n_steps` time + points. + max_time: the maximum time that will be used to scale the time vector. + filename: the name of the gif file that will be saved to disk. + + Returns: + A class that contains .animate, that can be used to produce a gif of the prediction of cell fate + commitment. + """ try: import pyvista as pv except ImportError: @@ -530,7 +517,6 @@ def __init__( color=color, logspace=logspace, max_time=max_time, - frame_color=frame_color, ) self.filename = filename @@ -551,6 +537,7 @@ def __init__( self.point_size = point_size def animate(self): + """Animate the cell fate commitment prediction.""" try: import pyvista as pv except ImportError: @@ -579,6 +566,7 @@ def animate(self): class PlotlyAnim(BaseAnim): + """The class for animating cell fate commitment prediction with plotly.""" def __init__( self, adata: AnnData, @@ -591,9 +579,35 @@ def __init__( pl=None, logspace: bool = False, max_time: Optional[float] = None, - frame_color=None, - filename: str = "fate_animation.gif", ): + """Construct the PlotlyAnim class. + + Args: + adata: annData object that already went through the fate prediction. + basis: the embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the + reconstructed trajectory will be projected back to high dimensional space via the `inverse_transform` + function space. + fp_basis: the basis that will be used for identifying or retrieving fixed points. Note that if `fps_basis` + is different from `basis`, the nearest cells of the fixed point from the `fps_basis` will be found and + used to visualize the position of the fixed point on `basis` embedding. + dims: the dimensions of low embedding space where cells will be drawn, and it should correspond to the space + fate prediction take place. + n_steps: the number of times steps (frames) fate prediction will take. + cell_states: the number of cells state that will be randomly selected (if `int`), the indices of the cells + states (if `list`) or all cell states which fate prediction executed (if `None`) + color: the key of the data that will be used to color the embedding. + pl: the plotly figure object that will be used to draw the cells. If `pl` is None, `topography_3D(adata, + basis=basis, fps_basis=fp_basis, color=color, ax=pl, save_show_or_return='return')` will be used to + create a plotter. + logspace: `whether or to sample time points linearly on log space. If not, the sorted unique set of all-time + points from all cell states' fate prediction will be used and then evenly sampled up to `n_steps` time + points. + max_time: the maximum time that will be used to scale the time vector. + + Returns: + A class that contains .animate, that can be used to produce a gif of the prediction of cell fate + commitment. + """ try: import pyvista as pv except ImportError: @@ -609,11 +623,8 @@ def __init__( color=color, logspace=logspace, max_time=max_time, - frame_color=frame_color, ) - self.filename = filename - if pl is None: self.pl = topography_3D( self.adata, @@ -631,6 +642,7 @@ def __init__( self.pts_history = [] def calculate_pts_history(self): + """Calculate the history of the cell states.""" pts = [i.tolist() for i in self.init_states] self.pts_history.append(pts) @@ -642,6 +654,7 @@ def calculate_pts_history(self): self.pts_history.append(np.asarray(pts)) def animate(self): + """Animate the cell fate commitment prediction.""" try: import plotly.graph_objects as go except ImportError: From e5705974f87f19504142d81cc537f7e581ace28a Mon Sep 17 00:00:00 2001 From: sichao Date: Fri, 20 Oct 2023 17:59:30 -0400 Subject: [PATCH 60/62] remove unnecessary code --- dynamo/plot/topography.py | 12 ------------ dynamo/vectorfield/topography.py | 27 --------------------------- 2 files changed, 39 deletions(-) diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index 725a22b1b..c731638ff 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -1496,15 +1496,7 @@ def topography_3D( init_cells: List[int] = None, init_states: np.ndarray = None, quiver_source: Literal["raw", "reconstructed"] = "raw", - fate: Literal["history", "future", "both"] = "both", approx: bool = False, - quiver_size: Optional[float] = None, - quiver_length: Optional[float] = None, - density: float = 1, - linewidth: float = 1, - streamline_color: Optional[str] = None, - streamline_alpha: float = 0.4, - color_start_points: Optional[str] = None, markersize: float = 200, marker_cmap: Optional[str] = None, save_show_or_return: Literal["save", "show", "return"] = "show", @@ -1515,13 +1507,9 @@ def topography_3D( sort: Literal["raw", "abs", "neg"] = "raw", frontier: bool = False, s_kwargs_dict: Dict[str, Any] = {}, - q_kwargs_dict: Dict[str, Any] = {}, n: int = 25, - **streamline_kwargs_dict, ) -> Union[Axes, List[Axes], None]: - from ..external.hodge import ddhodge - logger = LoggerManager.gen_logger("dynamo-topography-plot") logger.log_time() diff --git a/dynamo/vectorfield/topography.py b/dynamo/vectorfield/topography.py index 47d1dcde7..26e51a123 100755 --- a/dynamo/vectorfield/topography.py +++ b/dynamo/vectorfield/topography.py @@ -693,33 +693,6 @@ def find_fixed_points_by_sampling( raise ValueError(f"No fixed points found. Try to increase the number of samples n.") self.Xss.add_fixed_points(X, J, tol_redundant) - def compute_nullclines( - self, - x_range: Tuple[float, float], - y_range: Tuple[float, float], - z_range: Tuple[float, float], - find_new_fixed_points: Optional[bool] = False, - tol_redundant: Optional[float] = 1e-4, - ): - pass - # s_max = 5 * ((x_range[1] - x_range[0]) + (y_range[1] - y_range[0])) - # ds = s_max / 1e3 - # self.NCx, self.NCy, self.NCz = compute_nullclines_3d( - # self.Xss.get_X(), - # self.fx, - # self.fy, - # self.fz, - # x_range, - # y_range, - # z_range, - # s_max=s_max, - # ds=ds, - # ) - # if find_new_fixed_points: - # sample_interval = ds * 10 - # X, J = find_fixed_points_nullcline_3d(self.func, self.NCx, self.NCy, self.NCz, sample_interval, tol_redundant) - # outside = is_outside(X, [x_range, y_range]) - # self.Xss.add_fixed_points(X[~outside], J[~outside], tol_redundant) def output_to_dict(self, dict_vf) -> Dict: """Output the vector field as a dictionary. From 4efeff06e259d4c0b33b2e34f316719a33d4d3ae Mon Sep 17 00:00:00 2001 From: sichao Date: Fri, 20 Oct 2023 18:12:36 -0400 Subject: [PATCH 61/62] update docstr --- dynamo/plot/scatters.py | 8 ++-- dynamo/plot/streamtube.py | 2 +- dynamo/plot/topography.py | 83 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 88 insertions(+), 5 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index ca93259aa..c85ad1d46 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -1023,7 +1023,6 @@ def scatters_interactive( color: str = "ntr", layer: str = "X", plot_method: str = "pv", - highlights: Optional[list] = None, labels: Optional[list] = None, values: Optional[list] = None, cmap: Optional[str] = None, @@ -1054,8 +1053,6 @@ def scatters_interactive( z: the column index of the low dimensional embedding for the z-axis. Defaults to 2. color: any column names or gene expression, etc. that will be used for coloring cells. Defaults to "ntr". layer: the layer of data to use for the scatter plot. Defaults to "X". - highlights: the color group that will be highlighted. If highlights is a list of lists, each list is relate to - each color element. Defaults to None. labels: an array of labels (assumed integer or categorical), one for each data sample. This will be used for coloring the points in the plot according to their label. Note that this option is mutually exclusive to the `values` option. Defaults to None. @@ -1098,6 +1095,11 @@ def scatters_interactive( "title": PyVista Export, "raster": True, "painter": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. **kwargs: any other kwargs that would be passed to `Plotter.add_points()`. + + Returns: + If `save_show_or_return` is `save`, `show` or `both`, the function will return nothing but show or save the + figure. If `save_show_or_return` is `return`, the function will return the axis object(s) that contains the + figure. """ if plot_method == "pv": diff --git a/dynamo/plot/streamtube.py b/dynamo/plot/streamtube.py index fcacf3c58..3e8d01527 100644 --- a/dynamo/plot/streamtube.py +++ b/dynamo/plot/streamtube.py @@ -33,7 +33,7 @@ def plot_3d_streamtube( save_show_or_return: Literal["save", "show", "return"] = "show", save_kwargs: Dict[str, Any] = {}, ): - """Plot a interative 3d streamtube plot via plotly. + """Plot an interative 3d streamtube plot via plotly. A streamtube is a tubular region surrounded by streamlines that form a closed loop. It's a continuous version of a streamtube plot (3D quiver plot) and can provide insight into flow data from natural systems. The color of tubes is diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index c731638ff..4c55baf5c 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -429,7 +429,7 @@ def plot_fixed_points( ax: Optional[Axes] = None, **kwargs, ) -> Optional[Axes]: - """Plot fixed points stored in the VectorField2D class. + """Plot fixed points stored in the VectorField class. Args: vecfld: an instance of the vector_field class. @@ -1509,6 +1509,87 @@ def topography_3D( s_kwargs_dict: Dict[str, Any] = {}, n: int = 25, ) -> Union[Axes, List[Axes], None]: + """Plot the topography of the reconstructed vector field in 3D space. + + Args: + adata: AnnData object that contains the reconstructed vector field. + basis: The embedding data space that will be used to plot the topography. Defaults to `umap`. + fps_basis: The embedding data space that will be used to plot the fixed points. Defaults to `umap`. + x: The index of the first dimension of the embedding data space that will be used to plot the topography. + Defaults to 0. + y: The index of the second dimension of the embedding data space that will be used to plot the topography. + Defaults to 1. + z: The index of the third dimension of the embedding data space that will be used to plot the topography. + Defaults to 2. + color: The color of the topography. Defaults to `ntr`. + layer: The layer of the data that will be used to plot the topography. Defaults to `X`. + plot_method: The method that will be used to plot the topography. Defaults to `matplotlib`. + highlights: The list of gene names that will be used to highlight the gene expression on the topography. + Defaults to None. + labels: The list of gene names that will be used to label the gene expression on the topography. Defaults to + None. + values: The list of gene names that will be used to color the gene expression on the topography. Defaults to + None. + theme: The color theme that will be used to plot the topography. Defaults to None. + cmap: The name of a matplotlib colormap that will be used to color the topography. Defaults to None. + color_key: The color dictionary that will be used to color the topography. Defaults to None. + color_key_cmap: The name of a matplotlib colormap that will be used to color the color key. Defaults to None. + alpha: The transparency of the topography. Defaults to None. + background: The background color of the topography. Defaults to `white`. + ncols: The number of columns for the figure. Defaults to 4. + pointsize: The scale of the point size. Actual point cell size is calculated as + `500.0 / np.sqrt(adata.shape[0]) * pointsize`. Defaults to None. + figsize: The width and height of a figure. Defaults to (6, 4). + show_legend: Whether to display a legend of the labels. Defaults to `on data`. + use_smoothed: Whether to use smoothed values (i.e. M_s / M_u instead of spliced / unspliced, etc.). Defaults to + True. + xlim: The range of x-coordinate. Defaults to None. + ylim: The range of y-coordinate. Defaults to None. + zlim: The range of z-coordinate. Defaults to None. + t: The length of the time period from which to predict cell state forward or backward over time. This is used + by the odeint function. Defaults to None. + terms: A list of plotting items to include in the final topography figure. ('streamline', 'nullcline', + 'fixed_points', 'separatrix', 'trajectory', 'quiver') are all the items that we can support. Defaults to + ["streamline", "fixed_points"]. + init_cells: cell name or indices of the initial cell states for the historical or future cell state prediction + with numerical integration. If the names in init_cells are not find in the adata.obs_name, it will be + treated as cell indices and must be integers. Defaults to None. + init_states: the initial cell states for the historical or future cell state prediction with numerical + integration. It can be either a one-dimensional array or N x 2 dimension array. The `init_state` will be + replaced to that defined by init_cells if init_cells are not None. Defaults to None. + quiver_source: the data source that will be used to draw the quiver plot. If `init_cells` is provided, this will + set to be the projected RNA velocity before vector field reconstruction automatically. If `init_cells` is + not provided, this will set to be the velocity vectors calculated from the reconstructed vector field + function automatically. If quiver_source is `reconstructed`, the velocity vectors calculated from the + reconstructed vector field function will be used. Defaults to "raw". + approx: whether to use streamplot to draw the integration line from the init_state. Defaults to False. + markersize: the size of the marker. Defaults to 200. + marker_cmap: the name of a matplotlib colormap to use for coloring or shading the confidence of fixed points. If + None, the default color map will set to be viridis (inferno) when the background is white (black). Defaults + to None. + save_show_or_return: Whether to save, show or return the figure. Defaults to `show`. + save_kwargs: A dictionary that will be passed to the save_fig function. By default, it is an empty dictionary + and the save_fig function will use the {"path": None, "prefix": 'topography', "dpi": None, "ext": 'pdf', + "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a + dictionary that properly modify those keys according to your needs. Defaults to {}. + aggregate: The column in adata.obs that will be used to aggregate data points. Defaults to None. + show_arrowed_spines: Whether to show a pair of arrowed spines representing the basis of the scatter is currently + using. Defaults to False. + ax: The axis on which to make the plot. Defaults to None. + sort: The method to reorder data so that high values points will be on top of background points. Can be one of + {'raw', 'abs', 'neg'}, i.e. sorted by raw data, sort by absolute values or sort by negative values. Defaults + to "raw". Defaults to "raw". + frontier: whether to add the frontier. Scatter plots can be enhanced by using transparency (alpha) in order to + show area of high density and multiple scatter plots can be used to delineate a frontier. See matplotlib + tips & tricks cheatsheet (https://github.com/matplotlib/cheatsheets). Originally inspired by figures from + scEU-seq paper: https://science.sciencemag.org/content/367/6482/1151. Defaults to False. + s_kwargs_dict: The dictionary of the scatter arguments. Defaults to {}. + n: Number of samples for calculating the fixed points. + + Returns: + None would be returned by default. If `save_show_or_return` is set to be 'return', the Axes of the generated + subplots would be returned. + """ logger = LoggerManager.gen_logger("dynamo-topography-plot") logger.log_time() From 07afdb4f27e8425e9deedde092322f3c6738343a Mon Sep 17 00:00:00 2001 From: sichao Date: Tue, 31 Oct 2023 15:19:49 -0400 Subject: [PATCH 62/62] debug and remove not implemented function from init --- dynamo/movie/__init__.py | 2 +- dynamo/plot/scVectorField.py | 2 -- dynamo/plot/topography.py | 2 -- 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/dynamo/movie/__init__.py b/dynamo/movie/__init__.py index 731de9415..ee227bb7c 100644 --- a/dynamo/movie/__init__.py +++ b/dynamo/movie/__init__.py @@ -1,4 +1,4 @@ """Mapping Vector Field of Single Cells """ -from .fate import PlotlyAnim, PyvistaAnim, StreamFuncAnim, StreamFuncAnim3D, animate_fates +from .fate import PyvistaAnim, StreamFuncAnim, StreamFuncAnim3D, animate_fates diff --git a/dynamo/plot/scVectorField.py b/dynamo/plot/scVectorField.py index 94c26cfba..0d7460af2 100755 --- a/dynamo/plot/scVectorField.py +++ b/dynamo/plot/scVectorField.py @@ -336,7 +336,6 @@ def add_axis_label(ax, labels): z=z, color=color, layer=layer, - highlights=highlights, labels=labels, values=values, cmap=cmap, @@ -391,7 +390,6 @@ def add_axis_label(ax, labels): color=color, layer=layer, plot_method="plotly", - highlights=highlights, labels=labels, values=values, cmap=cmap, diff --git a/dynamo/plot/topography.py b/dynamo/plot/topography.py index 4c55baf5c..c47e6ba79 100755 --- a/dynamo/plot/topography.py +++ b/dynamo/plot/topography.py @@ -1744,7 +1744,6 @@ def topography_3D( z=z, color=color, layer=layer, - highlights=highlights, labels=labels, values=values, cmap=cmap, @@ -1790,7 +1789,6 @@ def topography_3D( color=color, layer=layer, plot_method="plotly", - highlights=highlights, labels=labels, values=values, cmap=cmap,