99import subprocess
1010import threading
1111import time
12- import uuid
1312from collections import defaultdict
1413from datetime import timedelta
15- from functools import cached_property , lru_cache
14+ from functools import lru_cache
1615from typing import Callable , NamedTuple
1716
1817import numpy as np
@@ -82,7 +81,7 @@ class MemoryBufferMetaList(BaseModel):
8281
8382class DataToGather (MemoryBufferMetaList ):
8483 host_ip : str
85- zmq_socket_path : tuple [ str , str ]
84+ device_uuid : str
8685
8786
8887# 256 bytes alignment when flatten torch tensors to uint8 buffer
@@ -516,13 +515,14 @@ def __init__(self, *, auto_pg: bool = False):
516515 self ._local_rank = self ._rank % self ._gpu_count
517516 self ._auto_pg = auto_pg
518517 self ._all_hosts = []
519- self ._global_socket_paths : list [tuple [ str , str ] ] = []
518+ self ._global_device_uuids : list [str ] = []
520519
521520 assert self ._rank is not None and self ._rank >= 0 , self ._rank
522521 assert self ._world_size and self ._world_size > 0 , self ._world_size
523522
524523 self ._device_uuid = _get_physical_gpu_id (self ._local_rank )
525524 self ._zmq_ctx = zmq .Context ()
525+ self ._zmq_addr_counter = 0
526526
527527 self ._memory_pool : dict [str , list [MemoryBuffer ]] = {}
528528 # dict key is owner_rank, value is a bucket metas list in owner_rank
@@ -583,10 +583,6 @@ def unregister_checkpoint(self, checkpoint_name: str):
583583 # this works by using torch>=2.5.0
584584 torch ._C ._host_emptyCache ()
585585
586- @cached_property
587- def _zmq_socket_path (self ) -> str :
588- return f"ipc://@checkpoint-engine-{ uuid .uuid4 ()} .sock"
589-
590586 def gather_metas (self , checkpoint_name : str ):
591587 """
592588 Gather the parameter metas from all ranks. This will gather memory_buffer, and other metadatas.
@@ -610,28 +606,28 @@ def gather_metas(self, checkpoint_name: str):
610606 ),
611607 p2p_store_addr = None if self ._p2p_store is None else self ._p2p_store .addr ,
612608 host_ip = _get_ip (),
613- zmq_socket_path = ( self ._device_uuid , self . _zmq_socket_path ) ,
609+ device_uuid = self ._device_uuid ,
614610 )
615611
616612 dist .all_gather_object (metas_lst , metas )
617613
618614 self ._current_global_parameter_metas = {}
619615 num_parameters = 0
620616 all_hosts : list [str ] = []
621- global_socket_paths : list [tuple [ str , str ] ] = []
617+ global_device_uuids : list [str ] = []
622618 for i , metas_buckets in enumerate (metas_lst ):
623619 assert metas_buckets is not None , f"metas_buckets { i } should not be None"
624620 if i % self ._gpu_count == 0 and not self ._all_hosts :
625621 all_hosts .append (metas_buckets .host_ip )
626- if not self ._global_socket_paths :
627- global_socket_paths .append (metas_buckets .zmq_socket_path )
622+ if not self ._global_device_uuids :
623+ global_device_uuids .append (metas_buckets .device_uuid )
628624 if metas_buckets .memory_buffer_metas_list :
629625 self ._current_global_parameter_metas [i ] = metas_buckets
630626 num_parameters += sum (map (lambda x : len (x .metas ), metas_buckets .memory_buffer_metas_list ))
631627 if not self ._all_hosts :
632628 self ._all_hosts = all_hosts
633- if not self ._global_socket_paths :
634- self ._global_socket_paths = global_socket_paths
629+ if not self ._global_device_uuids :
630+ self ._global_device_uuids = global_device_uuids
635631 logger .info (f"[rank{ self ._rank } ] gather parameter metas finished, num_parameters: { num_parameters } " )
636632
637633 def init_process_group (self , * , master_port : int | None = None , timeout : timedelta = timedelta (minutes = 10 )):
@@ -695,18 +691,35 @@ def update(
695691 logger .exception (f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } with ranks { ranks } error { e } " )
696692 raise e
697693
698- def _get_bucket_size (self , * , disable_h2d_buffer : bool = False ) -> tuple [int , bool ]:
694+ def _bind_zmq_socket (self ) -> tuple [zmq .Socket , list [tuple [str , str ]]]:
695+ zmq_handle = lambda device_uuid : (
696+ f"ipc://@checkpoint-engine-{ device_uuid } -{ self ._zmq_addr_counter } .sock"
697+ )
698+ socket_paths = [(uid , zmq_handle (uid ))
699+ for uid in self ._global_device_uuids ]
700+ socket = self ._zmq_ctx .socket (zmq .REQ )
701+ socket .bind (zmq_handle (self ._device_uuid ))
702+ self ._zmq_addr_counter += 1
703+ return socket , socket_paths
704+
705+ def _detect_bucket_size (self , * , disable_h2d_buffer : bool = False ) -> tuple [int , bool ]:
699706 GiB_bytes = 1 << 30
700707 # auto detect bucket size
701- free_bytes_tensor = torch .tensor (
702- int (float (torch .cuda .mem_get_info ()[0 ]) * 0.9 ),
708+ tensor = torch .tensor (
709+ [
710+ # 90% of current cuda free memory bytes
711+ int (float (torch .cuda .mem_get_info ()[0 ]) * 0.9 ),
712+ # we use negative value to reuse allreduce min operation
713+ # for getting the max value of zmq_addr_counter in all ranks
714+ - self ._zmq_addr_counter ,
715+ ],
703716 dtype = torch .int64 ,
704717 device = "cuda" ,
705718 )
706- dist .all_reduce (free_bytes_tensor , op = dist .ReduceOp .MIN )
707- free_bytes = free_bytes_tensor .item ()
719+ dist .all_reduce (tensor , op = dist .ReduceOp .MIN )
720+ tensor = tensor .cpu ()
721+ free_bytes , self ._zmq_addr_counter = tensor [0 ].item (), - tensor [1 ].item ()
708722 max_tensor_bytes = 0
709- max_bytes = int (os .getenv ("PS_MAX_BUCKET_SIZE_GB" , 8 )) * GiB_bytes
710723 for items in self ._current_global_parameter_metas .values ():
711724 for metas_list in items .memory_buffer_metas_list :
712725 for meta in metas_list .metas :
@@ -729,6 +742,7 @@ def _get_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, bo
729742 f"max_tensor_bytes { max_tensor_bytes } should be less than free_bytes { free_bytes } "
730743 )
731744 disable_h2d_buffer = True
745+ max_bytes = int (os .getenv ("PS_MAX_BUCKET_SIZE_GB" , 8 )) * GiB_bytes
732746 bucket_size = min (max (max_bytes , max_tensor_bytes ), free_bytes )
733747 logger .info (f"[rank{ self ._rank } ] auto detect bucket size { bucket_size / GiB_bytes :.2f} GiB" )
734748 return bucket_size , disable_h2d_buffer
@@ -814,7 +828,7 @@ def _update_per_bucket_p2p(
814828 # first execute a barrier to avoid subsequent cuda oom
815829 dist .barrier ()
816830
817- bucket_size , _ = self ._get_bucket_size (disable_h2d_buffer = True )
831+ bucket_size , _ = self ._detect_bucket_size (disable_h2d_buffer = True )
818832 buffer = torch .empty (bucket_size * 2 , dtype = torch .uint8 , device = "cuda" )
819833 IPC_BUFFER_NAME = "__ipc_buffer___"
820834 self ._p2p_store .register_named_tensors ({IPC_BUFFER_NAME : buffer })
@@ -825,13 +839,12 @@ def _update_per_bucket_p2p(
825839
826840 gidx = 0
827841 buckets = _gen_h2d_buckets (self ._current_global_parameter_metas , bucket_size )
842+ socket , socket_paths = self ._bind_zmq_socket ()
828843 req_thread = threading .Thread (
829844 target = req_func ,
830- args = (self . _global_socket_paths ,),
845+ args = (socket_paths ,),
831846 )
832847 req_thread .start ()
833- socket = self ._zmq_ctx .socket (zmq .REQ )
834- socket .bind (self ._zmq_socket_path )
835848 socket .send_pyobj (handle )
836849 for owner_rank , bucket in buckets :
837850 self ._logger_rank0 (
@@ -891,7 +904,7 @@ def _update_per_bucket(
891904
892905 logger .info (f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } " )
893906
894- bucket_size , disable_h2d_buffer = self ._get_bucket_size ()
907+ bucket_size , disable_h2d_buffer = self ._detect_bucket_size ()
895908 buckets = _gen_h2d_buckets (self ._current_global_parameter_metas , bucket_size )
896909
897910 h2d_buffer : torch .Tensor | None = (
@@ -914,13 +927,12 @@ def _update_per_bucket(
914927 if len (buckets_by_owner_rank [owner_rank ]) > max_len :
915928 max_len = len (buckets_by_owner_rank [owner_rank ])
916929
930+ socket , socket_paths = self ._bind_zmq_socket ()
917931 req_thread = threading .Thread (
918932 target = req_func ,
919- args = (self . _global_socket_paths ,),
933+ args = (socket_paths ,),
920934 )
921935 req_thread .start ()
922- socket = self ._zmq_ctx .socket (zmq .REQ )
923- socket .bind (self ._zmq_socket_path )
924936 socket .send_pyobj (handle )
925937
926938 gidx = 0
0 commit comments