์ ์: Iris Zhang, Wanchao Liang
์ฌ์ ์ค๋น(Prerequisites):
- ๋ถ์ฐ ํต์ ํจํค์ง - torch.distributed
- Python 3.8 - 3.11
- PyTorch 2.2
๋ถ์ฐ ํ์ต์ ์ํด ๋ถ์ฐ ํต์ ๊ธฐ(communicator), ์ฆ NVIDIA Collective Communication Library(NCCL) ํต์ ๊ธฐ๋ฅผ ์ค์ ํ๋ ์ผ์ ์๋นํ ์ด๋ ค์์ด ๋ ์ ์์ต๋๋ค. ์๋ก ๋ค๋ฅธ ๋ณ๋ ฌํ ๋ฐฉ์์ ์กฐํฉํด์ผ ํ๋ ์์ ์ด๋ผ๋ฉด, ๊ฐ ๋ณ๋ ฌํ ๋ฐฉ์๋ง๋ค NCCL ํต์ ๊ธฐ(์: :class:`ProcessGroup`)๋ฅผ ์ง์ ์ค์ ํ๊ณ ๊ด๋ฆฌํด์ผ ํฉ๋๋ค. ์ด ๊ณผ์ ์ ๋ณต์กํ๊ณ ์ค๋ฅ๊ฐ ๋ฐ์ํ๊ธฐ ์ฝ์ต๋๋ค. :class:`DeviceMesh` ๋ ์ด ๊ณผ์ ์ ๋จ์ํํ ์ ์๊ณ , ๋ ๋ค๋ฃจ๊ธฐ ์ฝ๊ฒ ๋ง๋ค๋ฉฐ ์ค๋ฅ ๋ฐ์ ๊ฐ๋ฅ์ฑ๋ ์ค์ฌ์ค๋๋ค.
:class:`DeviceMesh` ๋ :class:`ProcessGroup` ์ ๊ด๋ฆฌํ๋ ์์ ์์ค์ ์ถ์ํ์ ๋๋ค. ์๋ก ๋ค๋ฅธ ํ์ ํ๋ก์ธ์ค ๊ทธ๋ฃน์ ๋ํด ๋ญํฌ(rank)๋ฅผ ์ด๋ป๊ฒ ์ฌ๋ฐ๋ฅด๊ฒ ์ค์ ํ ์ง ๊ณ ๋ฏผํ์ง ์๊ณ ๋, ๋ ธ๋ ๊ฐ(inter-node) ๋ฐ ๋ ธ๋ ๋ด(intra-node) ํ๋ก์ธ์ค ๊ทธ๋ฃน์ ์์ฝ๊ฒ ๋ง๋ค ์ ์์ต๋๋ค. ๋ํ :class:`DeviceMesh` ๋ฅผ ํตํด ๋ค์ฐจ์ ๋ณ๋ ฌํ์ ์ฌ์ฉ๋๋ ๋ด๋ถ์ ํ๋ก์ธ์ค ๊ทธ๋ฃน๊ณผ ๋๋ฐ์ด์ค๋ฅผ ์ฝ๊ฒ ๊ด๋ฆฌํ ์ ์์ต๋๋ค.
DeviceMesh๋ ์ฌ๋ฌ ๋ณ๋ ฌํ ๋ฐฉ์์ ์กฐํฉ(composability)ํด์ผ ํ๋ ๋ค์ฐจ์ ๋ณ๋ ฌํ(์: 3-D ๋ณ๋ ฌ)๋ฅผ ๋ค๋ฃฐ ๋ ์ ์ฉํฉ๋๋ค. ์๋ฅผ ๋ค์ด, ๋ณ๋ ฌํ ๋ฐฉ์์ด ํธ์คํธ ๊ฐ ํต์ ๊ณผ ๊ฐ ํธ์คํธ ๋ด๋ถ์ ํต์ ์ ๋ชจ๋ ์๊ตฌํ๋ ๊ฒฝ์ฐ๊ฐ ๊ทธ๋ ์ต๋๋ค. ์ ์ด๋ฏธ์ง๋ ๊ท ์ผํ ํ๊ฒฝ์์ ๊ฐ ํธ์คํธ ๋ด๋ถ์ ๋๋ฐ์ด์ค๋ฅผ ์ฐ๊ฒฐํ๊ณ , ๊ฐ ๋๋ฐ์ด์ค๋ฅผ ๋ค๋ฅธ ํธ์คํธ์ ๋์ ๋๋ฐ์ด์ค์ ์ฐ๊ฒฐํ๋ 2D ๋ฉ์๋ฅผ ๋ง๋ค ์ ์์์ ๋ณด์ฌ์ค๋๋ค.
DeviceMesh๊ฐ ์๋ค๋ฉด, ์ด๋ค ๋ณ๋ ฌํ๋ฅผ ์ ์ฉํ๊ธฐ ์ ์ ๊ฐ ํ๋ก์ธ์ค๋ง๋ค NCCL ํต์ ๊ธฐ์ CUDA ๋๋ฐ์ด์ค๋ฅผ ์ง์ ์ค์ ํด์ผ ํ๋ฉฐ, ์ด๋ ๊ฝค ๋ณต์กํ ์์ ์ ๋๋ค. ๋ค์ ์ฝ๋๋ :class:`DeviceMesh` ์์ด ํ์ด๋ธ๋ฆฌ๋ ์ค๋ฉ(hybrid sharding) 2-D ๋ณ๋ ฌ ํจํด์ ์ค์ ํ๋ ์์์ ๋๋ค. ๋จผ์ ์ค๋(shard) ๊ทธ๋ฃน๊ณผ ๋ณต์ ๊ทธ๋ฃน์ ์ง์ ๊ณ์ฐํ๊ณ , ๊ฐ ๋ญํฌ์ ์๋ง์ ๊ทธ๋ฃน์ ํ ๋นํด์ผ ํฉ๋๋ค.
import os
import torch
import torch.distributed as dist
# ์๋ ํ ํด๋ก์ง ์ดํด
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
print(f"Running example on {rank=} in a world with {world_size=}")
# 2-D ํํ์ ๋ณ๋ ฌ ํจํด์ ๊ด๋ฆฌํ๊ธฐ ์ํ ํ๋ก์ธ์ค ๊ทธ๋ฃน ์์ฑ
dist.init_process_group("nccl")
torch.cuda.set_device(rank)
# ์ค๋ ๊ทธ๋ฃน ์์ฑ (์: (0, 1, 2, 3), (4, 5, 6, 7))
# ๊ฐ ๋ญํฌ์ ์ฌ๋ฐ๋ฅธ ์ค๋ ๊ทธ๋ฃน ํ ๋น
num_node_devices = torch.cuda.device_count()
shard_rank_lists = list(range(0, num_node_devices // 2)), list(range(num_node_devices // 2, num_node_devices))
shard_groups = (
dist.new_group(shard_rank_lists[0]),
dist.new_group(shard_rank_lists[1]),
)
current_shard_group = (
shard_groups[0] if rank in shard_rank_lists[0] else shard_groups[1]
)
# ๋ณต์ ๊ทธ๋ฃน ์์ฑ (์: (0, 4), (1, 5), (2, 6), (3, 7))
# ๊ฐ ๋ญํฌ์ ์ฌ๋ฐ๋ฅธ ๋ณต์ ๊ทธ๋ฃน ํ ๋น
current_replicate_group = None
shard_factor = len(shard_rank_lists[0])
for i in range(num_node_devices // 2):
replicate_group_ranks = list(range(i, num_node_devices, shard_factor))
replicate_group = dist.new_group(replicate_group_ranks)
if rank in replicate_group_ranks:
current_replicate_group = replicate_group์ ์ฝ๋๋ฅผ ์คํํ๋ ค๋ฉด PyTorch Elastic์ ํ์ฉํ ์ ์์ต๋๋ค. 2d_setup.py ๋ผ๋ ํ์ผ์ ๋ง๋ ๋ค,
torch elastic/torchrun ๋ช
๋ น์ ์คํํ์ธ์.
torchrun --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup.pyNote
์์๋ฅผ ๊ฐ๋จํ ๋ณด์ฌ์ฃผ๊ธฐ ์ํด ๋จ์ผ ๋ ธ๋๋ง ์ฌ์ฉํด 2D ๋ณ๋ ฌ์ ์๋ฎฌ๋ ์ด์ ํ๊ณ ์์ต๋๋ค. ์ด ์ฝ๋๋ ๋ฉํฐ ํธ์คํธ ํ๊ฒฝ์์๋ ๊ทธ๋๋ก ์ฌ์ฉํ ์ ์์ต๋๋ค.
:func:`init_device_mesh` ๋ฅผ ํ์ฉํ๋ฉด ์์ 2D ์ค์ ์ ๋จ ๋ ์ค๋ก ๋๋ผ ์ ์๊ณ , ํ์ํ ๋๋ ๋ด๋ถ์ :class:`ProcessGroup` ์๋ ๊ทธ๋๋ก ์ ๊ทผํ ์ ์์ต๋๋ค.
from torch.distributed.device_mesh import init_device_mesh
mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("replicate", "shard"))
# `get_group` API๋ฅผ ํตํด ๋ด๋ถ ํ๋ก์ธ์ค ๊ทธ๋ฃน์ ์ ๊ทผํ ์ ์์ต๋๋ค.
replicate_group = mesh_2d.get_group(mesh_dim="replicate")
shard_group = mesh_2d.get_group(mesh_dim="shard")2d_setup_with_device_mesh.py ๋ผ๋ ํ์ผ์ ๋ง๋ ๋ค,
torch elastic/torchrun ๋ช
๋ น์ ์คํํ์ธ์.
torchrun --nproc_per_node=8 2d_setup_with_device_mesh.pyHybrid Sharding Data Parallel(HSDP)์ ํธ์คํธ ๋ด๋ถ์์๋ FSDP๋ฅผ, ํธ์คํธ ๊ฐ์๋ DDP๋ฅผ ์ํํ๋ 2D ์ ๋ต์ ๋๋ค.
DeviceMesh๊ฐ ๊ฐ๋จํ ์ค์ ์ผ๋ก ๋ชจ๋ธ์ HSDP๋ฅผ ์ ์ฉํ๋ ๋ฐ ์ด๋ป๊ฒ ๋์์ด ๋๋์ง ์์๋ก ์ดํด๋ณด๊ฒ ์ต๋๋ค. DeviceMesh๋ฅผ ์ฌ์ฉํ๋ฉด ์ค๋ ๊ทธ๋ฃน๊ณผ ๋ณต์ ๊ทธ๋ฃน์ ์ง์ ๋ง๋ค๊ณ ๊ด๋ฆฌํ์ง ์์๋ ๋ฉ๋๋ค.
import torch
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard as FSDP
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
# HSDP: MeshShape(2, 4)
mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("dp_replicate", "dp_shard"))
model = FSDP(
ToyModel(), device_mesh=mesh_2d
)hsdp.py ๋ผ๋ ํ์ผ์ ๋ง๋ ๋ค,
torch elastic/torchrun ๋ช
๋ น์ ์คํํ์ธ์.
torchrun --nproc_per_node=8 hsdp.py๋๊ท๋ชจ ํ์ต ํ๊ฒฝ์์๋ ๋ ๋ณต์กํ ์ฌ์ฉ์ ์ ์ ๋ณ๋ ฌ ํ์ต ๊ตฌ์ฑ์ ๋ค๋ค์ผ ํ ์๋ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ์๋ก ๋ค๋ฅธ ๋ณ๋ ฌํ ๋ฐฉ์์ ๋ง์ถฐ ํ์ ๋ฉ์(sub-mesh)๋ฅผ ์๋ผ๋ด์ผ ํ ์ ์์ต๋๋ค. DeviceMesh๋ฅผ ์ฌ์ฉํ๋ฉด ์์ ๋ฉ์์์ ํ์ ๋ฉ์๋ฅผ ์๋ผ๋ด๊ณ , ์์ ๋ฉ์๋ฅผ ์ด๊ธฐํํ ๋ ์ด๋ฏธ ๋ง๋ค์ด์ง NCCL ํต์ ๊ธฐ๋ฅผ ๊ทธ๋๋ก ์ฌ์ฌ์ฉํ ์ ์์ต๋๋ค.
from torch.distributed.device_mesh import init_device_mesh
mesh_3d = init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("replicate", "shard", "tp"))
# ์์ ๋ฉ์์์ ํ์ ๋ฉ์๋ฅผ ์๋ผ๋ผ ์ ์์ต๋๋ค.
hsdp_mesh = mesh_3d["replicate", "shard"]
tp_mesh = mesh_3d["tp"]
# `get_group` API๋ฅผ ํตํด ๋ด๋ถ ํ๋ก์ธ์ค ๊ทธ๋ฃน์ ์ ๊ทผํ ์ ์์ต๋๋ค.
replicate_group = hsdp_mesh["replicate"].get_group()
shard_group = hsdp_mesh["shard"].get_group()
tp_group = tp_mesh.get_group()์ง๊ธ๊น์ง :class:`DeviceMesh` ์ :func:`init_device_mesh` ๋ฅผ ์ดํด๋ณด๊ณ , ์ด๋ฅผ ํ์ฉํด ํด๋ฌ์คํฐ์ ๋ถ์ฐ๋ ๋๋ฐ์ด์ค์ ๋ฐฐ์น๋ฅผ ํํํ๋ ๋ฐฉ๋ฒ๋ ์์๋ดค์ต๋๋ค.
๋ ์์ธํ ๋ด์ฉ์ ๋ค์ ์๋ฃ๋ฅผ ์ฐธ๊ณ ํ์ธ์.
