Skip to content

Latest commit

ย 

History

History
178 lines (125 loc) ยท 8.51 KB

File metadata and controls

178 lines (125 loc) ยท 8.51 KB

DeviceMesh ์‹œ์ž‘ํ•˜๊ธฐ

์ €์ž: Iris Zhang, Wanchao Liang

Note

|edit| ์ด ํŠœํ† ๋ฆฌ์–ผ์€ github ์—์„œ ๋ณด๊ฑฐ๋‚˜ ํŽธ์ง‘ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์‚ฌ์ „ ์ค€๋น„(Prerequisites):

๋ถ„์‚ฐ ํ•™์Šต์„ ์œ„ํ•ด ๋ถ„์‚ฐ ํ†ต์‹ ๊ธฐ(communicator), ์ฆ‰ NVIDIA Collective Communication Library(NCCL) ํ†ต์‹ ๊ธฐ๋ฅผ ์„ค์ •ํ•˜๋Š” ์ผ์€ ์ƒ๋‹นํ•œ ์–ด๋ ค์›€์ด ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์„œ๋กœ ๋‹ค๋ฅธ ๋ณ‘๋ ฌํ™” ๋ฐฉ์‹์„ ์กฐํ•ฉํ•ด์•ผ ํ•˜๋Š” ์ž‘์—…์ด๋ผ๋ฉด, ๊ฐ ๋ณ‘๋ ฌํ™” ๋ฐฉ์‹๋งˆ๋‹ค NCCL ํ†ต์‹ ๊ธฐ(์˜ˆ: :class:`ProcessGroup`)๋ฅผ ์ง์ ‘ ์„ค์ •ํ•˜๊ณ  ๊ด€๋ฆฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด ๊ณผ์ •์€ ๋ณต์žกํ•˜๊ณ  ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•˜๊ธฐ ์‰ฝ์Šต๋‹ˆ๋‹ค. :class:`DeviceMesh` ๋Š” ์ด ๊ณผ์ •์„ ๋‹จ์ˆœํ™”ํ•  ์ˆ˜ ์žˆ๊ณ , ๋” ๋‹ค๋ฃจ๊ธฐ ์‰ฝ๊ฒŒ ๋งŒ๋“ค๋ฉฐ ์˜ค๋ฅ˜ ๋ฐœ์ƒ ๊ฐ€๋Šฅ์„ฑ๋„ ์ค„์—ฌ์ค๋‹ˆ๋‹ค.

DeviceMesh๋ž€ ๋ฌด์—‡์ธ๊ฐ€

:class:`DeviceMesh` ๋Š” :class:`ProcessGroup` ์„ ๊ด€๋ฆฌํ•˜๋Š” ์ƒ์œ„ ์ˆ˜์ค€์˜ ์ถ”์ƒํ™”์ž…๋‹ˆ๋‹ค. ์„œ๋กœ ๋‹ค๋ฅธ ํ•˜์œ„ ํ”„๋กœ์„ธ์Šค ๊ทธ๋ฃน์— ๋Œ€ํ•ด ๋žญํฌ(rank)๋ฅผ ์–ด๋–ป๊ฒŒ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์„ค์ •ํ• ์ง€ ๊ณ ๋ฏผํ•˜์ง€ ์•Š๊ณ ๋„, ๋…ธ๋“œ ๊ฐ„(inter-node) ๋ฐ ๋…ธ๋“œ ๋‚ด(intra-node) ํ”„๋กœ์„ธ์Šค ๊ทธ๋ฃน์„ ์†์‰ฝ๊ฒŒ ๋งŒ๋“ค ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ :class:`DeviceMesh` ๋ฅผ ํ†ตํ•ด ๋‹ค์ฐจ์› ๋ณ‘๋ ฌํ™”์— ์‚ฌ์šฉ๋˜๋Š” ๋‚ด๋ถ€์˜ ํ”„๋กœ์„ธ์Šค ๊ทธ๋ฃน๊ณผ ๋””๋ฐ”์ด์Šค๋ฅผ ์‰ฝ๊ฒŒ ๊ด€๋ฆฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

PyTorch DeviceMesh

DeviceMesh๊ฐ€ ์œ ์šฉํ•œ ์ด์œ 

