|
1 | 1 | """ |
2 | 2 | box-basic: Basic Box Plot |
3 | | -Implementation for: seaborn |
4 | | -Variant: default |
5 | | -Python: 3.10+ |
| 3 | +Library: seaborn |
6 | 4 | """ |
7 | 5 |
|
8 | | -from typing import TYPE_CHECKING, Optional |
9 | | - |
10 | 6 | import matplotlib.pyplot as plt |
11 | 7 | import numpy as np |
12 | 8 | import pandas as pd |
13 | 9 | import seaborn as sns |
14 | 10 |
|
15 | 11 |
|
16 | | -if TYPE_CHECKING: |
17 | | - from matplotlib.figure import Figure |
18 | | - |
19 | | - |
20 | | -def create_plot( |
21 | | - data: pd.DataFrame, |
22 | | - values: str, |
23 | | - groups: str, |
24 | | - title: Optional[str] = None, |
25 | | - xlabel: Optional[str] = None, |
26 | | - ylabel: Optional[str] = None, |
27 | | - palette: Optional[str] = "Set2", |
28 | | - figsize: tuple[float, float] = (16, 9), |
29 | | - showfliers: bool = True, |
30 | | - **kwargs, |
31 | | -) -> Figure: |
32 | | - """ |
33 | | - Create a basic box plot showing statistical distribution of multiple groups using seaborn. |
34 | | -
|
35 | | - Args: |
36 | | - data: Input DataFrame with required columns |
37 | | - values: Column name containing numeric values |
38 | | - groups: Column name containing group categories |
39 | | - title: Plot title (optional) |
40 | | - xlabel: Custom x-axis label (optional, defaults to groups column name) |
41 | | - ylabel: Custom y-axis label (optional, defaults to values column name) |
42 | | - palette: Color palette name for boxes (default: 'Set2') |
43 | | - figsize: Figure size as (width, height) in inches (default: (16, 9)) |
44 | | - showfliers: Whether to show outliers (default: True) |
45 | | - **kwargs: Additional parameters passed to seaborn boxplot function |
46 | | -
|
47 | | - Returns: |
48 | | - Matplotlib Figure object |
49 | | -
|
50 | | - Raises: |
51 | | - ValueError: If data is empty |
52 | | - KeyError: If required columns not found |
53 | | -
|
54 | | - Example: |
55 | | - >>> data = pd.DataFrame({ |
56 | | - ... 'Group': ['A', 'A', 'B', 'B', 'C', 'C'], |
57 | | - ... 'Value': [1, 2, 2, 3, 3, 4] |
58 | | - ... }) |
59 | | - >>> fig = create_plot(data, values='Value', groups='Group') |
60 | | - """ |
61 | | - # Input validation |
62 | | - if data.empty: |
63 | | - raise ValueError("Data cannot be empty") |
64 | | - |
65 | | - # Check required columns |
66 | | - for col in [values, groups]: |
67 | | - if col not in data.columns: |
68 | | - available = ", ".join(data.columns) |
69 | | - raise KeyError(f"Column '{col}' not found. Available columns: {available}") |
70 | | - |
71 | | - # Create figure |
72 | | - fig, ax = plt.subplots(figsize=figsize) |
73 | | - |
74 | | - # Create boxplot with seaborn |
75 | | - sns.boxplot( |
76 | | - data=data, |
77 | | - x=groups, |
78 | | - y=values, |
79 | | - hue=groups, |
80 | | - palette=palette, |
81 | | - ax=ax, |
82 | | - showfliers=showfliers, |
83 | | - width=0.7, |
84 | | - linewidth=1.5, |
85 | | - fliersize=6, |
86 | | - legend=False, |
87 | | - **kwargs, |
88 | | - ) |
89 | | - |
90 | | - # Customize the appearance |
91 | | - # Set median line color to be more visible |
92 | | - for patch in ax.artists: |
93 | | - # Get the current face color |
94 | | - r, g, b, a = patch.get_facecolor() |
95 | | - # Set the box face color with some transparency |
96 | | - patch.set_facecolor((r, g, b, 0.7)) |
97 | | - # Set edge color |
98 | | - patch.set_edgecolor("black") |
99 | | - patch.set_linewidth(1.2) |
100 | | - |
101 | | - # Style the median lines |
102 | | - for line in ax.lines: |
103 | | - # Median lines are the ones inside the boxes |
104 | | - if line.get_linestyle() == "-" and line.get_marker() == "None": |
105 | | - line.set_color("red") |
106 | | - line.set_linewidth(2) |
107 | | - |
108 | | - # Labels and title |
109 | | - ax.set_xlabel(xlabel or groups) |
110 | | - ax.set_ylabel(ylabel or values) |
111 | | - |
112 | | - if title: |
113 | | - ax.set_title(title, fontsize=14, fontweight="bold", pad=20) |
114 | | - |
115 | | - # Grid for better readability |
116 | | - ax.grid(True, axis="y", alpha=0.3, linestyle="--") |
117 | | - ax.set_axisbelow(True) |
118 | | - |
119 | | - # Rotate x-axis labels if there are many groups |
120 | | - unique_groups = data[groups].nunique() |
121 | | - if unique_groups > 5: |
122 | | - plt.xticks(rotation=45, ha="right") |
123 | | - |
124 | | - # Add some statistical annotations |
125 | | - # Calculate and display the number of data points per group |
126 | | - group_counts = data.groupby(groups)[values].count() |
127 | | - y_bottom = ax.get_ylim()[0] |
128 | | - for i, (_group_name, count) in enumerate(group_counts.items()): |
129 | | - ax.text(i, y_bottom, f"n={count}", ha="center", va="top", fontsize=9, alpha=0.7) |
130 | | - |
131 | | - # Apply seaborn style for better aesthetics |
132 | | - sns.despine(ax=ax) |
133 | | - |
134 | | - # Layout |
135 | | - plt.tight_layout() |
136 | | - |
137 | | - return fig |
138 | | - |
139 | | - |
140 | | -if __name__ == "__main__": |
141 | | - # Sample data for testing with different distributions per group |
142 | | - np.random.seed(42) # For reproducibility |
143 | | - |
144 | | - # Generate sample data with 4 groups |
145 | | - data_dict = {"Group": [], "Value": []} |
146 | | - |
147 | | - # Group A: Normal distribution, mean=50, std=10 |
148 | | - group_a_data = np.random.normal(50, 10, 40) |
149 | | - # Add some outliers |
150 | | - group_a_data = np.append(group_a_data, [80, 85, 15]) |
151 | | - |
152 | | - # Group B: Normal distribution, mean=60, std=15 |
153 | | - group_b_data = np.random.normal(60, 15, 35) |
154 | | - # Add outliers |
155 | | - group_b_data = np.append(group_b_data, [100, 10]) |
156 | | - |
157 | | - # Group C: Normal distribution, mean=45, std=8 |
158 | | - group_c_data = np.random.normal(45, 8, 45) |
159 | | - |
160 | | - # Group D: Skewed distribution |
161 | | - group_d_data = np.random.gamma(2, 2, 30) + 40 |
162 | | - # Add outliers |
163 | | - group_d_data = np.append(group_d_data, [75, 78, 20]) |
164 | | - |
165 | | - # Combine all data |
166 | | - for group, values in zip( |
167 | | - ["Group A", "Group B", "Group C", "Group D"], |
168 | | - [group_a_data, group_b_data, group_c_data, group_d_data], |
169 | | - strict=False, |
170 | | - ): |
171 | | - data_dict["Group"].extend([group] * len(values)) |
172 | | - data_dict["Value"].extend(values) |
173 | | - |
174 | | - data = pd.DataFrame(data_dict) |
175 | | - |
176 | | - # Create plot |
177 | | - fig = create_plot( |
178 | | - data, |
179 | | - values="Value", |
180 | | - groups="Group", |
181 | | - title="Statistical Distribution Comparison Across Groups", |
182 | | - ylabel="Measurement Value", |
183 | | - xlabel="Categories", |
184 | | - palette="Set2", |
185 | | - ) |
186 | | - |
187 | | - # Save for inspection |
188 | | - plt.savefig("plot.png", dpi=300, bbox_inches="tight") |
189 | | - print("Plot saved to plot.png") |
| 12 | +# Data |
| 13 | +np.random.seed(42) |
| 14 | +data = pd.DataFrame( |
| 15 | + { |
| 16 | + "group": ["A"] * 50 + ["B"] * 50 + ["C"] * 50 + ["D"] * 50, |
| 17 | + "value": np.concatenate( |
| 18 | + [ |
| 19 | + np.random.normal(50, 10, 50), |
| 20 | + np.random.normal(60, 15, 50), |
| 21 | + np.random.normal(45, 8, 50), |
| 22 | + np.random.normal(70, 20, 50), |
| 23 | + ] |
| 24 | + ), |
| 25 | + } |
| 26 | +) |
| 27 | + |
| 28 | +# Custom color palette using style guide colors |
| 29 | +colors = ["#306998", "#FFD43B", "#DC2626", "#059669"] |
| 30 | + |
| 31 | +# Create plot |
| 32 | +fig, ax = plt.subplots(figsize=(16, 9)) |
| 33 | +sns.boxplot( |
| 34 | + data=data, |
| 35 | + x="group", |
| 36 | + y="value", |
| 37 | + hue="group", |
| 38 | + palette=colors, |
| 39 | + legend=False, |
| 40 | + ax=ax, |
| 41 | + linewidth=2, |
| 42 | + flierprops={"marker": "o", "markersize": 8, "alpha": 0.7}, |
| 43 | +) |
| 44 | + |
| 45 | +# Labels and styling |
| 46 | +ax.set_xlabel("Group", fontsize=20) |
| 47 | +ax.set_ylabel("Value", fontsize=20) |
| 48 | +ax.set_title("Basic Box Plot", fontsize=20) |
| 49 | +ax.tick_params(axis="both", labelsize=16) |
| 50 | +ax.grid(True, alpha=0.3, axis="y") |
| 51 | + |
| 52 | +plt.tight_layout() |
| 53 | +plt.savefig("plot.png", dpi=300, bbox_inches="tight") |
0 commit comments