Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 33 additions & 24 deletions python/interpret-core/interpret/glassbox/_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
delayed(booster)(
shm_name=shm_name,
bag_idx=idx,
stage=0,
callback=callback,
dataset=(
shared.name if shared.name is not None else shared.dataset
Expand Down Expand Up @@ -1238,6 +1239,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
delayed(booster)(
shm_name=shm_name,
bag_idx=idx,
stage=1,
callback=callback,
dataset=(
shared.name
Expand Down Expand Up @@ -1349,6 +1351,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
exception, intercept_change, _, _, rng = booster(
shm_name=None,
bag_idx=0,
stage=-1,
callback=None,
dataset=shared.dataset,
intercept_rounds=develop.get_option("n_intercept_rounds_final"),
Expand Down Expand Up @@ -3206,13 +3209,15 @@ class EBMModel(BaseEBM):
tradeoff for the ensemble of models --- not the individual models --- a small
amount of overfitting of the individual models can improve the accuracy of
the ensemble as a whole.
callback : Optional[Callable[[int, int, bool, float], bool]], default=None
A user-defined function that is invoked at the end of each boosting step to determine
whether to terminate boosting or continue. If it returns True, the boosting loop is
stopped immediately. By default, no callback is used and training proceeds according
to the early stopping settings. The callback function receives:
(1) the bag index, (2) the number of boosting steps completed,
(3) a boolean indicating whether progress was made in the current step, and (4) the current best score.
callback : Optional[Callable[..., bool]], default=None
A user-defined function invoked after each progressing boosting step. Must use
keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``.
If it returns True, boosting is stopped immediately.
The callback receives: ``bag`` (int) the outer bag index,
``stage`` (int) the boosting stage (0=mains, 1=pairs),
``step`` (int) the number of boosting steps completed,
``term`` (int) the index of the term that was just boosted,
and ``metric`` (float) the current validation metric.
min_samples_leaf : int, default=4
Minimum number of samples allowed in the leaves.
min_hessian : float, default=0.0
Expand Down Expand Up @@ -3309,7 +3314,7 @@ def __init__(
max_rounds: Optional[int] = 50000,
early_stopping_rounds: Optional[int] = 100,
early_stopping_tolerance: Optional[float] = 1e-5,
callback: Optional[Callable[[int, int, bool, float], bool]] = None,
callback: Optional[Callable[..., bool]] = None,
# Trees
min_samples_leaf: Optional[int] = 4,
min_hessian: Optional[float] = 0.0,
Expand Down Expand Up @@ -3572,13 +3577,15 @@ class EBMClassifier(EBMClassifierMixin, EBMModel):
tradeoff for the ensemble of models --- not the individual models --- a small
amount of overfitting of the individual models can improve the accuracy of
the ensemble as a whole.
callback : Optional[Callable[[int, int, bool, float], bool]], default=None
A user-defined function that is invoked at the end of each boosting step to determine
whether to terminate boosting or continue. If it returns True, the boosting loop is
stopped immediately. By default, no callback is used and training proceeds according
to the early stopping settings. The callback function receives:
(1) the bag index, (2) the number of boosting steps completed,
(3) a boolean indicating whether progress was made in the current step, and (4) the current best score.
callback : Optional[Callable[..., bool]], default=None
A user-defined function invoked after each progressing boosting step. Must use
keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``.
If it returns True, boosting is stopped immediately.
The callback receives: ``bag`` (int) the outer bag index,
``stage`` (int) the boosting stage (0=mains, 1=pairs),
``step`` (int) the number of boosting steps completed,
``term`` (int) the index of the term that was just boosted,
and ``metric`` (float) the current validation metric.
min_samples_leaf : int, default=4
Minimum number of samples allowed in the leaves.
min_hessian : float, default=1e-4
Expand Down Expand Up @@ -3734,7 +3741,7 @@ def __init__(
max_rounds: Optional[int] = 50000,
early_stopping_rounds: Optional[int] = 100,
early_stopping_tolerance: Optional[float] = 1e-5,
callback: Optional[Callable[[int, int, bool, float], bool]] = None,
callback: Optional[Callable[..., bool]] = None,
# Trees
min_samples_leaf: Optional[int] = 4,
min_hessian: Optional[float] = 1e-4,
Expand Down Expand Up @@ -3876,13 +3883,15 @@ class EBMRegressor(EBMRegressorMixin, EBMModel):
tradeoff for the ensemble of models --- not the individual models --- a small
amount of overfitting of the individual models can improve the accuracy of
the ensemble as a whole.
callback : Optional[Callable[[int, int, bool, float], bool]], default=None
A user-defined function that is invoked at the end of each boosting step to determine
whether to terminate boosting or continue. If it returns True, the boosting loop is
stopped immediately. By default, no callback is used and training proceeds according
to the early stopping settings. The callback function receives:
(1) the bag index, (2) the number of boosting steps completed,
(3) a boolean indicating whether progress was made in the current step, and (4) the current best score.
callback : Optional[Callable[..., bool]], default=None
A user-defined function invoked after each progressing boosting step. Must use
keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``.
If it returns True, boosting is stopped immediately.
The callback receives: ``bag`` (int) the outer bag index,
``stage`` (int) the boosting stage (0=mains, 1=pairs),
``step`` (int) the number of boosting steps completed,
``term`` (int) the index of the term that was just boosted,
and ``metric`` (float) the current validation metric.
min_samples_leaf : int, default=4
Minimum number of samples allowed in the leaves.
min_hessian : float, default=0.0
Expand Down Expand Up @@ -4039,7 +4048,7 @@ def __init__(
max_rounds: Optional[int] = 50000,
early_stopping_rounds: Optional[int] = 100,
early_stopping_tolerance: Optional[float] = 1e-5,
callback: Optional[Callable[[int, int, bool, float], bool]] = None,
callback: Optional[Callable[..., bool]] = None,
# Trees
min_samples_leaf: Optional[int] = 4,
min_hessian: Optional[float] = 0.0,
Expand Down
26 changes: 17 additions & 9 deletions python/interpret-core/interpret/glassbox/_ebm_core/_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def load_shared_memory(name: str) -> Iterator[shared_memory.SharedMemory]:
def boost(
shm_name,
bag_idx,
stage,
callback,
dataset,
intercept_rounds,
Expand Down Expand Up @@ -364,18 +365,25 @@ def boost(
):
break

if stop_flag is not None and stop_flag[0]:
break

if callback is not None:
is_done = callback(
bag=bag_idx,
stage=stage,
step=step_idx,
term=term_idx,
metric=cur_metric,
)
if is_done:
if stop_flag is not None:
stop_flag[0] = True
break

Comment thread
paulbkoch marked this conversation as resolved.
if stop_flag is not None and stop_flag[0]:
break

if callback is not None:
is_done = callback(
bag_idx, step_idx, make_progress, cur_metric
)
if is_done:
if stop_flag is not None:
stop_flag[0] = True
break

state_idx = state_idx + 1
if len(term_features) <= state_idx:
if smoothing_rounds > 0:
Expand Down
226 changes: 226 additions & 0 deletions python/interpret-core/tests/glassbox/ebm/test_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright (c) 2023 The InterpretML Contributors
# Distributed under the MIT software license

"""Regression tests for issue #635: callback API uses keyword-only args."""

import numpy as np

from interpret.glassbox import (
ExplainableBoostingClassifier,
ExplainableBoostingRegressor,
)
from interpret.utils import make_synthetic


class RecordingCallback:
"""Picklable callback that records all invocations.

Uses n_jobs=1 in tests so that state is shared in-process.
"""

def __init__(self):
self.records = []

def __call__(self, *, bag, stage, step, term, metric):
self.records.append((bag, stage, step, term, metric))
return False


class StopAfterCallback:
"""Picklable callback that stops training after N calls."""

def __init__(self, stop_after):
self.stop_after = stop_after
self.call_count = 0

def __call__(self, *, bag, stage, step, term, metric):
self.call_count += 1
return self.call_count >= self.stop_after



def test_callback_no_repeated_steps_classifier():
"""Verify the callback receives strictly increasing step values.

Before the fix, the callback was invoked on every internal loop
iteration — including non-progressing cycles — which caused
the same step value to be reported multiple times.
"""
cb = RecordingCallback()

X, y, names, types = make_synthetic(
seed=42, classes=2, output_type="float", n_samples=500
)

ebm = ExplainableBoostingClassifier(
names,
types,
outer_bags=1,
max_rounds=50,
n_jobs=1,
callback=cb,
)
ebm.fit(X, y)

assert len(cb.records) > 0, "Callback should have been invoked at least once"

steps_by_bag = {}
for bag, stage, step, _, _ in cb.records:
steps_by_bag.setdefault(bag, []).append((stage, step))

for bag, steps in steps_by_bag.items():
for i in range(1, len(steps)):
assert steps[i] > steps[i - 1], (
f"Bag {bag}: (stage, step) went from {steps[i - 1]} to "
f"{steps[i]} (expected strictly increasing)"
)


def test_callback_no_repeated_steps_regressor():
"""Same test as above but for ExplainableBoostingRegressor."""
cb = RecordingCallback()

X, y, names, types = make_synthetic(
seed=42, classes=None, output_type="float", n_samples=500
)

ebm = ExplainableBoostingRegressor(
names,
types,
outer_bags=1,
max_rounds=50,
n_jobs=1,
callback=cb,
)
ebm.fit(X, y)

assert len(cb.records) > 0, "Callback should have been invoked at least once"

steps_by_bag = {}
for bag, stage, step, _, _ in cb.records:
steps_by_bag.setdefault(bag, []).append((stage, step))

for bag, steps in steps_by_bag.items():
for i in range(1, len(steps)):
assert steps[i] > steps[i - 1], (
f"Bag {bag}: (stage, step) went from {steps[i - 1]} to "
f"{steps[i]} (expected strictly increasing)"
)


def test_callback_receives_term_index():
"""Verify the callback receives a valid term index."""
cb = RecordingCallback()

X, y, names, types = make_synthetic(
seed=42, classes=2, output_type="float", n_samples=500
)

ebm = ExplainableBoostingClassifier(
names,
types,
outer_bags=1,
max_rounds=50,
n_jobs=1,
callback=cb,
)
ebm.fit(X, y)

assert len(cb.records) > 0, "Callback should have been invoked at least once"

for i, (_, _, _, term, _) in enumerate(cb.records):
assert isinstance(term, (int, np.integer)), (
f"term at call {i} should be an int, got {type(term)}"
)
assert term >= 0, f"term at call {i} should be non-negative, got {term}"


def test_callback_early_termination():
"""Verify the callback can still terminate training early."""
cb = StopAfterCallback(stop_after=5)

X, y, names, types = make_synthetic(
seed=42, classes=2, output_type="float", n_samples=500
)

ebm = ExplainableBoostingClassifier(
names,
types,
outer_bags=1,
max_rounds=5000,
n_jobs=1,
callback=cb,
)
ebm.fit(X, y)

assert cb.call_count == cb.stop_after, (
f"Expected callback to be called exactly {cb.stop_after} times "
f"before stopping, but was called {cb.call_count} times"
)

# The model should still be valid after early stopping
predictions = ebm.predict(X)
assert len(predictions) == len(y)


def test_callback_receives_valid_metrics():
"""Verify the callback receives valid (finite) metric values."""
cb = RecordingCallback()

X, y, names, types = make_synthetic(
seed=42, classes=2, output_type="float", n_samples=500
)

ebm = ExplainableBoostingClassifier(
names,
types,
outer_bags=1,
max_rounds=50,
n_jobs=1,
callback=cb,
)
ebm.fit(X, y)

assert len(cb.records) > 0, "Callback should have been invoked at least once"

for i, (_, _, _, _, metric) in enumerate(cb.records):
assert np.isfinite(metric), f"Metric at step {i} is not finite: {metric}"


def test_callback_keyword_only_signature():
"""Verify the callback is invoked with keyword-only arguments.

This test ensures that the callback cannot be invoked with positional
arguments, which is the core API change in this PR.
"""

class KeywordOnlyCallback:
def __init__(self):
self.called = False

def __call__(self, *, bag, stage, step, term, metric):
self.called = True
# Verify all args were passed as keywords by checking they exist
assert isinstance(bag, int)
assert isinstance(step, int)
assert isinstance(term, (int, np.integer))
assert isinstance(metric, float)
return True # stop immediately

cb = KeywordOnlyCallback()

X, y, names, types = make_synthetic(
seed=42, classes=2, output_type="float", n_samples=500
)

ebm = ExplainableBoostingClassifier(
names,
types,
outer_bags=1,
max_rounds=50,
n_jobs=1,
callback=cb,
)
ebm.fit(X, y)

assert cb.called, "Keyword-only callback should have been invoked"
Loading
Loading