-
Notifications
You must be signed in to change notification settings - Fork 9.8k
Expand file tree
/
Copy pathsequence_parallel_example.py
More file actions
113 lines (83 loc) · 3.25 KB
/
Copy pathsequence_parallel_example.py
File metadata and controls
113 lines (83 loc) · 3.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
This is the script to test Sequence Parallel(SP) on a toy model in a
Megetron-LM SPMD style. We show an E2E working flow from forward,
backward and optimization.
We use the example of two `nn.Linear` layers with an element-wise `nn.RELU`
in between to show an example of sequence parallel, which was proposed in paper:
https://arxiv.org/pdf/2205.05198.pdf.
Like tensor parallel, we parallelize the first linear layer by column
and also parallelize the second linear layer by row. But the input in each rank
now is different so that we need one all-gather for input and one reduce-scatter
in the end of the second linear layer.
The following is an example command to run this code
torchrun --nnodes 1 --nproc-per-node 4 sequence_parallel_example.py
"""
import os
import sys
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed._tensor import Shard
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
)
from log_utils import rank_log, get_logger, verify_min_gpu_count
# ---- GPU check ------------
_min_gpu_count = 2
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit()
# ---------------------------
from torch.distributed._tensor.device_mesh import init_device_mesh
class ToyModel(nn.Module):
"""MLP based model"""
def __init__(self):
super().__init__()
self.in_proj = nn.Linear(10, 32)
self.relu = nn.ReLU()
self.out_proj = nn.Linear(32, 5)
def forward(self, x):
return self.out_proj(self.relu(self.in_proj(x)))
"""
Main body of the demo of a basic version of sequence parallel by using
PyTorch native APIs.
"""
logger = get_logger()
device_type = torch.accelerator.current_accelerator().type
# create a device mesh based on the given world_size.
device_mesh = init_device_mesh(
device_type=device_type, mesh_shape=(int(os.environ["WORLD_SIZE"]),)
)
_rank = device_mesh.get_rank()
print(f"Starting PyTorch Sequence Parallel example on rank {_rank}.")
rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")
# create model and move it to GPU. Init_device_mesh has already assigned gpu ids...
model = ToyModel().to(device_type)
# Custom parallelization plan for the model
sp_model = parallelize_module(
module=model,
device_mesh=device_mesh,
parallelize_plan={
"in_proj": ColwiseParallel(input_layouts=Shard(0)),
"out_proj": RowwiseParallel(output_layouts=Shard(0)),
},
)
# Create a optimizer for the parallelized module.
lr = 0.25
optimizer = torch.optim.AdamW(sp_model.parameters(), lr=lr, foreach=True)
# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
num_iters = 10
rank_log(_rank, logger, "Sequence Parallel training starting...")
for i in range(num_iters):
# For SP, input can be different across all ranks.
inp = torch.rand(20, 10, device=device_type)
output = sp_model(inp)
output.sum().backward()
optimizer.step()
rank_log(_rank, logger, f"Sequence Parallel iter {i} completed")
rank_log(_rank, logger, "Sequence Parallel training completed!")
if dist.is_initialized():
dist.destroy_process_group()