Skip to content

Commit 00ba29a

Browse files
claude[bot]github-actions[bot]MarkusNeusinger
authored
feat(matplotlib): implement scatter-color-groups (#356)
## Summary Implements `scatter-color-groups` for **matplotlib** library. **Parent Issue:** #208 **Sub-Issue:** #238 **Base Branch:** `plot/scatter-color-groups` **Attempt:** 1/3 ## Implementation - `plots/matplotlib/scatter/scatter-color-groups/default.py` ## Details - Uses iris dataset to demonstrate categorical color groups (species) - Colorblind-safe palette from style guide (#306998, #FFD43B, #DC2626) - Proper axis labels, title, legend with appropriate font sizes - Simple KISS-style script (no functions/classes) --------- Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Markus Neusinger <2921697+MarkusNeusinger@users.noreply.github.com>
1 parent c2f45a1 commit 00ba29a

1 file changed

Lines changed: 48 additions & 124 deletions

File tree

  • plots/matplotlib/scatter/scatter-color-groups
Lines changed: 48 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -1,135 +1,59 @@
11
"""
22
scatter-color-groups: Scatter Plot with Color Groups
3-
Implementation for: matplotlib
4-
Variant: default
5-
Python: 3.10+
3+
Library: matplotlib
64
"""
75

8-
from typing import TYPE_CHECKING
9-
106
import matplotlib.pyplot as plt
7+
import numpy as np
118
import pandas as pd
129

1310

14-
if TYPE_CHECKING:
15-
from matplotlib.figure import Figure
16-
17-
18-
def create_plot(
19-
data: pd.DataFrame,
20-
x: str,
21-
y: str,
22-
group: str,
23-
figsize: tuple[float, float] = (16, 9),
24-
alpha: float = 0.7,
25-
size: float = 50,
26-
title: str | None = None,
27-
xlabel: str | None = None,
28-
ylabel: str | None = None,
29-
palette: str = "Set1",
30-
**kwargs,
31-
) -> "Figure":
32-
"""
33-
Create a scatter plot with points colored by categorical groups.
34-
35-
Visualizes data points in a 2D x-y space with distinct colors for each
36-
categorical group, showing separate "color clouds" for different categories.
37-
38-
Args:
39-
data: Input DataFrame with required columns
40-
x: Column name for x-axis values
41-
y: Column name for y-axis values
42-
group: Column name for categorical grouping and coloring
43-
figsize: Figure size as (width, height) tuple (default: (16, 9))
44-
alpha: Transparency level for points (default: 0.7)
45-
size: Point size (default: 50)
46-
title: Plot title (default: None)
47-
xlabel: Custom x-axis label (default: uses column name)
48-
ylabel: Custom y-axis label (default: uses column name)
49-
palette: Matplotlib colormap or seaborn palette name (default: "Set1")
50-
**kwargs: Additional parameters passed to scatter plot
51-
52-
Returns:
53-
Matplotlib Figure object
54-
55-
Raises:
56-
ValueError: If data is empty
57-
KeyError: If required columns not found in data
58-
59-
Example:
60-
>>> import pandas as pd
61-
>>> data = pd.DataFrame({
62-
... 'x': [1, 2, 3, 4, 5, 6],
63-
... 'y': [2, 4, 3, 5, 6, 4],
64-
... 'group': ['A', 'A', 'B', 'B', 'C', 'C']
65-
... })
66-
>>> fig = create_plot(data, 'x', 'y', 'group')
67-
>>> plt.savefig('plot.png')
68-
"""
69-
# Input validation
70-
if data.empty:
71-
raise ValueError("Data cannot be empty")
72-
73-
# Check required columns
74-
required_cols = [x, y, group]
75-
for col in required_cols:
76-
if col not in data.columns:
77-
available = ", ".join(data.columns)
78-
raise KeyError(f"Column '{col}' not found in data. Available columns: {available}")
79-
80-
# Create figure and axis
81-
fig, ax = plt.subplots(figsize=figsize)
82-
83-
# Get unique groups and create color mapping
84-
groups = data[group].unique()
85-
86-
# Get colors from palette
87-
try:
88-
cmap = plt.get_cmap(palette)
89-
colors = [cmap(i / max(len(groups) - 1, 1)) for i in range(len(groups))]
90-
except (ValueError, AttributeError):
91-
# Fallback to tab10 if palette not found
92-
cmap = plt.get_cmap("tab10")
93-
colors = [cmap(i % 10) for i in range(len(groups))]
94-
95-
# Plot each group with a different color
96-
for idx, group_val in enumerate(groups):
97-
group_data = data[data[group] == group_val]
98-
ax.scatter(group_data[x], group_data[y], label=str(group_val), alpha=alpha, s=size, color=colors[idx], **kwargs)
99-
100-
# Set labels
101-
ax.set_xlabel(xlabel or x, fontsize=11)
102-
ax.set_ylabel(ylabel or y, fontsize=11)
103-
104-
# Add title if provided
105-
if title:
106-
ax.set_title(title, fontsize=12, fontweight="bold", pad=15)
107-
108-
# Add legend
109-
ax.legend(title=group, loc="best", framealpha=0.9)
110-
111-
# Add subtle grid
112-
ax.grid(True, alpha=0.3, linestyle="--")
113-
114-
# Layout
115-
plt.tight_layout()
116-
117-
return fig
118-
119-
120-
if __name__ == "__main__":
121-
# Sample data for testing
122-
data = pd.DataFrame(
123-
{
124-
"x": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5],
125-
"y": [2, 4, 3, 5, 6, 4, 7, 8, 9, 10, 3, 5, 4, 6, 7, 5],
126-
"group": ["A", "A", "A", "A", "A", "A", "B", "B", "B", "B", "C", "C", "C", "C", "C", "C"],
127-
}
11+
# Data - Iris-like dataset
12+
np.random.seed(42)
13+
n_per_group = 50
14+
15+
data = pd.DataFrame({
16+
"sepal_length": np.concatenate([
17+
np.random.normal(5.0, 0.35, n_per_group),
18+
np.random.normal(5.9, 0.50, n_per_group),
19+
np.random.normal(6.6, 0.60, n_per_group),
20+
]),
21+
"sepal_width": np.concatenate([
22+
np.random.normal(3.4, 0.38, n_per_group),
23+
np.random.normal(2.8, 0.30, n_per_group),
24+
np.random.normal(3.0, 0.30, n_per_group),
25+
]),
26+
"species": ["setosa"] * n_per_group + ["versicolor"] * n_per_group + ["virginica"] * n_per_group,
27+
})
28+
29+
# Color palette (colorblind safe from style guide)
30+
colors = ["#306998", "#FFD43B", "#DC2626"]
31+
species = data["species"].unique()
32+
color_map = {sp: colors[i] for i, sp in enumerate(species)}
33+
34+
# Create plot
35+
fig, ax = plt.subplots(figsize=(16, 9))
36+
37+
for species_name in species:
38+
subset = data[data["species"] == species_name]
39+
ax.scatter(
40+
subset["sepal_length"],
41+
subset["sepal_width"],
42+
c=color_map[species_name],
43+
label=species_name.capitalize(),
44+
alpha=0.7,
45+
s=80,
46+
edgecolors="white",
47+
linewidths=0.5,
12848
)
12949

130-
# Create plot
131-
fig = create_plot(data, "x", "y", "group", title="Scatter Plot with Color Groups")
50+
# Labels and styling
51+
ax.set_xlabel("Sepal Length (cm)", fontsize=20)
52+
ax.set_ylabel("Sepal Width (cm)", fontsize=20)
53+
ax.set_title("Iris Species by Sepal Dimensions", fontsize=20)
54+
ax.tick_params(axis="both", labelsize=16)
55+
ax.legend(title="Species", fontsize=16, title_fontsize=16)
56+
ax.grid(True, alpha=0.3)
13257

133-
# Save for inspection
134-
plt.savefig("plot.png", dpi=300, bbox_inches="tight")
135-
print("Plot saved to plot.png")
58+
plt.tight_layout()
59+
plt.savefig("plot.png", dpi=300, bbox_inches="tight")

0 commit comments

Comments
 (0)