Skip to content

Commit 60e38a9

Browse files
authored
set deterministic parameter in TrainModel (#181)
* set deterministic in TrainModel * update library setup.py with PyTorch 1.9 * move logger init earlier * fix output_dir init
1 parent a972fbf commit 60e38a9

4 files changed

Lines changed: 27 additions & 13 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ We recommend using Selene with Python 3.6 or above.
1212
Package installation should only take a few minutes (less than 10 minutes, typically ~2-3 minutes) with any of these methods (conda, pip, source).
1313

1414
**First, install [PyTorch](https://pytorch.org/get-started/locally/).** If you have an NVIDIA GPU, install a version of PyTorch that supports it--Selene will run much faster with a discrete GPU.
15-
The library is currently compatible with PyTorch versions between 0.4.1 and 1.4.0.
15+
The library is currently compatible with PyTorch versions between 0.4.1 and 1.9.
1616
We will continue to update Selene to be compatible with the latest version of PyTorch.
1717

1818
### Installing selene with [Anaconda](https://www.anaconda.com/download/) (for Linux):

selene_sdk/train_model.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,14 @@ class TrainModel(object):
138138
using `torch.load`.
139139
use_scheduler : bool, optional
140140
Default is `True`. If `True`, learning rate scheduler is used to
141-
reduce learning rate on plateau. PyTorch ReduceLROnPlateau scheduler
141+
reduce learning rate on plateau. PyTorch ReduceLROnPlateau scheduler
142142
with patience=16 and factor=0.8 is used.
143+
deterministic : bool, optional
144+
Default is `False`. If `True`, will set
145+
`torch.backends.cudnn.deterministic` to True and
146+
`torch.backends.cudnn.benchmark = False`. In Selene CLI,
147+
if `random_seed` is set in the configuration YAML, Selene automatically
148+
passes in `deterministic=True` to the TrainModel class.
143149
144150
Attributes
145151
----------
@@ -190,7 +196,8 @@ def __init__(self,
190196
checkpoint_resume=None,
191197
metrics=dict(roc_auc=roc_auc_score,
192198
average_precision=average_precision_score),
193-
use_scheduler=True):
199+
use_scheduler=True,
200+
deterministic=False):
194201
"""
195202
Constructs a new `TrainModel` object.
196203
"""
@@ -212,6 +219,19 @@ def __init__(self,
212219

213220
self._save_new_checkpoints = save_new_checkpoints_after_n_steps
214221

222+
os.makedirs(output_dir, exist_ok=True)
223+
self.output_dir = output_dir
224+
225+
initialize_logger(
226+
os.path.join(self.output_dir, "{0}.log".format(__name__)),
227+
verbosity=logging_verbosity)
228+
229+
if deterministic:
230+
logger.info("Setting deterministic = True for reproducibility.")
231+
torch.backends.cudnn.deterministic = True
232+
torch.backends.cudnn.benchmark = False
233+
234+
215235
logger.info("Training parameters set: batch size {0}, "
216236
"number of steps per 'epoch': {1}, "
217237
"maximum number of steps: {2}".format(
@@ -233,13 +253,6 @@ def __init__(self,
233253
self.criterion.cuda()
234254
logger.debug("Set modules to use CUDA")
235255

236-
os.makedirs(output_dir, exist_ok=True)
237-
self.output_dir = output_dir
238-
239-
initialize_logger(
240-
os.path.join(self.output_dir, "{0}.log".format(__name__)),
241-
verbosity=logging_verbosity)
242-
243256
self._report_gt_feature_n_positives = report_gt_feature_n_positives
244257
self._metrics = metrics
245258
self._n_validation_samples = n_validation_samples

selene_sdk/utils/config_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def execute(operations, configs, output_dir):
181181
optimizer_kwargs=optim_kwargs)
182182
if output_dir is not None:
183183
train_model_info.bind(output_dir=output_dir)
184+
if "random_seed" in configs:
185+
train_model_info.bind(deterministic=True)
184186

185187
train_model = instantiate(train_model_info)
186188
# TODO: will find a better way to handle this in the future
@@ -341,8 +343,7 @@ def parse_configs_and_run(configs,
341343
np.random.seed(seed)
342344
torch.manual_seed(seed)
343345
torch.cuda.manual_seed_all(seed)
344-
torch.backends.cudnn.deterministic = True
345-
torch.backends.cudnn.benchmark = False
346+
print("Setting random seed = {0}".format(seed))
346347
else:
347348
print("Warning: no random seed specified in config file. "
348349
"Using a random seed ensures results are reproducible.")

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
"scipy",
6464
"seaborn",
6565
"statsmodels",
66-
"torch>=0.4.1, <=1.4.0",
66+
"torch>=0.4.1, <=1.9",
6767
],
6868
entry_points={
6969
'console_scripts': [

0 commit comments

Comments
 (0)