Skip to content

Commit 97165f0

Browse files
committed
Adding torch accelerator and requirements file to FSDP2 example
1 parent 2944a9d commit 97165f0

5 files changed

Lines changed: 45 additions & 8 deletions

File tree

distributed/FSDP2/README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
11
## FSDP2
22
To run FSDP2 on transformer model:
3+
34
```
45
cd distributed/FSDP2
5-
torchrun --nproc_per_node 2 train.py
6+
pip install -r requirements.txt
7+
torchrun --nproc_per_node 2 example.py
68
```
79
* For 1st time, it creates a "checkpoints" folder and saves state dicts there
810
* For 2nd time, it loads from previous checkpoints
911

1012
To enable explicit prefetching
1113
```
12-
torchrun --nproc_per_node 2 train.py --explicit-prefetch
14+
torchrun --nproc_per_node 2 example.py --explicit-prefetch
1315
```
1416

1517
To enable mixed precision
1618
```
17-
torchrun --nproc_per_node 2 train.py --mixed-precision
19+
torchrun --nproc_per_node 2 example.py --mixed-precision
1820
```
1921

2022
To showcase DCP API
2123
```
22-
torchrun --nproc_per_node 2 train.py --dcp-api
24+
torchrun --nproc_per_node 2 example.py --dcp-api
2325
```
2426

2527
## Ensure you are running a recent version of PyTorch:
Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
88
from utils import inspect_mixed_precision, inspect_model
99

10+
def verify_min_gpu_count(min_gpus: int = 2) -> bool:
11+
""" verification that we have at least 2 gpus to run dist examples """
12+
has_gpu = torch.accelerator.is_available()
13+
gpu_count = torch.accelerator.device_count()
14+
return has_gpu and gpu_count >= min_gpus
1015

1116
def set_modules_to_forward_prefetch(model, num_to_forward_prefetch):
1217
for i, layer in enumerate(model.layers):
@@ -30,9 +35,18 @@ def set_modules_to_backward_prefetch(model, num_to_backward_prefetch):
3035

3136
def main(args):
3237
rank = int(os.environ["LOCAL_RANK"])
33-
device = torch.device(f"cuda:{rank}")
34-
torch.cuda.set_device(device)
35-
torch.distributed.init_process_group(backend="nccl", device_id=device)
38+
if torch.accelerator.is_available():
39+
device_type = torch.accelerator.current_accelerator()
40+
device: torch.device = torch.device(f"{device_type}:{rank}")
41+
torch.accelerator.device_index(rank)
42+
print(f"Running on rank {rank} on device {device}")
43+
backend = torch.distributed.get_default_backend_for_device(device)
44+
torch.distributed.init_process_group(backend=backend, device_id=device)
45+
else:
46+
device = torch.device("cpu")
47+
print(f"Running on device {device}")
48+
torch.distributed.init_process_group(backend="gloo", device_id=device)
49+
3650
torch.manual_seed(0)
3751
vocab_size = 1024
3852
batch_size = 32
@@ -64,7 +78,7 @@ def main(args):
6478

6579
checkpointer = Checkpointer("checkpoints", dcp_api=args.dcp_api)
6680
if checkpointer.last_training_time is None:
67-
model.to_empty(device="cuda")
81+
model.to_empty(device=device)
6882
model.reset_parameters()
6983
else:
7084
checkpointer.load_model(model)
@@ -96,4 +110,8 @@ def main(args):
96110
parser.add_argument("--mixed-precision", action="store_true", default=False)
97111
parser.add_argument("--dcp-api", action="store_true", default=False)
98112
args = parser.parse_args()
113+
_min_gpu_count = 2
114+
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
115+
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
116+
exit()
99117
main(args)

distributed/FSDP2/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch>=2.7
2+
numpy

distributed/FSDP2/run_example.sh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# /bin/bash
2+
# bash run_example.sh {file_to_run.py} {num_gpus}
3+
# where file_to_run = example to run. Default = 'example.py'
4+
# num_gpus = num local gpus to use (must be at least 2). Default = 4
5+
6+
# samples to run include:
7+
# example.py
8+
9+
echo "Launching ${1:-example.py} with ${2:-4} gpus"
10+
torchrun --nnodes=2 --nproc_per_node=${2:-4} ${1:-example.py}
11+

run_distributed_examples.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ function distributed_tensor_parallelism() {
5050
uv run bash run_example.sh fsdp_tp_example.py || error "2D parallel example failed"
5151
}
5252

53+
function distributed_FSDP2() {
54+
uv run bash run_example.sh example.py || error "FSDP2 example failed"
55+
}
56+
5357
function distributed_ddp() {
5458
uv run main.py || error "ddp example failed"
5559
}

0 commit comments

Comments
 (0)