-
Notifications
You must be signed in to change notification settings - Fork 9.8k
Expand file tree
/
Copy pathtensor_parallel_example.py
More file actions
executable file
·127 lines (95 loc) · 3.88 KB
/
Copy pathtensor_parallel_example.py
File metadata and controls
executable file
·127 lines (95 loc) · 3.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
This is the script to test Tensor Parallel(TP) on a toy model in a
Megetron-LM SPMD style. We show an E2E working flow from forward,
backward and optimization.
More context about API designs can be found in the design:
https://github.com/pytorch/pytorch/issues/89884.
And it is built on top of Distributed Tensor which is proposed in:
https://github.com/pytorch/pytorch/issues/88838.
We use the example of two `nn.Linear` layers with an element-wise `nn.RELU`
in between to show an example of Megatron-LM, which was proposed in paper:
https://arxiv.org/abs/1909.08053.
The basic idea is that we parallelize the first linear layer by column
and also parallelize the second linear layer by row so that we only need
one all reduce in the end of the second linear layer.
We can speed up the model training by avoiding communications between
two layers.
To parallelize a nn module, we need to specify what parallel style we want
to use and our `parallelize_module` API will parse and parallelize the modules
based on the given `ParallelStyle`. We are using this PyTorch native Tensor
Parallelism APIs in this example to show users how to use them.
The following is an example command to run this code
torchrun --nnodes 1 --nproc-per-node 4 tensor_parallel_example.py
"""
import os
import sys
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
)
from log_utils import rank_log, get_logger, verify_min_gpu_count
# ---- GPU check ------------
_min_gpu_count = 2
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit()
# ---------------------------
from torch.distributed._tensor.device_mesh import init_device_mesh
class ToyModel(nn.Module):
"""MLP based model"""
def __init__(self):
super(ToyModel, self).__init__()
self.in_proj = nn.Linear(10, 32)
self.relu = nn.ReLU()
self.out_proj = nn.Linear(32, 5)
def forward(self, x):
return self.out_proj(self.relu(self.in_proj(x)))
"""
Main body of the demo of a basic version of tensor parallel by using
PyTorch native APIs.
"""
logger = get_logger()
# create a device mesh based on the given world_size.
_world_size = int(os.environ["WORLD_SIZE"])
device_type = torch.accelerator.current_accelerator().type
device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(_world_size,))
_rank = device_mesh.get_rank()
print(f"Starting PyTorch TP example on rank {_rank}.")
assert (
_world_size % 2 == 0
), f"TP examples require even number of GPUs, but got {_world_size} gpus"
rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")
# create model and move it to GPU - initdevice_type_mesh has already mapped GPU ids.
tp_model = ToyModel().to(device_type)
# Custom parallelization plan for the model
tp_model = parallelize_module(
module=tp_model,
device_mesh=device_mesh,
parallelize_plan={
"in_proj": ColwiseParallel(),
"out_proj": RowwiseParallel(),
},
)
# Create an optimizer for the parallelized module.
lr = 0.25
optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True)
# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
num_iters = 10
rank_log(_rank, logger, "Tensor Parallel training starting...")
for i in range(num_iters):
# For TP, input needs to be same across all TP ranks.
# Setting the random seed is to mimic the behavior of dataloader.
torch.manual_seed(i)
inp = torch.rand(20, 10, device=device_type)
output = tp_model(inp)
output.sum().backward()
optimizer.step()
rank_log(_rank, logger, f"Tensor Parallel iter {i} completed")
rank_log(_rank, logger, "Tensor Parallel training completed!")
if dist.is_initialized():
dist.destroy_process_group()