Skip to content

Commit 9d86fba

Browse files
committed
Update torch.accelerator API usage in Imagenet example
1 parent 9844fee commit 9d86fba

2 files changed

Lines changed: 14 additions & 15 deletions

File tree

imagenet/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ python main.py -a resnet18 --dummy
3333

3434
## Multi-processing Distributed Data Parallel Training
3535

36-
You should always use the NCCL backend for multi-processing distributed training since it currently provides the best distributed training performance.
36+
If running on CUDA, you should always use the NCCL backend for multi-processing distributed training since it currently provides the best distributed training performance.
37+
38+
For XPU multiprocessing is not supported as of PyTorch 2.6.
3739

3840
### Single node, multiple GPUs:
3941

imagenet/main.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,13 @@ def main_worker(gpu, ngpus_per_node, args):
141141
use_accel = not args.no_accel and torch.accelerator.is_available()
142142

143143
if use_accel:
144+
if args.gpu is not None:
145+
torch.accelerator.set_device_index(args.gpu)
146+
print("Use GPU: {} for training".format(args.gpu))
144147
device = torch.accelerator.current_accelerator()
145148
else:
146149
device = torch.device("cpu")
147150

148-
if args.gpu is not None:
149-
print("Use GPU: {} for training".format(args.gpu))
150-
151151
if args.distributed:
152152
if args.dist_url == "env://" and args.rank == -1:
153153
args.rank = int(os.environ["RANK"])
@@ -173,8 +173,8 @@ def main_worker(gpu, ngpus_per_node, args):
173173
# DistributedDataParallel will use all available devices.
174174
if device.type == 'cuda':
175175
if args.gpu is not None:
176-
torch.accelerator.set_device_index(args.gpu)
177-
model.to(device)
176+
torch.cuda.set_device(args.gpu)
177+
model.cuda(device)
178178
# When using a single GPU per process and per
179179
# DistributedDataParallel, we need to divide the batch size
180180
# ourselves based on the total number of GPUs of the current node.
@@ -186,19 +186,16 @@ def main_worker(gpu, ngpus_per_node, args):
186186
# DistributedDataParallel will divide and allocate batch_size to all
187187
# available GPUs if device_ids are not set
188188
model = torch.nn.parallel.DistributedDataParallel(model)
189-
190-
elif args.gpu is not None and device.type=='cuda':
191-
torch.accelerator.set_device_index(args.gpu)
192-
model.to(device)
193-
elif device.type != 'cuda':
194-
model.to(device)
195-
else:
189+
elif device.type == 'cuda':
196190
# DataParallel will divide and allocate batch_size to all available GPUs
197191
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
198192
model.features = torch.nn.DataParallel(model.features)
199193
model.cuda()
200194
else:
201195
model = torch.nn.DataParallel(model).cuda()
196+
else:
197+
model.to(device)
198+
202199

203200
# define loss function (criterion), optimizer, and learning rate scheduler
204201
criterion = nn.CrossEntropyLoss().to(device)
@@ -216,9 +213,9 @@ def main_worker(gpu, ngpus_per_node, args):
216213
print("=> loading checkpoint '{}'".format(args.resume))
217214
if args.gpu is None:
218215
checkpoint = torch.load(args.resume)
219-
elif device.type=='cuda':
216+
else:
220217
# Map model to be loaded to specified single gpu.
221-
loc = 'cuda:{}'.format(args.gpu)
218+
loc = f'{device.type}:{args.gpu}'
222219
checkpoint = torch.load(args.resume, map_location=loc)
223220
args.start_epoch = checkpoint['epoch']
224221
best_acc1 = checkpoint['best_acc1']

0 commit comments

Comments
 (0)