Skip to content

Commit faadc59

Browse files
feat(matplotlib): implement scatter-color-groups
Add scatter plot with color groups implementation for matplotlib. Uses the iris dataset to demonstrate categorical coloring of data points by species with colorblind-safe palette from the style guide.
1 parent 27357ca commit faadc59

1 file changed

Lines changed: 35 additions & 127 deletions

File tree

  • plots/matplotlib/scatter/scatter-color-groups
Lines changed: 35 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1,135 +1,43 @@
11
"""
22
scatter-color-groups: Scatter Plot with Color Groups
3-
Implementation for: matplotlib
4-
Variant: default
5-
Python: 3.10+
3+
Library: matplotlib
64
"""
75

8-
from typing import TYPE_CHECKING
9-
106
import matplotlib.pyplot as plt
11-
import pandas as pd
12-
13-
14-
if TYPE_CHECKING:
15-
from matplotlib.figure import Figure
16-
17-
18-
def create_plot(
19-
data: pd.DataFrame,
20-
x: str,
21-
y: str,
22-
group: str,
23-
figsize: tuple[float, float] = (16, 9),
24-
alpha: float = 0.7,
25-
size: float = 50,
26-
title: str | None = None,
27-
xlabel: str | None = None,
28-
ylabel: str | None = None,
29-
palette: str = "Set1",
30-
**kwargs,
31-
) -> "Figure":
32-
"""
33-
Create a scatter plot with points colored by categorical groups.
34-
35-
Visualizes data points in a 2D x-y space with distinct colors for each
36-
categorical group, showing separate "color clouds" for different categories.
37-
38-
Args:
39-
data: Input DataFrame with required columns
40-
x: Column name for x-axis values
41-
y: Column name for y-axis values
42-
group: Column name for categorical grouping and coloring
43-
figsize: Figure size as (width, height) tuple (default: (16, 9))
44-
alpha: Transparency level for points (default: 0.7)
45-
size: Point size (default: 50)
46-
title: Plot title (default: None)
47-
xlabel: Custom x-axis label (default: uses column name)
48-
ylabel: Custom y-axis label (default: uses column name)
49-
palette: Matplotlib colormap or seaborn palette name (default: "Set1")
50-
**kwargs: Additional parameters passed to scatter plot
51-
52-
Returns:
53-
Matplotlib Figure object
54-
55-
Raises:
56-
ValueError: If data is empty
57-
KeyError: If required columns not found in data
58-
59-
Example:
60-
>>> import pandas as pd
61-
>>> data = pd.DataFrame({
62-
... 'x': [1, 2, 3, 4, 5, 6],
63-
... 'y': [2, 4, 3, 5, 6, 4],
64-
... 'group': ['A', 'A', 'B', 'B', 'C', 'C']
65-
... })
66-
>>> fig = create_plot(data, 'x', 'y', 'group')
67-
>>> plt.savefig('plot.png')
68-
"""
69-
# Input validation
70-
if data.empty:
71-
raise ValueError("Data cannot be empty")
72-
73-
# Check required columns
74-
required_cols = [x, y, group]
75-
for col in required_cols:
76-
if col not in data.columns:
77-
available = ", ".join(data.columns)
78-
raise KeyError(f"Column '{col}' not found in data. Available columns: {available}")
79-
80-
# Create figure and axis
81-
fig, ax = plt.subplots(figsize=figsize)
82-
83-
# Get unique groups and create color mapping
84-
groups = data[group].unique()
85-
86-
# Get colors from palette
87-
try:
88-
cmap = plt.get_cmap(palette)
89-
colors = [cmap(i / max(len(groups) - 1, 1)) for i in range(len(groups))]
90-
except (ValueError, AttributeError):
91-
# Fallback to tab10 if palette not found
92-
cmap = plt.get_cmap("tab10")
93-
colors = [cmap(i % 10) for i in range(len(groups))]
94-
95-
# Plot each group with a different color
96-
for idx, group_val in enumerate(groups):
97-
group_data = data[data[group] == group_val]
98-
ax.scatter(group_data[x], group_data[y], label=str(group_val), alpha=alpha, s=size, color=colors[idx], **kwargs)
99-
100-
# Set labels
101-
ax.set_xlabel(xlabel or x, fontsize=11)
102-
ax.set_ylabel(ylabel or y, fontsize=11)
103-
104-
# Add title if provided
105-
if title:
106-
ax.set_title(title, fontsize=12, fontweight="bold", pad=15)
107-
108-
# Add legend
109-
ax.legend(title=group, loc="best", framealpha=0.9)
110-
111-
# Add subtle grid
112-
ax.grid(True, alpha=0.3, linestyle="--")
113-
114-
# Layout
115-
plt.tight_layout()
116-
117-
return fig
118-
119-
120-
if __name__ == "__main__":
121-
# Sample data for testing
122-
data = pd.DataFrame(
123-
{
124-
"x": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5],
125-
"y": [2, 4, 3, 5, 6, 4, 7, 8, 9, 10, 3, 5, 4, 6, 7, 5],
126-
"group": ["A", "A", "A", "A", "A", "A", "B", "B", "B", "B", "C", "C", "C", "C", "C", "C"],
127-
}
7+
import seaborn as sns
8+
9+
10+
# Data
11+
data = sns.load_dataset("iris")
12+
13+
# Color palette (colorblind safe from style guide)
14+
colors = ["#306998", "#FFD43B", "#DC2626"]
15+
species = data["species"].unique()
16+
color_map = {sp: colors[i] for i, sp in enumerate(species)}
17+
18+
# Create plot
19+
fig, ax = plt.subplots(figsize=(16, 9))
20+
21+
for species_name in species:
22+
subset = data[data["species"] == species_name]
23+
ax.scatter(
24+
subset["sepal_length"],
25+
subset["sepal_width"],
26+
c=color_map[species_name],
27+
label=species_name.capitalize(),
28+
alpha=0.7,
29+
s=80,
30+
edgecolors="white",
31+
linewidths=0.5,
12832
)
12933

130-
# Create plot
131-
fig = create_plot(data, "x", "y", "group", title="Scatter Plot with Color Groups")
34+
# Labels and styling
35+
ax.set_xlabel("Sepal Length (cm)", fontsize=20)
36+
ax.set_ylabel("Sepal Width (cm)", fontsize=20)
37+
ax.set_title("Iris Species by Sepal Dimensions", fontsize=20)
38+
ax.tick_params(axis="both", labelsize=16)
39+
ax.legend(title="Species", fontsize=16, title_fontsize=16)
40+
ax.grid(True, alpha=0.3)
13241

133-
# Save for inspection
134-
plt.savefig("plot.png", dpi=300, bbox_inches="tight")
135-
print("Plot saved to plot.png")
42+
plt.tight_layout()
43+
plt.savefig("plot.png", dpi=300, bbox_inches="tight")

0 commit comments

Comments
 (0)