Skip to content

Commit 2792833

Browse files
feat(bokeh): implement box-basic (#315)
## Summary Implements `box-basic` for **bokeh** library. **Parent Issue:** #204 **Sub-Issue:** #233 **Base Branch:** `plot/box-basic` **Attempt:** 2/3 ## Implementation - `plots/bokeh/custom/box-basic/default.py` ## Features - Clean KISS-style implementation (no functions, no classes) - Uses spec-defined data structure with `group` and `value` columns - Displays Q1, median, Q3 boxes with whiskers extending to 1.5×IQR - Shows outliers as individual red points (#DC2626) - Correct figure dimensions (4800×2700 px) - Font sizes match default-style-guide.md recommendations (20pt labels, 16pt ticks) - Uses style guide color palette (#306998, #FFD43B, #DC2626, #059669) ## Changes in This Attempt - Updated to use the official style guide color palette instead of Category10_4 - Selenium dependencies already added for PNG export --------- Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
1 parent 8392814 commit 2792833

3 files changed

Lines changed: 1652 additions & 1759 deletions

File tree

Lines changed: 123 additions & 234 deletions
Original file line numberDiff line numberDiff line change
@@ -1,248 +1,137 @@
11
"""
22
box-basic: Basic Box Plot
3-
Implementation for: bokeh
4-
Variant: default
5-
Python: 3.10+
3+
Library: bokeh
64
"""
75

8-
from typing import TYPE_CHECKING, Optional
9-
106
import numpy as np
117
import pandas as pd
12-
from bokeh.models import ColumnDataSource, FixedTicker, Label, Whisker
8+
from bokeh.io import export_png
9+
from bokeh.models import ColumnDataSource, FixedTicker, Whisker
1310
from bokeh.plotting import figure
1411

1512

16-
if TYPE_CHECKING:
17-
from bokeh.plotting import Figure
18-
19-
20-
def create_plot(
21-
data: pd.DataFrame,
22-
values: str,
23-
groups: str,
24-
title: Optional[str] = None,
25-
xlabel: Optional[str] = None,
26-
ylabel: Optional[str] = None,
27-
colors: Optional[list] = None,
28-
width: int = 1600,
29-
height: int = 900,
30-
**kwargs,
31-
) -> Figure:
32-
"""
33-
Create a basic box plot showing statistical distribution of multiple groups using bokeh.
34-
35-
Args:
36-
data: Input DataFrame with required columns
37-
values: Column name containing numeric values
38-
groups: Column name containing group categories
39-
title: Plot title (optional)
40-
xlabel: Custom x-axis label (optional, defaults to groups column name)
41-
ylabel: Custom y-axis label (optional, defaults to values column name)
42-
colors: List of colors for each box (optional)
43-
width: Figure width in pixels (default: 1600)
44-
height: Figure height in pixels (default: 900)
45-
**kwargs: Additional parameters
46-
47-
Returns:
48-
Bokeh Figure object
49-
50-
Raises:
51-
ValueError: If data is empty
52-
KeyError: If required columns not found
53-
54-
Example:
55-
>>> data = pd.DataFrame({
56-
... 'Group': ['A', 'A', 'B', 'B', 'C', 'C'],
57-
... 'Value': [1, 2, 2, 3, 3, 4]
58-
... })
59-
>>> fig = create_plot(data, values='Value', groups='Group')
60-
"""
61-
# Input validation
62-
if data.empty:
63-
raise ValueError("Data cannot be empty")
64-
65-
# Check required columns
66-
for col in [values, groups]:
67-
if col not in data.columns:
68-
available = ", ".join(data.columns)
69-
raise KeyError(f"Column '{col}' not found. Available columns: {available}")
70-
71-
# Calculate box plot statistics for each group
72-
group_names = sorted(data[groups].unique())
73-
n_groups = len(group_names)
74-
75-
# Prepare data for box plot
76-
stats = {"x": [], "q1": [], "q2": [], "q3": [], "upper": [], "lower": [], "group": []}
77-
outliers = {"x": [], "y": []}
78-
79-
for i, group in enumerate(group_names):
80-
group_data = data[data[groups] == group][values].dropna()
81-
82-
q1 = group_data.quantile(0.25)
83-
q2 = group_data.quantile(0.5)
84-
q3 = group_data.quantile(0.75)
85-
iqr = q3 - q1
86-
upper = min(group_data.max(), q3 + 1.5 * iqr)
87-
lower = max(group_data.min(), q1 - 1.5 * iqr)
88-
89-
stats["x"].append(i)
90-
stats["q1"].append(q1)
91-
stats["q2"].append(q2)
92-
stats["q3"].append(q3)
93-
stats["upper"].append(upper)
94-
stats["lower"].append(lower)
95-
stats["group"].append(group)
96-
97-
# Find outliers
98-
outlier_data = group_data[(group_data < lower) | (group_data > upper)]
99-
for val in outlier_data:
100-
outliers["x"].append(i)
101-
outliers["y"].append(val)
102-
103-
# Set colors
104-
if not colors:
105-
from bokeh.palettes import Set2_8
106-
107-
colors = Set2_8[:n_groups]
108-
109-
# Create figure with numeric x-axis
110-
p = figure(
111-
width=width,
112-
height=height,
113-
title=title or "Box Plot Distribution",
114-
toolbar_location="above",
115-
tools="pan,wheel_zoom,box_zoom,reset,save",
13+
# Data
14+
np.random.seed(42)
15+
data = pd.DataFrame(
16+
{
17+
"group": ["A"] * 50 + ["B"] * 50 + ["C"] * 50 + ["D"] * 50,
18+
"value": np.concatenate(
19+
[
20+
np.random.normal(50, 10, 50),
21+
np.random.normal(60, 15, 50),
22+
np.random.normal(45, 8, 50),
23+
np.random.normal(70, 20, 50),
24+
]
25+
),
26+
}
27+
)
28+
29+
# Calculate box plot statistics for each group
30+
group_names = sorted(data["group"].unique())
31+
n_groups = len(group_names)
32+
33+
stats = {"x": [], "q1": [], "q2": [], "q3": [], "upper": [], "lower": []}
34+
outliers = {"x": [], "y": []}
35+
36+
for i, group in enumerate(group_names):
37+
group_data = data[data["group"] == group]["value"].dropna()
38+
39+
q1 = group_data.quantile(0.25)
40+
q2 = group_data.quantile(0.50)
41+
q3 = group_data.quantile(0.75)
42+
iqr = q3 - q1
43+
44+
upper_fence = q3 + 1.5 * iqr
45+
lower_fence = q1 - 1.5 * iqr
46+
upper = group_data[group_data <= upper_fence].max()
47+
lower = group_data[group_data >= lower_fence].min()
48+
49+
stats["x"].append(i)
50+
stats["q1"].append(q1)
51+
stats["q2"].append(q2)
52+
stats["q3"].append(q3)
53+
stats["upper"].append(upper)
54+
stats["lower"].append(lower)
55+
56+
# Find outliers
57+
outlier_data = group_data[(group_data < lower_fence) | (group_data > upper_fence)]
58+
for val in outlier_data:
59+
outliers["x"].append(i)
60+
outliers["y"].append(val)
61+
62+
# Create figure
63+
p = figure(
64+
width=4800,
65+
height=2700,
66+
title="Basic Box Plot",
67+
x_axis_label="Group",
68+
y_axis_label="Value",
69+
toolbar_location="above",
70+
tools="pan,wheel_zoom,box_zoom,reset,save",
71+
)
72+
73+
source = ColumnDataSource(data=stats)
74+
box_width = 0.5
75+
# Style guide color palette
76+
colors = ["#306998", "#FFD43B", "#DC2626", "#059669"]
77+
78+
# Draw boxes (Q1 to Q3)
79+
for i, color in enumerate(colors):
80+
p.vbar(
81+
x=i,
82+
width=box_width,
83+
bottom=stats["q1"][i],
84+
top=stats["q3"][i],
85+
fill_color=color,
86+
fill_alpha=0.7,
87+
line_color="#333333",
88+
line_width=2,
11689
)
11790

118-
source = ColumnDataSource(data=stats)
119-
120-
# Draw boxes (Q1 to Q3)
121-
box_width = 0.5
122-
for i, color in enumerate(colors):
123-
p.vbar(
124-
x=i,
125-
width=box_width,
126-
bottom=stats["q1"][i],
127-
top=stats["q3"][i],
128-
fill_color=color,
129-
line_color="black",
130-
alpha=0.7,
131-
)
132-
133-
# Draw median lines
134-
for i in range(n_groups):
135-
p.segment(
136-
x0=i - box_width / 2,
137-
y0=stats["q2"][i],
138-
x1=i + box_width / 2,
139-
y1=stats["q2"][i],
140-
line_color="red",
141-
line_width=2,
142-
)
143-
144-
# Draw whiskers
145-
upper_whisker = Whisker(base="x", upper="upper", lower="q3", source=source, line_color="black")
146-
upper_whisker.upper_head.size = 10
147-
upper_whisker.lower_head.size = 0
148-
p.add_layout(upper_whisker)
149-
150-
lower_whisker = Whisker(base="x", upper="q1", lower="lower", source=source, line_color="black")
151-
lower_whisker.upper_head.size = 0
152-
lower_whisker.lower_head.size = 10
153-
p.add_layout(lower_whisker)
154-
155-
# Draw outliers
156-
if outliers["x"]:
157-
outlier_source = ColumnDataSource(data=outliers)
158-
p.scatter(x="x", y="y", source=outlier_source, size=8, color="red", alpha=0.5, line_color="black", line_width=1)
159-
160-
# Set x-axis to show group names
161-
p.xaxis.ticker = FixedTicker(ticks=list(range(n_groups)))
162-
p.xaxis.major_label_overrides = dict(enumerate(group_names))
163-
164-
# Labels
165-
p.xaxis.axis_label = xlabel or groups
166-
p.yaxis.axis_label = ylabel or values
167-
168-
# Styling
169-
p.title.text_font_size = "14pt"
170-
p.title.align = "center"
171-
p.ygrid.grid_line_alpha = 0.3
172-
p.ygrid.grid_line_dash = [6, 4]
173-
p.xgrid.visible = False
174-
175-
# Add sample size annotations
176-
group_counts = data.groupby(groups)[values].count()
177-
y_min = data[values].min()
178-
y_range = data[values].max() - y_min
179-
for i, group in enumerate(group_names):
180-
count = group_counts[group]
181-
label = Label(
182-
x=i, y=y_min - y_range * 0.08, text=f"n={count}", text_align="center", text_font_size="9pt", text_alpha=0.7
183-
)
184-
p.add_layout(label)
185-
186-
return p
187-
188-
189-
if __name__ == "__main__":
190-
# Sample data for testing with different distributions per group
191-
np.random.seed(42)
192-
193-
data_dict = {"Group": [], "Value": []}
194-
195-
# Group A: Normal distribution
196-
group_a_data = np.random.normal(50, 10, 40)
197-
group_a_data = np.append(group_a_data, [80, 85, 15])
198-
199-
# Group B: Normal distribution
200-
group_b_data = np.random.normal(60, 15, 35)
201-
group_b_data = np.append(group_b_data, [100, 10])
202-
203-
# Group C: Normal distribution
204-
group_c_data = np.random.normal(45, 8, 45)
205-
206-
# Group D: Skewed distribution
207-
group_d_data = np.random.gamma(2, 2, 30) + 40
208-
group_d_data = np.append(group_d_data, [75, 78, 20])
209-
210-
# Combine all data
211-
for group, values in zip(
212-
["Group A", "Group B", "Group C", "Group D"],
213-
[group_a_data, group_b_data, group_c_data, group_d_data],
214-
strict=False,
215-
):
216-
data_dict["Group"].extend([group] * len(values))
217-
data_dict["Value"].extend(values)
218-
219-
data = pd.DataFrame(data_dict)
220-
221-
# Create plot
222-
fig = create_plot(
223-
data,
224-
values="Value",
225-
groups="Group",
226-
title="Statistical Distribution Comparison Across Groups",
227-
ylabel="Measurement Value",
228-
xlabel="Categories",
91+
# Draw median lines
92+
for i in range(n_groups):
93+
p.segment(
94+
x0=i - box_width / 2,
95+
y0=stats["q2"][i],
96+
x1=i + box_width / 2,
97+
y1=stats["q2"][i],
98+
line_color="#333333",
99+
line_width=3,
229100
)
230101

231-
# Save as PNG using webdriver-manager for automatic chromedriver
232-
from bokeh.io import export_png
233-
from selenium import webdriver
234-
from selenium.webdriver.chrome.options import Options
235-
from selenium.webdriver.chrome.service import Service
236-
from webdriver_manager.chrome import ChromeDriverManager
237-
238-
chrome_options = Options()
239-
chrome_options.add_argument("--headless")
240-
chrome_options.add_argument("--no-sandbox")
241-
chrome_options.add_argument("--disable-dev-shm-usage")
242-
243-
service = Service(ChromeDriverManager().install())
244-
driver = webdriver.Chrome(service=service, options=chrome_options)
102+
# Draw whiskers
103+
upper_whisker = Whisker(base="x", upper="upper", lower="q3", source=source, line_color="#333333", line_width=2)
104+
upper_whisker.upper_head.size = 20
105+
upper_whisker.upper_head.line_width = 2
106+
upper_whisker.lower_head.size = 0
107+
p.add_layout(upper_whisker)
108+
109+
lower_whisker = Whisker(base="x", upper="q1", lower="lower", source=source, line_color="#333333", line_width=2)
110+
lower_whisker.upper_head.size = 0
111+
lower_whisker.lower_head.size = 20
112+
lower_whisker.lower_head.line_width = 2
113+
p.add_layout(lower_whisker)
114+
115+
# Draw outliers
116+
if outliers["x"]:
117+
outlier_source = ColumnDataSource(data=outliers)
118+
p.scatter(
119+
x="x", y="y", source=outlier_source, size=12, color="#DC2626", alpha=0.7, line_color="#333333", line_width=1
120+
)
245121

246-
export_png(fig, filename="plot.png", webdriver=driver)
247-
driver.quit()
248-
print("Plot saved to plot.png")
122+
# Set x-axis to show group names
123+
p.xaxis.ticker = FixedTicker(ticks=list(range(n_groups)))
124+
p.xaxis.major_label_overrides = dict(enumerate(group_names))
125+
126+
# Styling
127+
p.title.text_font_size = "20pt"
128+
p.title.align = "center"
129+
p.xaxis.axis_label_text_font_size = "20pt"
130+
p.yaxis.axis_label_text_font_size = "20pt"
131+
p.xaxis.major_label_text_font_size = "16pt"
132+
p.yaxis.major_label_text_font_size = "16pt"
133+
p.ygrid.grid_line_alpha = 0.3
134+
p.xgrid.visible = False
135+
136+
# Save as PNG (requires selenium + webdriver)
137+
export_png(p, filename="plot.png")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ plotting = [
7575
lib-matplotlib = ["matplotlib>=3.9.0", "numpy>=1.26.0", "pandas>=2.2.0"]
7676
lib-seaborn = ["seaborn>=0.13.0", "matplotlib>=3.9.0", "numpy>=1.26.0", "pandas>=2.2.0"]
7777
lib-plotly = ["plotly>=5.18.0", "kaleido>=0.2.1", "numpy>=1.26.0", "pandas>=2.2.0"]
78-
lib-bokeh = ["bokeh>=3.4.0", "numpy>=1.26.0", "pandas>=2.2.0"]
78+
lib-bokeh = ["bokeh>=3.4.0", "numpy>=1.26.0", "pandas>=2.2.0", "selenium>=4.15.0", "webdriver-manager>=4.0.0"]
7979
lib-altair = ["altair>=5.2.0", "vl-convert-python>=1.3.0", "numpy>=1.26.0", "pandas>=2.2.0"]
8080
lib-plotnine = ["plotnine>=0.13.0", "numpy>=1.26.0", "pandas>=2.2.0"]
8181
lib-pygal = ["pygal>=3.0.0", "cairosvg>=2.7.0", "pandas>=2.2.0"]

0 commit comments

Comments
 (0)