Skip to content

Commit 4f1728a

Browse files
committed
fix: torch group modified
1 parent c4a1e7c commit 4f1728a

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

checkpoint_engine/ps.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ def update(
573573
is_master=self._rank == 0,
574574
)
575575
# if ranks is None or [], it will use fully broadcast to update to all ranks
576-
ranks_group = dist.new_group(ranks if ranks else None)
576+
ranks_group = dist.new_group(ranks) if ranks else None
577577
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
578578
self.store_based_barrier(manager_store)
579579
except Exception as e:
@@ -604,7 +604,7 @@ def zmq_handle(device_uuid: str) -> str:
604604
return socket, socket_paths
605605

606606
def _detect_bucket_size(
607-
self, ranks_group: dist.ProcessGroup, *, disable_h2d_buffer: bool = False
607+
self, ranks_group: dist.ProcessGroup | None, *, disable_h2d_buffer: bool = False
608608
) -> tuple[int, bool]:
609609
GiB = 1 << 30 # noqa: N806
610610
# auto detect bucket size
@@ -723,7 +723,7 @@ def _update_per_bucket(
723723
self,
724724
checkpoint_name: str,
725725
req_func: Callable[[list[tuple[str, str]]], None],
726-
ranks_group: dist.ProcessGroup,
726+
ranks_group: dist.ProcessGroup | None,
727727
ranks: list[int] | None = None,
728728
):
729729
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"

0 commit comments

Comments
 (0)