Skip to content

Commit 7d38448

Browse files
feat: enhance box plot generation and export functionality
- Refactor box plot data preparation for improved clarity - Implement outlier detection and visualization - Update export method to save plots as PNG using Selenium - Remove HTML export and related code
1 parent e820c80 commit 7d38448

9 files changed

Lines changed: 248 additions & 139 deletions

File tree

plots/altair/boxplot/box-basic/default.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,6 @@ def create_plot(
174174
xlabel="Categories",
175175
)
176176

177-
# Save for inspection
178-
chart.save("plot.html")
179-
print("Interactive plot saved to plot.html")
180-
181-
# Also save as PNG
177+
# Save as PNG
182178
chart.save("plot.png", scale_factor=2.0)
183-
print("Static plot saved to plot.png")
179+
print("Plot saved to plot.png")

plots/bokeh/custom/box-basic/default.py

Lines changed: 80 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99

1010
import numpy as np
1111
import pandas as pd
12-
from bokeh.models import ColumnDataSource
13-
from bokeh.plotting import figure, output_file, save
14-
12+
from bokeh.models import ColumnDataSource, FixedTicker, Label, Whisker
13+
from bokeh.plotting import figure
1514

1615
if TYPE_CHECKING:
1716
from bokeh.plotting import Figure
@@ -70,146 +69,143 @@ def create_plot(
7069

7170
# Calculate box plot statistics for each group
7271
group_names = sorted(data[groups].unique())
72+
n_groups = len(group_names)
73+
74+
# Prepare data for box plot
75+
stats = {"x": [], "q1": [], "q2": [], "q3": [], "upper": [], "lower": [], "group": []}
76+
outliers = {"x": [], "y": []}
7377

74-
# Prepare data structures for box plot components
75-
box_data = {
76-
"groups": [],
77-
"q1": [],
78-
"q2": [],
79-
"q3": [],
80-
"upper": [],
81-
"lower": [],
82-
"outliers_x": [],
83-
"outliers_y": [],
84-
}
85-
86-
for group in group_names:
78+
for i, group in enumerate(group_names):
8779
group_data = data[data[groups] == group][values].dropna()
8880

8981
q1 = group_data.quantile(0.25)
90-
q2 = group_data.quantile(0.5) # median
82+
q2 = group_data.quantile(0.5)
9183
q3 = group_data.quantile(0.75)
9284
iqr = q3 - q1
9385
upper = min(group_data.max(), q3 + 1.5 * iqr)
9486
lower = max(group_data.min(), q1 - 1.5 * iqr)
9587

88+
stats["x"].append(i)
89+
stats["q1"].append(q1)
90+
stats["q2"].append(q2)
91+
stats["q3"].append(q3)
92+
stats["upper"].append(upper)
93+
stats["lower"].append(lower)
94+
stats["group"].append(group)
95+
9696
# Find outliers
97-
outliers = group_data[(group_data < lower) | (group_data > upper)]
97+
outlier_data = group_data[(group_data < lower) | (group_data > upper)]
98+
for val in outlier_data:
99+
outliers["x"].append(i)
100+
outliers["y"].append(val)
98101

99-
box_data["groups"].append(group)
100-
box_data["q1"].append(q1)
101-
box_data["q2"].append(q2)
102-
box_data["q3"].append(q3)
103-
box_data["upper"].append(upper)
104-
box_data["lower"].append(lower)
102+
# Set colors
103+
if not colors:
104+
from bokeh.palettes import Set2_8
105105

106-
# Add outliers
107-
for outlier in outliers:
108-
box_data["outliers_x"].append(group)
109-
box_data["outliers_y"].append(outlier)
106+
colors = Set2_8[:n_groups]
110107

111-
# Create figure
108+
# Create figure with numeric x-axis
112109
p = figure(
113-
x_range=group_names,
114110
width=width,
115111
height=height,
116112
title=title or "Box Plot Distribution",
117113
toolbar_location="above",
118114
tools="pan,wheel_zoom,box_zoom,reset,save",
119115
)
120116

121-
# Set colors
122-
if not colors:
123-
from bokeh.palettes import Set2_8
124-
125-
colors = Set2_8[: len(group_names)]
117+
source = ColumnDataSource(data=stats)
126118

127-
# Draw boxes (Q1 to Q3) for each group
128-
for i, group in enumerate(group_names):
129-
idx = box_data["groups"].index(group)
130-
131-
# Box from Q1 to Q3
119+
# Draw boxes (Q1 to Q3)
120+
box_width = 0.5
121+
for i, color in enumerate(colors):
132122
p.vbar(
133-
x=group,
134-
width=0.5,
135-
bottom=box_data["q1"][idx],
136-
top=box_data["q3"][idx],
137-
fill_color=colors[i % len(colors)],
123+
x=i,
124+
width=box_width,
125+
bottom=stats["q1"][i],
126+
top=stats["q3"][i],
127+
fill_color=color,
138128
line_color="black",
139129
alpha=0.7,
140130
)
141131

142-
# Median line
143-
p.line(x=[i - 0.25, i + 0.25], y=[box_data["q2"][idx], box_data["q2"][idx]], line_color="red", line_width=2)
144-
145-
# Upper whisker
146-
p.line(x=[i, i], y=[box_data["q3"][idx], box_data["upper"][idx]], line_color="black", line_width=1)
147-
148-
# Upper whisker cap
149-
p.line(
150-
x=[i - 0.1, i + 0.1], y=[box_data["upper"][idx], box_data["upper"][idx]], line_color="black", line_width=1.5
132+
# Draw median lines
133+
for i in range(n_groups):
134+
p.segment(
135+
x0=i - box_width / 2,
136+
y0=stats["q2"][i],
137+
x1=i + box_width / 2,
138+
y1=stats["q2"][i],
139+
line_color="red",
140+
line_width=2,
151141
)
152142

153-
# Lower whisker
154-
p.line(x=[i, i], y=[box_data["q1"][idx], box_data["lower"][idx]], line_color="black", line_width=1)
155-
156-
# Lower whisker cap
157-
p.line(
158-
x=[i - 0.1, i + 0.1], y=[box_data["lower"][idx], box_data["lower"][idx]], line_color="black", line_width=1.5
143+
# Draw whiskers
144+
upper_whisker = Whisker(base="x", upper="upper", lower="q3", source=source, line_color="black")
145+
upper_whisker.upper_head.size = 10
146+
upper_whisker.lower_head.size = 0
147+
p.add_layout(upper_whisker)
148+
149+
lower_whisker = Whisker(base="x", upper="q1", lower="lower", source=source, line_color="black")
150+
lower_whisker.upper_head.size = 0
151+
lower_whisker.lower_head.size = 10
152+
p.add_layout(lower_whisker)
153+
154+
# Draw outliers
155+
if outliers["x"]:
156+
outlier_source = ColumnDataSource(data=outliers)
157+
p.scatter(
158+
x="x", y="y", source=outlier_source, size=8, color="red", alpha=0.5, line_color="black", line_width=1
159159
)
160160

161-
# Draw outliers using ColumnDataSource (required for categorical x-axis)
162-
if box_data["outliers_x"]:
163-
outlier_source = ColumnDataSource(data={"x": box_data["outliers_x"], "y": box_data["outliers_y"]})
164-
p.scatter(x="x", y="y", source=outlier_source, size=8, color="red", alpha=0.5, line_color="black", line_width=1)
161+
# Set x-axis to show group names
162+
p.xaxis.ticker = FixedTicker(ticks=list(range(n_groups)))
163+
p.xaxis.major_label_overrides = {i: name for i, name in enumerate(group_names)}
165164

166-
# Styling
165+
# Labels
167166
p.xaxis.axis_label = xlabel or groups
168167
p.yaxis.axis_label = ylabel or values
169168

169+
# Styling
170170
p.title.text_font_size = "14pt"
171171
p.title.align = "center"
172-
173-
# Grid
174172
p.ygrid.grid_line_alpha = 0.3
175173
p.ygrid.grid_line_dash = [6, 4]
176174
p.xgrid.visible = False
177175

178176
# Add sample size annotations
179177
group_counts = data.groupby(groups)[values].count()
180-
for i, (_group, count) in enumerate(group_counts.items()):
181-
y_position = data[values].min() - (data[values].max() - data[values].min()) * 0.05
182-
from bokeh.models import Label
183-
184-
label = Label(x=i, y=y_position, text=f"n={count}", text_align="center", text_font_size="9pt", text_alpha=0.7)
178+
y_min = data[values].min()
179+
y_range = data[values].max() - y_min
180+
for i, group in enumerate(group_names):
181+
count = group_counts[group]
182+
label = Label(
183+
x=i, y=y_min - y_range * 0.08, text=f"n={count}", text_align="center", text_font_size="9pt", text_alpha=0.7
184+
)
185185
p.add_layout(label)
186186

187187
return p
188188

189189

190190
if __name__ == "__main__":
191191
# Sample data for testing with different distributions per group
192-
np.random.seed(42) # For reproducibility
192+
np.random.seed(42)
193193

194-
# Generate sample data with 4 groups
195194
data_dict = {"Group": [], "Value": []}
196195

197-
# Group A: Normal distribution, mean=50, std=10
196+
# Group A: Normal distribution
198197
group_a_data = np.random.normal(50, 10, 40)
199-
# Add some outliers
200198
group_a_data = np.append(group_a_data, [80, 85, 15])
201199

202-
# Group B: Normal distribution, mean=60, std=15
200+
# Group B: Normal distribution
203201
group_b_data = np.random.normal(60, 15, 35)
204-
# Add outliers
205202
group_b_data = np.append(group_b_data, [100, 10])
206203

207-
# Group C: Normal distribution, mean=45, std=8
204+
# Group C: Normal distribution
208205
group_c_data = np.random.normal(45, 8, 45)
209206

210207
# Group D: Skewed distribution
211208
group_d_data = np.random.gamma(2, 2, 30) + 40
212-
# Add outliers
213209
group_d_data = np.append(group_d_data, [75, 78, 20])
214210

215211
# Combine all data
@@ -233,16 +229,8 @@ def create_plot(
233229
xlabel="Categories",
234230
)
235231

236-
# Save for inspection
237-
output_file("plot.html")
238-
save(fig)
239-
print("Interactive plot saved to plot.html")
240-
241-
# Also export as PNG if possible
242-
try:
243-
from bokeh.io import export_png
232+
# Save as PNG
233+
from bokeh.io import export_png
244234

245-
export_png(fig, filename="plot.png")
246-
print("Static plot saved to plot.png")
247-
except ImportError:
248-
print("Note: Install 'selenium' and 'pillow' to export PNG images")
235+
export_png(fig, filename="plot.png")
236+
print("Plot saved to plot.png")

plots/highcharts/boxplot/box-basic/default.py

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -149,21 +149,28 @@ def create_plot(
149149
chart.options.chart = {"type": "boxplot", "height": height, "backgroundColor": "white"}
150150

151151
# Add box plot series
152-
chart.add_series(BoxPlotSeries.from_array(data=box_data, name="Distribution", colorByPoint=True))
152+
box_series = BoxPlotSeries()
153+
box_series.data = box_data
154+
box_series.name = "Distribution"
155+
box_series.color_by_point = True
156+
chart.add_series(box_series)
153157

154158
# Add outliers as scatter series if any exist
155159
if outliers_data:
156160
from highcharts_core.options.series.scatter import ScatterSeries
157161

158-
chart.add_series(
159-
ScatterSeries.from_array(
160-
data=outliers_data,
161-
name="Outliers",
162-
color="rgba(255, 0, 0, 0.5)",
163-
marker={"fillColor": "rgba(255, 0, 0, 0.5)", "lineWidth": 1, "lineColor": "#000000", "radius": 4},
164-
tooltip={"pointFormat": "Outlier: <b>{point.y}</b>"},
165-
)
166-
)
162+
scatter_series = ScatterSeries()
163+
scatter_series.data = outliers_data
164+
scatter_series.name = "Outliers"
165+
scatter_series.color = "rgba(255, 0, 0, 0.5)"
166+
scatter_series.marker = {
167+
"fillColor": "rgba(255, 0, 0, 0.5)",
168+
"lineWidth": 1,
169+
"lineColor": "#000000",
170+
"radius": 4,
171+
}
172+
scatter_series.tooltip = {"pointFormat": "Outlier: <b>{point.y}</b>"}
173+
chart.add_series(scatter_series)
167174

168175
# Legend
169176
chart.options.legend = {
@@ -222,31 +229,45 @@ def create_plot(
222229
xlabel="Categories",
223230
)
224231

225-
# Export to HTML
226-
html_str = chart.to_js_literal()
232+
# Export to PNG via Selenium screenshot
233+
import tempfile
234+
import time
235+
from pathlib import Path
236+
237+
from selenium import webdriver
238+
from selenium.webdriver.chrome.options import Options
227239

228-
# Create HTML file
240+
# Generate HTML content
241+
html_str = chart.to_js_literal()
229242
html_content = f"""<!DOCTYPE html>
230243
<html>
231244
<head>
232245
<meta charset="utf-8">
233-
<title>Box Plot - Highcharts</title>
234246
<script src="https://code.highcharts.com/highcharts.js"></script>
235247
<script src="https://code.highcharts.com/highcharts-more.js"></script>
236248
</head>
237-
<body>
238-
<div id="container" style="width: 100%; height: 600px;"></div>
239-
<script>
240-
{html_str}
241-
</script>
249+
<body style="margin:0;">
250+
<div id="container" style="width: 1000px; height: 600px;"></div>
251+
<script>{html_str}</script>
242252
</body>
243253
</html>"""
244254

245-
with open("plot.html", "w") as f:
255+
# Write temp HTML and take screenshot
256+
with tempfile.NamedTemporaryFile(mode="w", suffix=".html", delete=False) as f:
246257
f.write(html_content)
247-
248-
print("Interactive plot saved to plot.html")
249-
250-
# Note about PNG export
251-
print("Note: Highcharts requires a license for commercial use")
252-
print("For static image export, use Highcharts Export Server or phantomjs")
258+
temp_path = f.name
259+
260+
chrome_options = Options()
261+
chrome_options.add_argument("--headless")
262+
chrome_options.add_argument("--no-sandbox")
263+
chrome_options.add_argument("--disable-dev-shm-usage")
264+
chrome_options.add_argument("--window-size=1000,600")
265+
266+
driver = webdriver.Chrome(options=chrome_options)
267+
driver.get(f"file://{temp_path}")
268+
time.sleep(1) # Wait for chart to render
269+
driver.save_screenshot("plot.png")
270+
driver.quit()
271+
272+
Path(temp_path).unlink() # Clean up temp file
273+
print("Plot saved to plot.png")

plots/plotly/box/box-basic/default.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,6 @@ def create_plot(
201201
xlabel="Categories",
202202
)
203203

204-
# Save for inspection
205-
fig.write_html("plot.html")
204+
# Save as PNG
206205
fig.write_image("plot.png", width=1000, height=600, scale=2)
207-
print("Interactive plot saved to plot.html")
208-
print("Static plot saved to plot.png")
206+
print("Plot saved to plot.png")

plots/plotnine/boxplot/box-basic/default.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def create_plot(
9090
width=0.6,
9191
**kwargs,
9292
)
93-
+ scale_fill_brewer(type="qual", palette=fill_palette, guide=False) # Hide legend
93+
+ scale_fill_brewer(type="qual", palette=fill_palette, guide=None) # Hide legend
9494
+ labs(title=title or "Box Plot Distribution", x=xlabel or groups, y=ylabel or values)
9595
+ theme_minimal()
9696
+ theme(

0 commit comments

Comments
 (0)