Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 82 additions & 24 deletions ignite/handlers/early_stopping.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from collections import OrderedDict
from collections.abc import Callable, Mapping
from typing import Any, cast, Literal
from typing import Any, cast, Literal, Optional

from ignite.base import Serializable, ResettableHandler
from ignite.engine import Engine, Events
Expand All @@ -18,25 +19,28 @@ class EarlyStopping(Serializable, ResettableHandler):
object, and return a score `float`. An improvement is considered if the score is higher (for ``mode='max'``)
or lower (for ``mode='min'``).
trainer: Trainer engine to stop the run if no improvement.
min_delta: A minimum change in the score to qualify as an improvement. For ``mode='max'``, it's a minimum
threshold: A minimum change in the score to qualify as an improvement. For ``mode='max'``, it's a minimum
increase; for ``mode='min'``, it's a minimum decrease. An improvement is only considered if the change
exceeds the threshold determined by `min_delta` and `min_delta_mode`.
cumulative_delta: If True, `min_delta` defines the change since the last `patience` reset, otherwise,
exceeds the threshold determined by `threshold` and `threshold_mode`.
cumulative: If True, `threshold` defines the change since the last `patience` reset, otherwise,
it defines the change after the last event. Default value is False.
min_delta_mode: Determines whether `min_delta` is an absolute change or a relative change.
threshold_mode: Determines whether `threshold` is an absolute change or a relative change.

- In 'abs' mode:

- For ``mode='max'``: improvement if score > best_score + min_delta
- For ``mode='min'``: improvement if score < best_score - min_delta
- For ``mode='max'``: improvement if score > best_score + threshold
- For ``mode='min'``: improvement if score < best_score - threshold

- In 'rel' mode:

- For ``mode='max'``: improvement if score > best_score * (1 + min_delta)
- For ``mode='min'``: improvement if score < best_score * (1 - min_delta)
- For ``mode='max'``: improvement if score > best_score * (1 + threshold)
- For ``mode='min'``: improvement if score < best_score * (1 - threshold)

Possible values are "abs" and "rel". Default value is "abs".
mode: Whether to maximize ('max') or minimize ('min') the score. Default is 'max'.
min_delta: Deprecated, use `threshold` instead.
cumulative_delta: Deprecated, use `cumulative` instead.
min_delta_mode: Deprecated, use `threshold_mode` instead.

Examples:
.. code-block:: python
Expand All @@ -56,6 +60,11 @@ def score_function(engine):
Added `mode` parameter to support minimization in addition to maximization.
Added `min_delta_mode` parameter to support both absolute and relative improvements.

.. versionchanged:: 0.5.5
Renamed `min_delta` to `threshold`, `min_delta_mode` to `threshold_mode`, and
`cumulative_delta` to `cumulative`. The old parameter names are deprecated and
will be removed in a future version.

"""

_state_dict_all_req_keys = (
Expand All @@ -68,38 +77,87 @@ def __init__(
patience: int,
score_function: Callable,
trainer: Engine,
min_delta: float = 0.0,
cumulative_delta: bool = False,
min_delta_mode: Literal["abs", "rel"] = "abs",
threshold: float = 0.0,
cumulative: bool = False,
threshold_mode: Literal["abs", "rel"] = "abs",
mode: Literal["min", "max"] = "max",
# Deprecated parameter names kept for backward compatibility
min_delta: Optional[float] = None,
cumulative_delta: Optional[bool] = None,
min_delta_mode: Optional[Literal["abs", "rel"]] = None,
):
# Handle deprecated parameter: min_delta -> threshold
if min_delta is not None:
if threshold != 0.0:
raise ValueError(
"Cannot specify both 'min_delta' and 'threshold'. "
"'min_delta' is deprecated, use 'threshold' instead."
)
warnings.warn(
"'min_delta' is deprecated and will be removed in a future version. "
"Please use 'threshold' instead.",
DeprecationWarning,
stacklevel=2,
)
threshold = min_delta

# Handle deprecated parameter: cumulative_delta -> cumulative
if cumulative_delta is not None:
if cumulative is not False:
raise ValueError(
"Cannot specify both 'cumulative_delta' and 'cumulative'. "
"'cumulative_delta' is deprecated, use 'cumulative' instead."
)
warnings.warn(
"'cumulative_delta' is deprecated and will be removed in a future version. "
"Please use 'cumulative' instead.",
DeprecationWarning,
stacklevel=2,
)
cumulative = cumulative_delta

# Handle deprecated parameter: min_delta_mode -> threshold_mode
if min_delta_mode is not None:
if threshold_mode != "abs":
raise ValueError(
"Cannot specify both 'min_delta_mode' and 'threshold_mode'. "
"'min_delta_mode' is deprecated, use 'threshold_mode' instead."
)
warnings.warn(
"'min_delta_mode' is deprecated and will be removed in a future version. "
"Please use 'threshold_mode' instead.",
DeprecationWarning,
stacklevel=2,
)
threshold_mode = min_delta_mode

if not callable(score_function):
raise TypeError("Argument score_function should be a function.")

if patience < 1:
raise ValueError("Argument patience should be positive integer.")

if min_delta < 0.0:
raise ValueError("Argument min_delta should not be a negative number.")
if threshold < 0.0:
raise ValueError("Argument threshold should not be a negative number.")

if not isinstance(trainer, Engine):
raise TypeError("Argument trainer should be an instance of Engine.")

if min_delta_mode not in ("abs", "rel"):
raise ValueError("Argument min_delta_mode should be either 'abs' or 'rel'.")
if threshold_mode not in ("abs", "rel"):
raise ValueError("Argument threshold_mode should be either 'abs' or 'rel'.")

if mode not in ("min", "max"):
raise ValueError("Argument mode should be either 'min' or 'max'.")

self.score_function = score_function
self.patience = patience
self.min_delta = min_delta
self.cumulative_delta = cumulative_delta
self.threshold = threshold
self.cumulative = cumulative
self.trainer = trainer
self.counter = 0
self.best_score: float | None = None
self.logger = setup_logger(__name__ + "." + self.__class__.__name__)
self.min_delta_mode = min_delta_mode
self.threshold_mode = threshold_mode
self.mode = mode

def __call__(self, engine: Engine) -> None:
Expand All @@ -109,16 +167,16 @@ def __call__(self, engine: Engine) -> None:
self.best_score = score
return

min_delta = -self.min_delta if self.mode == "min" else self.min_delta
if self.min_delta_mode == "abs":
improvement_threshold = self.best_score + min_delta
threshold = -self.threshold if self.mode == "min" else self.threshold
if self.threshold_mode == "abs":
improvement_threshold = self.best_score + threshold
else:
improvement_threshold = self.best_score * (1 + min_delta)
improvement_threshold = self.best_score * (1 + threshold)

no_improvement = score <= improvement_threshold if self.mode == "max" else score >= improvement_threshold

if no_improvement:
if not self.cumulative_delta:
if not self.cumulative:
self.best_score = max(score, self.best_score) if self.mode == "max" else min(score, self.best_score)
self.counter += 1
self.logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience))
Expand Down
Loading