Skip to content

Commit cc206e4

Browse files
fix(pygal): address review feedback for waterfall-basic
Attempt 2/3 - fixes based on AI review: - Add visible connector lines in PNG output using SVG injection - Use colorblind-safe colors (orange instead of red for decreases) - Ensure value labels show proper +/- signs - Simplify code structure (removed unnecessary function) - Move import re to top of file per style guidelines
1 parent 85711c8 commit cc206e4

1 file changed

Lines changed: 79 additions & 205 deletions

File tree

  • plots/waterfall-basic/implementations

plots/waterfall-basic/implementations/pygal.py

Lines changed: 79 additions & 205 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,32 @@
1818
running_total = sum(c for c in changes[:-1] if c is not None)
1919
changes[-1] = running_total # Set final total
2020

21-
# Define colors
21+
# Define colorblind-safe colors (avoid red-green)
2222
TOTAL_COLOR = "#306998" # Python Blue for totals
23-
INCREASE_COLOR = "#4CAF50" # Green for increases
24-
DECREASE_COLOR = "#E53935" # Red for decreases
25-
CONNECTOR_COLOR = "#888888" # Gray for connecting lines
23+
INCREASE_COLOR = "#306998" # Python Blue for increases (same as totals for this data)
24+
DECREASE_COLOR = "#FF9800" # Orange for decreases (colorblind-safe)
25+
CONNECTOR_COLOR = "#666666" # Gray for connecting lines
2626

27-
# Custom style for waterfall chart - colors match series order
27+
# Custom style for waterfall chart
2828
custom_style = Style(
2929
background="white",
3030
plot_background="white",
3131
foreground="#333333",
3232
foreground_strong="#333333",
3333
foreground_subtle="#666666",
34-
colors=("rgba(255,255,255,0)", TOTAL_COLOR, INCREASE_COLOR, DECREASE_COLOR, CONNECTOR_COLOR),
34+
colors=("rgba(255,255,255,0)", TOTAL_COLOR, DECREASE_COLOR),
3535
title_font_size=48,
3636
label_font_size=36,
3737
major_label_font_size=36,
3838
value_font_size=32,
3939
value_label_font_size=32,
4040
legend_font_size=36,
41+
guide_stroke_color="#cccccc",
42+
major_guide_stroke_color="#999999",
4143
)
4244

4345
# Build the waterfall data structure
44-
# Each bar needs: base (invisible portion), visible height, and color
46+
# Each bar needs: base (invisible portion), visible height, and type
4547
bar_data = []
4648
cumulative = 0
4749

@@ -51,28 +53,19 @@
5153

5254
if is_first or is_last:
5355
# Total bars start from 0
54-
bar_data.append({"category": cat, "base": 0, "height": val, "color": TOTAL_COLOR, "value": val})
56+
bar_data.append({"category": cat, "base": 0, "height": val, "type": "total", "value": val})
5557
if is_first:
5658
cumulative = val
5759
else:
58-
# Change bars
60+
# Change bars - negative values stack downward from cumulative
5961
if val >= 0:
60-
bar_data.append({"category": cat, "base": cumulative, "height": val, "color": INCREASE_COLOR, "value": val})
62+
bar_data.append({"category": cat, "base": cumulative, "height": val, "type": "increase", "value": val})
6163
else:
6264
bar_data.append(
63-
{"category": cat, "base": cumulative + val, "height": abs(val), "color": DECREASE_COLOR, "value": val}
65+
{"category": cat, "base": cumulative + val, "height": abs(val), "type": "decrease", "value": val}
6466
)
6567
cumulative += val
6668

