Skip to content

Commit 0be3ff1

Browse files
feat: use zmq_addr_counter to random zmq_handle for each update (#4)
Use a non-repeated zmq_handle for each update by zmq_addr_counter as the same logic from https://github.com/vllm-project/vllm/blob/v0.10.2rc3/examples/offline_inference/rlhf_colocate.py#L102
1 parent 357eee6 commit 0be3ff1

2 files changed

Lines changed: 47 additions & 31 deletions

File tree

checkpoint_engine/ps.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@
99
import subprocess
1010
import threading
1111
import time
12-
import uuid
1312
from collections import defaultdict
1413
from datetime import timedelta
15-
from functools import cached_property, lru_cache
14+
from functools import lru_cache
1615
from typing import Callable, NamedTuple
1716

1817
import numpy as np
@@ -82,7 +81,7 @@ class MemoryBufferMetaList(BaseModel):
8281

8382
class DataToGather(MemoryBufferMetaList):
8483
host_ip: str
85-
zmq_socket_path: tuple[str, str]
84+
device_uuid: str
8685

8786

8887
# 256 bytes alignment when flatten torch tensors to uint8 buffer
@@ -516,13 +515,14 @@ def __init__(self, *, auto_pg: bool = False):
516515
self._local_rank = self._rank % self._gpu_count
517516
self._auto_pg = auto_pg
518517
self._all_hosts = []
519-
self._global_socket_paths: list[tuple[str, str]] = []
518+
self._global_device_uuids: list[str] = []
520519

521520
assert self._rank is not None and self._rank >= 0, self._rank
522521
assert self._world_size and self._world_size > 0, self._world_size
523522

524523
self._device_uuid = _get_physical_gpu_id(self._local_rank)
525524
self._zmq_ctx = zmq.Context()
525+
self._zmq_addr_counter = 0
526526

527527
self._memory_pool: dict[str, list[MemoryBuffer]] = {}
528528
# dict key is owner_rank, value is a bucket metas list in owner_rank
@@ -583,10 +583,6 @@ def unregister_checkpoint(self, checkpoint_name: str):
583583
# this works by using torch>=2.5.0
584584
torch._C._host_emptyCache()
585585

586-
@cached_property
587-
def _zmq_socket_path(self) -> str:
588-
return f"ipc://@checkpoint-engine-{uuid.uuid4()}.sock"
589-
590586
def gather_metas(self, checkpoint_name: str):
591587
"""
592588
Gather the parameter metas from all ranks. This will gather memory_buffer, and other metadatas.
@@ -610,28 +606,28 @@ def gather_metas(self, checkpoint_name: str):
610606
),
611607
p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
612608
host_ip=_get_ip(),
613-
zmq_socket_path=(self._device_uuid, self._zmq_socket_path),
609+
device_uuid=self._device_uuid,
614610
)
615611

616612
dist.all_gather_object(metas_lst, metas)
617613

618614
self._current_global_parameter_metas = {}
619615
num_parameters = 0
620616
all_hosts: list[str] = []
621-
global_socket_paths: list[tuple[str, str]] = []
617+
global_device_uuids: list[str] = []
622618
for i, metas_buckets in enumerate(metas_lst):
623619
assert metas_buckets is not None, f"metas_buckets {i} should not be None"
624620
if i % self._gpu_count == 0 and not self._all_hosts:
625621
all_hosts.append(metas_buckets.host_ip)
626-
if not self._global_socket_paths:
627-
global_socket_paths.append(metas_buckets.zmq_socket_path)
622+
if not self._global_device_uuids:
623+
global_device_uuids.append(metas_buckets.device_uuid)
628624
if metas_buckets.memory_buffer_metas_list:
629625
self._current_global_parameter_metas[i] = metas_buckets
630626
num_parameters += sum(map(lambda x: len(x.metas), metas_buckets.memory_buffer_metas_list))
631627
if not self._all_hosts:
632628
self._all_hosts = all_hosts
633-
if not self._global_socket_paths:
634-
self._global_socket_paths = global_socket_paths
629+
if not self._global_device_uuids:
630+
self._global_device_uuids = global_device_uuids
635631
logger.info(f"[rank{self._rank}] gather parameter metas finished, num_parameters: {num_parameters}")
636632

637633
def init_process_group(self, *, master_port: int | None = None, timeout: timedelta = timedelta(minutes=10)):
@@ -695,18 +691,35 @@ def update(
695691
logger.exception(f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}")
696692
raise e
697693

698-
def _get_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, bool]:
694+
def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
695+
zmq_handle = lambda device_uuid: (
696+
f"ipc://@checkpoint-engine-{device_uuid}-{self._zmq_addr_counter}.sock"
697+
)
698+
socket_paths = [(uid, zmq_handle(uid))
699+
for uid in self._global_device_uuids]
700+
socket = self._zmq_ctx.socket(zmq.REQ)
701+
socket.bind(zmq_handle(self._device_uuid))
702+
self._zmq_addr_counter += 1
703+
return socket, socket_paths
704+
705+
def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, bool]:
699706
GiB_bytes = 1 << 30
700707
# auto detect bucket size
701-
free_bytes_tensor = torch.tensor(
702-
int(float(torch.cuda.mem_get_info()[0]) * 0.9),
708+
tensor = torch.tensor(
709+
[
710+
# 90% of current cuda free memory bytes
711+
int(float(torch.cuda.mem_get_info()[0]) * 0.9),
712+
# we use negative value to reuse allreduce min operation
713+
# for getting the max value of zmq_addr_counter in all ranks
714+
-self._zmq_addr_counter,
715+
],
703716
dtype=torch.int64,
704717
device="cuda",
705718
)
706-
dist.all_reduce(free_bytes_tensor, op=dist.ReduceOp.MIN)
707-
free_bytes = free_bytes_tensor.item()
719+
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
720+
tensor = tensor.cpu()
721+
free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
708722
max_tensor_bytes = 0
709-
max_bytes = int(os.getenv("PS_MAX_BUCKET_SIZE_GB", 8)) * GiB_bytes
710723
for items in self._current_global_parameter_metas.values():
711724
for metas_list in items.memory_buffer_metas_list:
712725
for meta in metas_list.metas:
@@ -729,6 +742,7 @@ def _get_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, bo
729742
f"max_tensor_bytes {max_tensor_bytes} should be less than free_bytes {free_bytes}"
730743
)
731744
disable_h2d_buffer = True
745+
max_bytes = int(os.getenv("PS_MAX_BUCKET_SIZE_GB", 8)) * GiB_bytes
732746
bucket_size = min(max(max_bytes, max_tensor_bytes), free_bytes)
733747
logger.info(f"[rank{self._rank}] auto detect bucket size {bucket_size / GiB_bytes:.2f} GiB")
734748
return bucket_size, disable_h2d_buffer
@@ -814,7 +828,7 @@ def _update_per_bucket_p2p(
814828
# first execute a barrier to avoid subsequent cuda oom
815829
dist.barrier()
816830

817-
bucket_size, _ = self._get_bucket_size(disable_h2d_buffer=True)
831+
bucket_size, _ = self._detect_bucket_size(disable_h2d_buffer=True)
818832
buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
819833
IPC_BUFFER_NAME = "__ipc_buffer___"
820834
self._p2p_store.register_named_tensors({IPC_BUFFER_NAME: buffer})
@@ -825,13 +839,12 @@ def _update_per_bucket_p2p(
825839

826840
gidx = 0
827841
buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
842+
socket, socket_paths = self._bind_zmq_socket()
828843
req_thread = threading.Thread(
829844
target=req_func,
830-
args=(self._global_socket_paths,),
845+
args=(socket_paths,),
831846
)
832847
req_thread.start()
833-
socket = self._zmq_ctx.socket(zmq.REQ)
834-
socket.bind(self._zmq_socket_path)
835848
socket.send_pyobj(handle)
836849
for owner_rank, bucket in buckets:
837850
self._logger_rank0(
@@ -891,7 +904,7 @@ def _update_per_bucket(
891904

892905
logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
893906

894-
bucket_size, disable_h2d_buffer = self._get_bucket_size()
907+
bucket_size, disable_h2d_buffer = self._detect_bucket_size()
895908
buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
896909

897910
h2d_buffer: torch.Tensor | None = (
@@ -914,13 +927,12 @@ def _update_per_bucket(
914927
if len(buckets_by_owner_rank[owner_rank]) > max_len:
915928
max_len = len(buckets_by_owner_rank[owner_rank])
916929

930+
socket, socket_paths = self._bind_zmq_socket()
917931
req_thread = threading.Thread(
918932
target=req_func,
919-
args=(self._global_socket_paths,),
933+
args=(socket_paths,),
920934
)
921935
req_thread.start()
922-
socket = self._zmq_ctx.socket(zmq.REQ)
923-
socket.bind(self._zmq_socket_path)
924936
socket.send_pyobj(handle)
925937

926938
gidx = 0

tests/test_update.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import random
3+
import time
34

45
import torch
56
import zmq
@@ -56,7 +57,7 @@ def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str,
5657
assert all(names_to_check.values())
5758

5859
while True:
59-
socket_paths = queue.get()
60+
socket_paths: list[tuple[str, str]] = queue.get()
6061
if socket_paths is None:
6162
break
6263
names_to_check = {name: False for name in named_tensors.keys()}
@@ -76,8 +77,11 @@ def run():
7677
proc.start()
7778
ps.register_checkpoint(checkpoint_name, named_tensors=named_tensors)
7879
ps.gather_metas(checkpoint_name)
79-
ps.update(checkpoint_name, queue.put)
80-
ps.update(checkpoint_name, queue.put, ranks=list(range(world_size)))
80+
ranks_list = [[], list(range(world_size // 2)), [], list(range(world_size))]
81+
for ranks in ranks_list:
82+
ps.update(checkpoint_name, queue.put, ranks=ranks)
83+
# sleep 3s to wait process group is destroyed
84+
time.sleep(3)
8185
ps.unregister_checkpoint(checkpoint_name)
8286
queue.put(None)
8387
proc.join()

0 commit comments

Comments
 (0)