Skip to content

Commit 4020a3f

Browse files
committed
Added arguments that enable to set batch norm momentum and eps. Easier optimizer choice in main.py.
1 parent 0732bc3 commit 4020a3f

3 files changed

Lines changed: 30 additions & 15 deletions

File tree

main.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,12 @@ def reload_checkpoint(path, device=None):
5656
parser.add_argument('--num_workers', default=0, type=int, help="Number of parallel workers for the train dataset.")
5757
parser.add_argument('--learning_rate', default=0.025, type=float, help='base learning rate')
5858
parser.add_argument('--lr_decay_method', default='COSINE_BY_STEP', type=str, help='learning decay method')
59+
parser.add_argument('--optimizer', default='sgd', type=str, help='Optimizer (sgd or rmsprop)')
5960
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
6061
parser.add_argument('--weight_decay', default=1e-4, type=float, help='L2 regularization weight')
6162
parser.add_argument('--grad_clip', default=5, type=float, help='gradient clipping')
63+
parser.add_argument('--batch_norm_momentum', default=0.1, type=float, help='Batch normalization momentum')
64+
parser.add_argument('--batch_norm_eps', default=1e-5, type=float, help='Batch normalization epsilon')
6265
parser.add_argument('--load_checkpoint', default='', type=str, help='Reload model from checkpoint')
6366
parser.add_argument('--num_labels', default=10, type=int, help='#classes')
6467
parser.add_argument('--device', default='cuda', type=str, help='Device for network training.')
@@ -77,14 +80,23 @@ def reload_checkpoint(path, device=None):
7780
# model
7881
spec = ModelSpec(matrix, operations)
7982
net = Network(spec, num_labels=args.num_labels, in_channels=args.in_channels, stem_out_channels=args.stem_out_channels,
80-
num_stacks=args.num_stacks, num_modules_per_stack=args.num_modules_per_stack)
83+
num_stacks=args.num_stacks, num_modules_per_stack=args.num_modules_per_stack,
84+
momentum=args.batch_norm_momentum, eps=args.batch_norm_eps)
8185

8286
if args.load_checkpoint != '':
8387
net.load_state_dict(reload_checkpoint(args.load_checkpoint))
8488
net.to(args.device)
8589

8690
criterion = nn.CrossEntropyLoss()
87-
optimizer = optim.SGD(net.parameters(), lr=args.learning_rate, momentum=args.momentum,
91+
92+
if args.optimizer.lower() == 'sgd':
93+
optimizer = optim.SGD
94+
elif args.optimizer.lower() == 'rmsprop':
95+
optimizer = optim.RMSprop
96+
else:
97+
raise ValueError(f"Invalid optimizer {args.optimizer}, possible: SGD, RMSProp")
98+
99+
optimizer = optimizer(net.parameters(), lr=args.learning_rate, momentum=args.momentum,
88100
weight_decay=args.weight_decay)
89101
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
90102

nasbench_pytorch/model/base_ops.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
import torch.nn.functional as F
1010

1111
class ConvBnRelu(nn.Module):
12-
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
12+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, momentum=0.1, eps=1e-5):
1313
super(ConvBnRelu, self).__init__()
1414

1515
self.conv_bn_relu = nn.Sequential(
1616
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
17-
nn.BatchNorm2d(out_channels),
17+
nn.BatchNorm2d(out_channels, eps=eps, momentum=momentum),
1818
nn.ReLU()
1919
)
2020

@@ -23,21 +23,21 @@ def forward(self, x):
2323

2424
class Conv3x3BnRelu(nn.Module):
2525
"""3x3 convolution with batch norm and ReLU activation."""
26-
def __init__(self, in_channels, out_channels):
26+
def __init__(self, in_channels, out_channels, **kwargs):
2727
super(Conv3x3BnRelu, self).__init__()
2828

29-
self.conv3x3 = ConvBnRelu(in_channels, out_channels, 3, 1, 1)
29+
self.conv3x3 = ConvBnRelu(in_channels, out_channels, 3, 1, 1, **kwargs)
3030

