Skip to content

Commit af48ebe

Browse files
Fix repeated callback iterations (#635) (#662)
* Fix repeated callback iterations (#635) Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com> * Change callback API to keyword-only args (#635) Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com> * Address maintainer review: add stage to callback, fix stop_flag order Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com> * Remove _split_into_phases and draft files per review Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com> --------- Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com>
1 parent 2298ab5 commit af48ebe

File tree

5 files changed

+279
-36
lines changed

5 files changed

+279
-36
lines changed

python/interpret-core/interpret/glassbox/_ebm.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
999999
delayed(booster)(
10001000
shm_name=shm_name,
10011001
bag_idx=idx,
1002+
stage=0,
10021003
callback=callback,
10031004
dataset=(
10041005
shared.name if shared.name is not None else shared.dataset
@@ -1238,6 +1239,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
12381239
delayed(booster)(
12391240
shm_name=shm_name,
12401241
bag_idx=idx,
1242+
stage=1,
12411243
callback=callback,
12421244
dataset=(
12431245
shared.name
@@ -1349,6 +1351,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
13491351
exception, intercept_change, _, _, rng = booster(
13501352
shm_name=None,
13511353
bag_idx=0,
1354+
stage=-1,
13521355
callback=None,
13531356
dataset=shared.dataset,
13541357
intercept_rounds=develop.get_option("n_intercept_rounds_final"),
@@ -3206,13 +3209,15 @@ class EBMModel(BaseEBM):
32063209
tradeoff for the ensemble of models --- not the individual models --- a small
32073210
amount of overfitting of the individual models can improve the accuracy of
32083211
the ensemble as a whole.
3209-
callback : Optional[Callable[[int, int, bool, float], bool]], default=None
3210-
A user-defined function that is invoked at the end of each boosting step to determine
3211-
whether to terminate boosting or continue. If it returns True, the boosting loop is
3212-
stopped immediately. By default, no callback is used and training proceeds according
3213-
to the early stopping settings. The callback function receives:
3214-
(1) the bag index, (2) the number of boosting steps completed,
3215-
(3) a boolean indicating whether progress was made in the current step, and (4) the current best score.
3212+
callback : Optional[Callable[..., bool]], default=None
3213+
A user-defined function invoked after each progressing boosting step. Must use
3214+
keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``.
3215+
If it returns True, boosting is stopped immediately.
3216+
The callback receives: ``bag`` (int) the outer bag index,
3217+
``stage`` (int) the boosting stage (0=mains, 1=pairs),
3218+
``step`` (int) the number of boosting steps completed,
3219+
``term`` (int) the index of the term that was just boosted,
3220+
and ``metric`` (float) the current validation metric.
32163221
min_samples_leaf : int, default=4
32173222
Minimum number of samples allowed in the leaves.
32183223
min_hessian : float, default=0.0
@@ -3309,7 +3314,7 @@ def __init__(
33093314
max_rounds: Optional[int] = 50000,
33103315
early_stopping_rounds: Optional[int] = 100,
33113316
early_stopping_tolerance: Optional[float] = 1e-5,
3312-
callback: Optional[Callable[[int, int, bool, float], bool]] = None,
3317+
callback: Optional[Callable[..., bool]] = None,
33133318
# Trees
33143319
min_samples_leaf: Optional[int] = 4,
33153320
min_hessian: Optional[float] = 0.0,
@@ -3572,13 +3577,15 @@ class EBMClassifier(EBMClassifierMixin, EBMModel):
35723577
tradeoff for the ensemble of models --- not the individual models --- a small
35733578
amount of overfitting of the individual models can improve the accuracy of
35743579
the ensemble as a whole.
3575-
callback : Optional[Callable[[int, int, bool, float], bool]], default=None
3576-
A user-defined function that is invoked at the end of each boosting step to determine
3577-
whether to terminate boosting or continue. If it returns True, the boosting loop is
3578-
stopped immediately. By default, no callback is used and training proceeds according
3579-
to the early stopping settings. The callback function receives:
3580-
(1) the bag index, (2) the number of boosting steps completed,
3581-
(3) a boolean indicating whether progress was made in the current step, and (4) the current best score.
3580+
callback : Optional[Callable[..., bool]], default=None
3581+
A user-defined function invoked after each progressing boosting step. Must use
3582+
keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``.
3583+
If it returns True, boosting is stopped immediately.
3584+
The callback receives: ``bag`` (int) the outer bag index,
3585+
``stage`` (int) the boosting stage (0=mains, 1=pairs),
3586+
``step`` (int) the number of boosting steps completed,
3587+
``term`` (int) the index of the term that was just boosted,
3588+
and ``metric`` (float) the current validation metric.
35823589
min_samples_leaf : int, default=4
35833590
Minimum number of samples allowed in the leaves.
35843591
min_hessian : float, default=1e-4
@@ -3734,7 +3741,7 @@ def __init__(
37343741
max_rounds: Optional[int] = 50000,
37353742
early_stopping_rounds: Optional[int] = 100,
37363743
early_stopping_tolerance: Optional[float] = 1e-5,
3737-
callback: Optional[Callable[[int, int, bool, float], bool]] = None,
3744+
callback: Optional[Callable[..., bool]] = None,
37383745
# Trees
37393746
min_samples_leaf: Optional[int] = 4,
37403747
min_hessian: Optional[float] = 1e-4,
@@ -3876,13 +3883,15 @@ class EBMRegressor(EBMRegressorMixin, EBMModel):
38763883
tradeoff for the ensemble of models --- not the individual models --- a small
38773884
amount of overfitting of the individual models can improve the accuracy of
38783885
the ensemble as a whole.
3879-
callback : Optional[Callable[[int, int, bool, float], bool]], default=None
3880-
A user-defined function that is invoked at the end of each boosting step to determine
3881-
whether to terminate boosting or continue. If it returns True, the boosting loop is
3882-
stopped immediately. By default, no callback is used and training proceeds according
3883-
to the early stopping settings. The callback function receives:
3884-
(1) the bag index, (2) the number of boosting steps completed,
3885-
(3) a boolean indicating whether progress was made in the current step, and (4) the current best score.
3886+
callback : Optional[Callable[..., bool]], default=None
3887+
A user-defined function invoked after each progressing boosting step. Must use
3888+
keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``.
3889+
If it returns True, boosting is stopped immediately.
3890+
The callback receives: ``bag`` (int) the outer bag index,
3891+
``stage`` (int) the boosting stage (0=mains, 1=pairs),
3892+
``step`` (int) the number of boosting steps completed,
3893+
``term`` (int) the index of the term that was just boosted,
3894+
and ``metric`` (float) the current validation metric.
38863895
min_samples_leaf : int, default=4
38873896
Minimum number of samples allowed in the leaves.
38883897
min_hessian : float, default=0.0
@@ -4039,7 +4048,7 @@ def __init__(
40394048
max_rounds: Optional[int] = 50000,
40404049
early_stopping_rounds: Optional[int] = 100,
40414050
early_stopping_tolerance: Optional[float] = 1e-5,
4042-
callback: Optional[Callable[[int, int, bool, float], bool]] = None,
4051+
callback: Optional[Callable[..., bool]] = None,
40434052
# Trees
40444053
min_samples_leaf: Optional[int] = 4,
40454054
min_hessian: Optional[float] = 0.0,

python/interpret-core/interpret/glassbox/_ebm_core/_boost.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def load_shared_memory(name: str) -> Iterator[shared_memory.SharedMemory]:
2828
def boost(
2929
shm_name,
3030
bag_idx,
31+
stage,
3132
callback,
3233
dataset,
3334
intercept_rounds,
@@ -364,18 +365,25 @@ def boost(
364365
):
365366
break
366367

368+
if stop_flag is not None and stop_flag[0]:
369+
break
370+
371+
if callback is not None:
372+
is_done = callback(
373+
bag=bag_idx,
374+
stage=stage,
375+
step=step_idx,
376+
term=term_idx,
377+
metric=cur_metric,
378+
)
379+
if is_done:
380+
if stop_flag is not None:
381+
stop_flag[0] = True
382+
break
383+
367384
if stop_flag is not None and stop_flag[0]:
368385
break
369386

370-
if callback is not None:
371-
is_done = callback(
372-
bag_idx, step_idx, make_progress, cur_metric
373-
)
374-
if is_done:
375-
if stop_flag is not None:
376-
stop_flag[0] = True
377-
break
378-
379387
state_idx = state_idx + 1
380388
if len(term_features) <= state_idx:
381389
if smoothing_rounds > 0:
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# Copyright (c) 2023 The InterpretML Contributors
2+
# Distributed under the MIT software license
3+
4+
"""Regression tests for issue #635: callback API uses keyword-only args."""
5+
6+
import numpy as np
7+
8+
from interpret.glassbox import (
9+
ExplainableBoostingClassifier,
10+
ExplainableBoostingRegressor,
11+
)
12+
from interpret.utils import make_synthetic
13+
14+
15+
class RecordingCallback:
16+
"""Picklable callback that records all invocations.
17+
18+
Uses n_jobs=1 in tests so that state is shared in-process.
19+
"""
20+
21+
def __init__(self):
22+
self.records = []
23+
24+
def __call__(self, *, bag, stage, step, term, metric):
25+
self.records.append((bag, stage, step, term, metric))
26+
return False
27+
28+
29+
class StopAfterCallback:
30+
"""Picklable callback that stops training after N calls."""
31+
32+
def __init__(self, stop_after):
33+
self.stop_after = stop_after
34+
self.call_count = 0
35+
36+
def __call__(self, *, bag, stage, step, term, metric):
37+
self.call_count += 1
38+
return self.call_count >= self.stop_after
39+
40+
41+
42+
def test_callback_no_repeated_steps_classifier():
43+
"""Verify the callback receives strictly increasing step values.
44+
45+
Before the fix, the callback was invoked on every internal loop
46+
iteration — including non-progressing cycles — which caused
47+
the same step value to be reported multiple times.
48+
"""
49+
cb = RecordingCallback()
50+
51+
X, y, names, types = make_synthetic(
52+
seed=42, classes=2, output_type="float", n_samples=500
53+
)
54+
55+
ebm = ExplainableBoostingClassifier(
56+
names,
57+
types,
58+
outer_bags=1,
59+
max_rounds=50,
60+
n_jobs=1,
61+
callback=cb,
62+
)
63+
ebm.fit(X, y)
64+
65+
assert len(cb.records) > 0, "Callback should have been invoked at least once"
66+
67+
steps_by_bag = {}
68+
for bag, stage, step, _, _ in cb.records:
69+
steps_by_bag.setdefault(bag, []).append((stage, step))
70+
71+
for bag, steps in steps_by_bag.items():
72+
for i in range(1, len(steps)):
73+
assert steps[i] > steps[i - 1], (
74+
f"Bag {bag}: (stage, step) went from {steps[i - 1]} to "
75+
f"{steps[i]} (expected strictly increasing)"
76+
)
77+
78+
79+
def test_callback_no_repeated_steps_regressor():
80+
"""Same test as above but for ExplainableBoostingRegressor."""
81+
cb = RecordingCallback()
82+
83+
X, y, names, types = make_synthetic(
84+
seed=42, classes=None, output_type="float", n_samples=500
85+
)
86+
87+
ebm = ExplainableBoostingRegressor(
88+
names,
89+
types,
90+
outer_bags=1,
91+
max_rounds=50,
92+
n_jobs=1,
93+
callback=cb,
94+
)
95+
ebm.fit(X, y)
96+
97+
assert len(cb.records) > 0, "Callback should have been invoked at least once"
98+
99+
steps_by_bag = {}
100+
for bag, stage, step, _, _ in cb.records:
101+
steps_by_bag.setdefault(bag, []).append((stage, step))
102+
103+
for bag, steps in steps_by_bag.items():
104+
for i in range(1, len(steps)):
105+
assert steps[i] > steps[i - 1], (
106+
f"Bag {bag}: (stage, step) went from {steps[i - 1]} to "
107+
f"{steps[i]} (expected strictly increasing)"
108+
)
109+
110+
111+
def test_callback_receives_term_index():
112+
"""Verify the callback receives a valid term index."""
113+
cb = RecordingCallback()
114+
115+
X, y, names, types = make_synthetic(
116+
seed=42, classes=2, output_type="float", n_samples=500
117+
)
118+
119+
ebm = ExplainableBoostingClassifier(
120+
names,
121+
types,
122+
outer_bags=1,
123+
max_rounds=50,
124+
n_jobs=1,
125+
callback=cb,
126+
)
127+
ebm.fit(X, y)
128+
129+
assert len(cb.records) > 0, "Callback should have been invoked at least once"
130+
131+
for i, (_, _, _, term, _) in enumerate(cb.records):
132+
assert isinstance(term, (int, np.integer)), (
133+
f"term at call {i} should be an int, got {type(term)}"
134+
)
135+
assert term >= 0, f"term at call {i} should be non-negative, got {term}"
136+
137+
138+
def test_callback_early_termination():
139+
"""Verify the callback can still terminate training early."""
140+
cb = StopAfterCallback(stop_after=5)
141+
142+
X, y, names, types = make_synthetic(
143+
seed=42, classes=2, output_type="float", n_samples=500
144+
)
145+
146+
ebm = ExplainableBoostingClassifier(
147+
names,
148+
types,
149+
outer_bags=1,
150+
max_rounds=5000,
151+
n_jobs=1,
152+
callback=cb,
153+
)
154+
ebm.fit(X, y)
155+
156+
assert cb.call_count == cb.stop_after, (
157+
f"Expected callback to be called exactly {cb.stop_after} times "
158+
f"before stopping, but was called {cb.call_count} times"
159+
)
160+
161+
# The model should still be valid after early stopping
162+
predictions = ebm.predict(X)
163+
assert len(predictions) == len(y)
164+
165+
166+
def test_callback_receives_valid_metrics():
167+
"""Verify the callback receives valid (finite) metric values."""
168+
cb = RecordingCallback()
169+
170+
X, y, names, types = make_synthetic(
171+
seed=42, classes=2, output_type="float", n_samples=500
172+
)
173+
174+
ebm = ExplainableBoostingClassifier(
175+
names,
176+
types,
177+
outer_bags=1,
178+
max_rounds=50,
179+
n_jobs=1,
180+
callback=cb,
181+
)
182+
ebm.fit(X, y)
183+
184+
assert len(cb.records) > 0, "Callback should have been invoked at least once"
185+
186+
for i, (_, _, _, _, metric) in enumerate(cb.records):
187+
assert np.isfinite(metric), f"Metric at step {i} is not finite: {metric}"
188+
189+
190+
def test_callback_keyword_only_signature():
191+
"""Verify the callback is invoked with keyword-only arguments.
192+
193+
This test ensures that the callback cannot be invoked with positional
194+
arguments, which is the core API change in this PR.
195+
"""
196+
197+
class KeywordOnlyCallback:
198+
def __init__(self):
199+
self.called = False
200+
201+
def __call__(self, *, bag, stage, step, term, metric):
202+
self.called = True
203+
# Verify all args were passed as keywords by checking they exist
204+
assert isinstance(bag, int)
205+
assert isinstance(step, int)
206+
assert isinstance(term, (int, np.integer))
207+
assert isinstance(metric, float)
208+
return True # stop immediately
209+
210+
cb = KeywordOnlyCallback()
211+
212+
X, y, names, types = make_synthetic(
213+
seed=42, classes=2, output_type="float", n_samples=500
214+
)
215+
216+
ebm = ExplainableBoostingClassifier(
217+
names,
218+
types,
219+
outer_bags=1,
220+
max_rounds=50,
221+
n_jobs=1,
222+
callback=cb,
223+
)
224+
ebm.fit(X, y)
225+
226+
assert cb.called, "Keyword-only callback should have been invoked"

0 commit comments

Comments
 (0)