Skip to content

Commit 9835e74

Browse files
committed
Fix callback iteration index monotonicity
Signed-off-by: Nandana Dileep <nandanadileep29@gmail.com>
1 parent 8959a30 commit 9835e74

File tree

3 files changed

+57
-8
lines changed

3 files changed

+57
-8
lines changed

python/interpret-core/interpret/glassbox/_ebm/_boost.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def boost(
8686
else None
8787
)
8888

89-
step_idx = 0
89+
step_idx = 0 # number of accepted boosting updates
90+
callback_idx = 0 # number of callback invocations (attempted steps)
9091
cur_metric = np.nan
9192

9293
_log.info("Start boosting")
@@ -369,12 +370,13 @@ def boost(
369370

370371
if callback is not None:
371372
is_done = callback(
372-
bag_idx, step_idx, make_progress, cur_metric
373+
bag_idx, callback_idx, make_progress, cur_metric
373374
)
374375
if is_done:
375376
if stop_flag is not None:
376377
stop_flag[0] = True
377378
break
379+
callback_idx += 1
378380

379381
state_idx = state_idx + 1
380382
if len(term_features) <= state_idx:

python/interpret-core/interpret/glassbox/_ebm/_ebm.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3086,12 +3086,13 @@ class ExplainableBoostingClassifier(ClassifierMixin, EBMModel):
30863086
amount of overfitting of the individual models can improve the accuracy of
30873087
the ensemble as a whole.
30883088
callback : Optional[Callable[[int, int, bool, float], bool]], default=None
3089-
A user-defined function that is invoked at the end of each boosting step to determine
3089+
A user-defined function that is invoked at the end of each boosting step attempt to determine
30903090
whether to terminate boosting or continue. If it returns True, the boosting loop is
30913091
stopped immediately. By default, no callback is used and training proceeds according
30923092
to the early stopping settings. The callback function receives:
3093-
(1) the bag index, (2) the number of boosting steps completed,
3094-
(3) a boolean indicating whether progress was made in the current step, and (4) the current best score.
3093+
(1) the bag index, (2) the callback step index (increments on every callback, even when
3094+
no update is applied), (3) a boolean indicating whether progress was made in the current step,
3095+
and (4) the current best score.
30953096
min_samples_leaf : int, default=4
30963097
Minimum number of samples allowed in the leaves.
30973098
min_hessian : float, default=1e-4
@@ -3596,12 +3597,13 @@ class ExplainableBoostingRegressor(RegressorMixin, EBMModel):
35963597
amount of overfitting of the individual models can improve the accuracy of
35973598
the ensemble as a whole.
35983599
callback : Optional[Callable[[int, int, bool, float], bool]], default=None
3599-
A user-defined function that is invoked at the end of each boosting step to determine
3600+
A user-defined function that is invoked at the end of each boosting step attempt to determine
36003601
whether to terminate boosting or continue. If it returns True, the boosting loop is
36013602
stopped immediately. By default, no callback is used and training proceeds according
36023603
to the early stopping settings. The callback function receives:
3603-
(1) the bag index, (2) the number of boosting steps completed,
3604-
(3) a boolean indicating whether progress was made in the current step, and (4) the current best score.
3604+
(1) the bag index, (2) the callback step index (increments on every callback, even when
3605+
no update is applied), (3) a boolean indicating whether progress was made in the current step,
3606+
and (4) the current best score.
36053607
min_samples_leaf : int, default=4
36063608
Minimum number of samples allowed in the leaves.
36073609
min_hessian : float, default=0.0
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) 2026 The InterpretML Contributors
2+
# Distributed under the MIT software license
3+
4+
import numpy as np
5+
import pytest
6+
7+
from interpret.glassbox import ExplainableBoostingClassifier
8+
from interpret.utils._native import Native
9+
10+
# Skip if the native lib is not available (common in source checkouts without build artifacts)
11+
try:
12+
Native._get_ebm_lib_path()
13+
except Exception: # pragma: no cover - just gating environments without libebm
14+
pytest.skip("libebm shared library not available", allow_module_level=True)
15+
16+
17+
def test_callback_iteration_is_monotonic():
18+
"""Ensure callback receives strictly increasing iteration indexes even when no progress is made."""
19+
20+
X = np.array([[0], [1], [0], [1]], dtype=np.float64)
21+
y = np.array([0, 1, 0, 1], dtype=np.int64)
22+
23+
iterations = []
24+
25+
def cb(bag_idx, iteration_idx, made_progress, metric):
26+
iterations.append(iteration_idx)
27+
# stop early to keep test fast; plenty of callback invocations happen before the first boost
28+
return len(iterations) >= 15
29+
30+
ebm = ExplainableBoostingClassifier(
31+
interactions=0,
32+
max_rounds=1,
33+
cyclic_progress=0.1, # forces several no-progress iterations up front
34+
outer_bags=1,
35+
max_bins=2,
36+
max_leaves=2,
37+
min_samples_leaf=2,
38+
n_jobs=1,
39+
random_state=1,
40+
callback=cb,
41+
)
42+
43+
ebm.fit(X, y)
44+
45+
assert iterations == sorted(set(iterations))

0 commit comments

Comments
 (0)