Skip to content

Commit 366e2f1

Browse files
committed
updating the linear regression visual
1 parent 32de0f9 commit 366e2f1

1 file changed

Lines changed: 52 additions & 11 deletions

File tree

linear_regression/visuals.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,22 @@
1010
from IPython.display import display, clear_output
1111

1212

13+
def _float_slider_range(lo: float, hi: float, value: float, *, eps: float = 1e-12):
14+
"""Return (min_, max_, clamped_value) with max_ > min_; works for negative slopes and edge cases."""
15+
vmin = float(min(lo, hi))
16+
vmax = float(max(lo, hi))
17+
if not (np.isfinite(vmin) and np.isfinite(vmax)):
18+
v = float(value) if np.isfinite(value) else 0.0
19+
return v - 1.0, v + 1.0, v
20+
if vmax - vmin < eps:
21+
v = float(value) if np.isfinite(value) else (vmin + vmax) / 2
22+
half = max(eps, abs(v) * 1e-9 if v != 0 else 1e-9)
23+
vmin, vmax = v - half, v + half
24+
v = float(value) if np.isfinite(value) else (vmin + vmax) / 2
25+
v = min(max(v, vmin), vmax)
26+
return vmin, vmax, v
27+
28+
1329
def show_interactive_correlation():
1430
"""Interactive visual for correlation direction and strength."""
1531
r_slider = widgets.FloatSlider(
@@ -75,39 +91,64 @@ def show_interactive_line_tuner(
7591
):
7692
"""Interactive visual to compare a user-selected line with least-squares."""
7793
y_spread = float(np.std(y))
94+
if not np.isfinite(y_spread):
95+
y_spread = 1.0
7896
b_float = float(b)
97+
if not np.isfinite(b_float):
98+
b_float = 0.0
7999
low, high = b_float * 0.2, b_float * 1.8
80100
if abs(b_float) < 1e-15:
81101
sx = float(np.std(x))
82102
span = (y_spread / sx) if sx > 1e-15 else 1.0
83-
b_min, b_max = -abs(span), abs(span)
103+
span = max(abs(span), 1e-9)
104+
b_min, b_max, b_val = _float_slider_range(-span, span, b_float)
84105
else:
85-
b_min, b_max = min(low, high), max(low, high)
106+
b_min, b_max, b_val = _float_slider_range(low, high, b_float)
107+
b_rng = b_max - b_min
108+
step_b = (
109+
max(abs(b_float) * 0.02, 1e-9)
110+
if abs(b_float) >= 1e-15
111+
else max(y_spread / 80, 0.05, 1e-9)
112+
)
113+
step_b = min(step_b, b_rng / 5) if b_rng > 0 else step_b
114+
step_b = max(step_b, 1e-12)
115+
86116
b_try = widgets.FloatSlider(
87-
value=b_float,
117+
value=b_val,
88118
min=b_min,
89119
max=b_max,
90-
step=max(abs(b_float) * 0.02, 1e-6) if abs(b_float) >= 1e-15 else max(y_spread / 80, 0.05, 1e-6),
120+
step=step_b,
91121
description="Your slope b:",
92122
style={"description_width": "initial"},
93123
layout=widgets.Layout(width="500px"),
94124
)
125+
a_float = float(a)
126+
if not np.isfinite(a_float):
127+
a_float = 0.0
95128
if money_yaxis:
129+
a_lo, a_hi = a_float - 8_000_000, a_float + 8_000_000
130+
a_min, a_max, a_val = _float_slider_range(a_lo, a_hi, a_float)
96131
a_try = widgets.FloatSlider(
97-
value=float(a),
98-
min=float(a - 8_000_000),
99-
max=float(a + 8_000_000),
132+
value=a_val,
133+
min=a_min,
134+
max=a_max,
100135
step=100_000,
101136
description="Your intercept a:",
102137
style={"description_width": "initial"},
103138
layout=widgets.Layout(width="500px"),
104139
)
105140
else:
141+
half = max(4 * y_spread, 1.0)
142+
a_lo, a_hi = a_float - half, a_float + half
143+
a_min, a_max, a_val = _float_slider_range(a_lo, a_hi, a_float)
144+
step_a = max(y_spread / 80, 0.05, 1e-9)
145+
step_a = min(step_a, (a_max - a_min) / 5) if a_max > a_min else step_a
146+
step_a = max(step_a, 1e-12)
106147
a_try = widgets.FloatSlider(
107-
value=float(a),
108-
min=float(a - 4 * y_spread),
109-
max=float(a + 4 * y_spread),
110-
step=max(y_spread / 80, 0.05),
148+
value=a_val,
149+
min=a_min,
150+
max=a_max,
151+
step=step_a,
111152
description="Your intercept a:",
112153
style={"description_width": "initial"},
113154
layout=widgets.Layout(width="500px"),

0 commit comments

Comments
 (0)