Skip to content

Commit 6a638fc

Browse files
0oshowero0zhabuye
andcommitted
fix gramma & add more check
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com> Co-authored-by: zhabuye <2947436155@qq.com>
1 parent fe8c9cc commit 6a638fc

3 files changed

Lines changed: 11 additions & 11 deletions

File tree

tests/test_samplers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def test_rank_aware_sampler_multiple_dp_groups(self):
499499
sampled0_g0, consumed0_g0 = sampler.sample(
500500
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
501501
)
502-
# minic the consumption status update managed in TransferQueueController
502+
# mimic the consumption status update managed in TransferQueueController
503503
ready_indexes = [i for i in ready_indexes if i not in consumed0_g0]
504504

505505
# DP group 1: rank 0 samples first

transfer_queue/sampler/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@ class BaseSampler(ABC):
3434
- **SequentialSampler**: Default sampler, selects samples sequentially without replacement
3535
- **GRPOGroupNSampler**: A sampler that performs sampling on continuous N samples only when all of them are ready.
3636
It assumes the N samples associated with the same prompt are stored contiguously
37-
- **RankAwareSampler**: Rank-aware sampling for distributed training where each ranks independently retrieve data
38-
by themselves. This sampler will guarantee ranks of the same DP group consume identical
39-
samples.
37+
- **RankAwareSampler**: Rank-aware sampling for distributed training where each rank retrieves data independently.
38+
This sampler will guarantee ranks of the same DP group consume identical samples.
4039
4140
NOTE: Always return both sampled and consumed indexes (may be identical).
4241
"""

transfer_queue/sampler/rank_aware_sampler.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@ class RankAwareSampler(BaseSampler):
2222
"""Rank-aware sampler for distributed training with TransferQueue.
2323
2424
This sampler is designed for distributed data parallel training scenarios
25-
where each ranks independently retrieve data by themselves.
25+
where each rank retrieves data independently.
2626
27-
Each rank independently calls the sampler, passing its own rank information,
28-
and the sampler guarantees that all ranks within the same DP group receive
27+
This sampler guarantees that all ranks within the same DP group receive
2928
the same sample indices.
3029
3130
The sampler maintains per-DP-group state to coordinate sampling across ranks:
@@ -72,7 +71,7 @@ def sample(
7271
ready_indexes: List of global indices for which all required fields of the
7372
corresponding samples have been produced, and the samples are not labeled
7473
as consumed in the corresponding task.
75-
batch_size: batch_size: Number of samples to select. If larger than available
74+
batch_size: Number of samples to select. If larger than available
7675
ready samples, all available samples will be returned.
7776
dp_group: The group id of current data parallel group. Used to
7877
identify which DP group this rank belongs to.
@@ -89,9 +88,8 @@ def sample(
8988
List of global indices of length batch_size that should be labeled as consumed
9089
(will never be retrieved in the future)
9190
92-
Raises:
93-
RuntimeError: If the fetch count exceeds the expected number of
94-
fetches per DP group.
91+
Raise:
92+
RuntimeError: If ``world_size`` is not divisible by ``dp_world_size``.
9593
9694
Note:
9795
The ``world_size // dp_world_size`` calculation determines how many
@@ -102,6 +100,9 @@ def sample(
102100
data_for_dp_group = self._states.get(dp_group, None)
103101

104102
# Calculate how many times this batch should be fetched across all ranks
103+
if world_size % dp_world_size != 0:
104+
raise RuntimeError(f"world_size ({world_size}) is not divisible by dp_world_size ({dp_world_size})")
105+
105106
fetches_per_batch = world_size // dp_world_size
106107

107108
if data_for_dp_group is None:

0 commit comments

Comments
 (0)