@@ -48,12 +48,12 @@ class is a subclass of ``torch.Tensor``, with the special behavior that
4848class TinyModel (torch .nn .Module ):
4949
5050 def __init__ (self ):
51- super (TinyModel , self ).__init__ ()
51+ super ().__init__ ()
5252
5353 self .linear1 = torch .nn .Linear (100 , 200 )
5454 self .activation = torch .nn .ReLU ()
5555 self .linear2 = torch .nn .Linear (200 , 10 )
56- self .softmax = torch .nn .Softmax ()
56+ self .softmax = torch .nn .Softmax (dim = 1 )
5757
5858 def forward (self , x ):
5959 x = self .linear1 (x )
@@ -150,7 +150,7 @@ def forward(self, x):
150150class LeNet (torch .nn .Module ):
151151
152152 def __init__ (self ):
153- super (LeNet , self ).__init__ ()
153+ super ().__init__ ()
154154 # 1 input image channel (black & white), 6 output channels, 5x5 square convolution
155155 # kernel
156156 self .conv1 = torch .nn .Conv2d (1 , 6 , 5 )
@@ -249,7 +249,7 @@ def num_flat_features(self, x):
249249class LSTMTagger (torch .nn .Module ):
250250
251251 def __init__ (self , embedding_dim , hidden_dim , vocab_size , tagset_size ):
252- super (LSTMTagger , self ).__init__ ()
252+ super ().__init__ ()
253253 self .hidden_dim = hidden_dim
254254
255255 self .word_embeddings = torch .nn .Embedding (vocab_size , embedding_dim )
0 commit comments