|
1 | 1 | """ |
2 | 2 | scatter-color-groups: Scatter Plot with Color Groups |
3 | | -Implementation for: matplotlib |
4 | | -Variant: default |
5 | | -Python: 3.10+ |
| 3 | +Library: matplotlib |
6 | 4 | """ |
7 | 5 |
|
8 | | -from typing import TYPE_CHECKING |
9 | | - |
10 | 6 | import matplotlib.pyplot as plt |
| 7 | +import numpy as np |
11 | 8 | import pandas as pd |
12 | 9 |
|
13 | 10 |
|
14 | | -if TYPE_CHECKING: |
15 | | - from matplotlib.figure import Figure |
16 | | - |
17 | | - |
18 | | -def create_plot( |
19 | | - data: pd.DataFrame, |
20 | | - x: str, |
21 | | - y: str, |
22 | | - group: str, |
23 | | - figsize: tuple[float, float] = (16, 9), |
24 | | - alpha: float = 0.7, |
25 | | - size: float = 50, |
26 | | - title: str | None = None, |
27 | | - xlabel: str | None = None, |
28 | | - ylabel: str | None = None, |
29 | | - palette: str = "Set1", |
30 | | - **kwargs, |
31 | | -) -> "Figure": |
32 | | - """ |
33 | | - Create a scatter plot with points colored by categorical groups. |
34 | | -
|
35 | | - Visualizes data points in a 2D x-y space with distinct colors for each |
36 | | - categorical group, showing separate "color clouds" for different categories. |
37 | | -
|
38 | | - Args: |
39 | | - data: Input DataFrame with required columns |
40 | | - x: Column name for x-axis values |
41 | | - y: Column name for y-axis values |
42 | | - group: Column name for categorical grouping and coloring |
43 | | - figsize: Figure size as (width, height) tuple (default: (16, 9)) |
44 | | - alpha: Transparency level for points (default: 0.7) |
45 | | - size: Point size (default: 50) |
46 | | - title: Plot title (default: None) |
47 | | - xlabel: Custom x-axis label (default: uses column name) |
48 | | - ylabel: Custom y-axis label (default: uses column name) |
49 | | - palette: Matplotlib colormap or seaborn palette name (default: "Set1") |
50 | | - **kwargs: Additional parameters passed to scatter plot |
51 | | -
|
52 | | - Returns: |
53 | | - Matplotlib Figure object |
54 | | -
|
55 | | - Raises: |
56 | | - ValueError: If data is empty |
57 | | - KeyError: If required columns not found in data |
58 | | -
|
59 | | - Example: |
60 | | - >>> import pandas as pd |
61 | | - >>> data = pd.DataFrame({ |
62 | | - ... 'x': [1, 2, 3, 4, 5, 6], |
63 | | - ... 'y': [2, 4, 3, 5, 6, 4], |
64 | | - ... 'group': ['A', 'A', 'B', 'B', 'C', 'C'] |
65 | | - ... }) |
66 | | - >>> fig = create_plot(data, 'x', 'y', 'group') |
67 | | - >>> plt.savefig('plot.png') |
68 | | - """ |
69 | | - # Input validation |
70 | | - if data.empty: |
71 | | - raise ValueError("Data cannot be empty") |
72 | | - |
73 | | - # Check required columns |
74 | | - required_cols = [x, y, group] |
75 | | - for col in required_cols: |
76 | | - if col not in data.columns: |
77 | | - available = ", ".join(data.columns) |
78 | | - raise KeyError(f"Column '{col}' not found in data. Available columns: {available}") |
79 | | - |
80 | | - # Create figure and axis |
81 | | - fig, ax = plt.subplots(figsize=figsize) |
82 | | - |
83 | | - # Get unique groups and create color mapping |
84 | | - groups = data[group].unique() |
85 | | - |
86 | | - # Get colors from palette |
87 | | - try: |
88 | | - cmap = plt.get_cmap(palette) |
89 | | - colors = [cmap(i / max(len(groups) - 1, 1)) for i in range(len(groups))] |
90 | | - except (ValueError, AttributeError): |
91 | | - # Fallback to tab10 if palette not found |
92 | | - cmap = plt.get_cmap("tab10") |
93 | | - colors = [cmap(i % 10) for i in range(len(groups))] |
94 | | - |
95 | | - # Plot each group with a different color |
96 | | - for idx, group_val in enumerate(groups): |
97 | | - group_data = data[data[group] == group_val] |
98 | | - ax.scatter(group_data[x], group_data[y], label=str(group_val), alpha=alpha, s=size, color=colors[idx], **kwargs) |
99 | | - |
100 | | - # Set labels |
101 | | - ax.set_xlabel(xlabel or x, fontsize=11) |
102 | | - ax.set_ylabel(ylabel or y, fontsize=11) |
103 | | - |
104 | | - # Add title if provided |
105 | | - if title: |
106 | | - ax.set_title(title, fontsize=12, fontweight="bold", pad=15) |
107 | | - |
108 | | - # Add legend |
109 | | - ax.legend(title=group, loc="best", framealpha=0.9) |
110 | | - |
111 | | - # Add subtle grid |
112 | | - ax.grid(True, alpha=0.3, linestyle="--") |
113 | | - |
114 | | - # Layout |
115 | | - plt.tight_layout() |
116 | | - |
117 | | - return fig |
118 | | - |
119 | | - |
120 | | -if __name__ == "__main__": |
121 | | - # Sample data for testing |
122 | | - data = pd.DataFrame( |
123 | | - { |
124 | | - "x": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5], |
125 | | - "y": [2, 4, 3, 5, 6, 4, 7, 8, 9, 10, 3, 5, 4, 6, 7, 5], |
126 | | - "group": ["A", "A", "A", "A", "A", "A", "B", "B", "B", "B", "C", "C", "C", "C", "C", "C"], |
127 | | - } |
| 11 | +# Data - Iris-like dataset |
| 12 | +np.random.seed(42) |
| 13 | +n_per_group = 50 |
| 14 | + |
| 15 | +data = pd.DataFrame({ |
| 16 | + "sepal_length": np.concatenate([ |
| 17 | + np.random.normal(5.0, 0.35, n_per_group), |
| 18 | + np.random.normal(5.9, 0.50, n_per_group), |
| 19 | + np.random.normal(6.6, 0.60, n_per_group), |
| 20 | + ]), |
| 21 | + "sepal_width": np.concatenate([ |
| 22 | + np.random.normal(3.4, 0.38, n_per_group), |
| 23 | + np.random.normal(2.8, 0.30, n_per_group), |
| 24 | + np.random.normal(3.0, 0.30, n_per_group), |
| 25 | + ]), |
| 26 | + "species": ["setosa"] * n_per_group + ["versicolor"] * n_per_group + ["virginica"] * n_per_group, |
| 27 | +}) |
| 28 | + |
| 29 | +# Color palette (colorblind safe from style guide) |
| 30 | +colors = ["#306998", "#FFD43B", "#DC2626"] |
| 31 | +species = data["species"].unique() |
| 32 | +color_map = {sp: colors[i] for i, sp in enumerate(species)} |
| 33 | + |
| 34 | +# Create plot |
| 35 | +fig, ax = plt.subplots(figsize=(16, 9)) |
| 36 | + |
| 37 | +for species_name in species: |
| 38 | + subset = data[data["species"] == species_name] |
| 39 | + ax.scatter( |
| 40 | + subset["sepal_length"], |
| 41 | + subset["sepal_width"], |
| 42 | + c=color_map[species_name], |
| 43 | + label=species_name.capitalize(), |
| 44 | + alpha=0.7, |
| 45 | + s=80, |
| 46 | + edgecolors="white", |
| 47 | + linewidths=0.5, |
128 | 48 | ) |
129 | 49 |
|
130 | | - # Create plot |
131 | | - fig = create_plot(data, "x", "y", "group", title="Scatter Plot with Color Groups") |
| 50 | +# Labels and styling |
| 51 | +ax.set_xlabel("Sepal Length (cm)", fontsize=20) |
| 52 | +ax.set_ylabel("Sepal Width (cm)", fontsize=20) |
| 53 | +ax.set_title("Iris Species by Sepal Dimensions", fontsize=20) |
| 54 | +ax.tick_params(axis="both", labelsize=16) |
| 55 | +ax.legend(title="Species", fontsize=16, title_fontsize=16) |
| 56 | +ax.grid(True, alpha=0.3) |
132 | 57 |
|
133 | | - # Save for inspection |
134 | | - plt.savefig("plot.png", dpi=300, bbox_inches="tight") |
135 | | - print("Plot saved to plot.png") |
| 58 | +plt.tight_layout() |
| 59 | +plt.savefig("plot.png", dpi=300, bbox_inches="tight") |
0 commit comments