Skip to content

Commit 9cb4d09

Browse files
committed
Add default hyperparams also to model args, not only to main.py. When training without validation set, num_workers was not being set - fixed.
1 parent e7c82a2 commit 9cb4d09

3 files changed

Lines changed: 5 additions & 6 deletions

File tree

main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ def reload_checkpoint(path, device=None):
5050
parser.add_argument('--epochs', default=108, type=int, help='#epochs of training')
5151
parser.add_argument('--validation_size', default=10000, type=int, help="Size of the validation set to split off.")
5252
parser.add_argument('--num_workers', default=0, type=int, help="Number of parallel workers for the train dataset.")
53-
parser.add_argument('--learning_rate', default=0.02, type=float, help='base learning rate')
54-
parser.add_argument('--lr_decay_method', default='COSINE_BY_STEP', type=str, help='learning decay method')
53+
parser.add_argument('--learning_rate', default=0.2, type=float, help='base learning rate')
5554
parser.add_argument('--optimizer', default='rmsprop', type=str, help='Optimizer (sgd, rmsprop or rmsprop_tf)')
5655
parser.add_argument('--rmsprop_eps', default=1.0, type=float, help='RMSProp eps parameter.')
5756
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')

nasbench_pytorch/datasets/cifar10.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def prepare_dataset(batch_size, test_batch_size=256, root='./data/', use_validat
109109
worker_init_fn=worker_fn)
110110
else:
111111
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True,
112-
worker_init_fn=worker_fn)
112+
num_workers=num_workers, worker_init_fn=worker_fn)
113113
valid_loader = None
114114

115115
test_set = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=test_transform)

nasbench_pytorch/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def train(net, train_loader, loss=None, optimizer=None, scheduler=None, grad_cli
3737

3838
if optimizer is not None and not isinstance(optimizer, str):
3939
pass
40-
elif optimizer is None or optimizer.lower() == 'sgd':
40+
elif optimizer is None or optimizer.lower() == 'rmsprop':
41+
optimizer = torch.optim.RMSprop(net.parameters(), lr=0.2, momentum=0.9, weight_decay=1e-4, eps=1.0)
42+
elif optimizer.lower() == 'sgd':
4143
optimizer = torch.optim.SGD(net.parameters(), lr=0.025, momentum=0.9, weight_decay=1e-4)
42-
elif optimizer.lower() == 'rmsprop':
43-
optimizer = torch.optim.RMSprop(net.parameters(), lr=0.2, momentum=0.9, weight_decay=1e-4)
4444
elif optimizer.lower() == 'adam':
4545
optimizer = torch.optim.Adam(net.parameters())
4646

0 commit comments

Comments
 (0)