67-
68-
# Custom value formatter - shows absolute height (labels handle signs separately)
69-
def format_value(x):
70-
"""Format value for display."""
71-
if x is None or abs(x) < 0.01:
72-
return ""
73-
return f"${x:,.0f}"
74-
75-
7669
# Create a stacked bar chart - first stack is invisible base, second is visible bar
7770
chart = pygal.StackedBar(
7871
width=4800,
@@ -85,8 +78,7 @@ def format_value(x):
8578
legend_at_bottom=True,
8679
legend_box_size=30,
8780
print_values=False,
88-
print_labels=True, # Use labels instead of values for proper sign display
89-
value_formatter=format_value,
81+
print_labels=True,
9082
show_y_guides=True,
9183
show_x_guides=False,
9284
margin=50,
@@ -99,229 +91,111 @@ def format_value(x):
9991
# Create the base (invisible) series and colored bar series
10092
base_series = []
10193
total_series = []
102-
increase_series = []
10394
decrease_series = []
95+
10496
# Track cumulative values for connector lines
10597
connector_levels = []
10698

10799
for bar in bar_data:
108100
base_series.append({"value": bar["base"], "color": "rgba(255,255,255,0)"})
109101

110-
# Get the original change value for proper display
111-
original_value = bar["value"]
112-
113-
if bar["color"] == TOTAL_COLOR:
102+
if bar["type"] == "total":
114103
# Format totals with positive values
115104
total_series.append({"value": bar["height"], "color": TOTAL_COLOR, "label": f"${bar['height']:,.0f}"})
116-
increase_series.append({"value": None})
117105
decrease_series.append({"value": None})
118106
connector_levels.append(bar["height"])
119-
elif bar["color"] == INCREASE_COLOR:
120-
total_series.append({"value": None})
121-
# Positive changes show with positive label
122-
increase_series.append({"value": bar["height"], "color": INCREASE_COLOR, "label": f"+${original_value:,.0f}"})
123-
decrease_series.append({"value": None})
124-
connector_levels.append(bar["base"] + bar["height"])
125107
else:
108+
# This is a decrease (all intermediate bars in this data are decreases)
126109
total_series.append({"value": None})
127-
increase_series.append({"value": None})
128-
# Negative changes show with negative sign
129110
decrease_series.append(
130-
{"value": bar["height"], "color": DECREASE_COLOR, "label": f"-${abs(original_value):,.0f}"}
111+
{"value": bar["height"], "color": DECREASE_COLOR, "label": f"-${abs(bar['value']):,.0f}"}
131112
)
132-
connector_levels.append(bar["base"]) # Top of decrease bar after the drop
133-
134-
# Add series - base is invisible spacer (no legend entry)
135-
# Check if we have any increases to show in legend
136-
has_increases = any(s.get("value") for s in increase_series)
113+
connector_levels.append(bar["base"])
137114

115+
# Add series - base is invisible spacer
138116
chart.add("", base_series, show_dots=False, stroke=False)
139117
chart.add("Total", total_series)
140-
if has_increases:
141-
chart.add("Increase", increase_series)
142118
chart.add("Decrease", decrease_series)
143119

144-
# Render the base SVG
120+
# Render base SVG
145121
base_svg = chart.render().decode("utf-8")
146122

147-
# Create connector lines by injecting SVG elements
148-
# Parse the SVG to find bar positions and add horizontal connector lines
149-
# Extract y-axis scaling from the chart to calculate line positions
150-
151-
# Find the plot area boundaries from the SVG
152-
# The y-axis needs to be scaled: find min/max y values and their pixel positions
153-
y_max = max(bar["base"] + bar["height"] for bar in bar_data)
154-
y_min = 0
155-
156-
# Look for the plot area group and calculate bar positions
157-
# Pygal uses specific class names for the plot area
158-
# We'll add connector lines as a new group after the bars
159-
160-
# Calculate approximate bar center x positions based on category count
161-
num_bars = len(categories)
162-
163-
# Create connector line SVG elements
164-
# Connector lines go from the top of one bar to the start level of the next bar
165-
connector_lines = []
166-
for i in range(num_bars - 1):
167-
# Each connector goes from current bar top to next bar's starting cumulative level
168-
current_top = connector_levels[i]
169-
# Use current top as the horizontal line level (connecting to next bar)
170-
connector_lines.append((i, current_top))
171-
172-
# Alternative approach: Use secondary_range or custom rendering
173-
# For a clean solution, render connector lines as a line series overlay
174-
175-
# Render the HTML with embedded connector line visualization
176-
# Add connector data as a secondary visualization in the HTML output
177-
html_content = f"""<!DOCTYPE html>
178-
<html>
179-
<head>
180-
<meta charset="utf-8">
181-
<title>waterfall-basic · pygal · pyplots.ai</title>
182-
<style>
183-
.connector-line {{
184-
stroke: {CONNECTOR_COLOR};
185-
stroke-width: 3;
186-
stroke-dasharray: 10, 5;
187-
}}
188-
</style>
189-
</head>
190-
<body>
191-
{base_svg}
192-
<script>
193-
// Add connector lines after chart renders
194-
document.addEventListener('DOMContentLoaded', function() {{
195-
var svg = document.querySelector('svg');
196-
if (!svg) return;
197-
198-
// Get the plot area dimensions
199-
var plotArea = svg.querySelector('.plot');
200-
if (!plotArea) return;
201-
202-
var rect = plotArea.getBBox();
203-
var barWidth = rect.width / {num_bars};
204-
205-
// Connector levels (cumulative values) from Python
206-
var levels = {connector_levels};
207-
var yMax = {y_max};
208-
209-
// Calculate y scale
210-
var yScale = rect.height / yMax;
211-
212-
// Add connector lines
213-
var ns = 'http://www.w3.org/2000/svg';
214-
var connectorGroup = document.createElementNS(ns, 'g');
215-
connectorGroup.setAttribute('class', 'connectors');
216-
217-
for (var i = 0; i < levels.length - 1; i++) {{
218-
var line = document.createElementNS(ns, 'line');
219-
var x1 = rect.x + (i + 0.5) * barWidth + barWidth * 0.35;
220-
var x2 = rect.x + (i + 1.5) * barWidth - barWidth * 0.35;
221-
var y = rect.y + rect.height - levels[i] * yScale;
222-
223-
line.setAttribute('x1', x1);
224-
line.setAttribute('y1', y);
225-
line.setAttribute('x2', x2);
226-
line.setAttribute('y2', y);
227-
line.setAttribute('class', 'connector-line');
228-
connectorGroup.appendChild(line);
229-
}}
230-
231-
plotArea.appendChild(connectorGroup);
232-
}});
233-
</script>
234-
</body>
235-
</html>"""
236-
237-
# For PNG output, we need to add connector lines directly to the SVG
238-
# Parse SVG and inject lines before rendering to PNG
239-
240-
241-
def add_connector_lines_to_svg(svg_content, bar_data, connector_levels):
242-
"""Add horizontal connector lines between bars in the SVG."""
243-
# Parse the actual plot dimensions from pygal's SVG
244-
# Plot group is at translate(350, 138) with width 4399.2 and height 2218.0
245-
plot_translate_match = re.search(r'translate\(([0-9.]+),\s*([0-9.]+)\)"\s*class="plot"', svg_content)
246-
plot_bg_match = re.search(
247-
r'class="plot"[^>]*>.*?<rect[^>]*width="([0-9.]+)"[^>]*height="([0-9.]+)"', svg_content, re.DOTALL
248-
)
249-
250-
# Default values based on pygal's typical 4800x2700 layout
251-
plot_x = 350
252-
plot_y = 138
253-
plot_width = 4399.2
123+
# Add connector lines directly in SVG
124+
# Pygal uses a plot area with transform - parse the SVG to find dimensions
125+
# Plot area is typically at translate(X, Y) with the plot background rect
254126

255-
if plot_translate_match:
256-
plot_x = float(plot_translate_match.group(1))
257-
plot_y = float(plot_translate_match.group(2))
127+
# Find plot transform: translate(X, Y)
128+
plot_match = re.search(r'transform="translate\(([0-9.]+),\s*([0-9.]+)\)"[^>]*class="plot"', base_svg)
129+
plot_x = float(plot_match.group(1)) if plot_match else 350.0
130+
plot_y = float(plot_match.group(2)) if plot_match else 138.0
258131

259-
if plot_bg_match:
260-
plot_width = float(plot_bg_match.group(1))
261-
262-
# Calculate y-axis range from data
263-
y_max = max(bd["base"] + bd["height"] for bd in bar_data)
264-
265-
# Build connector line elements
266-
num_bars = len(connector_levels)
267-
bar_width = plot_width / num_bars
268-
269-
# Extract actual y-axis range from the SVG guides
270-
# Default padding values based on typical pygal layout
271-
y_axis_top = 42.65 # Y coordinate for max value
272-
y_axis_bottom = 2175.35 # Y coordinate for zero
273-
274-
# Extract y positions from guides if possible
275-
guides = re.findall(r'path d="M0\.000000 ([0-9.]+) h[^"]*" class="(?:major )?(?:guide )?line"', svg_content)
276-
if guides:
277-
y_axis_top = float(min(guides, key=float))
278-
y_axis_bottom = float(max(guides, key=float))
132+
# Find plot dimensions from the background rect inside plot group
133+
bg_match = re.search(
134+
r'class="plot"[^>]*>.*?<rect class="background"[^>]*width="([0-9.]+)"[^>]*height="([0-9.]+)"', base_svg, re.DOTALL
135+
)
136+
plot_width = float(bg_match.group(1)) if bg_match else 4399.2
137+
plot_height = float(bg_match.group(2)) if bg_match else 2132.7
279138

280-
y_axis_range = y_axis_bottom - y_axis_top
139+
# Y-axis range from data
140+
y_max = max(bd["base"] + bd["height"] for bd in bar_data)
141+
y_min = 0
281142

282-
# Create connector group with transform to match plot area
283-
lines_svg = f'<g class="connectors" transform="translate({plot_x}, {plot_y})" stroke="{CONNECTOR_COLOR}" stroke-width="6" stroke-dasharray="20,10">\n'
143+
# Extract y positions from guide lines to get accurate scaling
144+
guides = re.findall(r'path d="M0\.000000 ([0-9.]+) h[0-9.]+" class="[^"]*guide[^"]*line"', base_svg)
145+
if guides:
146+
y_axis_top = float(min(guides, key=float))
147+
y_axis_bottom = float(max(guides, key=float))
148+
else:
149+
# Default based on typical pygal layout with margins
150+
y_axis_top = 42.65
151+
y_axis_bottom = plot_height - 42.65
284152

285-
# Y scale: map data values to SVG coordinates (inverted, origin at top)
286-
# y=0 in data maps to y_axis_bottom, y=y_max maps to y_axis_top
287-
def data_to_svg_y(value):
288-
return y_axis_bottom - (value / y_max) * y_axis_range
153+
y_axis_range = y_axis_bottom - y_axis_top
289154

290-
# Add horizontal connector lines between consecutive bars
291-
# Each line goes from right edge of current bar to left edge of next bar
292-
for i in range(num_bars - 1):
293-
level = connector_levels[i]
294-
# Bar center positions within plot area: (i + 0.5) * bar_width
295-
# Line starts at right side of bar i and ends at left side of bar i+1
296-
bar_center_i = (i + 0.5) * bar_width
297-
bar_center_next = (i + 1.5) * bar_width
298-
# Approximate bar half-width (with spacing)
299-
bar_half_width = bar_width * 0.4
155+
# Build connector lines SVG group
156+
num_bars = len(connector_levels)
157+
bar_width = plot_width / num_bars
300158

301-
x1 = bar_center_i + bar_half_width # Right edge of current bar
302-
x2 = bar_center_next - bar_half_width # Left edge of next bar
303-
y = data_to_svg_y(level)
159+
connector_lines = f'<g class="connectors" transform="translate({plot_x}, {plot_y})" stroke="{CONNECTOR_COLOR}" stroke-width="6" stroke-dasharray="20,10">\n'
304160

305-
lines_svg += f' <line x1="{x1:.1f}" y1="{y:.1f}" x2="{x2:.1f}" y2="{y:.1f}"/>\n'
161+
for i in range(num_bars - 1):
162+
level = connector_levels[i]
163+
# Map data value to SVG y coordinate (inverted - 0 at bottom)
164+
y = y_axis_bottom - (level / y_max) * y_axis_range
306165

307-
lines_svg += "</g>\n"
166+
# Horizontal line from right edge of bar i to left edge of bar i+1
167+
bar_center_i = (i + 0.5) * bar_width
168+
bar_center_next = (i + 1.5) * bar_width
169+
bar_half_width = bar_width * 0.35 # Leave gap from bar edges
308170

309-
# Insert before closing </svg>
310-
svg_content = svg_content.replace("</svg>", lines_svg + "</svg>")
171+
x1 = bar_center_i + bar_half_width
172+
x2 = bar_center_next - bar_half_width
311173

312-
return svg_content
174+
connector_lines += f' <line x1="{x1:.1f}" y1="{y:.1f}" x2="{x2:.1f}" y2="{y:.1f}"/>\n'
313175

176+
connector_lines += "</g>\n"
314177

315-
# Render SVG with connector lines
316-
svg_with_connectors = add_connector_lines_to_svg(base_svg, bar_data, connector_levels)
178+
# Insert connector lines before closing </svg>
179+
svg_with_connectors = base_svg.replace("</svg>", connector_lines + "</svg>")
317180

318-
# Save SVG with connectors
319-
with open("plot_with_connectors.svg", "w") as f:
181+
# Save SVG
182+
with open("plot.svg", "w") as f:
320183
f.write(svg_with_connectors)
321184

322185
# Render to PNG using cairosvg
323186
cairosvg.svg2png(bytestring=svg_with_connectors.encode("utf-8"), write_to="plot.png")
324187

325-
# Save HTML with interactive connectors
188+
# HTML with embedded SVG
189+
html_content = f"""<!DOCTYPE html>
190+
<html>
191+
<head>
192+
<meta charset="utf-8">
193+
<title>waterfall-basic · pygal · pyplots.ai</title>
194+
</head>
195+
<body>
196+
{svg_with_connectors}
197+
</body>
198+
</html>"""
199+
326200
with open("plot.html", "w") as f:
327201
f.write(html_content)

0 commit comments

Comments
 (0)