1818running_total = sum (c for c in changes [:- 1 ] if c is not None )
1919changes [- 1 ] = running_total # Set final total
2020
21- # Define colors
21+ # Define colorblind-safe colors (avoid red-green)
2222TOTAL_COLOR = "#306998" # Python Blue for totals
23- INCREASE_COLOR = "#4CAF50 " # Green for increases
24- DECREASE_COLOR = "#E53935 " # Red for decreases
25- CONNECTOR_COLOR = "#888888 " # Gray for connecting lines
23+ INCREASE_COLOR = "#306998 " # Python Blue for increases (same as totals for this data)
24+ DECREASE_COLOR = "#FF9800 " # Orange for decreases (colorblind-safe)
25+ CONNECTOR_COLOR = "#666666 " # Gray for connecting lines
2626
27- # Custom style for waterfall chart - colors match series order
27+ # Custom style for waterfall chart
2828custom_style = Style (
2929 background = "white" ,
3030 plot_background = "white" ,
3131 foreground = "#333333" ,
3232 foreground_strong = "#333333" ,
3333 foreground_subtle = "#666666" ,
34- colors = ("rgba(255,255,255,0)" , TOTAL_COLOR , INCREASE_COLOR , DECREASE_COLOR , CONNECTOR_COLOR ),
34+ colors = ("rgba(255,255,255,0)" , TOTAL_COLOR , DECREASE_COLOR ),
3535 title_font_size = 48 ,
3636 label_font_size = 36 ,
3737 major_label_font_size = 36 ,
3838 value_font_size = 32 ,
3939 value_label_font_size = 32 ,
4040 legend_font_size = 36 ,
41+ guide_stroke_color = "#cccccc" ,
42+ major_guide_stroke_color = "#999999" ,
4143)
4244
4345# Build the waterfall data structure
44- # Each bar needs: base (invisible portion), visible height, and color
46+ # Each bar needs: base (invisible portion), visible height, and type
4547bar_data = []
4648cumulative = 0
4749
5153
5254 if is_first or is_last :
5355 # Total bars start from 0
54- bar_data .append ({"category" : cat , "base" : 0 , "height" : val , "color " : TOTAL_COLOR , "value" : val })
56+ bar_data .append ({"category" : cat , "base" : 0 , "height" : val , "type " : "total" , "value" : val })
5557 if is_first :
5658 cumulative = val
5759 else :
58- # Change bars
60+ # Change bars - negative values stack downward from cumulative
5961 if val >= 0 :
60- bar_data .append ({"category" : cat , "base" : cumulative , "height" : val , "color " : INCREASE_COLOR , "value" : val })
62+ bar_data .append ({"category" : cat , "base" : cumulative , "height" : val , "type " : "increase" , "value" : val })
6163 else :
6264 bar_data .append (
63- {"category" : cat , "base" : cumulative + val , "height" : abs (val ), "color " : DECREASE_COLOR , "value" : val }
65+ {"category" : cat , "base" : cumulative + val , "height" : abs (val ), "type " : "decrease" , "value" : val }
6466 )
6567 cumulative += val
6668
67-
68- # Custom value formatter - shows absolute height (labels handle signs separately)
69- def format_value (x ):
70- """Format value for display."""
71- if x is None or abs (x ) < 0.01 :
72- return ""
73- return f"${ x :,.0f} "
74-
75-
7669# Create a stacked bar chart - first stack is invisible base, second is visible bar
7770chart = pygal .StackedBar (
7871 width = 4800 ,
@@ -85,8 +78,7 @@ def format_value(x):
8578 legend_at_bottom = True ,
8679 legend_box_size = 30 ,
8780 print_values = False ,
88- print_labels = True , # Use labels instead of values for proper sign display
89- value_formatter = format_value ,
81+ print_labels = True ,
9082 show_y_guides = True ,
9183 show_x_guides = False ,
9284 margin = 50 ,
@@ -99,229 +91,111 @@ def format_value(x):
9991# Create the base (invisible) series and colored bar series
10092base_series = []
10193total_series = []
102- increase_series = []
10394decrease_series = []
95+
10496# Track cumulative values for connector lines
10597connector_levels = []
10698
10799for bar in bar_data :
108100 base_series .append ({"value" : bar ["base" ], "color" : "rgba(255,255,255,0)" })
109101
110- # Get the original change value for proper display
111- original_value = bar ["value" ]
112-
113- if bar ["color" ] == TOTAL_COLOR :
102+ if bar ["type" ] == "total" :
114103 # Format totals with positive values
115104 total_series .append ({"value" : bar ["height" ], "color" : TOTAL_COLOR , "label" : f"${ bar ['height' ]:,.0f} " })
116- increase_series .append ({"value" : None })
117105 decrease_series .append ({"value" : None })
118106 connector_levels .append (bar ["height" ])
119- elif bar ["color" ] == INCREASE_COLOR :
120- total_series .append ({"value" : None })
121- # Positive changes show with positive label
122- increase_series .append ({"value" : bar ["height" ], "color" : INCREASE_COLOR , "label" : f"+${ original_value :,.0f} " })
123- decrease_series .append ({"value" : None })
124- connector_levels .append (bar ["base" ] + bar ["height" ])
125107 else :
108+ # This is a decrease (all intermediate bars in this data are decreases)
126109 total_series .append ({"value" : None })
127- increase_series .append ({"value" : None })
128- # Negative changes show with negative sign
129110 decrease_series .append (
130- {"value" : bar ["height" ], "color" : DECREASE_COLOR , "label" : f"-${ abs (original_value ):,.0f} " }
111+ {"value" : bar ["height" ], "color" : DECREASE_COLOR , "label" : f"-${ abs (bar [ 'value' ] ):,.0f} " }
131112 )
132- connector_levels .append (bar ["base" ]) # Top of decrease bar after the drop
133-
134- # Add series - base is invisible spacer (no legend entry)
135- # Check if we have any increases to show in legend
136- has_increases = any (s .get ("value" ) for s in increase_series )
113+ connector_levels .append (bar ["base" ])
137114
115+ # Add series - base is invisible spacer
138116chart .add ("" , base_series , show_dots = False , stroke = False )
139117chart .add ("Total" , total_series )
140- if has_increases :
141- chart .add ("Increase" , increase_series )
142118chart .add ("Decrease" , decrease_series )
143119
144- # Render the base SVG
120+ # Render base SVG
145121base_svg = chart .render ().decode ("utf-8" )
146122
147- # Create connector lines by injecting SVG elements
148- # Parse the SVG to find bar positions and add horizontal connector lines
149- # Extract y-axis scaling from the chart to calculate line positions
150-
151- # Find the plot area boundaries from the SVG
152- # The y-axis needs to be scaled: find min/max y values and their pixel positions
153- y_max = max (bar ["base" ] + bar ["height" ] for bar in bar_data )
154- y_min = 0
155-
156- # Look for the plot area group and calculate bar positions
157- # Pygal uses specific class names for the plot area
158- # We'll add connector lines as a new group after the bars
159-
160- # Calculate approximate bar center x positions based on category count
161- num_bars = len (categories )
162-
163- # Create connector line SVG elements
164- # Connector lines go from the top of one bar to the start level of the next bar
165- connector_lines = []
166- for i in range (num_bars - 1 ):
167- # Each connector goes from current bar top to next bar's starting cumulative level
168- current_top = connector_levels [i ]
169- # Use current top as the horizontal line level (connecting to next bar)
170- connector_lines .append ((i , current_top ))
171-
172- # Alternative approach: Use secondary_range or custom rendering
173- # For a clean solution, render connector lines as a line series overlay
174-
175- # Render the HTML with embedded connector line visualization
176- # Add connector data as a secondary visualization in the HTML output
177- html_content = f"""<!DOCTYPE html>
178- <html>
179- <head>
180- <meta charset="utf-8">
181- <title>waterfall-basic · pygal · pyplots.ai</title>
182- <style>
183- .connector-line {{
184- stroke: { CONNECTOR_COLOR } ;
185- stroke-width: 3;
186- stroke-dasharray: 10, 5;
187- }}
188- </style>
189- </head>
190- <body>
191- { base_svg }
192- <script>
193- // Add connector lines after chart renders
194- document.addEventListener('DOMContentLoaded', function() {{
195- var svg = document.querySelector('svg');
196- if (!svg) return;
197-
198- // Get the plot area dimensions
199- var plotArea = svg.querySelector('.plot');
200- if (!plotArea) return;
201-
202- var rect = plotArea.getBBox();
203- var barWidth = rect.width / { num_bars } ;
204-
205- // Connector levels (cumulative values) from Python
206- var levels = { connector_levels } ;
207- var yMax = { y_max } ;
208-
209- // Calculate y scale
210- var yScale = rect.height / yMax;
211-
212- // Add connector lines
213- var ns = 'http://www.w3.org/2000/svg';
214- var connectorGroup = document.createElementNS(ns, 'g');
215- connectorGroup.setAttribute('class', 'connectors');
216-
217- for (var i = 0; i < levels.length - 1; i++) {{
218- var line = document.createElementNS(ns, 'line');
219- var x1 = rect.x + (i + 0.5) * barWidth + barWidth * 0.35;
220- var x2 = rect.x + (i + 1.5) * barWidth - barWidth * 0.35;
221- var y = rect.y + rect.height - levels[i] * yScale;
222-
223- line.setAttribute('x1', x1);
224- line.setAttribute('y1', y);
225- line.setAttribute('x2', x2);
226- line.setAttribute('y2', y);
227- line.setAttribute('class', 'connector-line');
228- connectorGroup.appendChild(line);
229- }}
230-
231- plotArea.appendChild(connectorGroup);
232- }});
233- </script>
234- </body>
235- </html>"""
236-
237- # For PNG output, we need to add connector lines directly to the SVG
238- # Parse SVG and inject lines before rendering to PNG
239-
240-
241- def add_connector_lines_to_svg (svg_content , bar_data , connector_levels ):
242- """Add horizontal connector lines between bars in the SVG."""
243- # Parse the actual plot dimensions from pygal's SVG
244- # Plot group is at translate(350, 138) with width 4399.2 and height 2218.0
245- plot_translate_match = re .search (r'translate\(([0-9.]+),\s*([0-9.]+)\)"\s*class="plot"' , svg_content )
246- plot_bg_match = re .search (
247- r'class="plot"[^>]*>.*?<rect[^>]*width="([0-9.]+)"[^>]*height="([0-9.]+)"' , svg_content , re .DOTALL
248- )
249-
250- # Default values based on pygal's typical 4800x2700 layout
251- plot_x = 350
252- plot_y = 138
253- plot_width = 4399.2
123+ # Add connector lines directly in SVG
124+ # Pygal uses a plot area with transform - parse the SVG to find dimensions
125+ # Plot area is typically at translate(X, Y) with the plot background rect
254126
255- if plot_translate_match :
256- plot_x = float (plot_translate_match .group (1 ))
257- plot_y = float (plot_translate_match .group (2 ))
127+ # Find plot transform: translate(X, Y)
128+ plot_match = re .search (r'transform="translate\(([0-9.]+),\s*([0-9.]+)\)"[^>]*class="plot"' , base_svg )
129+ plot_x = float (plot_match .group (1 )) if plot_match else 350.0
130+ plot_y = float (plot_match .group (2 )) if plot_match else 138.0
258131
259- if plot_bg_match :
260- plot_width = float (plot_bg_match .group (1 ))
261-
262- # Calculate y-axis range from data
263- y_max = max (bd ["base" ] + bd ["height" ] for bd in bar_data )
264-
265- # Build connector line elements
266- num_bars = len (connector_levels )
267- bar_width = plot_width / num_bars
268-
269- # Extract actual y-axis range from the SVG guides
270- # Default padding values based on typical pygal layout
271- y_axis_top = 42.65 # Y coordinate for max value
272- y_axis_bottom = 2175.35 # Y coordinate for zero
273-
274- # Extract y positions from guides if possible
275- guides = re .findall (r'path d="M0\.000000 ([0-9.]+) h[^"]*" class="(?:major )?(?:guide )?line"' , svg_content )
276- if guides :
277- y_axis_top = float (min (guides , key = float ))
278- y_axis_bottom = float (max (guides , key = float ))
132+ # Find plot dimensions from the background rect inside plot group
133+ bg_match = re .search (
134+ r'class="plot"[^>]*>.*?<rect class="background"[^>]*width="([0-9.]+)"[^>]*height="([0-9.]+)"' , base_svg , re .DOTALL
135+ )
136+ plot_width = float (bg_match .group (1 )) if bg_match else 4399.2
137+ plot_height = float (bg_match .group (2 )) if bg_match else 2132.7
279138
280- y_axis_range = y_axis_bottom - y_axis_top
139+ # Y-axis range from data
140+ y_max = max (bd ["base" ] + bd ["height" ] for bd in bar_data )
141+ y_min = 0
281142
282- # Create connector group with transform to match plot area
283- lines_svg = f'<g class="connectors" transform="translate({ plot_x } , { plot_y } )" stroke="{ CONNECTOR_COLOR } " stroke-width="6" stroke-dasharray="20,10">\n '
143+ # Extract y positions from guide lines to get accurate scaling
144+ guides = re .findall (r'path d="M0\.000000 ([0-9.]+) h[0-9.]+" class="[^"]*guide[^"]*line"' , base_svg )
145+ if guides :
146+ y_axis_top = float (min (guides , key = float ))
147+ y_axis_bottom = float (max (guides , key = float ))
148+ else :
149+ # Default based on typical pygal layout with margins
150+ y_axis_top = 42.65
151+ y_axis_bottom = plot_height - 42.65
284152
285- # Y scale: map data values to SVG coordinates (inverted, origin at top)
286- # y=0 in data maps to y_axis_bottom, y=y_max maps to y_axis_top
287- def data_to_svg_y (value ):
288- return y_axis_bottom - (value / y_max ) * y_axis_range
153+ y_axis_range = y_axis_bottom - y_axis_top
289154
290- # Add horizontal connector lines between consecutive bars
291- # Each line goes from right edge of current bar to left edge of next bar
292- for i in range (num_bars - 1 ):
293- level = connector_levels [i ]
294- # Bar center positions within plot area: (i + 0.5) * bar_width
295- # Line starts at right side of bar i and ends at left side of bar i+1
296- bar_center_i = (i + 0.5 ) * bar_width
297- bar_center_next = (i + 1.5 ) * bar_width
298- # Approximate bar half-width (with spacing)
299- bar_half_width = bar_width * 0.4
155+ # Build connector lines SVG group
156+ num_bars = len (connector_levels )
157+ bar_width = plot_width / num_bars
300158
301- x1 = bar_center_i + bar_half_width # Right edge of current bar
302- x2 = bar_center_next - bar_half_width # Left edge of next bar
303- y = data_to_svg_y (level )
159+ connector_lines = f'<g class="connectors" transform="translate({ plot_x } , { plot_y } )" stroke="{ CONNECTOR_COLOR } " stroke-width="6" stroke-dasharray="20,10">\n '
304160
305- lines_svg += f' <line x1="{ x1 :.1f} " y1="{ y :.1f} " x2="{ x2 :.1f} " y2="{ y :.1f} "/>\n '
161+ for i in range (num_bars - 1 ):
162+ level = connector_levels [i ]
163+ # Map data value to SVG y coordinate (inverted - 0 at bottom)
164+ y = y_axis_bottom - (level / y_max ) * y_axis_range
306165
307- lines_svg += "</g>\n "
166+ # Horizontal line from right edge of bar i to left edge of bar i+1
167+ bar_center_i = (i + 0.5 ) * bar_width
168+ bar_center_next = (i + 1.5 ) * bar_width
169+ bar_half_width = bar_width * 0.35 # Leave gap from bar edges
308170
309- # Insert before closing </svg>
310- svg_content = svg_content . replace ( "</svg>" , lines_svg + "</svg>" )
171+ x1 = bar_center_i + bar_half_width
172+ x2 = bar_center_next - bar_half_width
311173
312- return svg_content
174+ connector_lines += f' <line x1=" { x1 :.1f } " y1=" { y :.1f } " x2=" { x2 :.1f } " y2=" { y :.1f } "/> \n '
313175
176+ connector_lines += "</g>\n "
314177
315- # Render SVG with connector lines
316- svg_with_connectors = add_connector_lines_to_svg ( base_svg , bar_data , connector_levels )
178+ # Insert connector lines before closing </svg>
179+ svg_with_connectors = base_svg . replace ( "</svg>" , connector_lines + "</svg>" )
317180
318- # Save SVG with connectors
319- with open ("plot_with_connectors .svg" , "w" ) as f :
181+ # Save SVG
182+ with open ("plot .svg" , "w" ) as f :
320183 f .write (svg_with_connectors )
321184
322185# Render to PNG using cairosvg
323186cairosvg .svg2png (bytestring = svg_with_connectors .encode ("utf-8" ), write_to = "plot.png" )
324187
325- # Save HTML with interactive connectors
188+ # HTML with embedded SVG
189+ html_content = f"""<!DOCTYPE html>
190+ <html>
191+ <head>
192+ <meta charset="utf-8">
193+ <title>waterfall-basic · pygal · pyplots.ai</title>
194+ </head>
195+ <body>
196+ { svg_with_connectors }
197+ </body>
198+ </html>"""
199+
326200with open ("plot.html" , "w" ) as f :
327201 f .write (html_content )
0 commit comments