diff --git a/python/interpret-core/interpret/glassbox/_ebm.py b/python/interpret-core/interpret/glassbox/_ebm.py index 838536a62..8404ceeff 100644 --- a/python/interpret-core/interpret/glassbox/_ebm.py +++ b/python/interpret-core/interpret/glassbox/_ebm.py @@ -999,6 +999,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None): delayed(booster)( shm_name=shm_name, bag_idx=idx, + stage=0, callback=callback, dataset=( shared.name if shared.name is not None else shared.dataset @@ -1238,6 +1239,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None): delayed(booster)( shm_name=shm_name, bag_idx=idx, + stage=1, callback=callback, dataset=( shared.name @@ -1349,6 +1351,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None): exception, intercept_change, _, _, rng = booster( shm_name=None, bag_idx=0, + stage=-1, callback=None, dataset=shared.dataset, intercept_rounds=develop.get_option("n_intercept_rounds_final"), @@ -3206,13 +3209,15 @@ class EBMModel(BaseEBM): tradeoff for the ensemble of models --- not the individual models --- a small amount of overfitting of the individual models can improve the accuracy of the ensemble as a whole. - callback : Optional[Callable[[int, int, bool, float], bool]], default=None - A user-defined function that is invoked at the end of each boosting step to determine - whether to terminate boosting or continue. If it returns True, the boosting loop is - stopped immediately. By default, no callback is used and training proceeds according - to the early stopping settings. The callback function receives: - (1) the bag index, (2) the number of boosting steps completed, - (3) a boolean indicating whether progress was made in the current step, and (4) the current best score. + callback : Optional[Callable[..., bool]], default=None + A user-defined function invoked after each progressing boosting step. Must use + keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``. + If it returns True, boosting is stopped immediately. + The callback receives: ``bag`` (int) the outer bag index, + ``stage`` (int) the boosting stage (0=mains, 1=pairs), + ``step`` (int) the number of boosting steps completed, + ``term`` (int) the index of the term that was just boosted, + and ``metric`` (float) the current validation metric. min_samples_leaf : int, default=4 Minimum number of samples allowed in the leaves. min_hessian : float, default=0.0 @@ -3309,7 +3314,7 @@ def __init__( max_rounds: Optional[int] = 50000, early_stopping_rounds: Optional[int] = 100, early_stopping_tolerance: Optional[float] = 1e-5, - callback: Optional[Callable[[int, int, bool, float], bool]] = None, + callback: Optional[Callable[..., bool]] = None, # Trees min_samples_leaf: Optional[int] = 4, min_hessian: Optional[float] = 0.0, @@ -3572,13 +3577,15 @@ class EBMClassifier(EBMClassifierMixin, EBMModel): tradeoff for the ensemble of models --- not the individual models --- a small amount of overfitting of the individual models can improve the accuracy of the ensemble as a whole. - callback : Optional[Callable[[int, int, bool, float], bool]], default=None - A user-defined function that is invoked at the end of each boosting step to determine - whether to terminate boosting or continue. If it returns True, the boosting loop is - stopped immediately. By default, no callback is used and training proceeds according - to the early stopping settings. The callback function receives: - (1) the bag index, (2) the number of boosting steps completed, - (3) a boolean indicating whether progress was made in the current step, and (4) the current best score. + callback : Optional[Callable[..., bool]], default=None + A user-defined function invoked after each progressing boosting step. Must use + keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``. + If it returns True, boosting is stopped immediately. + The callback receives: ``bag`` (int) the outer bag index, + ``stage`` (int) the boosting stage (0=mains, 1=pairs), + ``step`` (int) the number of boosting steps completed, + ``term`` (int) the index of the term that was just boosted, + and ``metric`` (float) the current validation metric. min_samples_leaf : int, default=4 Minimum number of samples allowed in the leaves. min_hessian : float, default=1e-4 @@ -3734,7 +3741,7 @@ def __init__( max_rounds: Optional[int] = 50000, early_stopping_rounds: Optional[int] = 100, early_stopping_tolerance: Optional[float] = 1e-5, - callback: Optional[Callable[[int, int, bool, float], bool]] = None, + callback: Optional[Callable[..., bool]] = None, # Trees min_samples_leaf: Optional[int] = 4, min_hessian: Optional[float] = 1e-4, @@ -3876,13 +3883,15 @@ class EBMRegressor(EBMRegressorMixin, EBMModel): tradeoff for the ensemble of models --- not the individual models --- a small amount of overfitting of the individual models can improve the accuracy of the ensemble as a whole. - callback : Optional[Callable[[int, int, bool, float], bool]], default=None - A user-defined function that is invoked at the end of each boosting step to determine - whether to terminate boosting or continue. If it returns True, the boosting loop is - stopped immediately. By default, no callback is used and training proceeds according - to the early stopping settings. The callback function receives: - (1) the bag index, (2) the number of boosting steps completed, - (3) a boolean indicating whether progress was made in the current step, and (4) the current best score. + callback : Optional[Callable[..., bool]], default=None + A user-defined function invoked after each progressing boosting step. Must use + keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``. + If it returns True, boosting is stopped immediately. + The callback receives: ``bag`` (int) the outer bag index, + ``stage`` (int) the boosting stage (0=mains, 1=pairs), + ``step`` (int) the number of boosting steps completed, + ``term`` (int) the index of the term that was just boosted, + and ``metric`` (float) the current validation metric. min_samples_leaf : int, default=4 Minimum number of samples allowed in the leaves. min_hessian : float, default=0.0 @@ -4039,7 +4048,7 @@ def __init__( max_rounds: Optional[int] = 50000, early_stopping_rounds: Optional[int] = 100, early_stopping_tolerance: Optional[float] = 1e-5, - callback: Optional[Callable[[int, int, bool, float], bool]] = None, + callback: Optional[Callable[..., bool]] = None, # Trees min_samples_leaf: Optional[int] = 4, min_hessian: Optional[float] = 0.0, diff --git a/python/interpret-core/interpret/glassbox/_ebm_core/_boost.py b/python/interpret-core/interpret/glassbox/_ebm_core/_boost.py index 79c4ffe8c..a789c91e8 100644 --- a/python/interpret-core/interpret/glassbox/_ebm_core/_boost.py +++ b/python/interpret-core/interpret/glassbox/_ebm_core/_boost.py @@ -28,6 +28,7 @@ def load_shared_memory(name: str) -> Iterator[shared_memory.SharedMemory]: def boost( shm_name, bag_idx, + stage, callback, dataset, intercept_rounds, @@ -364,18 +365,25 @@ def boost( ): break + if stop_flag is not None and stop_flag[0]: + break + + if callback is not None: + is_done = callback( + bag=bag_idx, + stage=stage, + step=step_idx, + term=term_idx, + metric=cur_metric, + ) + if is_done: + if stop_flag is not None: + stop_flag[0] = True + break + if stop_flag is not None and stop_flag[0]: break - if callback is not None: - is_done = callback( - bag_idx, step_idx, make_progress, cur_metric - ) - if is_done: - if stop_flag is not None: - stop_flag[0] = True - break - state_idx = state_idx + 1 if len(term_features) <= state_idx: if smoothing_rounds > 0: diff --git a/python/interpret-core/tests/glassbox/ebm/test_callback.py b/python/interpret-core/tests/glassbox/ebm/test_callback.py new file mode 100644 index 000000000..b32c79aa7 --- /dev/null +++ b/python/interpret-core/tests/glassbox/ebm/test_callback.py @@ -0,0 +1,226 @@ +# Copyright (c) 2023 The InterpretML Contributors +# Distributed under the MIT software license + +"""Regression tests for issue #635: callback API uses keyword-only args.""" + +import numpy as np + +from interpret.glassbox import ( + ExplainableBoostingClassifier, + ExplainableBoostingRegressor, +) +from interpret.utils import make_synthetic + + +class RecordingCallback: + """Picklable callback that records all invocations. + + Uses n_jobs=1 in tests so that state is shared in-process. + """ + + def __init__(self): + self.records = [] + + def __call__(self, *, bag, stage, step, term, metric): + self.records.append((bag, stage, step, term, metric)) + return False + + +class StopAfterCallback: + """Picklable callback that stops training after N calls.""" + + def __init__(self, stop_after): + self.stop_after = stop_after + self.call_count = 0 + + def __call__(self, *, bag, stage, step, term, metric): + self.call_count += 1 + return self.call_count >= self.stop_after + + + +def test_callback_no_repeated_steps_classifier(): + """Verify the callback receives strictly increasing step values. + + Before the fix, the callback was invoked on every internal loop + iteration — including non-progressing cycles — which caused + the same step value to be reported multiple times. + """ + cb = RecordingCallback() + + X, y, names, types = make_synthetic( + seed=42, classes=2, output_type="float", n_samples=500 + ) + + ebm = ExplainableBoostingClassifier( + names, + types, + outer_bags=1, + max_rounds=50, + n_jobs=1, + callback=cb, + ) + ebm.fit(X, y) + + assert len(cb.records) > 0, "Callback should have been invoked at least once" + + steps_by_bag = {} + for bag, stage, step, _, _ in cb.records: + steps_by_bag.setdefault(bag, []).append((stage, step)) + + for bag, steps in steps_by_bag.items(): + for i in range(1, len(steps)): + assert steps[i] > steps[i - 1], ( + f"Bag {bag}: (stage, step) went from {steps[i - 1]} to " + f"{steps[i]} (expected strictly increasing)" + ) + + +def test_callback_no_repeated_steps_regressor(): + """Same test as above but for ExplainableBoostingRegressor.""" + cb = RecordingCallback() + + X, y, names, types = make_synthetic( + seed=42, classes=None, output_type="float", n_samples=500 + ) + + ebm = ExplainableBoostingRegressor( + names, + types, + outer_bags=1, + max_rounds=50, + n_jobs=1, + callback=cb, + ) + ebm.fit(X, y) + + assert len(cb.records) > 0, "Callback should have been invoked at least once" + + steps_by_bag = {} + for bag, stage, step, _, _ in cb.records: + steps_by_bag.setdefault(bag, []).append((stage, step)) + + for bag, steps in steps_by_bag.items(): + for i in range(1, len(steps)): + assert steps[i] > steps[i - 1], ( + f"Bag {bag}: (stage, step) went from {steps[i - 1]} to " + f"{steps[i]} (expected strictly increasing)" + ) + + +def test_callback_receives_term_index(): + """Verify the callback receives a valid term index.""" + cb = RecordingCallback() + + X, y, names, types = make_synthetic( + seed=42, classes=2, output_type="float", n_samples=500 + ) + + ebm = ExplainableBoostingClassifier( + names, + types, + outer_bags=1, + max_rounds=50, + n_jobs=1, + callback=cb, + ) + ebm.fit(X, y) + + assert len(cb.records) > 0, "Callback should have been invoked at least once" + + for i, (_, _, _, term, _) in enumerate(cb.records): + assert isinstance(term, (int, np.integer)), ( + f"term at call {i} should be an int, got {type(term)}" + ) + assert term >= 0, f"term at call {i} should be non-negative, got {term}" + + +def test_callback_early_termination(): + """Verify the callback can still terminate training early.""" + cb = StopAfterCallback(stop_after=5) + + X, y, names, types = make_synthetic( + seed=42, classes=2, output_type="float", n_samples=500 + ) + + ebm = ExplainableBoostingClassifier( + names, + types, + outer_bags=1, + max_rounds=5000, + n_jobs=1, + callback=cb, + ) + ebm.fit(X, y) + + assert cb.call_count == cb.stop_after, ( + f"Expected callback to be called exactly {cb.stop_after} times " + f"before stopping, but was called {cb.call_count} times" + ) + + # The model should still be valid after early stopping + predictions = ebm.predict(X) + assert len(predictions) == len(y) + + +def test_callback_receives_valid_metrics(): + """Verify the callback receives valid (finite) metric values.""" + cb = RecordingCallback() + + X, y, names, types = make_synthetic( + seed=42, classes=2, output_type="float", n_samples=500 + ) + + ebm = ExplainableBoostingClassifier( + names, + types, + outer_bags=1, + max_rounds=50, + n_jobs=1, + callback=cb, + ) + ebm.fit(X, y) + + assert len(cb.records) > 0, "Callback should have been invoked at least once" + + for i, (_, _, _, _, metric) in enumerate(cb.records): + assert np.isfinite(metric), f"Metric at step {i} is not finite: {metric}" + + +def test_callback_keyword_only_signature(): + """Verify the callback is invoked with keyword-only arguments. + + This test ensures that the callback cannot be invoked with positional + arguments, which is the core API change in this PR. + """ + + class KeywordOnlyCallback: + def __init__(self): + self.called = False + + def __call__(self, *, bag, stage, step, term, metric): + self.called = True + # Verify all args were passed as keywords by checking they exist + assert isinstance(bag, int) + assert isinstance(step, int) + assert isinstance(term, (int, np.integer)) + assert isinstance(metric, float) + return True # stop immediately + + cb = KeywordOnlyCallback() + + X, y, names, types = make_synthetic( + seed=42, classes=2, output_type="float", n_samples=500 + ) + + ebm = ExplainableBoostingClassifier( + names, + types, + outer_bags=1, + max_rounds=50, + n_jobs=1, + callback=cb, + ) + ebm.fit(X, y) + + assert cb.called, "Keyword-only callback should have been invoked" diff --git a/python/interpret-core/tests/glassbox/ebm/test_ebm.py b/python/interpret-core/tests/glassbox/ebm/test_ebm.py index b419779c2..eadd334a6 100644 --- a/python/interpret-core/tests/glassbox/ebm/test_ebm.py +++ b/python/interpret-core/tests/glassbox/ebm/test_ebm.py @@ -1334,7 +1334,7 @@ class Callback: def __init__(self, seconds): self._seconds = seconds - def __call__(self, bag_index, step_index, progress, metric): + def __call__(self, *, bag, stage, step, term, metric): import time if not hasattr(self, "_end_time"): @@ -1362,7 +1362,7 @@ class Callback: def __init__(self, seconds): self._seconds = seconds - def __call__(self, bag_index, step_index, progress, metric): + def __call__(self, *, bag, stage, step, term, metric): import time if not hasattr(self, "_end_time"): diff --git a/python/interpret-core/tests/glassbox/ebm/test_merge_ebms.py b/python/interpret-core/tests/glassbox/ebm/test_merge_ebms.py index fe2c69696..0307fca3f 100644 --- a/python/interpret-core/tests/glassbox/ebm/test_merge_ebms.py +++ b/python/interpret-core/tests/glassbox/ebm/test_merge_ebms.py @@ -393,7 +393,7 @@ def test_merge_ebms_callback_is_none(): (e.g., the callback might hold references to training state). """ - def training_callback(_step_count, _term_count, _is_interaction, _metric_value): + def training_callback(*, bag, stage, step, term, metric): return False # continue training classifier_one = ExplainableBoostingClassifier(