|
20 | 20 | import os |
21 | 21 | import time |
22 | 22 | from collections.abc import Callable |
| 23 | +from contextlib import suppress |
| 24 | +from datetime import timedelta |
23 | 25 | from typing import Any |
24 | 26 |
|
25 | 27 | import torch |
@@ -70,11 +72,23 @@ def rank(group=None) -> int: |
70 | 72 | return 0 |
71 | 73 |
|
72 | 74 |
|
| 75 | +def local_rank() -> int: |
| 76 | + """Returns the local rank of the current process.""" |
| 77 | + if "LOCAL_RANK" in os.environ: |
| 78 | + return int(os.environ["LOCAL_RANK"]) |
| 79 | + raise RuntimeError("LOCAL_RANK environment variable not found.") |
| 80 | + |
| 81 | + |
73 | 82 | def is_master(group=None) -> bool: |
74 | 83 | """Returns whether the current process is the master process.""" |
75 | 84 | return rank(group=group) == 0 |
76 | 85 |
|
77 | 86 |
|
| 87 | +def is_last_process(group=None) -> bool: |
| 88 | + """Returns whether the current process is the last process.""" |
| 89 | + return rank(group=group) == size(group=group) - 1 |
| 90 | + |
| 91 | + |
78 | 92 | def _serialize(obj: Any) -> torch.Tensor: |
79 | 93 | buffer = io.BytesIO() |
80 | 94 | torch.save(obj, buffer) |
@@ -184,6 +198,21 @@ def wrapper(*args, **kwargs): |
184 | 198 | return wrapper |
185 | 199 |
|
186 | 200 |
|
| 201 | +def setup(timeout: timedelta | None = None): |
| 202 | + """Sets up the distributed environment.""" |
| 203 | + torch.cuda.set_device(local_rank()) |
| 204 | + if not is_initialized(): |
| 205 | + torch.distributed.init_process_group("cpu:gloo,cuda:nccl", timeout=timeout) |
| 206 | + |
| 207 | + |
| 208 | +def cleanup(): |
| 209 | + """Cleans up the distributed environment.""" |
| 210 | + if is_initialized(): |
| 211 | + with suppress(Exception): |
| 212 | + barrier() |
| 213 | + torch.distributed.destroy_process_group() |
| 214 | + |
| 215 | + |
187 | 216 | class DistributedProcessGroup: |
188 | 217 | """A convenient wrapper around torch.distributed.ProcessGroup objects.""" |
189 | 218 |
|
|
0 commit comments