Skip to content

Commit 1f0d7d3

Browse files
committed
Adding torch accelerator to FSDP2 example
Signed-off-by: dggaytan <diana.gaytan.munoz@intel.com>
1 parent 3d54e15 commit 1f0d7d3

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

distributed/FSDP2/example.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ def main(args):
4141
rank = int(os.environ["LOCAL_RANK"])
4242
if torch.accelerator.is_available():
4343
device_type = torch.accelerator.current_accelerator()
44-
device: torch.device = torch.device(f"{device_type}:{rank}")
44+
device = torch.device(f"{device_type}:{rank}")
4545
torch.accelerator.device_index(rank)
4646
print(f"Running on rank {rank} on device {device}")
47-
backend = torch.distributed.get_default_backend_for_device(device)
48-
torch.distributed.init_process_group(backend=backend, device_id=device)
4947
else:
5048
device = torch.device("cpu")
5149
print(f"Running on device {device}")
52-
torch.distributed.init_process_group(backend="gloo", device_id=device)
50+
51+
backend = torch.distributed.get_default_backend_for_device(device)
52+
torch.distributed.init_process_group(backend=backend, device_id=device)
5353

5454
torch.manual_seed(0)
5555
vocab_size = 1024

0 commit comments

Comments
 (0)