diff --git a/plots/matplotlib/scatter/scatter-color-groups/default.py b/plots/matplotlib/scatter/scatter-color-groups/default.py index 02fef6ecce..7f45e2c195 100644 --- a/plots/matplotlib/scatter/scatter-color-groups/default.py +++ b/plots/matplotlib/scatter/scatter-color-groups/default.py @@ -1,135 +1,59 @@ """ scatter-color-groups: Scatter Plot with Color Groups -Implementation for: matplotlib -Variant: default -Python: 3.10+ +Library: matplotlib """ -from typing import TYPE_CHECKING - import matplotlib.pyplot as plt +import numpy as np import pandas as pd -if TYPE_CHECKING: - from matplotlib.figure import Figure - - -def create_plot( - data: pd.DataFrame, - x: str, - y: str, - group: str, - figsize: tuple[float, float] = (16, 9), - alpha: float = 0.7, - size: float = 50, - title: str | None = None, - xlabel: str | None = None, - ylabel: str | None = None, - palette: str = "Set1", - **kwargs, -) -> "Figure": - """ - Create a scatter plot with points colored by categorical groups. - - Visualizes data points in a 2D x-y space with distinct colors for each - categorical group, showing separate "color clouds" for different categories. - - Args: - data: Input DataFrame with required columns - x: Column name for x-axis values - y: Column name for y-axis values - group: Column name for categorical grouping and coloring - figsize: Figure size as (width, height) tuple (default: (16, 9)) - alpha: Transparency level for points (default: 0.7) - size: Point size (default: 50) - title: Plot title (default: None) - xlabel: Custom x-axis label (default: uses column name) - ylabel: Custom y-axis label (default: uses column name) - palette: Matplotlib colormap or seaborn palette name (default: "Set1") - **kwargs: Additional parameters passed to scatter plot - - Returns: - Matplotlib Figure object - - Raises: - ValueError: If data is empty - KeyError: If required columns not found in data - - Example: - >>> import pandas as pd - >>> data = pd.DataFrame({ - ... 'x': [1, 2, 3, 4, 5, 6], - ... 'y': [2, 4, 3, 5, 6, 4], - ... 'group': ['A', 'A', 'B', 'B', 'C', 'C'] - ... }) - >>> fig = create_plot(data, 'x', 'y', 'group') - >>> plt.savefig('plot.png') - """ - # Input validation - if data.empty: - raise ValueError("Data cannot be empty") - - # Check required columns - required_cols = [x, y, group] - for col in required_cols: - if col not in data.columns: - available = ", ".join(data.columns) - raise KeyError(f"Column '{col}' not found in data. Available columns: {available}") - - # Create figure and axis - fig, ax = plt.subplots(figsize=figsize) - - # Get unique groups and create color mapping - groups = data[group].unique() - - # Get colors from palette - try: - cmap = plt.get_cmap(palette) - colors = [cmap(i / max(len(groups) - 1, 1)) for i in range(len(groups))] - except (ValueError, AttributeError): - # Fallback to tab10 if palette not found - cmap = plt.get_cmap("tab10") - colors = [cmap(i % 10) for i in range(len(groups))] - - # Plot each group with a different color - for idx, group_val in enumerate(groups): - group_data = data[data[group] == group_val] - ax.scatter(group_data[x], group_data[y], label=str(group_val), alpha=alpha, s=size, color=colors[idx], **kwargs) - - # Set labels - ax.set_xlabel(xlabel or x, fontsize=11) - ax.set_ylabel(ylabel or y, fontsize=11) - - # Add title if provided - if title: - ax.set_title(title, fontsize=12, fontweight="bold", pad=15) - - # Add legend - ax.legend(title=group, loc="best", framealpha=0.9) - - # Add subtle grid - ax.grid(True, alpha=0.3, linestyle="--") - - # Layout - plt.tight_layout() - - return fig - - -if __name__ == "__main__": - # Sample data for testing - data = pd.DataFrame( - { - "x": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5], - "y": [2, 4, 3, 5, 6, 4, 7, 8, 9, 10, 3, 5, 4, 6, 7, 5], - "group": ["A", "A", "A", "A", "A", "A", "B", "B", "B", "B", "C", "C", "C", "C", "C", "C"], - } +# Data - Iris-like dataset +np.random.seed(42) +n_per_group = 50 + +data = pd.DataFrame({ + "sepal_length": np.concatenate([ + np.random.normal(5.0, 0.35, n_per_group), + np.random.normal(5.9, 0.50, n_per_group), + np.random.normal(6.6, 0.60, n_per_group), + ]), + "sepal_width": np.concatenate([ + np.random.normal(3.4, 0.38, n_per_group), + np.random.normal(2.8, 0.30, n_per_group), + np.random.normal(3.0, 0.30, n_per_group), + ]), + "species": ["setosa"] * n_per_group + ["versicolor"] * n_per_group + ["virginica"] * n_per_group, +}) + +# Color palette (colorblind safe from style guide) +colors = ["#306998", "#FFD43B", "#DC2626"] +species = data["species"].unique() +color_map = {sp: colors[i] for i, sp in enumerate(species)} + +# Create plot +fig, ax = plt.subplots(figsize=(16, 9)) + +for species_name in species: + subset = data[data["species"] == species_name] + ax.scatter( + subset["sepal_length"], + subset["sepal_width"], + c=color_map[species_name], + label=species_name.capitalize(), + alpha=0.7, + s=80, + edgecolors="white", + linewidths=0.5, ) - # Create plot - fig = create_plot(data, "x", "y", "group", title="Scatter Plot with Color Groups") +# Labels and styling +ax.set_xlabel("Sepal Length (cm)", fontsize=20) +ax.set_ylabel("Sepal Width (cm)", fontsize=20) +ax.set_title("Iris Species by Sepal Dimensions", fontsize=20) +ax.tick_params(axis="both", labelsize=16) +ax.legend(title="Species", fontsize=16, title_fontsize=16) +ax.grid(True, alpha=0.3) - # Save for inspection - plt.savefig("plot.png", dpi=300, bbox_inches="tight") - print("Plot saved to plot.png") +plt.tight_layout() +plt.savefig("plot.png", dpi=300, bbox_inches="tight")