Skip to content

Commit 2a2c9d8

Browse files
feat(highcharts): implement shap-waterfall (#5902)
## Implementation: `shap-waterfall` - python/highcharts Implements the **python/highcharts** version of `shap-waterfall`. **File:** `plots/shap-waterfall/implementations/python/highcharts.py` **Parent Issue:** #5237 --- :robot: *[impl-generate workflow](https://github.com/MarkusNeusinger/anyplot/actions/runs/25493326647)* --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 9245686 commit 2a2c9d8

2 files changed

Lines changed: 491 additions & 0 deletions

File tree

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
""" anyplot.ai
2+
shap-waterfall: SHAP Waterfall Plot for Feature Attribution
3+
Library: highcharts unknown | Python 3.13.13
4+
Quality: 86/100 | Created: 2026-05-07
5+
"""
6+
7+
import base64
8+
import json
9+
import os
10+
import tempfile
11+
import time
12+
import urllib.request
13+
from pathlib import Path
14+
15+
from selenium import webdriver
16+
from selenium.webdriver.chrome.options import Options
17+
18+
19+
# Theme tokens
20+
THEME = os.getenv("ANYPLOT_THEME", "light")
21+
PAGE_BG = "#FAF8F1" if THEME == "light" else "#1A1A17"
22+
INK = "#1A1A17" if THEME == "light" else "#F0EFE8"
23+
INK_SOFT = "#4A4A44" if THEME == "light" else "#B8B7B0"
24+
GRID = "rgba(26,26,23,0.10)" if THEME == "light" else "rgba(240,239,232,0.10)"
25+
26+
BRAND = "#009E73" # Okabe-Ito pos 1 — baseline & prediction bars
27+
POSITIVE_COLOR = "#D55E00" # Okabe-Ito pos 2 — positive SHAP (pushes risk up)
28+
NEGATIVE_COLOR = "#0072B2" # Okabe-Ito pos 3 — negative SHAP (pushes risk down)
29+
30+
# Data — credit scoring model explaining a single loan application
31+
# Features sorted by absolute SHAP value (largest contribution first = top of chart)
32+
BASE_VALUE = 0.35 # Expected probability of default across all applicants
33+
34+
features = [
35+
"Credit Score",
36+
"Debt-to-Income",
37+
"Annual Income",
38+
"Loan Amount",
39+
"Employment Years",
40+
"Payment History",
41+
"Open Accounts",
42+
"Credit Inquiries",
43+
"Credit Age",
44+
"Savings Balance",
45+
]
46+
shap_values = [-0.18, +0.15, -0.12, +0.09, -0.07, -0.05, +0.04, +0.03, -0.02, -0.02]
47+
FINAL_VALUE = round(BASE_VALUE + sum(shap_values), 4)
48+
49+
# Build waterfall data: base bar → feature deltas → isSum final
50+
categories = ["E[f(x)] Baseline", *features, "f(x) Prediction"]
51+
52+
data_points = [{"y": BASE_VALUE, "color": BRAND}]
53+
for sv in shap_values:
54+
data_points.append({"y": sv, "color": POSITIVE_COLOR if sv > 0 else NEGATIVE_COLOR})
55+
data_points.append({"isSum": True, "color": BRAND})
56+
57+
# Chart configuration (JSON-serializable; JS functions injected via string replace)
58+
chart_config = {
59+
"chart": {
60+
"type": "waterfall",
61+
"inverted": True,
62+
"width": 4800,
63+
"height": 2700,
64+
"backgroundColor": PAGE_BG,
65+
"marginLeft": 340,
66+
"marginRight": 220,
67+
"marginTop": 130,
68+
"marginBottom": 220,
69+
"style": {"fontFamily": "Arial, sans-serif", "color": INK},
70+
},
71+
"title": {
72+
"text": "Credit Default Risk · shap-waterfall · highcharts · anyplot.ai",
73+
"style": {"fontSize": "28px", "fontWeight": "normal", "color": INK},
74+
"align": "left",
75+
"x": 340,
76+
},
77+
"xAxis": {
78+
"categories": categories,
79+
"title": {"text": "Feature", "style": {"fontSize": "22px", "color": INK}},
80+
"labels": {"style": {"fontSize": "20px", "color": INK_SOFT}},
81+
"lineColor": INK_SOFT,
82+
"tickColor": INK_SOFT,
83+
"gridLineColor": GRID,
84+
},
85+
"yAxis": {
86+
"title": {"text": "Probability of Default", "style": {"fontSize": "22px", "color": INK}},
87+
"labels": {"style": {"fontSize": "18px", "color": INK_SOFT}, "formatter": "__YAXIS_FORMATTER__"},
88+
"lineColor": INK_SOFT,
89+
"tickColor": INK_SOFT,
90+
"gridLineColor": GRID,
91+
"gridLineWidth": 1,
92+
"min": -0.02,
93+
"max": 0.50,
94+
"plotLines": [
95+
{
96+
"value": BASE_VALUE,
97+
"color": INK_SOFT,
98+
"width": 2,
99+
"dashStyle": "Dash",
100+
"zIndex": 2,
101+
"label": {
102+
"text": f"Baseline {BASE_VALUE:.2f}",
103+
"align": "right",
104+
"rotation": 0,
105+
"x": -6,
106+
"y": -12,
107+
"style": {"fontSize": "16px", "color": INK_SOFT},
108+
},
109+
},
110+
{
111+
"value": FINAL_VALUE,
112+
"color": BRAND,
113+
"width": 2,
114+
"dashStyle": "ShortDot",
115+
"zIndex": 2,
116+
"label": {
117+
"text": f"Prediction {FINAL_VALUE:.2f}",
118+
"align": "right",
119+
"rotation": 0,
120+
"x": -6,
121+
"y": -12,
122+
"style": {"fontSize": "16px", "color": BRAND},
123+
},
124+
},
125+
],
126+
},
127+
"legend": {"enabled": False},
128+
"tooltip": {"enabled": False},
129+
"plotOptions": {
130+
"waterfall": {
131+
"lineWidth": 2,
132+
"lineColor": INK_SOFT,
133+
"borderWidth": 0,
134+
"groupPadding": 0.05,
135+
"pointPadding": 0.08,
136+
"dataLabels": {
137+
"enabled": True,
138+
"formatter": "__DATA_LABEL_FORMATTER__",
139+
"style": {"fontSize": "18px", "fontWeight": "bold", "color": INK, "textOutline": "none"},
140+
"inside": False,
141+
},
142+
}
143+
},
144+
"series": [{"name": "SHAP Attribution", "data": data_points}],
145+
}
146+
147+
# Inject JavaScript formatter functions (not JSON-serializable, so replace placeholders)
148+
config_json = json.dumps(chart_config)
149+
150+
yaxis_formatter = """function() {
151+
return Highcharts.numberFormat(this.value, 2);
152+
}"""
153+
config_json = config_json.replace('"__YAXIS_FORMATTER__"', yaxis_formatter)
154+
155+
data_label_formatter = """function() {
156+
if (this.point.isSum) {
157+
return 'f(x) = ' + Highcharts.numberFormat(this.y, 2);
158+
}
159+
if (this.point.index === 0) {
160+
return 'E[f(x)] = ' + Highcharts.numberFormat(this.y, 2);
161+
}
162+
var sign = this.y > 0 ? '+' : '';
163+
return sign + Highcharts.numberFormat(this.y, 3);
164+
}"""
165+
config_json = config_json.replace('"__DATA_LABEL_FORMATTER__"', data_label_formatter)
166+
167+
168+
# Download Highcharts JS with multiple CDN fallbacks
169+
def download_js(paths, timeout=15):
170+
cdn_bases = [
171+
"https://cdn.jsdelivr.net/npm/highcharts@11/",
172+
"https://unpkg.com/highcharts@11/",
173+
"https://code.highcharts.com/",
174+
]
175+
for path in paths:
176+
for base in cdn_bases:
177+
url = base + path
178+
for attempt in range(2):
179+
try:
180+
req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
181+
with urllib.request.urlopen(req, timeout=timeout) as resp:
182+
return resp.read().decode("utf-8")
183+
except Exception:
184+
if attempt == 0:
185+
time.sleep(1)
186+
return None
187+
188+
189+
highcharts_js = download_js(["highcharts.js"])
190+
if highcharts_js is None:
191+
raise RuntimeError("Failed to download highcharts.js from all CDNs")
192+
193+
# Waterfall chart type lives in highcharts-more.js
194+
highcharts_more_js = download_js(["highcharts-more.js"])
195+
if highcharts_more_js is None:
196+
raise RuntimeError("Failed to download highcharts-more.js from all CDNs")
197+
198+
# Generate HTML with inline Highcharts JS (core + more module for waterfall type)
199+
html_content = f"""<!DOCTYPE html>
200+
<html>
201+
<head>
202+
<meta charset="utf-8">
203+
<script>{highcharts_js}</script>
204+
<script>{highcharts_more_js}</script>
205+
</head>
206+
<body style="margin:0; background:{PAGE_BG};">
207+
<div id="container" style="width: 4800px; height: 2700px;"></div>
208+
<script>
209+
Highcharts.chart('container', {config_json});
210+
</script>
211+
</body>
212+
</html>"""
213+
214+
# Save interactive HTML artifact
215+
with open(f"plot-{THEME}.html", "w", encoding="utf-8") as f:
216+
f.write(html_content)
217+
218+
# Screenshot via headless Chrome with CDP for full-resolution capture
219+
with tempfile.NamedTemporaryFile(mode="w", suffix=".html", delete=False, encoding="utf-8") as f:
220+
f.write(html_content)
221+
temp_path = f.name
222+
223+
chrome_options = Options()
224+
chrome_options.add_argument("--headless=new")
225+
chrome_options.add_argument("--no-sandbox")
226+
chrome_options.add_argument("--disable-dev-shm-usage")
227+
chrome_options.add_argument("--disable-gpu")
228+
chrome_options.add_argument("--hide-scrollbars")
229+
chrome_options.add_argument("--force-device-scale-factor=1")
230+
231+
driver = webdriver.Chrome(options=chrome_options)
232+
driver.get(f"file://{temp_path}")
233+
time.sleep(5)
234+
235+
screenshot_config = {"captureBeyondViewport": True, "clip": {"x": 0, "y": 0, "width": 4800, "height": 2700, "scale": 1}}
236+
result = driver.execute_cdp_cmd("Page.captureScreenshot", screenshot_config)
237+
with open(f"plot-{THEME}.png", "wb") as f:
238+
f.write(base64.b64decode(result["data"]))
239+
driver.quit()
240+
241+
Path(temp_path).unlink()

0 commit comments

Comments
 (0)