@@ -139,13 +139,24 @@ class TrainModel(object):
139139 use_scheduler : bool, optional
140140 Default is `True`. If `True`, learning rate scheduler is used to
141141 reduce learning rate on plateau. PyTorch ReduceLROnPlateau scheduler
142- with patience=16 and factor=0.8 is used.
142+ with patience=16 and factor=0.8 is used. Different scheduler parameters
143+ can be specified with `scheduler_kwargs`.
143144 deterministic : bool, optional
144145 Default is `False`. If `True`, will set
145146 `torch.backends.cudnn.deterministic` to True and
146147 `torch.backends.cudnn.benchmark = False`. In Selene CLI,
147148 if `random_seed` is set in the configuration YAML, Selene automatically
148149 passes in `deterministic=True` to the TrainModel class.
150+ scheduler_kwargs : dict, optional
151+ Default is patience=16, verbose=True, and factor=0.8. Set the parameters
152+ for the PyTorch ReduceLROnPlateau scheduler.
153+ stopping_criteria : list or None, optional
154+ Default is `None`. If `stopping_criteria` is not None, it should be a
155+ list specifying how to use early stopping. The first value should be
156+ a str corresponding to one of `metrics`. The second value should be an
157+ int indicating the patience. If the specified metric does not improve
158+ in the given patience (usually corresponding to the number of epochs),
159+ training stops early.
149160
150161 Attributes
151162 ----------
@@ -197,7 +208,11 @@ def __init__(self,
197208 metrics = dict (roc_auc = roc_auc_score ,
198209 average_precision = average_precision_score ),
199210 use_scheduler = True ,
200- deterministic = False ):
211+ deterministic = False ,
212+ scheduler_kwargs = dict (patience = 16 ,
213+ verbose = True ,
214+ factor = 0.8 ),
215+ stopping_criteria = None ):
201216 """
202217 Constructs a new `TrainModel` object.
203218 """
@@ -259,13 +274,26 @@ def __init__(self,
259274 self ._n_test_samples = n_test_samples
260275 self ._use_scheduler = use_scheduler
261276
262- self ._init_train ()
277+ self ._init_train (scheduler_kwargs )
263278 self ._init_validate ()
264279 if "test" in self .sampler .modes :
265280 self ._init_test ()
266281 if checkpoint_resume is not None :
267282 self ._load_checkpoint (checkpoint_resume )
268283
284+ if type (stopping_criteria ) is list and len (stopping_criteria ) == 2 :
285+ stopping_metric , stopping_patience = stopping_criteria
286+ self ._early_stopping = True
287+ if stopping_metric in self ._metrics :
288+ self ._stopping_metric = stopping_metric
289+ self ._stopping_patience = stopping_patience
290+ self ._stopping_reached = False
291+ else :
292+ logger .warning ("Did not recognize stopping metric. Not performing early stopping." )
293+ self ._early_stopping = False
294+ else :
295+ self ._early_stopping = False
296+
269297 def _load_checkpoint (self , checkpoint_resume ):
270298 checkpoint = torch .load (
271299 checkpoint_resume ,
@@ -297,7 +325,7 @@ def _load_checkpoint(self, checkpoint_resume):
297325 ("Resuming from checkpoint: step {0}, min loss {1}" ).format (
298326 self ._start_step , self ._min_loss ))
299327
300- def _init_train (self ):
328+ def _init_train (self , scheduler_kwargs ):
301329 self ._start_step = 0
302330 self ._train_logger = _metrics_logger (
303331 "{0}.train" .format (__name__ ), self .output_dir )
@@ -306,9 +334,7 @@ def _init_train(self):
306334 self .scheduler = ReduceLROnPlateau (
307335 self .optimizer ,
308336 'min' ,
309- patience = 16 ,
310- verbose = True ,
311- factor = 0.8 )
337+ ** scheduler_kwargs )
312338 self ._time_per_step = []
313339 self ._train_loss = []
314340
@@ -433,10 +459,12 @@ def train_and_validate(self):
433459 self ._checkpoint ()
434460 if self .step and self .step % self .nth_step_report_stats == 0 :
435461 self .validate ()
462+ if self ._early_stopping and self ._stopping_reached :
463+ logger .debug ("Patience ran out. Stopping early." )
464+ break
436465
437466 self .sampler .save_dataset_to_file ("train" , close_filehandle = True )
438467
439-
440468 def train (self ):
441469 """
442470 Trains the model on a batch of data.
@@ -565,6 +593,13 @@ def validate(self):
565593 logger .debug ("Updating `best_model.pth.tar`" )
566594 logger .info ("validation loss: {0}" .format (validation_loss ))
567595
596+ # check for early stopping
597+ if self ._early_stopping :
598+ stopping_metric = self ._validation_metrics .metrics [self ._stopping_metric ].data
599+ index = np .argmax (stopping_metric )
600+ if self ._stopping_patience - (len (stopping_metric ) - index - 1 ) <= 0 :
601+ self ._stopping_reached = True
602+
568603 def evaluate (self ):
569604 """
570605 Measures the model test performance.
0 commit comments