diff --git a/tools/train.py b/tools/train.py index aa98bba305..d4cff2de23 100644 --- a/tools/train.py +++ b/tools/train.py @@ -88,6 +88,13 @@ def make_parser(): Implemented loggers include `tensorboard`, `mlflow` and `wandb`.", default="tensorboard" ) + parser.add_argument( + "--early-stopping", + dest="early_stopping", + default=False, + action="store_true", + help="Use early stopping to prevent overfitting.", + ) parser.add_argument( "opts", help="Modify config options using the command-line", @@ -115,6 +122,13 @@ def main(exp: Exp, args): cudnn.benchmark = True trainer = exp.get_trainer(args) + + # configure early stopping parameters + if args.early_stopping: + # requires 1% relative improvement over 10 epochs to reset patience + # available modes: "max", "min", "percentage" + trainer.early_stopper = exp.get_early_stopping(patience=10, min_delta=0.01, mode="percentage") + trainer.train() diff --git a/yolox/core/__init__.py b/yolox/core/__init__.py index c2379c704e..46a5ba7061 100644 --- a/yolox/core/__init__.py +++ b/yolox/core/__init__.py @@ -4,3 +4,4 @@ from .launch import launch from .trainer import Trainer +from .trainer import EarlyStopping diff --git a/yolox/core/trainer.py b/yolox/core/trainer.py index 8f8016e578..2a78fd8db2 100644 --- a/yolox/core/trainer.py +++ b/yolox/core/trainer.py @@ -33,6 +33,42 @@ synchronize ) +class EarlyStopping: + def __init__(self, patience: int, min_delta: float, mode="max"): + self.patience = patience + self.min_delta = min_delta + self.mode = mode # "max", "min", "percentage" + self.best = None + self.counter = 0 + + def step(self, value): + # Initialize best value on first call + if self.best is None: + self.best = value + return False + + # Compute improvement depending on mode + if self.mode == "max": + improvement = value - self.best + elif self.mode == "min": + improvement = self.best - value + elif self.mode == "percentage": + if self.best == 0: + improvement = 0 # avoid division by zero + else: + improvement = (value - self.best) / abs(self.best) + else: + raise ValueError(f"Unknown mode: {self.mode}, supported modes are 'max', 'min', 'percentage'.") + + # Check if improvement is sufficient + if improvement > self.min_delta: + self.best = value + self.counter = 0 + else: + self.counter += 1 + + return self.counter >= self.patience + class Trainer: def __init__(self, exp: Exp, args): @@ -40,6 +76,7 @@ def __init__(self, exp: Exp, args): # before_train methods. self.exp = exp self.args = args + self.early_stopper = None # training related attr self.max_epoch = exp.max_epoch @@ -234,7 +271,15 @@ def after_epoch(self): if (self.epoch + 1) % self.exp.eval_interval == 0: all_reduce_norm(self.model) - self.evaluate_and_save_model() + ap50_95 = self.evaluate_and_save_model() + + # Early stopping + if self.early_stopper is not None: + if self.early_stopper.step(ap50_95): + logger.info(f"Early stopping triggered at epoch {self.epoch}. " f"Best AP: {self.early_stopper.best}") + # save best checkpoint before exiting + self.save_ckpt("best_ckpt") + raise SystemExit def before_iter(self): pass @@ -395,6 +440,7 @@ def evaluate_and_save_model(self): } self.mlflow_logger.save_checkpoints(self.args, self.exp, self.file_name, self.epoch, metadata, update_best_ckpt) + return ap50_95 def save_ckpt(self, ckpt_name, update_best_ckpt=False, ap=None): if self.rank == 0: diff --git a/yolox/exp/yolox_base.py b/yolox/exp/yolox_base.py index 82e93c21bd..6815c6977c 100644 --- a/yolox/exp/yolox_base.py +++ b/yolox/exp/yolox_base.py @@ -349,6 +349,11 @@ def get_trainer(self, args): # NOTE: trainer shouldn't be an attribute of exp object return trainer + def get_early_stopping(self, patience, min_delta, mode): + from yolox.core import EarlyStopping + + return EarlyStopping(patience=patience, min_delta=min_delta, mode=mode) + def eval(self, model, evaluator, is_distributed, half=False, return_outputs=False): return evaluator.evaluate(model, is_distributed, half, return_outputs=return_outputs)