99
1010import numpy as np
1111import pandas as pd
12- from bokeh .models import ColumnDataSource
13- from bokeh .plotting import figure , output_file , save
14-
12+ from bokeh .models import ColumnDataSource , FixedTicker , Label , Whisker
13+ from bokeh .plotting import figure
1514
1615if TYPE_CHECKING :
1716 from bokeh .plotting import Figure
@@ -70,146 +69,143 @@ def create_plot(
7069
7170 # Calculate box plot statistics for each group
7271 group_names = sorted (data [groups ].unique ())
72+ n_groups = len (group_names )
73+
74+ # Prepare data for box plot
75+ stats = {"x" : [], "q1" : [], "q2" : [], "q3" : [], "upper" : [], "lower" : [], "group" : []}
76+ outliers = {"x" : [], "y" : []}
7377
74- # Prepare data structures for box plot components
75- box_data = {
76- "groups" : [],
77- "q1" : [],
78- "q2" : [],
79- "q3" : [],
80- "upper" : [],
81- "lower" : [],
82- "outliers_x" : [],
83- "outliers_y" : [],
84- }
85-
86- for group in group_names :
78+ for i , group in enumerate (group_names ):
8779 group_data = data [data [groups ] == group ][values ].dropna ()
8880
8981 q1 = group_data .quantile (0.25 )
90- q2 = group_data .quantile (0.5 ) # median
82+ q2 = group_data .quantile (0.5 )
9183 q3 = group_data .quantile (0.75 )
9284 iqr = q3 - q1
9385 upper = min (group_data .max (), q3 + 1.5 * iqr )
9486 lower = max (group_data .min (), q1 - 1.5 * iqr )
9587
88+ stats ["x" ].append (i )
89+ stats ["q1" ].append (q1 )
90+ stats ["q2" ].append (q2 )
91+ stats ["q3" ].append (q3 )
92+ stats ["upper" ].append (upper )
93+ stats ["lower" ].append (lower )
94+ stats ["group" ].append (group )
95+
9696 # Find outliers
97- outliers = group_data [(group_data < lower ) | (group_data > upper )]
97+ outlier_data = group_data [(group_data < lower ) | (group_data > upper )]
98+ for val in outlier_data :
99+ outliers ["x" ].append (i )
100+ outliers ["y" ].append (val )
98101
99- box_data ["groups" ].append (group )
100- box_data ["q1" ].append (q1 )
101- box_data ["q2" ].append (q2 )
102- box_data ["q3" ].append (q3 )
103- box_data ["upper" ].append (upper )
104- box_data ["lower" ].append (lower )
102+ # Set colors
103+ if not colors :
104+ from bokeh .palettes import Set2_8
105105
106- # Add outliers
107- for outlier in outliers :
108- box_data ["outliers_x" ].append (group )
109- box_data ["outliers_y" ].append (outlier )
106+ colors = Set2_8 [:n_groups ]
110107
111- # Create figure
108+ # Create figure with numeric x-axis
112109 p = figure (
113- x_range = group_names ,
114110 width = width ,
115111 height = height ,
116112 title = title or "Box Plot Distribution" ,
117113 toolbar_location = "above" ,
118114 tools = "pan,wheel_zoom,box_zoom,reset,save" ,
119115 )
120116
121- # Set colors
122- if not colors :
123- from bokeh .palettes import Set2_8
124-
125- colors = Set2_8 [: len (group_names )]
117+ source = ColumnDataSource (data = stats )
126118
127- # Draw boxes (Q1 to Q3) for each group
128- for i , group in enumerate (group_names ):
129- idx = box_data ["groups" ].index (group )
130-
131- # Box from Q1 to Q3
119+ # Draw boxes (Q1 to Q3)
120+ box_width = 0.5
121+ for i , color in enumerate (colors ):
132122 p .vbar (
133- x = group ,
134- width = 0.5 ,
135- bottom = box_data ["q1" ][idx ],
136- top = box_data ["q3" ][idx ],
137- fill_color = colors [ i % len ( colors )] ,
123+ x = i ,
124+ width = box_width ,
125+ bottom = stats ["q1" ][i ],
126+ top = stats ["q3" ][i ],
127+ fill_color = color ,
138128 line_color = "black" ,
139129 alpha = 0.7 ,
140130 )
141131
142- # Median line
143- p . line ( x = [ i - 0.25 , i + 0.25 ], y = [ box_data [ "q2" ][ idx ], box_data [ "q2" ][ idx ]], line_color = "red" , line_width = 2 )
144-
145- # Upper whisker
146- p . line ( x = [ i , i ], y = [ box_data [ "q3 " ][idx ], box_data [ "upper" ][ idx ]], line_color = "black" , line_width = 1 )
147-
148- # Upper whisker cap
149- p . line (
150- x = [ i - 0.1 , i + 0.1 ], y = [ box_data [ "upper" ][ idx ], box_data [ "upper" ][ idx ]], line_color = "black" , line_width = 1.5
132+ # Draw median lines
133+ for i in range ( n_groups ):
134+ p . segment (
135+ x0 = i - box_width / 2 ,
136+ y0 = stats [ "q2 " ][i ],
137+ x1 = i + box_width / 2 ,
138+ y1 = stats [ "q2" ][ i ],
139+ line_color = "red" ,
140+ line_width = 2 ,
151141 )
152142
153- # Lower whisker
154- p .line (x = [i , i ], y = [box_data ["q1" ][idx ], box_data ["lower" ][idx ]], line_color = "black" , line_width = 1 )
155-
156- # Lower whisker cap
157- p .line (
158- x = [i - 0.1 , i + 0.1 ], y = [box_data ["lower" ][idx ], box_data ["lower" ][idx ]], line_color = "black" , line_width = 1.5
143+ # Draw whiskers
144+ upper_whisker = Whisker (base = "x" , upper = "upper" , lower = "q3" , source = source , line_color = "black" )
145+ upper_whisker .upper_head .size = 10
146+ upper_whisker .lower_head .size = 0
147+ p .add_layout (upper_whisker )
148+
149+ lower_whisker = Whisker (base = "x" , upper = "q1" , lower = "lower" , source = source , line_color = "black" )
150+ lower_whisker .upper_head .size = 0
151+ lower_whisker .lower_head .size = 10
152+ p .add_layout (lower_whisker )
153+
154+ # Draw outliers
155+ if outliers ["x" ]:
156+ outlier_source = ColumnDataSource (data = outliers )
157+ p .scatter (
158+ x = "x" , y = "y" , source = outlier_source , size = 8 , color = "red" , alpha = 0.5 , line_color = "black" , line_width = 1
159159 )
160160
161- # Draw outliers using ColumnDataSource (required for categorical x-axis)
162- if box_data ["outliers_x" ]:
163- outlier_source = ColumnDataSource (data = {"x" : box_data ["outliers_x" ], "y" : box_data ["outliers_y" ]})
164- p .scatter (x = "x" , y = "y" , source = outlier_source , size = 8 , color = "red" , alpha = 0.5 , line_color = "black" , line_width = 1 )
161+ # Set x-axis to show group names
162+ p .xaxis .ticker = FixedTicker (ticks = list (range (n_groups )))
163+ p .xaxis .major_label_overrides = {i : name for i , name in enumerate (group_names )}
165164
166- # Styling
165+ # Labels
167166 p .xaxis .axis_label = xlabel or groups
168167 p .yaxis .axis_label = ylabel or values
169168
169+ # Styling
170170 p .title .text_font_size = "14pt"
171171 p .title .align = "center"
172-
173- # Grid
174172 p .ygrid .grid_line_alpha = 0.3
175173 p .ygrid .grid_line_dash = [6 , 4 ]
176174 p .xgrid .visible = False
177175
178176 # Add sample size annotations
179177 group_counts = data .groupby (groups )[values ].count ()
180- for i , (_group , count ) in enumerate (group_counts .items ()):
181- y_position = data [values ].min () - (data [values ].max () - data [values ].min ()) * 0.05
182- from bokeh .models import Label
183-
184- label = Label (x = i , y = y_position , text = f"n={ count } " , text_align = "center" , text_font_size = "9pt" , text_alpha = 0.7 )
178+ y_min = data [values ].min ()
179+ y_range = data [values ].max () - y_min
180+ for i , group in enumerate (group_names ):
181+ count = group_counts [group ]
182+ label = Label (
183+ x = i , y = y_min - y_range * 0.08 , text = f"n={ count } " , text_align = "center" , text_font_size = "9pt" , text_alpha = 0.7
184+ )
185185 p .add_layout (label )
186186
187187 return p
188188
189189
190190if __name__ == "__main__" :
191191 # Sample data for testing with different distributions per group
192- np .random .seed (42 ) # For reproducibility
192+ np .random .seed (42 )
193193
194- # Generate sample data with 4 groups
195194 data_dict = {"Group" : [], "Value" : []}
196195
197- # Group A: Normal distribution, mean=50, std=10
196+ # Group A: Normal distribution
198197 group_a_data = np .random .normal (50 , 10 , 40 )
199- # Add some outliers
200198 group_a_data = np .append (group_a_data , [80 , 85 , 15 ])
201199
202- # Group B: Normal distribution, mean=60, std=15
200+ # Group B: Normal distribution
203201 group_b_data = np .random .normal (60 , 15 , 35 )
204- # Add outliers
205202 group_b_data = np .append (group_b_data , [100 , 10 ])
206203
207- # Group C: Normal distribution, mean=45, std=8
204+ # Group C: Normal distribution
208205 group_c_data = np .random .normal (45 , 8 , 45 )
209206
210207 # Group D: Skewed distribution
211208 group_d_data = np .random .gamma (2 , 2 , 30 ) + 40
212- # Add outliers
213209 group_d_data = np .append (group_d_data , [75 , 78 , 20 ])
214210
215211 # Combine all data
@@ -233,16 +229,8 @@ def create_plot(
233229 xlabel = "Categories" ,
234230 )
235231
236- # Save for inspection
237- output_file ("plot.html" )
238- save (fig )
239- print ("Interactive plot saved to plot.html" )
240-
241- # Also export as PNG if possible
242- try :
243- from bokeh .io import export_png
232+ # Save as PNG
233+ from bokeh .io import export_png
244234
245- export_png (fig , filename = "plot.png" )
246- print ("Static plot saved to plot.png" )
247- except ImportError :
248- print ("Note: Install 'selenium' and 'pillow' to export PNG images" )
235+ export_png (fig , filename = "plot.png" )
236+ print ("Plot saved to plot.png" )
0 commit comments