5757of data they contain.
5858
5959For this tutorial, we’ll be using the Fashion-MNIST dataset provided by
60- TorchVision. We use ``torchvision.transforms.Normalize()`` to
60+ TorchVision. We use ``torchvision.transforms.v2. Normalize()`` to
6161zero-center and normalize the distribution of the image tile content,
6262and download both training and validation data splits.
6363
6464"""
6565
6666import torch
6767import torchvision
68- import torchvision .transforms as transforms
68+ from torchvision .transforms import v2
6969
7070# PyTorch TensorBoard support
7171from torch .utils .tensorboard import SummaryWriter
7272from datetime import datetime
7373
7474
75- transform = transforms .Compose (
76- [transforms .ToTensor (),
77- transforms .Normalize ((0.5 ,), (0.5 ,))])
75+ transform = v2 .Compose ([
76+ v2 .ToImage (),
77+ v2 .ToDtype (torch .float32 , scale = True ),
78+ v2 .Normalize ((0.5 ,), (0.5 ,))
79+ ])
7880
7981# Create datasets for training & validation, download if necessary
8082training_set = torchvision .datasets .FashionMNIST ('./data' , train = True , transform = transform , download = True )
8991 'Sandal' , 'Shirt' , 'Sneaker' , 'Bag' , 'Ankle Boot' )
9092
9193# Report split sizes
92- print ('Training set has {} instances' . format ( len (training_set )) )
93- print ('Validation set has {} instances' . format ( len (validation_set )) )
94+ print (f 'Training set has { len (training_set )} instances' )
95+ print (f 'Validation set has { len (validation_set )} instances' )
9496
9597
9698######################################################################
@@ -134,7 +136,7 @@ def matplotlib_imshow(img, one_channel=False):
134136# PyTorch models inherit from torch.nn.Module
135137class GarmentClassifier (nn .Module ):
136138 def __init__ (self ):
137- super (GarmentClassifier , self ).__init__ ()
139+ super ().__init__ ()
138140 self .conv1 = nn .Conv2d (1 , 6 , 5 )
139141 self .pool = nn .MaxPool2d (2 , 2 )
140142 self .conv2 = nn .Conv2d (6 , 16 , 5 )
@@ -176,7 +178,7 @@ def forward(self, x):
176178print (dummy_labels )
177179
178180loss = loss_fn (dummy_outputs , dummy_labels )
179- print ('Total loss for this batch: {}' . format ( loss .item ()) )
181+ print (f 'Total loss for this batch: { loss .item ()} ' )
180182
181183
182184#################################################################################
@@ -251,7 +253,7 @@ def train_one_epoch(epoch_index, tb_writer):
251253 running_loss += loss .item ()
252254 if i % 1000 == 999 :
253255 last_loss = running_loss / 1000 # loss per batch
254- print (' batch {} loss: {}' . format ( i + 1 , last_loss ) )
256+ print (f ' batch { i + 1 } loss: { last_loss } ' )
255257 tb_x = epoch_index * len (training_loader ) + i + 1
256258 tb_writer .add_scalar ('Loss/train' , last_loss , tb_x )
257259 running_loss = 0.
@@ -276,15 +278,15 @@ def train_one_epoch(epoch_index, tb_writer):
276278
277279# Initializing in a separate cell so we can easily add more epochs to the same run
278280timestamp = datetime .now ().strftime ('%Y%m%d_%H%M%S' )
279- writer = SummaryWriter ('runs/fashion_trainer_{}' . format ( timestamp ) )
281+ writer = SummaryWriter (f 'runs/fashion_trainer_{ timestamp } ' )
280282epoch_number = 0
281283
282284EPOCHS = 5
283285
284286best_vloss = 1_000_000.
285287
286288for epoch in range (EPOCHS ):
287- print ('EPOCH {}:' . format ( epoch_number + 1 ) )
289+ print (f 'EPOCH { epoch_number + 1 } :' )
288290
289291 # Make sure gradient tracking is on, and do a pass over the data
290292 model .train (True )
@@ -305,7 +307,7 @@ def train_one_epoch(epoch_index, tb_writer):
305307 running_vloss += vloss
306308
307309 avg_vloss = running_vloss / (i + 1 )
308- print ('LOSS train {} valid {}' . format ( avg_loss , avg_vloss ) )
310+ print (f 'LOSS train { avg_loss } valid { avg_vloss } ' )
309311
310312 # Log the running loss averaged per batch
311313 # for both training and validation
@@ -317,7 +319,7 @@ def train_one_epoch(epoch_index, tb_writer):
317319 # Track best performance, and save the model's state
318320 if avg_vloss < best_vloss :
319321 best_vloss = avg_vloss
320- model_path = 'model_{}_{}' . format ( timestamp , epoch_number )
322+ model_path = f 'model_{ timestamp } _{ epoch_number } '
321323 torch .save (model .state_dict (), model_path )
322324
323325 epoch_number += 1
0 commit comments