DeviceMesh๋Š” ์—ฌ๋Ÿฌ ๋ณ‘๋ ฌํ™” ๋ฐฉ์‹์„ ์กฐํ•ฉ(composability)ํ•ด์•ผ ํ•˜๋Š” ๋‹ค์ฐจ์› ๋ณ‘๋ ฌํ™”(์˜ˆ: 3-D ๋ณ‘๋ ฌ)๋ฅผ ๋‹ค๋ฃฐ ๋•Œ ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ๋ณ‘๋ ฌํ™” ๋ฐฉ์‹์ด ํ˜ธ์ŠคํŠธ ๊ฐ„ ํ†ต์‹ ๊ณผ ๊ฐ ํ˜ธ์ŠคํŠธ ๋‚ด๋ถ€์˜ ํ†ต์‹ ์„ ๋ชจ๋‘ ์š”๊ตฌํ•˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๊ทธ๋ ‡์Šต๋‹ˆ๋‹ค. ์œ„ ์ด๋ฏธ์ง€๋Š” ๊ท ์ผํ•œ ํ™˜๊ฒฝ์—์„œ ๊ฐ ํ˜ธ์ŠคํŠธ ๋‚ด๋ถ€์˜ ๋””๋ฐ”์ด์Šค๋ฅผ ์—ฐ๊ฒฐํ•˜๊ณ , ๊ฐ ๋””๋ฐ”์ด์Šค๋ฅผ ๋‹ค๋ฅธ ํ˜ธ์ŠคํŠธ์˜ ๋Œ€์‘ ๋””๋ฐ”์ด์Šค์™€ ์—ฐ๊ฒฐํ•˜๋Š” 2D ๋ฉ”์‹œ๋ฅผ ๋งŒ๋“ค ์ˆ˜ ์žˆ์Œ์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

DeviceMesh๊ฐ€ ์—†๋‹ค๋ฉด, ์–ด๋–ค ๋ณ‘๋ ฌํ™”๋ฅผ ์ ์šฉํ•˜๊ธฐ ์ „์— ๊ฐ ํ”„๋กœ์„ธ์Šค๋งˆ๋‹ค NCCL ํ†ต์‹ ๊ธฐ์™€ CUDA ๋””๋ฐ”์ด์Šค๋ฅผ ์ง์ ‘ ์„ค์ •ํ•ด์•ผ ํ•˜๋ฉฐ, ์ด๋Š” ๊ฝค ๋ณต์žกํ•œ ์ž‘์—…์ž…๋‹ˆ๋‹ค. ๋‹ค์Œ ์ฝ”๋“œ๋Š” :class:`DeviceMesh` ์—†์ด ํ•˜์ด๋ธŒ๋ฆฌ๋“œ ์ƒค๋”ฉ(hybrid sharding) 2-D ๋ณ‘๋ ฌ ํŒจํ„ด์„ ์„ค์ •ํ•˜๋Š” ์˜ˆ์‹œ์ž…๋‹ˆ๋‹ค. ๋จผ์ € ์ƒค๋“œ(shard) ๊ทธ๋ฃน๊ณผ ๋ณต์ œ ๊ทธ๋ฃน์„ ์ง์ ‘ ๊ณ„์‚ฐํ•˜๊ณ , ๊ฐ ๋žญํฌ์— ์•Œ๋งž์€ ๊ทธ๋ฃน์„ ํ• ๋‹นํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

import os

import torch
import torch.distributed as dist

# ์›”๋“œ ํ† ํด๋กœ์ง€ ์ดํ•ด
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
print(f"Running example on {rank=} in a world with {world_size=}")

# 2-D ํ˜•ํƒœ์˜ ๋ณ‘๋ ฌ ํŒจํ„ด์„ ๊ด€๋ฆฌํ•˜๊ธฐ ์œ„ํ•œ ํ”„๋กœ์„ธ์Šค ๊ทธ๋ฃน ์ƒ์„ฑ
dist.init_process_group("nccl")
torch.cuda.set_device(rank)

