diff --git a/ignite/handlers/early_stopping.py b/ignite/handlers/early_stopping.py index fd6f01d70863..0a14bf516983 100644 --- a/ignite/handlers/early_stopping.py +++ b/ignite/handlers/early_stopping.py @@ -1,6 +1,7 @@ +import warnings from collections import OrderedDict from collections.abc import Callable, Mapping -from typing import Any, cast, Literal +from typing import Any, cast, Literal, Optional from ignite.base import Serializable, ResettableHandler from ignite.engine import Engine, Events @@ -18,25 +19,28 @@ class EarlyStopping(Serializable, ResettableHandler): object, and return a score `float`. An improvement is considered if the score is higher (for ``mode='max'``) or lower (for ``mode='min'``). trainer: Trainer engine to stop the run if no improvement. - min_delta: A minimum change in the score to qualify as an improvement. For ``mode='max'``, it's a minimum + threshold: A minimum change in the score to qualify as an improvement. For ``mode='max'``, it's a minimum increase; for ``mode='min'``, it's a minimum decrease. An improvement is only considered if the change - exceeds the threshold determined by `min_delta` and `min_delta_mode`. - cumulative_delta: If True, `min_delta` defines the change since the last `patience` reset, otherwise, + exceeds the threshold determined by `threshold` and `threshold_mode`. + cumulative: If True, `threshold` defines the change since the last `patience` reset, otherwise, it defines the change after the last event. Default value is False. - min_delta_mode: Determines whether `min_delta` is an absolute change or a relative change. + threshold_mode: Determines whether `threshold` is an absolute change or a relative change. - In 'abs' mode: - - For ``mode='max'``: improvement if score > best_score + min_delta - - For ``mode='min'``: improvement if score < best_score - min_delta + - For ``mode='max'``: improvement if score > best_score + threshold + - For ``mode='min'``: improvement if score < best_score - threshold - In 'rel' mode: - - For ``mode='max'``: improvement if score > best_score * (1 + min_delta) - - For ``mode='min'``: improvement if score < best_score * (1 - min_delta) + - For ``mode='max'``: improvement if score > best_score * (1 + threshold) + - For ``mode='min'``: improvement if score < best_score * (1 - threshold) Possible values are "abs" and "rel". Default value is "abs". mode: Whether to maximize ('max') or minimize ('min') the score. Default is 'max'. + min_delta: Deprecated, use `threshold` instead. + cumulative_delta: Deprecated, use `cumulative` instead. + min_delta_mode: Deprecated, use `threshold_mode` instead. Examples: .. code-block:: python @@ -56,6 +60,11 @@ def score_function(engine): Added `mode` parameter to support minimization in addition to maximization. Added `min_delta_mode` parameter to support both absolute and relative improvements. + .. versionchanged:: 0.5.5 + Renamed `min_delta` to `threshold`, `min_delta_mode` to `threshold_mode`, and + `cumulative_delta` to `cumulative`. The old parameter names are deprecated and + will be removed in a future version. + """ _state_dict_all_req_keys = ( @@ -68,38 +77,87 @@ def __init__( patience: int, score_function: Callable, trainer: Engine, - min_delta: float = 0.0, - cumulative_delta: bool = False, - min_delta_mode: Literal["abs", "rel"] = "abs", + threshold: float = 0.0, + cumulative: bool = False, + threshold_mode: Literal["abs", "rel"] = "abs", mode: Literal["min", "max"] = "max", + # Deprecated parameter names kept for backward compatibility + min_delta: Optional[float] = None, + cumulative_delta: Optional[bool] = None, + min_delta_mode: Optional[Literal["abs", "rel"]] = None, ): + # Handle deprecated parameter: min_delta -> threshold + if min_delta is not None: + if threshold != 0.0: + raise ValueError( + "Cannot specify both 'min_delta' and 'threshold'. " + "'min_delta' is deprecated, use 'threshold' instead." + ) + warnings.warn( + "'min_delta' is deprecated and will be removed in a future version. " + "Please use 'threshold' instead.", + DeprecationWarning, + stacklevel=2, + ) + threshold = min_delta + + # Handle deprecated parameter: cumulative_delta -> cumulative + if cumulative_delta is not None: + if cumulative is not False: + raise ValueError( + "Cannot specify both 'cumulative_delta' and 'cumulative'. " + "'cumulative_delta' is deprecated, use 'cumulative' instead." + ) + warnings.warn( + "'cumulative_delta' is deprecated and will be removed in a future version. " + "Please use 'cumulative' instead.", + DeprecationWarning, + stacklevel=2, + ) + cumulative = cumulative_delta + + # Handle deprecated parameter: min_delta_mode -> threshold_mode + if min_delta_mode is not None: + if threshold_mode != "abs": + raise ValueError( + "Cannot specify both 'min_delta_mode' and 'threshold_mode'. " + "'min_delta_mode' is deprecated, use 'threshold_mode' instead." + ) + warnings.warn( + "'min_delta_mode' is deprecated and will be removed in a future version. " + "Please use 'threshold_mode' instead.", + DeprecationWarning, + stacklevel=2, + ) + threshold_mode = min_delta_mode + if not callable(score_function): raise TypeError("Argument score_function should be a function.") if patience < 1: raise ValueError("Argument patience should be positive integer.") - if min_delta < 0.0: - raise ValueError("Argument min_delta should not be a negative number.") + if threshold < 0.0: + raise ValueError("Argument threshold should not be a negative number.") if not isinstance(trainer, Engine): raise TypeError("Argument trainer should be an instance of Engine.") - if min_delta_mode not in ("abs", "rel"): - raise ValueError("Argument min_delta_mode should be either 'abs' or 'rel'.") + if threshold_mode not in ("abs", "rel"): + raise ValueError("Argument threshold_mode should be either 'abs' or 'rel'.") if mode not in ("min", "max"): raise ValueError("Argument mode should be either 'min' or 'max'.") self.score_function = score_function self.patience = patience - self.min_delta = min_delta - self.cumulative_delta = cumulative_delta + self.threshold = threshold + self.cumulative = cumulative self.trainer = trainer self.counter = 0 self.best_score: float | None = None self.logger = setup_logger(__name__ + "." + self.__class__.__name__) - self.min_delta_mode = min_delta_mode + self.threshold_mode = threshold_mode self.mode = mode def __call__(self, engine: Engine) -> None: @@ -109,16 +167,16 @@ def __call__(self, engine: Engine) -> None: self.best_score = score return - min_delta = -self.min_delta if self.mode == "min" else self.min_delta - if self.min_delta_mode == "abs": - improvement_threshold = self.best_score + min_delta + threshold = -self.threshold if self.mode == "min" else self.threshold + if self.threshold_mode == "abs": + improvement_threshold = self.best_score + threshold else: - improvement_threshold = self.best_score * (1 + min_delta) + improvement_threshold = self.best_score * (1 + threshold) no_improvement = score <= improvement_threshold if self.mode == "max" else score >= improvement_threshold if no_improvement: - if not self.cumulative_delta: + if not self.cumulative: self.best_score = max(score, self.best_score) if self.mode == "max" else min(score, self.best_score) self.counter += 1 self.logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience)) diff --git a/tests/ignite/handlers/test_early_stopping.py b/tests/ignite/handlers/test_early_stopping.py index b43fa75d7ad6..6caa529fbef5 100644 --- a/tests/ignite/handlers/test_early_stopping.py +++ b/tests/ignite/handlers/test_early_stopping.py @@ -1,4 +1,5 @@ import os +import warnings import pytest import torch @@ -18,8 +19,8 @@ def test_args_validation(): with pytest.raises(ValueError, match=r"Argument patience should be positive integer."): EarlyStopping(patience=-1, score_function=lambda engine: 0, trainer=trainer) - with pytest.raises(ValueError, match=r"Argument min_delta should not be a negative number."): - EarlyStopping(patience=2, min_delta=-0.1, score_function=lambda engine: 0, trainer=trainer) + with pytest.raises(ValueError, match=r"Argument threshold should not be a negative number."): + EarlyStopping(patience=2, threshold=-0.1, score_function=lambda engine: 0, trainer=trainer) with pytest.raises(TypeError, match=r"Argument score_function should be a function."): EarlyStopping(patience=2, score_function=12345, trainer=trainer) @@ -27,13 +28,70 @@ def test_args_validation(): with pytest.raises(TypeError, match=r"Argument trainer should be an instance of Engine."): EarlyStopping(patience=2, score_function=lambda engine: 0, trainer=None) - with pytest.raises(ValueError, match=r"Argument min_delta_mode should be either 'abs' or 'rel'."): - EarlyStopping(patience=2, min_delta_mode="invalid_mode", score_function=lambda engine: 0, trainer=trainer) + with pytest.raises(ValueError, match=r"Argument threshold_mode should be either 'abs' or 'rel'."): + EarlyStopping(patience=2, threshold_mode="invalid_mode", score_function=lambda engine: 0, trainer=trainer) with pytest.raises(ValueError, match=r"Argument mode should be either 'min' or 'max'."): EarlyStopping(patience=2, mode="invalid_mode", score_function=lambda engine: 0, trainer=trainer) +def test_args_validation_deprecated_names(): + """Test that deprecated parameter names still trigger proper validation.""" + trainer = Engine(do_nothing_update_fn) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + + with pytest.raises(ValueError, match=r"Argument threshold should not be a negative number."): + EarlyStopping(patience=2, min_delta=-0.1, score_function=lambda engine: 0, trainer=trainer) + + with pytest.raises(ValueError, match=r"Argument threshold_mode should be either 'abs' or 'rel'."): + EarlyStopping( + patience=2, min_delta_mode="invalid_mode", score_function=lambda engine: 0, trainer=trainer + ) + + +def test_deprecated_min_delta(): + trainer = Engine(do_nothing_update_fn) + + with pytest.warns(DeprecationWarning, match="'min_delta' is deprecated"): + h = EarlyStopping(patience=2, min_delta=0.1, score_function=lambda engine: 0, trainer=trainer) + assert h.threshold == 0.1 + + +def test_deprecated_cumulative_delta(): + trainer = Engine(do_nothing_update_fn) + + with pytest.warns(DeprecationWarning, match="'cumulative_delta' is deprecated"): + h = EarlyStopping(patience=2, cumulative_delta=True, score_function=lambda engine: 0, trainer=trainer) + assert h.cumulative is True + + +def test_deprecated_min_delta_mode(): + trainer = Engine(do_nothing_update_fn) + + with pytest.warns(DeprecationWarning, match="'min_delta_mode' is deprecated"): + h = EarlyStopping(patience=2, min_delta_mode="rel", score_function=lambda engine: 0, trainer=trainer) + assert h.threshold_mode == "rel" + + +def test_deprecated_and_new_param_conflict(): + trainer = Engine(do_nothing_update_fn) + + with pytest.raises(ValueError, match="Cannot specify both 'min_delta' and 'threshold'"): + EarlyStopping(patience=2, min_delta=0.1, threshold=0.2, score_function=lambda engine: 0, trainer=trainer) + + with pytest.raises(ValueError, match="Cannot specify both 'cumulative_delta' and 'cumulative'"): + EarlyStopping( + patience=2, cumulative_delta=True, cumulative=True, score_function=lambda engine: 0, trainer=trainer + ) + + with pytest.raises(ValueError, match="Cannot specify both 'min_delta_mode' and 'threshold_mode'"): + EarlyStopping( + patience=2, min_delta_mode="rel", threshold_mode="rel", score_function=lambda engine: 0, trainer=trainer + ) + + def test_simple_early_stopping(): scores = iter([1.0, 0.8, 0.88]) @@ -86,17 +144,21 @@ def score_function(engine): trainer = Engine(do_nothing_update_fn) # Use "rel" mode - h = EarlyStopping(patience=2, score_function=score_function, trainer=trainer, min_delta=0.1, min_delta_mode="rel") + h = EarlyStopping( + patience=2, score_function=score_function, trainer=trainer, threshold=0.1, threshold_mode="rel" + ) h(None) # best_score=1.0 h(None) # score=2.0 (improvement) state = h.state_dict() # New handler with "rel" mode - h2 = EarlyStopping(patience=2, score_function=score_function, trainer=trainer, min_delta=0.1, min_delta_mode="rel") + h2 = EarlyStopping( + patience=2, score_function=score_function, trainer=trainer, threshold=0.1, threshold_mode="rel" + ) h2.load_state_dict(state) - assert h2.min_delta_mode == "rel" + assert h2.threshold_mode == "rel" h2(None) # score=2.1 (no improvement: 2.1 <= 2.0 * 1.1 = 2.2) assert h2.counter == 1 assert not trainer.should_terminate @@ -110,7 +172,7 @@ def test_early_stopping_on_delta(): trainer = Engine(do_nothing_update_fn) - h = EarlyStopping(patience=2, min_delta=0.1, score_function=lambda _: next(scores), trainer=trainer) + h = EarlyStopping(patience=2, threshold=0.1, score_function=lambda _: next(scores), trainer=trainer) assert not trainer.should_terminate h(None) # counter == 0 @@ -132,9 +194,9 @@ def test_early_stopping_on_rel_delta(): trainer = Engine(do_nothing_update_fn) - # upper_bound = best_score * (1 + min_delta) + # upper_bound = best_score * (1 + threshold) h = EarlyStopping( - patience=2, min_delta=0.1, min_delta_mode="rel", score_function=lambda _: next(scores), trainer=trainer + patience=2, threshold=0.1, threshold_mode="rel", score_function=lambda _: next(scores), trainer=trainer ) assert not trainer.should_terminate @@ -158,7 +220,7 @@ def test_early_stopping_on_last_event_delta(): trainer = Engine(do_nothing_update_fn) h = EarlyStopping( - patience=2, min_delta=0.4, cumulative_delta=False, score_function=lambda _: next(scores), trainer=trainer + patience=2, threshold=0.4, cumulative=False, score_function=lambda _: next(scores), trainer=trainer ) assert not trainer.should_terminate @@ -176,7 +238,7 @@ def test_early_stopping_on_cumulative_delta(): trainer = Engine(do_nothing_update_fn) h = EarlyStopping( - patience=2, min_delta=0.4, cumulative_delta=True, score_function=lambda _: next(scores), trainer=trainer + patience=2, threshold=0.4, cumulative=True, score_function=lambda _: next(scores), trainer=trainer ) assert not trainer.should_terminate @@ -323,7 +385,7 @@ def test_early_stopping_min_mode_with_delta(): trainer = Engine(do_nothing_update_fn) - h = EarlyStopping(patience=2, min_delta=0.1, score_function=lambda _: next(scores), trainer=trainer, mode="min") + h = EarlyStopping(patience=2, threshold=0.1, score_function=lambda _: next(scores), trainer=trainer, mode="min") assert not trainer.should_terminate h(None) # best_score=1.1 @@ -343,10 +405,10 @@ def test_early_stopping_min_mode_with_delta_cumulative(): h = EarlyStopping( patience=2, - min_delta=0.1, + threshold=0.1, score_function=lambda _: next(scores), trainer=trainer, - cumulative_delta=True, + cumulative=True, mode="min", ) @@ -368,8 +430,8 @@ def test_early_stopping_min_mode_rel_delta(): h = EarlyStopping( patience=2, - min_delta=0.1, - min_delta_mode="rel", + threshold=0.1, + threshold_mode="rel", score_function=lambda _: next(scores), trainer=trainer, mode="min", @@ -386,6 +448,79 @@ def test_early_stopping_min_mode_rel_delta(): assert trainer.should_terminate +def test_backward_compat_deprecated_params_still_work(): + """Test that using deprecated parameter names still produces correct behavior.""" + scores = iter([1.0, 2.0, 2.01, 3.0, 3.01, 3.02]) + + trainer = Engine(do_nothing_update_fn) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + h = EarlyStopping(patience=2, min_delta=0.1, score_function=lambda _: next(scores), trainer=trainer) + + assert not trainer.should_terminate + h(None) # counter == 0 + assert not trainer.should_terminate + h(None) # delta == 1.0; counter == 0 + assert not trainer.should_terminate + h(None) # delta == 0.01; counter == 1 + assert not trainer.should_terminate + h(None) # delta == 0.99; counter == 0 + assert not trainer.should_terminate + h(None) # delta == 0.01; counter == 1 + assert not trainer.should_terminate + h(None) # delta == 0.01; counter == 2 + assert trainer.should_terminate + + +def test_backward_compat_cumulative_delta(): + """Test that using deprecated cumulative_delta still produces correct behavior.""" + scores = iter([0.0, 0.3, 0.6]) + + trainer = Engine(do_nothing_update_fn) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + h = EarlyStopping( + patience=2, min_delta=0.4, cumulative_delta=True, score_function=lambda _: next(scores), trainer=trainer + ) + + assert not trainer.should_terminate + h(None) # counter == 0 + assert not trainer.should_terminate + h(None) # delta == 0.3; counter == 1 + assert not trainer.should_terminate + h(None) # delta == 0.6; counter == 0 + assert not trainer.should_terminate + + +def test_backward_compat_min_delta_mode(): + """Test that using deprecated min_delta_mode still produces correct behavior.""" + scores = iter([1.0, 2.0, 2.1, 3.0, 3.2, 3.25]) + + trainer = Engine(do_nothing_update_fn) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + h = EarlyStopping( + patience=2, min_delta=0.1, min_delta_mode="rel", score_function=lambda _: next(scores), trainer=trainer + ) + + assert not trainer.should_terminate + h(None) # best_score = 1.0; counter == 0 + assert not trainer.should_terminate + h(None) # score = 2.0; upper_bound = 1.0 * (1.1) = 1.1; 2.0 > 1.1; best_score = 2.0; counter == 0 + assert not trainer.should_terminate + h(None) # score = 2.1; upper_bound = 2.0 * (1.1) = 2.2; 2.1 <= 2.2; counter == 1 + assert not trainer.should_terminate + h(None) # score = 3.0; upper_bound = 2.0 * (1.1) = 2.2; 3.0 > 2.2; best_score = 3.0; counter == 0 + assert not trainer.should_terminate + h(None) # score = 3.2; upper_bound = 3.0 * (1.1) = 3.3; 3.2 <= 3.3; counter == 1 + assert not trainer.should_terminate + h(None) # score = 3.25; upper_bound = 3.0 * (1.1) = 3.3; 3.25 <= 3.3; counter == 2 + assert trainer.should_terminate + + def _test_distrib_with_engine_early_stopping(device): if device is None: device = idist.device()