3131
def forward(self, x):
3232
x = self.conv3x3(x)
3333
return x
3434

3535
class Conv1x1BnRelu(nn.Module):
3636
"""1x1 convolution with batch norm and ReLU activation."""
37-
def __init__(self, in_channels, out_channels):
37+
def __init__(self, in_channels, out_channels, **kwargs):
3838
super(Conv1x1BnRelu, self).__init__()
3939

40-
self.conv1x1 = ConvBnRelu(in_channels, out_channels, 1, 1, 0)
40+
self.conv1x1 = ConvBnRelu(in_channels, out_channels, 1, 1, 0, **kwargs)
4141

4242
def forward(self, x):
4343
x = self.conv1x1(x)

nasbench_pytorch/model/model.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
class Network(nn.Module):
2727
def __init__(self, spec, num_labels=10,
28-
in_channels=3, stem_out_channels=128, num_stacks=3, num_modules_per_stack=3):
28+
in_channels=3, stem_out_channels=128, num_stacks=3, num_modules_per_stack=3, momentum=0.1, eps=1e-5):
2929
"""
3030
3131
Args:
@@ -49,7 +49,7 @@ def __init__(self, spec, num_labels=10,
4949

5050
# initial stem convolution
5151
out_channels = stem_out_channels
52-
stem_conv = ConvBnRelu(in_channels, out_channels, 3, 1, 1)
52+
stem_conv = ConvBnRelu(in_channels, out_channels, 3, 1, 1, momentum=momentum, eps=eps)
5353
self.layers.append(stem_conv)
5454

5555
# stacked cells
@@ -63,7 +63,7 @@ def __init__(self, spec, num_labels=10,
6363
out_channels *= 2
6464

6565
for module_num in range(num_modules_per_stack):
66-
cell = Cell(spec, in_channels, out_channels)
66+
cell = Cell(spec, in_channels, out_channels, momentum=momentum, eps=eps)
6767
self.layers.append(cell)
6868
in_channels = out_channels
6969

@@ -102,7 +102,7 @@ class Cell(nn.Module):
102102
determined via equally splitting the channel count whenever there is a
103103
concatenation of Tensors.
104104
"""
105-
def __init__(self, spec, in_channels, out_channels):
105+
def __init__(self, spec, in_channels, out_channels, momentum=0.1, eps=1e-5):
106106
super(Cell, self).__init__()
107107

108108
self.dev_param = nn.Parameter(torch.empty(0))
@@ -124,7 +124,7 @@ def __init__(self, spec, in_channels, out_channels):
124124
self.input_op = nn.ModuleList([Placeholder()])
125125
for t in range(1, self.num_vertices):
126126
if self.matrix[0, t]:
127-
self.input_op.append(Projection(in_channels, self.vertex_channels[t]))
127+
self.input_op.append(Projection(in_channels, self.vertex_channels[t], momentum=momentum, eps=eps))
128128
else:
129129
self.input_op.append(Placeholder())
130130

@@ -179,9 +179,11 @@ def forward(self, x):
179179

180180
return outputs
181181

182-
def Projection(in_channels, out_channels):
182+
183+
def Projection(in_channels, out_channels, momentum=0.1, eps=1e-5):
183184
"""1x1 projection (as in ResNet) followed by batch normalization and ReLU."""
184-
return ConvBnRelu(in_channels, out_channels, 1)
185+
return ConvBnRelu(in_channels, out_channels, 1, momentum=momentum, eps=eps)
186+
185187

186188
def Truncate(inputs, channels):
187189
"""Slice the inputs to channels if necessary."""
@@ -197,6 +199,7 @@ def Truncate(inputs, channels):
197199
assert input_channels - channels == 1
198200
return inputs[:, :channels, :, :]
199201

202+
200203
def ComputeVertexChannels(in_channels, out_channels, matrix):
201204
"""Computes the number of channels at every vertex.
202205

0 commit comments

Comments
 (0)