Skip to content

Commit 68d604d

Browse files
Move new puzzle dist utils from feature/compress to main (#746)
- Move new `modelopt.torch.utils.distributed` from `feature/compress` to `main` branch so they can be used via modelopt in puzzletron gitlab Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 9a3b986 commit 68d604d

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

modelopt/torch/utils/distributed.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import os
2121
import time
2222
from collections.abc import Callable
23+
from contextlib import suppress
24+
from datetime import timedelta
2325
from typing import Any
2426

2527
import torch
@@ -70,11 +72,23 @@ def rank(group=None) -> int:
7072
return 0
7173

7274

75+
def local_rank() -> int:
76+
"""Returns the local rank of the current process."""
77+
if "LOCAL_RANK" in os.environ:
78+
return int(os.environ["LOCAL_RANK"])
79+
raise RuntimeError("LOCAL_RANK environment variable not found.")
80+
81+
7382
def is_master(group=None) -> bool:
7483
"""Returns whether the current process is the master process."""
7584
return rank(group=group) == 0
7685

7786

87+
def is_last_process(group=None) -> bool:
88+
"""Returns whether the current process is the last process."""
89+
return rank(group=group) == size(group=group) - 1
90+
91+
7892
def _serialize(obj: Any) -> torch.Tensor:
7993
buffer = io.BytesIO()
8094
torch.save(obj, buffer)
@@ -184,6 +198,21 @@ def wrapper(*args, **kwargs):
184198
return wrapper
185199

186200

201+
def setup(timeout: timedelta | None = None):
202+
"""Sets up the distributed environment."""
203+
torch.cuda.set_device(local_rank())
204+
if not is_initialized():
205+
torch.distributed.init_process_group("cpu:gloo,cuda:nccl", timeout=timeout)
206+
207+
208+
def cleanup():
209+
"""Cleans up the distributed environment."""
210+
if is_initialized():
211+
with suppress(Exception):
212+
barrier()
213+
torch.distributed.destroy_process_group()
214+
215+
187216
class DistributedProcessGroup:
188217
"""A convenient wrapper around torch.distributed.ProcessGroup objects."""
189218

0 commit comments

Comments
 (0)