2525
2626class Network (nn .Module ):
2727 def __init__ (self , spec , num_labels = 10 ,
28- in_channels = 3 , stem_out_channels = 128 , num_stacks = 3 , num_modules_per_stack = 3 ):
28+ in_channels = 3 , stem_out_channels = 128 , num_stacks = 3 , num_modules_per_stack = 3 , momentum = 0.1 , eps = 1e-5 ):
2929 """
3030
3131 Args:
@@ -49,7 +49,7 @@ def __init__(self, spec, num_labels=10,
4949
5050 # initial stem convolution
5151 out_channels = stem_out_channels
52- stem_conv = ConvBnRelu (in_channels , out_channels , 3 , 1 , 1 )
52+ stem_conv = ConvBnRelu (in_channels , out_channels , 3 , 1 , 1 , momentum = momentum , eps = eps )
5353 self .layers .append (stem_conv )
5454
5555 # stacked cells
@@ -63,7 +63,7 @@ def __init__(self, spec, num_labels=10,
6363 out_channels *= 2
6464
6565 for module_num in range (num_modules_per_stack ):
66- cell = Cell (spec , in_channels , out_channels )
66+ cell = Cell (spec , in_channels , out_channels , momentum = momentum , eps = eps )
6767 self .layers .append (cell )
6868 in_channels = out_channels
6969
@@ -102,7 +102,7 @@ class Cell(nn.Module):
102102 determined via equally splitting the channel count whenever there is a
103103 concatenation of Tensors.
104104 """
105- def __init__ (self , spec , in_channels , out_channels ):
105+ def __init__ (self , spec , in_channels , out_channels , momentum = 0.1 , eps = 1e-5 ):
106106 super (Cell , self ).__init__ ()
107107
108108 self .dev_param = nn .Parameter (torch .empty (0 ))
@@ -124,7 +124,7 @@ def __init__(self, spec, in_channels, out_channels):
124124 self .input_op = nn .ModuleList ([Placeholder ()])
125125 for t in range (1 , self .num_vertices ):
126126 if self .matrix [0 , t ]:
127- self .input_op .append (Projection (in_channels , self .vertex_channels [t ]))
127+ self .input_op .append (Projection (in_channels , self .vertex_channels [t ], momentum = momentum , eps = eps ))
128128 else :
129129 self .input_op .append (Placeholder ())
130130
@@ -179,9 +179,11 @@ def forward(self, x):
179179
180180 return outputs
181181
182- def Projection (in_channels , out_channels ):
182+
183+ def Projection (in_channels , out_channels , momentum = 0.1 , eps = 1e-5 ):
183184 """1x1 projection (as in ResNet) followed by batch normalization and ReLU."""
184- return ConvBnRelu (in_channels , out_channels , 1 )
185+ return ConvBnRelu (in_channels , out_channels , 1 , momentum = momentum , eps = eps )
186+
185187
186188def Truncate (inputs , channels ):
187189 """Slice the inputs to channels if necessary."""
@@ -197,6 +199,7 @@ def Truncate(inputs, channels):
197199 assert input_channels - channels == 1
198200 return inputs [:, :channels , :, :]
199201
202+
200203def ComputeVertexChannels (in_channels , out_channels , matrix ):
201204 """Computes the number of channels at every vertex.
202205
0 commit comments