Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 48 additions & 124 deletions plots/matplotlib/scatter/scatter-color-groups/default.py
Original file line number Diff line number Diff line change
@@ -1,135 +1,59 @@
"""
scatter-color-groups: Scatter Plot with Color Groups
Implementation for: matplotlib
Variant: default
Python: 3.10+
Library: matplotlib
"""

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


if TYPE_CHECKING:
from matplotlib.figure import Figure


def create_plot(
data: pd.DataFrame,
x: str,
y: str,
group: str,
figsize: tuple[float, float] = (16, 9),
alpha: float = 0.7,
size: float = 50,
title: str | None = None,
xlabel: str | None = None,
ylabel: str | None = None,
palette: str = "Set1",
**kwargs,
) -> "Figure":
"""
Create a scatter plot with points colored by categorical groups.

Visualizes data points in a 2D x-y space with distinct colors for each
categorical group, showing separate "color clouds" for different categories.

Args:
data: Input DataFrame with required columns
x: Column name for x-axis values
y: Column name for y-axis values
group: Column name for categorical grouping and coloring
figsize: Figure size as (width, height) tuple (default: (16, 9))
alpha: Transparency level for points (default: 0.7)
size: Point size (default: 50)
title: Plot title (default: None)
xlabel: Custom x-axis label (default: uses column name)
ylabel: Custom y-axis label (default: uses column name)
palette: Matplotlib colormap or seaborn palette name (default: "Set1")
**kwargs: Additional parameters passed to scatter plot

Returns:
Matplotlib Figure object

Raises:
ValueError: If data is empty
KeyError: If required columns not found in data

Example:
>>> import pandas as pd
>>> data = pd.DataFrame({
... 'x': [1, 2, 3, 4, 5, 6],
... 'y': [2, 4, 3, 5, 6, 4],
... 'group': ['A', 'A', 'B', 'B', 'C', 'C']
... })
>>> fig = create_plot(data, 'x', 'y', 'group')
>>> plt.savefig('plot.png')
"""
# Input validation
if data.empty:
raise ValueError("Data cannot be empty")

# Check required columns
required_cols = [x, y, group]
for col in required_cols:
if col not in data.columns:
available = ", ".join(data.columns)
raise KeyError(f"Column '{col}' not found in data. Available columns: {available}")

# Create figure and axis
fig, ax = plt.subplots(figsize=figsize)

# Get unique groups and create color mapping
groups = data[group].unique()

# Get colors from palette
try:
cmap = plt.get_cmap(palette)
colors = [cmap(i / max(len(groups) - 1, 1)) for i in range(len(groups))]
except (ValueError, AttributeError):
# Fallback to tab10 if palette not found
cmap = plt.get_cmap("tab10")
colors = [cmap(i % 10) for i in range(len(groups))]

# Plot each group with a different color
for idx, group_val in enumerate(groups):
group_data = data[data[group] == group_val]
ax.scatter(group_data[x], group_data[y], label=str(group_val), alpha=alpha, s=size, color=colors[idx], **kwargs)

# Set labels
ax.set_xlabel(xlabel or x, fontsize=11)
ax.set_ylabel(ylabel or y, fontsize=11)

# Add title if provided
if title:
ax.set_title(title, fontsize=12, fontweight="bold", pad=15)

# Add legend
ax.legend(title=group, loc="best", framealpha=0.9)

# Add subtle grid
ax.grid(True, alpha=0.3, linestyle="--")

# Layout
plt.tight_layout()

return fig


if __name__ == "__main__":
# Sample data for testing
data = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5],
"y": [2, 4, 3, 5, 6, 4, 7, 8, 9, 10, 3, 5, 4, 6, 7, 5],
"group": ["A", "A", "A", "A", "A", "A", "B", "B", "B", "B", "C", "C", "C", "C", "C", "C"],
}
# Data - Iris-like dataset
np.random.seed(42)
n_per_group = 50

data = pd.DataFrame({
"sepal_length": np.concatenate([
np.random.normal(5.0, 0.35, n_per_group),
np.random.normal(5.9, 0.50, n_per_group),
np.random.normal(6.6, 0.60, n_per_group),
]),
"sepal_width": np.concatenate([
np.random.normal(3.4, 0.38, n_per_group),
np.random.normal(2.8, 0.30, n_per_group),
np.random.normal(3.0, 0.30, n_per_group),
]),
"species": ["setosa"] * n_per_group + ["versicolor"] * n_per_group + ["virginica"] * n_per_group,
})

# Color palette (colorblind safe from style guide)
colors = ["#306998", "#FFD43B", "#DC2626"]
species = data["species"].unique()
color_map = {sp: colors[i] for i, sp in enumerate(species)}

# Create plot
fig, ax = plt.subplots(figsize=(16, 9))

for species_name in species:
subset = data[data["species"] == species_name]
ax.scatter(
subset["sepal_length"],
subset["sepal_width"],
c=color_map[species_name],
label=species_name.capitalize(),
alpha=0.7,
s=80,
edgecolors="white",
linewidths=0.5,
)

# Create plot
fig = create_plot(data, "x", "y", "group", title="Scatter Plot with Color Groups")
# Labels and styling
ax.set_xlabel("Sepal Length (cm)", fontsize=20)
ax.set_ylabel("Sepal Width (cm)", fontsize=20)
ax.set_title("Iris Species by Sepal Dimensions", fontsize=20)
ax.tick_params(axis="both", labelsize=16)
ax.legend(title="Species", fontsize=16, title_fontsize=16)
ax.grid(True, alpha=0.3)

# Save for inspection
plt.savefig("plot.png", dpi=300, bbox_inches="tight")
print("Plot saved to plot.png")
plt.tight_layout()
plt.savefig("plot.png", dpi=300, bbox_inches="tight")
Loading