Skip to content

Commit 86f4df3

Browse files
authored
Early stopping and other parameters for learning rate scheduler (#184)
* Added support for early stopping and other parameters for learning rate scheduler * Add example config for early stopping and different scheduler parameters
1 parent 9723144 commit 86f4df3

File tree

2 files changed

+89
-8
lines changed

2 files changed

+89
-8
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
---
2+
ops: [train]
3+
create_subdirectory: False
4+
lr: 0.0003
5+
model: {
6+
path: ./Data/SeleneFiles/enhancer_resnet_regression.py,
7+
class: EnhancerResnet,
8+
class_args: {
9+
sequence_length: 164,
10+
},
11+
#non_strand_specific: "mean",
12+
}
13+
sampler: !obj:selene_sdk.samplers.MultiSampler {
14+
features: ["expression_log2_standardized"],
15+
train_sampler: !obj:selene_sdk.samplers.file_samplers.MatFileSampler {
16+
filepath: ./Data/SeleneFiles/train_regression.mat,
17+
sequence_key: sequence,
18+
targets_key: activity,
19+
shuffle: True,
20+
},
21+
validate_sampler: !obj:selene_sdk.samplers.file_samplers.MatFileSampler {
22+
filepath: ./Data/SeleneFiles/validate_regression.mat,
23+
sequence_key: sequence,
24+
targets_key: activity,
25+
shuffle: False,
26+
},
27+
}
28+
train_model: !obj:selene_sdk.TrainModel {
29+
batch_size: 128, # 25757 training examples
30+
report_stats_every_n_steps: 202, # 201.23 steps for full epoch
31+
max_steps: 10062, # 50 epochs
32+
use_cuda: True,
33+
data_parallel: False,
34+
logging_verbosity: 2,
35+
metrics: {
36+
pcc: !import metrics.pearson,
37+
scc: !import metrics.spearman,
38+
},
39+
scheduler_kwargs: {
40+
patience: 3,
41+
factor: 0.2,
42+
verbose: True,
43+
},
44+
stopping_criteria: ["scc", 10],
45+
}
46+
...

selene_sdk/train_model.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,24 @@ class TrainModel(object):
139139
use_scheduler : bool, optional
140140
Default is `True`. If `True`, learning rate scheduler is used to
141141
reduce learning rate on plateau. PyTorch ReduceLROnPlateau scheduler
142-
with patience=16 and factor=0.8 is used.
142+
with patience=16 and factor=0.8 is used. Different scheduler parameters
143+
can be specified with `scheduler_kwargs`.
143144
deterministic : bool, optional
144145
Default is `False`. If `True`, will set
145146
`torch.backends.cudnn.deterministic` to True and
146147
`torch.backends.cudnn.benchmark = False`. In Selene CLI,
147148
if `random_seed` is set in the configuration YAML, Selene automatically
148149
passes in `deterministic=True` to the TrainModel class.
150+
scheduler_kwargs : dict, optional
151+
Default is patience=16, verbose=True, and factor=0.8. Set the parameters
152+
for the PyTorch ReduceLROnPlateau scheduler.
153+
stopping_criteria : list or None, optional
154+
Default is `None`. If `stopping_criteria` is not None, it should be a
155+
list specifying how to use early stopping. The first value should be
156+
a str corresponding to one of `metrics`. The second value should be an
157+
int indicating the patience. If the specified metric does not improve
158+
in the given patience (usually corresponding to the number of epochs),
159+
training stops early.
149160
150161
Attributes
151162
----------
@@ -197,7 +208,11 @@ def __init__(self,
197208
metrics=dict(roc_auc=roc_auc_score,
198209
average_precision=average_precision_score),
199210
use_scheduler=True,
200-
deterministic=False):
211+
deterministic=False,
212+
scheduler_kwargs=dict(patience=16,
213+
verbose=True,
214+
factor=0.8),
215+
stopping_criteria=None):
201216
"""
202217
Constructs a new `TrainModel` object.
203218
"""
@@ -259,13 +274,26 @@ def __init__(self,
259274
self._n_test_samples = n_test_samples
260275
self._use_scheduler = use_scheduler
261276

262-
self._init_train()
277+
self._init_train(scheduler_kwargs)
263278
self._init_validate()
264279
if "test" in self.sampler.modes:
265280
self._init_test()
266281
if checkpoint_resume is not None:
267282
self._load_checkpoint(checkpoint_resume)
268283

284+
if type(stopping_criteria) is list and len(stopping_criteria) == 2:
285+
stopping_metric, stopping_patience = stopping_criteria
286+
self._early_stopping = True
287+
if stopping_metric in self._metrics:
288+
self._stopping_metric = stopping_metric
289+
self._stopping_patience = stopping_patience
290+
self._stopping_reached = False
291+
else:
292+
logger.warning("Did not recognize stopping metric. Not performing early stopping.")
293+
self._early_stopping = False
294+
else:
295+
self._early_stopping = False
296+
269297
def _load_checkpoint(self, checkpoint_resume):
270298
checkpoint = torch.load(
271299
checkpoint_resume,
@@ -297,7 +325,7 @@ def _load_checkpoint(self, checkpoint_resume):
297325
("Resuming from checkpoint: step {0}, min loss {1}").format(
298326
self._start_step, self._min_loss))
299327

300-
def _init_train(self):
328+
def _init_train(self, scheduler_kwargs):
301329
self._start_step = 0
302330
self._train_logger = _metrics_logger(
303331
"{0}.train".format(__name__), self.output_dir)
@@ -306,9 +334,7 @@ def _init_train(self):
306334
self.scheduler = ReduceLROnPlateau(
307335
self.optimizer,
308336
'min',
309-
patience=16,
310-
verbose=True,
311-
factor=0.8)
337+
**scheduler_kwargs)
312338
self._time_per_step = []
313339
self._train_loss = []
314340

@@ -433,10 +459,12 @@ def train_and_validate(self):
433459
self._checkpoint()
434460
if self.step and self.step % self.nth_step_report_stats == 0:
435461
self.validate()
462+
if self._early_stopping and self._stopping_reached:
463+
logger.debug("Patience ran out. Stopping early.")
464+
break
436465

437466
self.sampler.save_dataset_to_file("train", close_filehandle=True)
438467

439-
440468
def train(self):
441469
"""
442470
Trains the model on a batch of data.
@@ -565,6 +593,13 @@ def validate(self):
565593
logger.debug("Updating `best_model.pth.tar`")
566594
logger.info("validation loss: {0}".format(validation_loss))
567595

596+
# check for early stopping
597+
if self._early_stopping:
598+
stopping_metric = self._validation_metrics.metrics[self._stopping_metric].data
599+
index = np.argmax(stopping_metric)
600+
if self._stopping_patience - (len(stopping_metric) - index - 1) <= 0:
601+
self._stopping_reached = True
602+
568603
def evaluate(self):
569604
"""
570605
Measures the model test performance.

0 commit comments

Comments
 (0)