@@ -553,22 +553,19 @@ def update(
553553 try :
554554 master_addr = os .getenv ("MASTER_ADDR" ) or master_addr
555555 assert master_addr , "master_addr is required"
556- if self ._auto_pg :
557- if not dist .is_initialized ():
558- self .init_process_group (
559- timeout = timeout , master_addr = master_addr , master_port = master_port
560- )
561- manager_store = torch .distributed .distributed_c10d ._get_default_store ()
562- else :
563- # HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
564- # If master_port is provided, use master_port+1 for barrier store
565- manager_store = torch .distributed .TCPStore (
566- master_addr ,
567- _get_master_port (master_port ) + 1 ,
568- self ._world_size ,
569- timeout = timeout ,
570- is_master = self ._rank == 0 ,
556+ if self ._auto_pg and not dist .is_initialized ():
557+ self .init_process_group (
558+ timeout = timeout , master_addr = master_addr , master_port = master_port
571559 )
560+ # HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
561+ # If master_port is provided, use master_port+1 for barrier store
562+ manager_store = torch .distributed .TCPStore (
563+ master_addr ,
564+ _get_master_port (master_port ) + 1 ,
565+ self ._world_size ,
566+ timeout = timeout ,
567+ is_master = self ._rank == 0 ,
568+ )
572569 # if ranks is None or [], it will use fully broadcast to update to all ranks
573570 ranks_group = dist .new_group (ranks ) if ranks else None
574571 self ._update_per_bucket (checkpoint_name , req_func , ranks_group , ranks )
@@ -583,6 +580,7 @@ def update(
583580 dist .destroy_process_group (ranks_group )
584581 if self ._auto_pg and dist .is_initialized ():
585582 dist .destroy_process_group ()
583+ del manager_store
586584 self .device_manager .device_module .empty_cache ()
587585 logger .info (
588586 f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } with ranks { ranks } done. "
0 commit comments