Skip to content

Commit 6fdd8b1

Browse files
Update introyt model tutorial APIs (#3895)
Fixes #3856 ## Description Updates the Building Models tutorial to use the current PyTorch API style: - specifies `dim=1` for the `TinyModel` softmax layer - replaces Python 2-style `super(...)` calls with `super().__init__()` ## Verification - [x] `python3 -m py_compile beginner_source/introyt/modelsyt_tutorial.py` - [x] `MPLBACKEND=Agg /tmp/docathon-torch-venv/bin/python beginner_source/introyt/modelsyt_tutorial.py` ## Checklist - [x] The issue that is being fixed is referred in the description - [x] Only one issue is addressed in this pull request - [x] Labels from the issue that this PR is fixing are added to this pull request - [x] No unnecessary issues are included into this pull request. cc @subramen
1 parent 046a3f3 commit 6fdd8b1

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

beginner_source/introyt/modelsyt_tutorial.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,12 @@ class is a subclass of ``torch.Tensor``, with the special behavior that
4848
class TinyModel(torch.nn.Module):
4949

5050
def __init__(self):
51-
super(TinyModel, self).__init__()
51+
super().__init__()
5252

5353
self.linear1 = torch.nn.Linear(100, 200)
5454
self.activation = torch.nn.ReLU()
5555
self.linear2 = torch.nn.Linear(200, 10)
56-
self.softmax = torch.nn.Softmax()
56+
self.softmax = torch.nn.Softmax(dim=1)
5757

5858
def forward(self, x):
5959
x = self.linear1(x)
@@ -150,7 +150,7 @@ def forward(self, x):
150150
class LeNet(torch.nn.Module):
151151

152152
def __init__(self):
153-
super(LeNet, self).__init__()
153+
super().__init__()
154154
# 1 input image channel (black & white), 6 output channels, 5x5 square convolution
155155
# kernel
156156
self.conv1 = torch.nn.Conv2d(1, 6, 5)
@@ -249,7 +249,7 @@ def num_flat_features(self, x):
249249
class LSTMTagger(torch.nn.Module):
250250

251251
def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
252-
super(LSTMTagger, self).__init__()
252+
super().__init__()
253253
self.hidden_dim = hidden_dim
254254

255255
self.word_embeddings = torch.nn.Embedding(vocab_size, embedding_dim)

0 commit comments

Comments
 (0)