|
57 | 57 | """ |
58 | 58 | import torch |
59 | 59 | import torchvision |
60 | | -import torchvision.transforms as transforms |
| 60 | +from torchvision.transforms import v2 |
61 | 61 |
|
62 | 62 | ######################################################################## |
63 | 63 | # The output of torchvision datasets are PILImage images of range [0, 1]. |
|
69 | 69 | # BrokenPipeError or RuntimeError related to multiprocessing, try setting |
70 | 70 | # the num_worker of torch.utils.data.DataLoader() to 0. |
71 | 71 |
|
72 | | -transform = transforms.Compose( |
73 | | - [transforms.ToTensor(), |
74 | | - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) |
| 72 | +transform = v2.Compose([ |
| 73 | + v2.ToImage(), |
| 74 | + v2.ToDtype(torch.float32, scale=True), |
| 75 | + v2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) |
75 | 76 |
|
76 | 77 | batch_size = 4 |
77 | 78 |
|
@@ -191,7 +192,7 @@ def forward(self, x): |
191 | 192 | ######################################################################## |
192 | 193 | # Let's quickly save our trained model: |
193 | 194 |
|
194 | | -PATH = './cifar_net.pth' |
| 195 | +PATH = './cifar_net.pt' |
195 | 196 | torch.save(net.state_dict(), PATH) |
196 | 197 |
|
197 | 198 | ######################################################################## |
@@ -302,7 +303,7 @@ def forward(self, x): |
302 | 303 | # Let's first define our device as the first visible cuda device if we have |
303 | 304 | # CUDA available: |
304 | 305 |
|
305 | | -device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| 306 | +device = torch.device(torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu') |
306 | 307 |
|
307 | 308 | # Assuming that we are on a CUDA machine, this should print a CUDA device: |
308 | 309 |
|
@@ -355,9 +356,9 @@ def forward(self, x): |
355 | 356 | # - `Discuss PyTorch on the Forums`_ |
356 | 357 | # - `Chat with other users on Slack`_ |
357 | 358 | # |
358 | | -# .. _Train a state-of-the-art ResNet network on imagenet: https://github.com/pytorch/examples/tree/master/imagenet |
359 | | -# .. _Train a face generator using Generative Adversarial Networks: https://github.com/pytorch/examples/tree/master/dcgan |
360 | | -# .. _Train a word-level language model using Recurrent LSTM networks: https://github.com/pytorch/examples/tree/master/word_language_model |
| 359 | +# .. _Train a state-of-the-art ResNet network on imagenet: https://github.com/pytorch/examples/tree/main/imagenet |
| 360 | +# .. _Train a face generator using Generative Adversarial Networks: https://github.com/pytorch/examples/tree/main/dcgan |
| 361 | +# .. _Train a word-level language model using Recurrent LSTM networks: https://github.com/pytorch/examples/tree/main/word_language_model |
361 | 362 | # .. _More examples: https://github.com/pytorch/examples |
362 | 363 | # .. _More tutorials: https://github.com/pytorch/tutorials |
363 | 364 | # .. _Discuss PyTorch on the Forums: https://discuss.pytorch.org/ |
|
0 commit comments