Skip to content

Commit 3965cfb

Browse files
feat(seaborn): implement box-basic
Simplify implementation to follow KISS principles: - Remove function wrapper, use sequential script style - Use style guide color palette - Follow seaborn 0.14+ API (hue with palette, legend=False) - Use spec-defined sample data
1 parent 27357ca commit 3965cfb

1 file changed

Lines changed: 43 additions & 179 deletions

File tree

Lines changed: 43 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -1,189 +1,53 @@
11
"""
22
box-basic: Basic Box Plot
3-
Implementation for: seaborn
4-
Variant: default
5-
Python: 3.10+
3+
Library: seaborn
64
"""
75

8-
from typing import TYPE_CHECKING, Optional
9-
106
import matplotlib.pyplot as plt
117
import numpy as np
128
import pandas as pd
139
import seaborn as sns
1410

1511

16-
if TYPE_CHECKING:
17-
from matplotlib.figure import Figure
18-
19-
20-
def create_plot(
21-
data: pd.DataFrame,
22-
values: str,
23-
groups: str,
24-
title: Optional[str] = None,
25-
xlabel: Optional[str] = None,
26-
ylabel: Optional[str] = None,
27-
palette: Optional[str] = "Set2",
28-
figsize: tuple[float, float] = (16, 9),
29-
showfliers: bool = True,
30-
**kwargs,
31-
) -> Figure:
32-
"""
33-
Create a basic box plot showing statistical distribution of multiple groups using seaborn.
34-
35-
Args:
36-
data: Input DataFrame with required columns
37-
values: Column name containing numeric values
38-
groups: Column name containing group categories
39-
title: Plot title (optional)
40-
xlabel: Custom x-axis label (optional, defaults to groups column name)
41-
ylabel: Custom y-axis label (optional, defaults to values column name)
42-
palette: Color palette name for boxes (default: 'Set2')
43-
figsize: Figure size as (width, height) in inches (default: (16, 9))
44-
showfliers: Whether to show outliers (default: True)
45-
**kwargs: Additional parameters passed to seaborn boxplot function
46-
47-
Returns:
48-
Matplotlib Figure object
49-
50-
Raises:
51-
ValueError: If data is empty
52-
KeyError: If required columns not found
53-
54-
Example:
55-
>>> data = pd.DataFrame({
56-
... 'Group': ['A', 'A', 'B', 'B', 'C', 'C'],
57-
... 'Value': [1, 2, 2, 3, 3, 4]
58-
... })
59-
>>> fig = create_plot(data, values='Value', groups='Group')
60-
"""
61-
# Input validation
62-
if data.empty:
63-
raise ValueError("Data cannot be empty")
64-
65-
# Check required columns
66-
for col in [values, groups]:
67-
if col not in data.columns:
68-
available = ", ".join(data.columns)
69-
raise KeyError(f"Column '{col}' not found. Available columns: {available}")
70-
71-
# Create figure
72-
fig, ax = plt.subplots(figsize=figsize)
73-
74-
# Create boxplot with seaborn
75-
sns.boxplot(
76-
data=data,
77-
x=groups,
78-
y=values,
79-
hue=groups,
80-
palette=palette,
81-
ax=ax,
82-
showfliers=showfliers,
83-
width=0.7,
84-
linewidth=1.5,
85-
fliersize=6,
86-
legend=False,
87-
**kwargs,
88-
)
89-
90-
# Customize the appearance
91-
# Set median line color to be more visible
92-
for patch in ax.artists:
93-
# Get the current face color
94-
r, g, b, a = patch.get_facecolor()
95-
# Set the box face color with some transparency
96-
patch.set_facecolor((r, g, b, 0.7))
97-
# Set edge color
98-
patch.set_edgecolor("black")
99-
patch.set_linewidth(1.2)
100-
101-
# Style the median lines
102-
for line in ax.lines:
103-
# Median lines are the ones inside the boxes
104-
if line.get_linestyle() == "-" and line.get_marker() == "None":
105-
line.set_color("red")
106-
line.set_linewidth(2)
107-
108-
# Labels and title
109-
ax.set_xlabel(xlabel or groups)
110-
ax.set_ylabel(ylabel or values)
111-
112-
if title:
113-
ax.set_title(title, fontsize=14, fontweight="bold", pad=20)
114-
115-
# Grid for better readability
116-
ax.grid(True, axis="y", alpha=0.3, linestyle="--")
117-
ax.set_axisbelow(True)
118-
119-
# Rotate x-axis labels if there are many groups
120-
unique_groups = data[groups].nunique()
121-
if unique_groups > 5:
122-
plt.xticks(rotation=45, ha="right")
123-
124-
# Add some statistical annotations
125-
# Calculate and display the number of data points per group
126-
group_counts = data.groupby(groups)[values].count()
127-
y_bottom = ax.get_ylim()[0]
128-
for i, (_group_name, count) in enumerate(group_counts.items()):
129-
ax.text(i, y_bottom, f"n={count}", ha="center", va="top", fontsize=9, alpha=0.7)
130-
131-
# Apply seaborn style for better aesthetics
132-
sns.despine(ax=ax)
133-
134-
# Layout
135-
plt.tight_layout()
136-
137-
return fig
138-
139-
140-
if __name__ == "__main__":
141-
# Sample data for testing with different distributions per group
142-
np.random.seed(42) # For reproducibility
143-
144-
# Generate sample data with 4 groups
145-
data_dict = {"Group": [], "Value": []}
146-
147-
# Group A: Normal distribution, mean=50, std=10
148-
group_a_data = np.random.normal(50, 10, 40)
149-
# Add some outliers
150-
group_a_data = np.append(group_a_data, [80, 85, 15])
151-
152-
# Group B: Normal distribution, mean=60, std=15
153-
group_b_data = np.random.normal(60, 15, 35)
154-
# Add outliers
155-
group_b_data = np.append(group_b_data, [100, 10])
156-
157-
# Group C: Normal distribution, mean=45, std=8
158-
group_c_data = np.random.normal(45, 8, 45)
159-
160-
# Group D: Skewed distribution
161-
group_d_data = np.random.gamma(2, 2, 30) + 40
162-
# Add outliers
163-
group_d_data = np.append(group_d_data, [75, 78, 20])
164-
165-
# Combine all data
166-
for group, values in zip(
167-
["Group A", "Group B", "Group C", "Group D"],
168-
[group_a_data, group_b_data, group_c_data, group_d_data],
169-
strict=False,
170-
):
171-
data_dict["Group"].extend([group] * len(values))
172-
data_dict["Value"].extend(values)
173-
174-
data = pd.DataFrame(data_dict)
175-
176-
# Create plot
177-
fig = create_plot(
178-
data,
179-
values="Value",
180-
groups="Group",
181-
title="Statistical Distribution Comparison Across Groups",
182-
ylabel="Measurement Value",
183-
xlabel="Categories",
184-
palette="Set2",
185-
)
186-
187-
# Save for inspection
188-
plt.savefig("plot.png", dpi=300, bbox_inches="tight")
189-
print("Plot saved to plot.png")
12+
# Data
13+
np.random.seed(42)
14+
data = pd.DataFrame(
15+
{
16+
"group": ["A"] * 50 + ["B"] * 50 + ["C"] * 50 + ["D"] * 50,
17+
"value": np.concatenate(
18+
[
19+
np.random.normal(50, 10, 50),
20+
np.random.normal(60, 15, 50),
21+
np.random.normal(45, 8, 50),
22+
np.random.normal(70, 20, 50),
23+
]
24+
),
25+
}
26+
)
27+
28+
# Custom color palette using style guide colors
29+
colors = ["#306998", "#FFD43B", "#DC2626", "#059669"]
30+
31+
# Create plot
32+
fig, ax = plt.subplots(figsize=(16, 9))
33+
sns.boxplot(
34+
data=data,
35+
x="group",
36+
y="value",
37+
hue="group",
38+
palette=colors,
39+
legend=False,
40+
ax=ax,
41+
linewidth=2,
42+
flierprops={"marker": "o", "markersize": 8, "alpha": 0.7},
43+
)
44+
45+
# Labels and styling
46+
ax.set_xlabel("Group", fontsize=20)
47+
ax.set_ylabel("Value", fontsize=20)
48+
ax.set_title("Basic Box Plot", fontsize=20)
49+
ax.tick_params(axis="both", labelsize=16)
50+
ax.grid(True, alpha=0.3, axis="y")
51+
52+
plt.tight_layout()
53+
plt.savefig("plot.png", dpi=300, bbox_inches="tight")

0 commit comments

Comments
 (0)