@@ -22,10 +22,9 @@ class RankAwareSampler(BaseSampler):
2222 """Rank-aware sampler for distributed training with TransferQueue.
2323
2424 This sampler is designed for distributed data parallel training scenarios
25- where each ranks independently retrieve data by themselves .
25+ where each rank retrieves data independently .
2626
27- Each rank independently calls the sampler, passing its own rank information,
28- and the sampler guarantees that all ranks within the same DP group receive
27+ This sampler guarantees that all ranks within the same DP group receive
2928 the same sample indices.
3029
3130 The sampler maintains per-DP-group state to coordinate sampling across ranks:
@@ -72,7 +71,7 @@ def sample(
7271 ready_indexes: List of global indices for which all required fields of the
7372 corresponding samples have been produced, and the samples are not labeled
7473 as consumed in the corresponding task.
75- batch_size: batch_size: Number of samples to select. If larger than available
74+ batch_size: Number of samples to select. If larger than available
7675 ready samples, all available samples will be returned.
7776 dp_group: The group id of current data parallel group. Used to
7877 identify which DP group this rank belongs to.
@@ -89,9 +88,8 @@ def sample(
8988 List of global indices of length batch_size that should be labeled as consumed
9089 (will never be retrieved in the future)
9190
92- Raises:
93- RuntimeError: If the fetch count exceeds the expected number of
94- fetches per DP group.
91+ Raise:
92+ RuntimeError: If ``world_size`` is not divisible by ``dp_world_size``.
9593
9694 Note:
9795 The ``world_size // dp_world_size`` calculation determines how many
@@ -102,6 +100,9 @@ def sample(
102100 data_for_dp_group = self ._states .get (dp_group , None )
103101
104102 # Calculate how many times this batch should be fetched across all ranks
103+ if world_size % dp_world_size != 0 :
104+ raise RuntimeError (f"world_size ({ world_size } ) is not divisible by dp_world_size ({ dp_world_size } )" )
105+
105106 fetches_per_batch = world_size // dp_world_size
106107
107108 if data_for_dp_group is None :
0 commit comments