# ์ƒค๋“œ ๊ทธ๋ฃน ์ƒ์„ฑ (์˜ˆ: (0, 1, 2, 3), (4, 5, 6, 7))
# ๊ฐ ๋žญํฌ์— ์˜ฌ๋ฐ”๋ฅธ ์ƒค๋“œ ๊ทธ๋ฃน ํ• ๋‹น
num_node_devices = torch.cuda.device_count()
shard_rank_lists = list(range(0, num_node_devices // 2)), list(range(num_node_devices // 2, num_node_devices))
shard_groups = (
    dist.new_group(shard_rank_lists[0]),
    dist.new_group(shard_rank_lists[1]),
)
current_shard_group = (
    shard_groups[0] if rank in shard_rank_lists[0] else shard_groups[1]
)

# ๋ณต์ œ ๊ทธ๋ฃน ์ƒ์„ฑ (์˜ˆ: (0, 4), (1, 5), (2, 6), (3, 7))
# ๊ฐ ๋žญํฌ์— ์˜ฌ๋ฐ”๋ฅธ ๋ณต์ œ ๊ทธ๋ฃน ํ• ๋‹น
current_replicate_group = None
shard_factor = len(shard_rank_lists[0])
for i in range(num_node_devices // 2):
    replicate_group_ranks = list(range(i, num_node_devices, shard_factor))
    replicate_group = dist.new_group(replicate_group_ranks)
    if rank in replicate_group_ranks:
        current_replicate_group = replicate_group

์œ„ ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•˜๋ ค๋ฉด PyTorch Elastic์„ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. 2d_setup.py ๋ผ๋Š” ํŒŒ์ผ์„ ๋งŒ๋“  ๋’ค, torch elastic/torchrun ๋ช…๋ น์„ ์‹คํ–‰ํ•˜์„ธ์š”.

torchrun --nproc_per_node=8 --rdzv_id=100 --rdzv_endpoint=localhost:29400 2d_setup.py

Note

์˜ˆ์‹œ๋ฅผ ๊ฐ„๋‹จํžˆ ๋ณด์—ฌ์ฃผ๊ธฐ ์œ„ํ•ด ๋‹จ์ผ ๋…ธ๋“œ๋งŒ ์‚ฌ์šฉํ•ด 2D ๋ณ‘๋ ฌ์„ ์‹œ๋ฎฌ๋ ˆ์ด์…˜ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ์ฝ”๋“œ๋Š” ๋ฉ€ํ‹ฐ ํ˜ธ์ŠคํŠธ ํ™˜๊ฒฝ์—์„œ๋„ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

:func:`init_device_mesh` ๋ฅผ ํ™œ์šฉํ•˜๋ฉด ์œ„์˜ 2D ์„ค์ •์„ ๋‹จ ๋‘ ์ค„๋กœ ๋๋‚ผ ์ˆ˜ ์žˆ๊ณ , ํ•„์š”ํ•  ๋•Œ๋Š” ๋‚ด๋ถ€์˜ :class:`ProcessGroup` ์—๋„ ๊ทธ๋Œ€๋กœ ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

from torch.distributed.device_mesh import init_device_mesh
mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("replicate", "shard"))

# `get_group` API๋ฅผ ํ†ตํ•ด ๋‚ด๋ถ€ ํ”„๋กœ์„ธ์Šค ๊ทธ๋ฃน์— ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
replicate_group = mesh_2d.get_group(mesh_dim="replicate")
shard_group = mesh_2d.get_group(mesh_dim="shard")

2d_setup_with_device_mesh.py ๋ผ๋Š” ํŒŒ์ผ์„ ๋งŒ๋“  ๋’ค, torch elastic/torchrun ๋ช…๋ น์„ ์‹คํ–‰ํ•˜์„ธ์š”.

torchrun --nproc_per_node=8 2d_setup_with_device_mesh.py

HSDP์—์„œ DeviceMesh๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•

Hybrid Sharding Data Parallel(HSDP)์€ ํ˜ธ์ŠคํŠธ ๋‚ด๋ถ€์—์„œ๋Š” FSDP๋ฅผ, ํ˜ธ์ŠคํŠธ ๊ฐ„์—๋Š” DDP๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š” 2D ์ „๋žต์ž…๋‹ˆ๋‹ค.

DeviceMesh๊ฐ€ ๊ฐ„๋‹จํ•œ ์„ค์ •์œผ๋กœ ๋ชจ๋ธ์— HSDP๋ฅผ ์ ์šฉํ•˜๋Š” ๋ฐ ์–ด๋–ป๊ฒŒ ๋„์›€์ด ๋˜๋Š”์ง€ ์˜ˆ์‹œ๋กœ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. DeviceMesh๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์ƒค๋“œ ๊ทธ๋ฃน๊ณผ ๋ณต์ œ ๊ทธ๋ฃน์„ ์ง์ ‘ ๋งŒ๋“ค๊ณ  ๊ด€๋ฆฌํ•˜์ง€ ์•Š์•„๋„ ๋ฉ๋‹ˆ๋‹ค.

import torch
import torch.nn as nn

from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard as FSDP


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


# HSDP: MeshShape(2, 4)
mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("dp_replicate", "dp_shard"))
model = FSDP(
    ToyModel(), device_mesh=mesh_2d
)

hsdp.py ๋ผ๋Š” ํŒŒ์ผ์„ ๋งŒ๋“  ๋’ค, torch elastic/torchrun ๋ช…๋ น์„ ์‹คํ–‰ํ•˜์„ธ์š”.

torchrun --nproc_per_node=8 hsdp.py

์‚ฌ์šฉ์ž ์ •์˜ ๋ณ‘๋ ฌ ๋ฐฉ์‹์—์„œ DeviceMesh๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•

๋Œ€๊ทœ๋ชจ ํ•™์Šต ํ™˜๊ฒฝ์—์„œ๋Š” ๋” ๋ณต์žกํ•œ ์‚ฌ์šฉ์ž ์ •์˜ ๋ณ‘๋ ฌ ํ•™์Šต ๊ตฌ์„ฑ์„ ๋‹ค๋ค„์•ผ ํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ์„œ๋กœ ๋‹ค๋ฅธ ๋ณ‘๋ ฌํ™” ๋ฐฉ์‹์— ๋งž์ถฐ ํ•˜์œ„ ๋ฉ”์‹œ(sub-mesh)๋ฅผ ์ž˜๋ผ๋‚ด์•ผ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. DeviceMesh๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์ƒ์œ„ ๋ฉ”์‹œ์—์„œ ํ•˜์œ„ ๋ฉ”์‹œ๋ฅผ ์ž˜๋ผ๋‚ด๊ณ , ์ƒ์œ„ ๋ฉ”์‹œ๋ฅผ ์ดˆ๊ธฐํ™”ํ•  ๋•Œ ์ด๋ฏธ ๋งŒ๋“ค์–ด์ง„ NCCL ํ†ต์‹ ๊ธฐ๋ฅผ ๊ทธ๋Œ€๋กœ ์žฌ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

from torch.distributed.device_mesh import init_device_mesh
mesh_3d = init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("replicate", "shard", "tp"))

