|
1 | 1 | """ |
2 | 2 | box-basic: Basic Box Plot |
3 | | -Implementation for: bokeh |
4 | | -Variant: default |
5 | | -Python: 3.10+ |
| 3 | +Library: bokeh |
6 | 4 | """ |
7 | 5 |
|
8 | | -from typing import TYPE_CHECKING, Optional |
9 | | - |
10 | 6 | import numpy as np |
11 | 7 | import pandas as pd |
12 | | -from bokeh.models import ColumnDataSource, FixedTicker, Label, Whisker |
| 8 | +from bokeh.io import export_png |
| 9 | +from bokeh.models import ColumnDataSource, FixedTicker, Whisker |
13 | 10 | from bokeh.plotting import figure |
14 | 11 |
|
15 | 12 |
|
16 | | -if TYPE_CHECKING: |
17 | | - from bokeh.plotting import Figure |
18 | | - |
19 | | - |
20 | | -def create_plot( |
21 | | - data: pd.DataFrame, |
22 | | - values: str, |
23 | | - groups: str, |
24 | | - title: Optional[str] = None, |
25 | | - xlabel: Optional[str] = None, |
26 | | - ylabel: Optional[str] = None, |
27 | | - colors: Optional[list] = None, |
28 | | - width: int = 1600, |
29 | | - height: int = 900, |
30 | | - **kwargs, |
31 | | -) -> Figure: |
32 | | - """ |
33 | | - Create a basic box plot showing statistical distribution of multiple groups using bokeh. |
34 | | -
|
35 | | - Args: |
36 | | - data: Input DataFrame with required columns |
37 | | - values: Column name containing numeric values |
38 | | - groups: Column name containing group categories |
39 | | - title: Plot title (optional) |
40 | | - xlabel: Custom x-axis label (optional, defaults to groups column name) |
41 | | - ylabel: Custom y-axis label (optional, defaults to values column name) |
42 | | - colors: List of colors for each box (optional) |
43 | | - width: Figure width in pixels (default: 1600) |
44 | | - height: Figure height in pixels (default: 900) |
45 | | - **kwargs: Additional parameters |
46 | | -
|
47 | | - Returns: |
48 | | - Bokeh Figure object |
49 | | -
|
50 | | - Raises: |
51 | | - ValueError: If data is empty |
52 | | - KeyError: If required columns not found |
53 | | -
|
54 | | - Example: |
55 | | - >>> data = pd.DataFrame({ |
56 | | - ... 'Group': ['A', 'A', 'B', 'B', 'C', 'C'], |
57 | | - ... 'Value': [1, 2, 2, 3, 3, 4] |
58 | | - ... }) |
59 | | - >>> fig = create_plot(data, values='Value', groups='Group') |
60 | | - """ |
61 | | - # Input validation |
62 | | - if data.empty: |
63 | | - raise ValueError("Data cannot be empty") |
64 | | - |
65 | | - # Check required columns |
66 | | - for col in [values, groups]: |
67 | | - if col not in data.columns: |
68 | | - available = ", ".join(data.columns) |
69 | | - raise KeyError(f"Column '{col}' not found. Available columns: {available}") |
70 | | - |
71 | | - # Calculate box plot statistics for each group |
72 | | - group_names = sorted(data[groups].unique()) |
73 | | - n_groups = len(group_names) |
74 | | - |
75 | | - # Prepare data for box plot |
76 | | - stats = {"x": [], "q1": [], "q2": [], "q3": [], "upper": [], "lower": [], "group": []} |
77 | | - outliers = {"x": [], "y": []} |
78 | | - |
79 | | - for i, group in enumerate(group_names): |
80 | | - group_data = data[data[groups] == group][values].dropna() |
81 | | - |
82 | | - q1 = group_data.quantile(0.25) |
83 | | - q2 = group_data.quantile(0.5) |
84 | | - q3 = group_data.quantile(0.75) |
85 | | - iqr = q3 - q1 |
86 | | - upper = min(group_data.max(), q3 + 1.5 * iqr) |
87 | | - lower = max(group_data.min(), q1 - 1.5 * iqr) |
88 | | - |
89 | | - stats["x"].append(i) |
90 | | - stats["q1"].append(q1) |
91 | | - stats["q2"].append(q2) |
92 | | - stats["q3"].append(q3) |
93 | | - stats["upper"].append(upper) |
94 | | - stats["lower"].append(lower) |
95 | | - stats["group"].append(group) |
96 | | - |
97 | | - # Find outliers |
98 | | - outlier_data = group_data[(group_data < lower) | (group_data > upper)] |
99 | | - for val in outlier_data: |
100 | | - outliers["x"].append(i) |
101 | | - outliers["y"].append(val) |
102 | | - |
103 | | - # Set colors |
104 | | - if not colors: |
105 | | - from bokeh.palettes import Set2_8 |
106 | | - |
107 | | - colors = Set2_8[:n_groups] |
108 | | - |
109 | | - # Create figure with numeric x-axis |
110 | | - p = figure( |
111 | | - width=width, |
112 | | - height=height, |
113 | | - title=title or "Box Plot Distribution", |
114 | | - toolbar_location="above", |
115 | | - tools="pan,wheel_zoom,box_zoom,reset,save", |
| 13 | +# Data |
| 14 | +np.random.seed(42) |
| 15 | +data = pd.DataFrame( |
| 16 | + { |
| 17 | + "group": ["A"] * 50 + ["B"] * 50 + ["C"] * 50 + ["D"] * 50, |
| 18 | + "value": np.concatenate( |
| 19 | + [ |
| 20 | + np.random.normal(50, 10, 50), |
| 21 | + np.random.normal(60, 15, 50), |
| 22 | + np.random.normal(45, 8, 50), |
| 23 | + np.random.normal(70, 20, 50), |
| 24 | + ] |
| 25 | + ), |
| 26 | + } |
| 27 | +) |
| 28 | + |
| 29 | +# Calculate box plot statistics for each group |
| 30 | +group_names = sorted(data["group"].unique()) |
| 31 | +n_groups = len(group_names) |
| 32 | + |
| 33 | +stats = {"x": [], "q1": [], "q2": [], "q3": [], "upper": [], "lower": []} |
| 34 | +outliers = {"x": [], "y": []} |
| 35 | + |
| 36 | +for i, group in enumerate(group_names): |
| 37 | + group_data = data[data["group"] == group]["value"].dropna() |
| 38 | + |
| 39 | + q1 = group_data.quantile(0.25) |
| 40 | + q2 = group_data.quantile(0.50) |
| 41 | + q3 = group_data.quantile(0.75) |
| 42 | + iqr = q3 - q1 |
| 43 | + |
| 44 | + upper_fence = q3 + 1.5 * iqr |
| 45 | + lower_fence = q1 - 1.5 * iqr |
| 46 | + upper = group_data[group_data <= upper_fence].max() |
| 47 | + lower = group_data[group_data >= lower_fence].min() |
| 48 | + |
| 49 | + stats["x"].append(i) |
| 50 | + stats["q1"].append(q1) |
| 51 | + stats["q2"].append(q2) |
| 52 | + stats["q3"].append(q3) |
| 53 | + stats["upper"].append(upper) |
| 54 | + stats["lower"].append(lower) |
| 55 | + |
| 56 | + # Find outliers |
| 57 | + outlier_data = group_data[(group_data < lower_fence) | (group_data > upper_fence)] |
| 58 | + for val in outlier_data: |
| 59 | + outliers["x"].append(i) |
| 60 | + outliers["y"].append(val) |
| 61 | + |
| 62 | +# Create figure |
| 63 | +p = figure( |
| 64 | + width=4800, |
| 65 | + height=2700, |
| 66 | + title="Basic Box Plot", |
| 67 | + x_axis_label="Group", |
| 68 | + y_axis_label="Value", |
| 69 | + toolbar_location="above", |
| 70 | + tools="pan,wheel_zoom,box_zoom,reset,save", |
| 71 | +) |
| 72 | + |
| 73 | +source = ColumnDataSource(data=stats) |
| 74 | +box_width = 0.5 |
| 75 | +# Style guide color palette |
| 76 | +colors = ["#306998", "#FFD43B", "#DC2626", "#059669"] |
| 77 | + |
| 78 | +# Draw boxes (Q1 to Q3) |
| 79 | +for i, color in enumerate(colors): |
| 80 | + p.vbar( |
| 81 | + x=i, |
| 82 | + width=box_width, |
| 83 | + bottom=stats["q1"][i], |
| 84 | + top=stats["q3"][i], |
| 85 | + fill_color=color, |
| 86 | + fill_alpha=0.7, |
| 87 | + line_color="#333333", |
| 88 | + line_width=2, |
116 | 89 | ) |
117 | 90 |
|
118 | | - source = ColumnDataSource(data=stats) |
119 | | - |
120 | | - # Draw boxes (Q1 to Q3) |
121 | | - box_width = 0.5 |
122 | | - for i, color in enumerate(colors): |
123 | | - p.vbar( |
124 | | - x=i, |
125 | | - width=box_width, |
126 | | - bottom=stats["q1"][i], |
127 | | - top=stats["q3"][i], |
128 | | - fill_color=color, |
129 | | - line_color="black", |
130 | | - alpha=0.7, |
131 | | - ) |
132 | | - |
133 | | - # Draw median lines |
134 | | - for i in range(n_groups): |
135 | | - p.segment( |
136 | | - x0=i - box_width / 2, |
137 | | - y0=stats["q2"][i], |
138 | | - x1=i + box_width / 2, |
139 | | - y1=stats["q2"][i], |
140 | | - line_color="red", |
141 | | - line_width=2, |
142 | | - ) |
143 | | - |
144 | | - # Draw whiskers |
145 | | - upper_whisker = Whisker(base="x", upper="upper", lower="q3", source=source, line_color="black") |
146 | | - upper_whisker.upper_head.size = 10 |
147 | | - upper_whisker.lower_head.size = 0 |
148 | | - p.add_layout(upper_whisker) |
149 | | - |
150 | | - lower_whisker = Whisker(base="x", upper="q1", lower="lower", source=source, line_color="black") |
151 | | - lower_whisker.upper_head.size = 0 |
152 | | - lower_whisker.lower_head.size = 10 |
153 | | - p.add_layout(lower_whisker) |
154 | | - |
155 | | - # Draw outliers |
156 | | - if outliers["x"]: |
157 | | - outlier_source = ColumnDataSource(data=outliers) |
158 | | - p.scatter(x="x", y="y", source=outlier_source, size=8, color="red", alpha=0.5, line_color="black", line_width=1) |
159 | | - |
160 | | - # Set x-axis to show group names |
161 | | - p.xaxis.ticker = FixedTicker(ticks=list(range(n_groups))) |
162 | | - p.xaxis.major_label_overrides = dict(enumerate(group_names)) |
163 | | - |
164 | | - # Labels |
165 | | - p.xaxis.axis_label = xlabel or groups |
166 | | - p.yaxis.axis_label = ylabel or values |
167 | | - |
168 | | - # Styling |
169 | | - p.title.text_font_size = "14pt" |
170 | | - p.title.align = "center" |
171 | | - p.ygrid.grid_line_alpha = 0.3 |
172 | | - p.ygrid.grid_line_dash = [6, 4] |
173 | | - p.xgrid.visible = False |
174 | | - |
175 | | - # Add sample size annotations |
176 | | - group_counts = data.groupby(groups)[values].count() |
177 | | - y_min = data[values].min() |
178 | | - y_range = data[values].max() - y_min |
179 | | - for i, group in enumerate(group_names): |
180 | | - count = group_counts[group] |
181 | | - label = Label( |
182 | | - x=i, y=y_min - y_range * 0.08, text=f"n={count}", text_align="center", text_font_size="9pt", text_alpha=0.7 |
183 | | - ) |
184 | | - p.add_layout(label) |
185 | | - |
186 | | - return p |
187 | | - |
188 | | - |
189 | | -if __name__ == "__main__": |
190 | | - # Sample data for testing with different distributions per group |
191 | | - np.random.seed(42) |
192 | | - |
193 | | - data_dict = {"Group": [], "Value": []} |
194 | | - |
195 | | - # Group A: Normal distribution |
196 | | - group_a_data = np.random.normal(50, 10, 40) |
197 | | - group_a_data = np.append(group_a_data, [80, 85, 15]) |
198 | | - |
199 | | - # Group B: Normal distribution |
200 | | - group_b_data = np.random.normal(60, 15, 35) |
201 | | - group_b_data = np.append(group_b_data, [100, 10]) |
202 | | - |
203 | | - # Group C: Normal distribution |
204 | | - group_c_data = np.random.normal(45, 8, 45) |
205 | | - |
206 | | - # Group D: Skewed distribution |
207 | | - group_d_data = np.random.gamma(2, 2, 30) + 40 |
208 | | - group_d_data = np.append(group_d_data, [75, 78, 20]) |
209 | | - |
210 | | - # Combine all data |
211 | | - for group, values in zip( |
212 | | - ["Group A", "Group B", "Group C", "Group D"], |
213 | | - [group_a_data, group_b_data, group_c_data, group_d_data], |
214 | | - strict=False, |
215 | | - ): |
216 | | - data_dict["Group"].extend([group] * len(values)) |
217 | | - data_dict["Value"].extend(values) |
218 | | - |
219 | | - data = pd.DataFrame(data_dict) |
220 | | - |
221 | | - # Create plot |
222 | | - fig = create_plot( |
223 | | - data, |
224 | | - values="Value", |
225 | | - groups="Group", |
226 | | - title="Statistical Distribution Comparison Across Groups", |
227 | | - ylabel="Measurement Value", |
228 | | - xlabel="Categories", |
| 91 | +# Draw median lines |
| 92 | +for i in range(n_groups): |
| 93 | + p.segment( |
| 94 | + x0=i - box_width / 2, |
| 95 | + y0=stats["q2"][i], |
| 96 | + x1=i + box_width / 2, |
| 97 | + y1=stats["q2"][i], |
| 98 | + line_color="#333333", |
| 99 | + line_width=3, |
229 | 100 | ) |
230 | 101 |
|
231 | | - # Save as PNG using webdriver-manager for automatic chromedriver |
232 | | - from bokeh.io import export_png |
233 | | - from selenium import webdriver |
234 | | - from selenium.webdriver.chrome.options import Options |
235 | | - from selenium.webdriver.chrome.service import Service |
236 | | - from webdriver_manager.chrome import ChromeDriverManager |
237 | | - |
238 | | - chrome_options = Options() |
239 | | - chrome_options.add_argument("--headless") |
240 | | - chrome_options.add_argument("--no-sandbox") |
241 | | - chrome_options.add_argument("--disable-dev-shm-usage") |
242 | | - |
243 | | - service = Service(ChromeDriverManager().install()) |
244 | | - driver = webdriver.Chrome(service=service, options=chrome_options) |
| 102 | +# Draw whiskers |
| 103 | +upper_whisker = Whisker(base="x", upper="upper", lower="q3", source=source, line_color="#333333", line_width=2) |
| 104 | +upper_whisker.upper_head.size = 20 |
| 105 | +upper_whisker.upper_head.line_width = 2 |
| 106 | +upper_whisker.lower_head.size = 0 |
| 107 | +p.add_layout(upper_whisker) |
| 108 | + |
| 109 | +lower_whisker = Whisker(base="x", upper="q1", lower="lower", source=source, line_color="#333333", line_width=2) |
| 110 | +lower_whisker.upper_head.size = 0 |
| 111 | +lower_whisker.lower_head.size = 20 |
| 112 | +lower_whisker.lower_head.line_width = 2 |
| 113 | +p.add_layout(lower_whisker) |
| 114 | + |
| 115 | +# Draw outliers |
| 116 | +if outliers["x"]: |
| 117 | + outlier_source = ColumnDataSource(data=outliers) |
| 118 | + p.scatter( |
| 119 | + x="x", y="y", source=outlier_source, size=12, color="#DC2626", alpha=0.7, line_color="#333333", line_width=1 |
| 120 | + ) |
245 | 121 |
|
246 | | - export_png(fig, filename="plot.png", webdriver=driver) |
247 | | - driver.quit() |
248 | | - print("Plot saved to plot.png") |
| 122 | +# Set x-axis to show group names |
| 123 | +p.xaxis.ticker = FixedTicker(ticks=list(range(n_groups))) |
| 124 | +p.xaxis.major_label_overrides = dict(enumerate(group_names)) |
| 125 | + |
| 126 | +# Styling |
| 127 | +p.title.text_font_size = "20pt" |
| 128 | +p.title.align = "center" |
| 129 | +p.xaxis.axis_label_text_font_size = "20pt" |
| 130 | +p.yaxis.axis_label_text_font_size = "20pt" |
| 131 | +p.xaxis.major_label_text_font_size = "16pt" |
| 132 | +p.yaxis.major_label_text_font_size = "16pt" |
| 133 | +p.ygrid.grid_line_alpha = 0.3 |
| 134 | +p.xgrid.visible = False |
| 135 | + |
| 136 | +# Save as PNG (requires selenium + webdriver) |
| 137 | +export_png(p, filename="plot.png") |
0 commit comments