Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions src/scanpy/plotting/_tools/scatterplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,8 @@ def embedding( # noqa: PLR0912, PLR0913, PLR0915
elif sort_order and color_type == "cat":
# Null points go on bottom
order = np.argsort(~pd.isnull(color_source_vector), kind="stable")
# Set orders
if isinstance(size, np.ndarray):
size = np.array(size)[order]
# `size` is not a loop variable, so don’t overwrite it here
size_plot = size[order] if isinstance(size, np.ndarray) else size
color_source_vector = color_source_vector[order]
color_vector = color_vector[order]
coords = basis_values[:, dims][order, :]
Expand Down Expand Up @@ -348,11 +347,10 @@ def embedding( # noqa: PLR0912, PLR0913, PLR0915
)
else:
scatter = (
partial(ax.scatter, s=size, plotnonfinite=True)
partial(ax.scatter, s=size_plot, plotnonfinite=True)
if scale_factor is None
else partial(
circles, s=size, ax=ax, scale_factor=scale_factor
) # size in circles is radius
# size in circles is radius
else partial(circles, s=size_plot, ax=ax, scale_factor=scale_factor)
)

if add_outline:
Expand All @@ -366,7 +364,7 @@ def embedding( # noqa: PLR0912, PLR0913, PLR0915
# with some transparency.

bg_width, gap_width = outline_width
point = np.sqrt(size)
point = np.sqrt(size_plot)
gap_size = (point + (point * gap_width) * 2) ** 2
bg_size = (np.sqrt(gap_size) + (point * bg_width) * 2) ** 2
# the default black and white colors can be changes using
Expand Down
32 changes: 32 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1875,6 +1875,38 @@ def test_umap_mask_no_modification():
pd.testing.assert_series_equal(pbmc.obs["louvain"], data_copy)


def test_scatter_size_not_mutated_across_panels():
"""Per-point size array must not be cumulatively reordered across subplots.

Regression test for https://github.com/scverse/scanpy/issues/4024
Uses 3 panels (categorical + two continuous) to exercise cumulative reorder.
"""
pbmc = pbmc3k_processed()
rng = np.random.default_rng(0)
sizes = rng.uniform(10, 200, size=pbmc.n_obs)
sizes_original = sizes.copy()

axes = sc.pl.umap(
pbmc,
color=["louvain", "n_genes", "n_counts"],
size=sizes,
show=False,
)

# The input array must not be modified
np.testing.assert_array_equal(sizes, sizes_original)

# Each panel must plot the correct per-point sizes (just reordered by
# z-order, not cumulatively scrambled). Sorting makes the comparison
# independent of z-ordering.
expected_sorted = np.sort(sizes)
for ax in axes:
plotted = ax.collections[0].get_sizes()
np.testing.assert_allclose(np.sort(plotted), expected_sorted)

plt.close()


def test_string_mask(tmp_path, check_same_image):
"""Check that the same mask given as string or bool array provides the same result."""
pbmc = pbmc3k_processed()
Expand Down
Loading