# ์ƒ์œ„ ๋ฉ”์‹œ์—์„œ ํ•˜์œ„ ๋ฉ”์‹œ๋ฅผ ์ž˜๋ผ๋‚ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
hsdp_mesh = mesh_3d["replicate", "shard"]
tp_mesh = mesh_3d["tp"]

# `get_group` API๋ฅผ ํ†ตํ•ด ๋‚ด๋ถ€ ํ”„๋กœ์„ธ์Šค ๊ทธ๋ฃน์— ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
replicate_group = hsdp_mesh["replicate"].get_group()
shard_group = hsdp_mesh["shard"].get_group()
tp_group = tp_mesh.get_group()

๊ฒฐ๋ก 

์ง€๊ธˆ๊นŒ์ง€ :class:`DeviceMesh` ์™€ :func:`init_device_mesh` ๋ฅผ ์‚ดํŽด๋ณด๊ณ , ์ด๋ฅผ ํ™œ์šฉํ•ด ํด๋Ÿฌ์Šคํ„ฐ์— ๋ถ„์‚ฐ๋œ ๋””๋ฐ”์ด์Šค์˜ ๋ฐฐ์น˜๋ฅผ ํ‘œํ˜„ํ•˜๋Š” ๋ฐฉ๋ฒ•๋„ ์•Œ์•„๋ดค์Šต๋‹ˆ๋‹ค.

๋” ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๋‹ค์Œ ์ž๋ฃŒ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”.