Skip to content

Commit 1d62c57

Browse files
feat(highcharts): implement logistic-regression (#3570)
## Implementation: `logistic-regression` - highcharts Implements the **highcharts** version of `logistic-regression`. **File:** `plots/logistic-regression/implementations/highcharts.py` **Parent Issue:** #3550 --- :robot: *[impl-generate workflow](https://github.com/MarkusNeusinger/pyplots/actions/runs/20866601331)* --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent aa66c90 commit 1d62c57

2 files changed

Lines changed: 477 additions & 0 deletions

File tree

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
""" pyplots.ai
2+
logistic-regression: Logistic Regression Curve Plot
3+
Library: highcharts unknown | Python 3.13.11
4+
Quality: 91/100 | Created: 2026-01-09
5+
"""
6+
7+
import tempfile
8+
import time
9+
import urllib.request
10+
from pathlib import Path
11+
12+
import numpy as np
13+
from highcharts_core.chart import Chart
14+
from highcharts_core.options import HighchartsOptions
15+
from highcharts_core.options.series.area import AreaRangeSeries
16+
from highcharts_core.options.series.scatter import ScatterSeries
17+
from highcharts_core.options.series.spline import SplineSeries
18+
from selenium import webdriver
19+
from selenium.webdriver.chrome.options import Options
20+
from sklearn.linear_model import LogisticRegression
21+
22+
23+
# Data - Generate binary classification data
24+
np.random.seed(42)
25+
n_points = 150
26+
27+
# Predictor: Study hours (0-10)
28+
x = np.random.uniform(0, 10, n_points)
29+
30+
# Binary outcome: Pass/Fail (1/0) with logistic relationship
31+
true_prob = 1 / (1 + np.exp(-1.5 * (x - 5)))
32+
y = (np.random.random(n_points) < true_prob).astype(int)
33+
34+
# Fit logistic regression model
35+
model = LogisticRegression()
36+
model.fit(x.reshape(-1, 1), y)
37+
38+
# Generate smooth curve for plotting
39+
x_curve = np.linspace(0, 10, 200)
40+
y_prob = model.predict_proba(x_curve.reshape(-1, 1))[:, 1]
41+
42+
# Calculate confidence intervals using bootstrap
43+
n_bootstrap = 100
44+
bootstrap_probs = np.zeros((n_bootstrap, len(x_curve)))
45+
for i in range(n_bootstrap):
46+
indices = np.random.choice(n_points, n_points, replace=True)
47+
x_boot, y_boot = x[indices], y[indices]
48+
model_boot = LogisticRegression()
49+
model_boot.fit(x_boot.reshape(-1, 1), y_boot)
50+
bootstrap_probs[i] = model_boot.predict_proba(x_curve.reshape(-1, 1))[:, 1]
51+
52+
ci_lower = np.percentile(bootstrap_probs, 2.5, axis=0)
53+
ci_upper = np.percentile(bootstrap_probs, 97.5, axis=0)
54+
55+
# Jitter y values for visibility
56+
jitter = np.random.uniform(-0.03, 0.03, n_points)
57+
y_jittered = y + jitter
58+
59+
# Separate points by class
60+
x_class0 = x[y == 0].tolist()
61+
y_class0 = y_jittered[y == 0].tolist()
62+
x_class1 = x[y == 1].tolist()
63+
y_class1 = y_jittered[y == 1].tolist()
64+
65+
# Chart setup
66+
chart = Chart(container="container")
67+
chart.options = HighchartsOptions()
68+
69+
# Chart configuration
70+
chart.options.chart = {
71+
"width": 4800,
72+
"height": 2700,
73+
"backgroundColor": "#ffffff",
74+
"style": {"fontFamily": "Arial, sans-serif"},
75+
"spacingBottom": 100,
76+
"spacingLeft": 50,
77+
"spacingTop": 50,
78+
"spacingRight": 50,
79+
}
80+
81+
# Title
82+
chart.options.title = {
83+
"text": "logistic-regression · highcharts · pyplots.ai",
84+
"style": {"fontSize": "48px", "fontWeight": "bold"},
85+
}
86+
87+
chart.options.subtitle = {"text": "Exam Pass Probability vs Study Hours", "style": {"fontSize": "32px"}}
88+
89+
# X-axis
90+
chart.options.x_axis = {
91+
"title": {"text": "Study Hours", "style": {"fontSize": "36px"}},
92+
"labels": {"style": {"fontSize": "28px"}},
93+
"min": 0,
94+
"max": 10,
95+
"gridLineWidth": 1,
96+
"gridLineColor": "rgba(0, 0, 0, 0.1)",
97+
}
98+
99+
# Y-axis
100+
chart.options.y_axis = {
101+
"title": {"text": "Probability", "style": {"fontSize": "36px"}},
102+
"labels": {"style": {"fontSize": "28px"}},
103+
"min": -0.05,
104+
"max": 1.05,
105+
"gridLineWidth": 1,
106+
"gridLineColor": "rgba(0, 0, 0, 0.1)",
107+
"plotLines": [
108+
{
109+
"value": 0.5,
110+
"color": "#888888",
111+
"width": 3,
112+
"dashStyle": "Dash",
113+
"label": {
114+
"text": "Decision Threshold (0.5)",
115+
"align": "right",
116+
"style": {"fontSize": "24px", "color": "#888888"},
117+
"x": -10,
118+
"y": -10,
119+
},
120+
"zIndex": 4,
121+
}
122+
],
123+
}
124+
125+
# Legend
126+
chart.options.legend = {
127+
"enabled": True,
128+
"itemStyle": {"fontSize": "28px"},
129+
"symbolRadius": 6,
130+
"symbolHeight": 20,
131+
"symbolWidth": 20,
132+
}
133+
134+
# Plot options
135+
chart.options.plot_options = {
136+
"scatter": {"marker": {"radius": 14, "symbol": "circle"}},
137+
"spline": {"lineWidth": 6, "marker": {"enabled": False}},
138+
"arearange": {"fillOpacity": 0.25, "lineWidth": 0, "marker": {"enabled": False}},
139+
}
140+
141+
# Confidence interval (arearange series)
142+
ci_data = [[float(x_curve[i]), float(ci_lower[i]), float(ci_upper[i])] for i in range(len(x_curve))]
143+
ci_series = AreaRangeSeries()
144+
ci_series.data = ci_data
145+
ci_series.name = "95% CI"
146+
ci_series.color = "rgba(48, 105, 152, 0.3)"
147+
ci_series.fill_opacity = 0.3
148+
chart.add_series(ci_series)
149+
150+
# Logistic curve
151+
curve_data = [[float(x_curve[i]), float(y_prob[i])] for i in range(len(x_curve))]
152+
curve_series = SplineSeries()
153+
curve_series.data = curve_data
154+
curve_series.name = "Logistic Curve"
155+
curve_series.color = "#306998"
156+
chart.add_series(curve_series)
157+
158+
# Class 0 points (Fail)
159+
scatter_class0 = ScatterSeries()
160+
scatter_class0.data = [[x_class0[i], y_class0[i]] for i in range(len(x_class0))]
161+
scatter_class0.name = "Fail (0)"
162+
scatter_class0.color = "rgba(48, 105, 152, 0.6)"
163+
scatter_class0.marker = {"radius": 14, "symbol": "circle"}
164+
chart.add_series(scatter_class0)
165+
166+
# Class 1 points (Pass)
167+
scatter_class1 = ScatterSeries()
168+
scatter_class1.data = [[x_class1[i], y_class1[i]] for i in range(len(x_class1))]
169+
scatter_class1.name = "Pass (1)"
170+
scatter_class1.color = "rgba(255, 212, 59, 0.8)"
171+
scatter_class1.marker = {"radius": 14, "symbol": "circle"}
172+
chart.add_series(scatter_class1)
173+
174+
# Add model accuracy annotation
175+
accuracy = model.score(x.reshape(-1, 1), y)
176+
chart.options.annotations = [
177+
{
178+
"labels": [
179+
{
180+
"point": {"x": 8.5, "y": 0.15, "xAxis": 0, "yAxis": 0},
181+
"text": f"Accuracy: {accuracy:.1%}",
182+
"style": {"fontSize": "28px"},
183+
"backgroundColor": "rgba(255, 255, 255, 0.8)",
184+
"borderColor": "#306998",
185+
"borderWidth": 2,
186+
"padding": 15,
187+
}
188+
],
189+
"labelOptions": {"shape": "rect"},
190+
}
191+
]
192+
193+
# Credits
194+
chart.options.credits = {"enabled": False}
195+
196+
# Download Highcharts JS
197+
highcharts_url = "https://code.highcharts.com/highcharts.js"
198+
with urllib.request.urlopen(highcharts_url, timeout=30) as response:
199+
highcharts_js = response.read().decode("utf-8")
200+
201+
# Download highcharts-more for arearange
202+
highcharts_more_url = "https://code.highcharts.com/highcharts-more.js"
203+
with urllib.request.urlopen(highcharts_more_url, timeout=30) as response:
204+
highcharts_more_js = response.read().decode("utf-8")
205+
206+
# Download annotations module
207+
annotations_url = "https://code.highcharts.com/modules/annotations.js"
208+
with urllib.request.urlopen(annotations_url, timeout=30) as response:
209+
annotations_js = response.read().decode("utf-8")
210+
211+
# Generate HTML with inline scripts
212+
html_str = chart.to_js_literal()
213+
html_content = f"""<!DOCTYPE html>
214+
<html>
215+
<head>
216+
<meta charset="utf-8">
217+
<script>{highcharts_js}</script>
218+
<script>{highcharts_more_js}</script>
219+
<script>{annotations_js}</script>
220+
</head>
221+
<body style="margin:0;">
222+
<div id="container" style="width: 4800px; height: 2700px;"></div>
223+
<script>{html_str}</script>
224+
</body>
225+
</html>"""
226+
227+
# Write temp HTML and take screenshot
228+
with tempfile.NamedTemporaryFile(mode="w", suffix=".html", delete=False, encoding="utf-8") as f:
229+
f.write(html_content)
230+
temp_path = f.name
231+
232+
chrome_options = Options()
233+
chrome_options.add_argument("--headless")
234+
chrome_options.add_argument("--no-sandbox")
235+
chrome_options.add_argument("--disable-dev-shm-usage")
236+
chrome_options.add_argument("--disable-gpu")
237+
chrome_options.add_argument("--window-size=4900,2800")
238+
239+
driver = webdriver.Chrome(options=chrome_options)
240+
driver.get(f"file://{temp_path}")
241+
time.sleep(5)
242+
driver.save_screenshot("plot.png")
243+
driver.quit()
244+
245+
Path(temp_path).unlink()
246+
247+
# Save interactive HTML (using CDN scripts for standalone viewing)
248+
html_export = f"""<!DOCTYPE html>
249+
<html>
250+
<head>
251+
<meta charset="utf-8">
252+
<title>Logistic Regression - Highcharts</title>
253+
<script src="https://code.highcharts.com/highcharts.js"></script>
254+
<script src="https://code.highcharts.com/highcharts-more.js"></script>
255+
<script src="https://code.highcharts.com/modules/annotations.js"></script>
256+
</head>
257+
<body style="margin:0;">
258+
<div id="container" style="width: 100%; height: 100vh;"></div>
259+
<script>{html_str}</script>
260+
</body>
261+
</html>"""
262+
with open("plot.html", "w", encoding="utf-8") as f:
263+
f.write(html_export)

0 commit comments

Comments
 (0)