@@ -573,7 +573,7 @@ def update(
573573 is_master = self ._rank == 0 ,
574574 )
575575 # if ranks is None or [], it will use fully broadcast to update to all ranks
576- ranks_group = dist .new_group (ranks if ranks else None )
576+ ranks_group = dist .new_group (ranks ) if ranks else None
577577 self ._update_per_bucket (checkpoint_name , req_func , ranks_group , ranks )
578578 self .store_based_barrier (manager_store )
579579 except Exception as e :
@@ -604,7 +604,7 @@ def zmq_handle(device_uuid: str) -> str:
604604 return socket , socket_paths
605605
606606 def _detect_bucket_size (
607- self , ranks_group : dist .ProcessGroup , * , disable_h2d_buffer : bool = False
607+ self , ranks_group : dist .ProcessGroup | None , * , disable_h2d_buffer : bool = False
608608 ) -> tuple [int , bool ]:
609609 GiB = 1 << 30 # noqa: N806
610610 # auto detect bucket size
@@ -723,7 +723,7 @@ def _update_per_bucket(
723723 self ,
724724 checkpoint_name : str ,
725725 req_func : Callable [[list [tuple [str , str ]]], None ],
726- ranks_group : dist .ProcessGroup ,
726+ ranks_group : dist .ProcessGroup | None ,
727727 ranks : list [int ] | None = None ,
728728 ):
729729 assert len (self ._current_global_parameter_metas ) != 0 , "parameter metas is empty"
0 commit comments