Skip to content

Commit 7a85b2c

Browse files
0oshowero0ji-huazhongbaymax591jianjunzhongLLLLxmmm
committed
implement RankAwareSampler
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com> Co-authored-by: ji-huazhong <hzji210@gmail.com> Co-authored-by: baymax591 <cbai@mail.nwpu.edu.cn> Co-authored-by: jianjunzhong <jianjunzhong@foxmail.com> Co-authored-by: LLLLxmmm <liuqianmeng@huawei.com> Co-authored-by: dpj135 <958208521@qq.com> Co-authored-by: Evelynn-V <liwenlin0223l@gmail.com> Co-authored-by: liujia7 <liujia7@xiaohongshu.com> Co-authored-by: 赵海源 <zhaohaiyuan@xiaohongshu.com> Co-authored-by: NINGBENZHE <ningbenzhe@xiaohongshu.com>
1 parent d292bec commit 7a85b2c

5 files changed

Lines changed: 302 additions & 6 deletions

File tree

tests/test_samplers.py

Lines changed: 160 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,20 @@
1515

1616
"""Unit tests for TransferQueue samplers."""
1717

18+
import sys
19+
from pathlib import Path
1820
from typing import Any
1921

2022
import pytest
2123

22-
from transfer_queue.sampler import BaseSampler
23-
from transfer_queue.sampler.grpo_group_n_sampler import GRPOGroupNSampler
24-
from transfer_queue.sampler.sequential_sampler import SequentialSampler
24+
# Setup path
25+
parent_dir = Path(__file__).resolve().parent.parent
26+
sys.path.append(str(parent_dir))
27+
28+
from transfer_queue.sampler import BaseSampler # noqa: E402
29+
from transfer_queue.sampler.grpo_group_n_sampler import GRPOGroupNSampler # noqa: E402
30+
from transfer_queue.sampler.rank_aware_sampler import RankAwareSampler # noqa: E402
31+
from transfer_queue.sampler.sequential_sampler import SequentialSampler # noqa: E402
2532

2633

2734
class TestBaseSampler:
@@ -427,6 +434,156 @@ def test_grpo_sampler_insufficient_groups(self):
427434
assert consumed == []
428435

429436

437+
class TestRankAwareSampler:
438+
"""Test cases for RankAwareSampler."""
439+
440+
def test_rank_aware_sampler_initialization(self):
441+
"""Test RankAwareSampler initialization."""
442+
sampler = RankAwareSampler()
443+
assert isinstance(sampler, BaseSampler)
444+
assert hasattr(sampler, "_states")
445+
assert sampler._states == {}
446+
447+
def test_rank_aware_sampler_first_rank_sampling(self):
448+
"""Test that first rank in DP group performs actual sampling."""
449+
sampler = RankAwareSampler()
450+
ready_indexes = [0, 1, 2, 3, 4, 5]
451+
batch_size = 3
452+
453+
# When world_size == dp_world_size, fetches_per_batch = 1
454+
# First rank samples and immediately marks consumed (no other ranks to wait for)
455+
sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)
456+
457+
assert sampled == [0, 1, 2]
458+
# consumed is returned
459+
assert consumed == [0, 1, 2]
460+
assert len(sampled) == batch_size
461+
# State should be cleaned up
462+
assert sampler._states == {}
463+
464+
def test_rank_aware_sampler_second_rank_gets_cached(self):
465+
"""Test that second rank in DP group gets cached indices."""
466+
sampler = RankAwareSampler()
467+
ready_indexes = [0, 1, 2, 3, 4, 5]
468+
batch_size = 3
469+
dp_world_size = 2
470+
world_size = 4 # Use world_size=4 so fetches_per_batch=2
471+
472+
# Rank 0 (dp_group=0) samples first
473+
sampled1, consumed1 = sampler.sample(
474+
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
475+
)
476+
477+
# Rank 1 (dp_group=0) should get same cached indices
478+
sampled2, consumed2 = sampler.sample(
479+
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
480+
)
481+
482+
assert sampled1 == sampled2 == [0, 1, 2]
483+
# First rank returns empty consumed (not all ranks have fetched yet)
484+
assert consumed1 == [0, 1, 2]
485+
# Last rank returns consumed when all ranks have fetched
486+
assert consumed2 == [0, 1, 2]
487+
# State should be cleaned up
488+
assert sampler._states == {}
489+
490+
def test_rank_aware_sampler_multiple_dp_groups(self):
491+
"""Test that multiple DP groups work independently."""
492+
sampler = RankAwareSampler()
493+
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7]
494+
batch_size = 2
495+
dp_world_size = 4
496+
world_size = 8
497+
498+
# DP group 0: rank 0 samples first
499+
sampled0_g0, consumed0_g0 = sampler.sample(
500+
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
501+
)
502+
# minic the consumption status update managed in TransferQueueController
503+
ready_indexes = [i for i in ready_indexes if i not in consumed0_g0]
504+
505+
# DP group 1: rank 0 samples first
506+
sampled0_g1, consumed0_g1 = sampler.sample(
507+
ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size
508+
)
509+
ready_indexes = [i for i in ready_indexes if i not in consumed0_g1]
510+
511+
# Both should have sampled their first batch
512+
assert sampled0_g0 == [0, 1]
513+
assert sampled0_g1 == [2, 3]
514+
assert consumed0_g0 == [0, 1]
515+
assert consumed0_g1 == [2, 3]
516+
517+
# DP group 0: rank 1 fetches cached, and all the data should be labeled as consumed
518+
sampled1_g0, consumed1_g0 = sampler.sample(
519+
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
520+
)
521+
ready_indexes = [i for i in ready_indexes if i not in consumed1_g0]
522+
assert sampled1_g0 == [0, 1]
523+
assert consumed1_g0 == [0, 1]
524+
525+
# DP group 1: rank 1 fetches cached, and all the data should be labeled as consumed
526+
sampled1_g1, consumed1_g1 = sampler.sample(
527+
ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size
528+
)
529+
ready_indexes = [i for i in ready_indexes if i not in consumed1_g1]
530+
assert sampled1_g1 == [2, 3]
531+
assert consumed1_g1 == [2, 3]
532+
533+
# DP group 0: rank 0 fetches again, this should return new data
534+
sampled2_g0, consumed2_g0 = sampler.sample(
535+
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
536+
)
537+
ready_indexes = [i for i in ready_indexes if i not in consumed2_g0]
538+
assert sampled2_g0 == [4, 5]
539+
assert consumed2_g0 == [4, 5]
540+
541+
# DP group 0: rank 1 fetches cached
542+
sampled3_g0, consumed3_g0 = sampler.sample(
543+
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
544+
)
545+
assert sampled3_g0 == [4, 5]
546+
assert consumed3_g0 == [4, 5]
547+
548+
# Both groups should be cleaned up
549+
assert sampler._states == {}
550+
551+
def test_rank_aware_sampler_empty_ready_indexes(self):
552+
"""Test behavior with empty ready indexes."""
553+
sampler = RankAwareSampler()
554+
ready_indexes = []
555+
batch_size = 3
556+
557+
sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)
558+
559+
assert sampled == []
560+
assert consumed == []
561+
562+
def test_rank_aware_sampler_batch_size_larger_than_ready(self):
563+
"""Test behavior when batch_size > len(ready_indexes)."""
564+
sampler = RankAwareSampler()
565+
ready_indexes = [0, 1]
566+
batch_size = 5
567+
568+
# When world_size == dp_world_size, fetches_per_batch=1, consumed returned immediately
569+
sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)
570+
571+
assert sampled == [0, 1]
572+
assert consumed == [0, 1]
573+
assert len(sampled) == len(ready_indexes)
574+
575+
def test_rank_aware_sampler_zero_batch_size(self):
576+
"""Test behavior with zero batch size."""
577+
sampler = RankAwareSampler()
578+
ready_indexes = [0, 1, 2, 3]
579+
batch_size = 0
580+
581+
sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)
582+
583+
assert sampled == []
584+
assert consumed == []
585+
586+
430587
class TestSamplerIntegration:
431588
"""Integration tests for samplers."""
432589

transfer_queue/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
from .metadata import BatchMeta
2525
from .sampler import BaseSampler
2626
from .sampler.grpo_group_n_sampler import GRPOGroupNSampler
27+
from .sampler.rank_aware_sampler import RankAwareSampler
2728
from .sampler.sequential_sampler import SequentialSampler
2829
from .storage import SimpleStorageUnit
30+
from .streaming_dataloader import StreamDataLoader, StreamingDataset
2931
from .utils.utils import get_placement_group
3032
from .utils.zmq_utils import ZMQServerInfo
3133

@@ -41,6 +43,9 @@
4143
"BaseSampler",
4244
"GRPOGroupNSampler",
4345
"SequentialSampler",
46+
"RankAwareSampler",
47+
"StreamingDataset",
48+
"StreamDataLoader",
4449
]
4550

4651
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))

transfer_queue/sampler/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from .base import BaseSampler
1717
from .grpo_group_n_sampler import GRPOGroupNSampler
18+
from .rank_aware_sampler import RankAwareSampler
1819
from .sequential_sampler import SequentialSampler
1920

20-
__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler"]
21+
__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler", "RankAwareSampler"]

transfer_queue/sampler/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@ class BaseSampler(ABC):
3434
- **SequentialSampler**: Default sampler, selects samples sequentially without replacement
3535
- **GRPOGroupNSampler**: A sampler that performs sampling on continuous N samples only when all of them are ready.
3636
It assumes the N samples associated with the same prompt are stored contiguously
37-
- **RankAwareSampler**: Rank-aware sampling for distributed scenarios (TODO)
37+
- **RankAwareSampler**: Rank-aware sampling for distributed training where each ranks independently retrieve data
38+
by themselves. This sampler will guarantee ranks of the same DP group consume identical
39+
samples.
3840
3941
NOTE: Always return both sampled and consumed indexes (may be identical).
4042
"""
4143

4244
def __init__(self):
43-
self._states: dict[str, Any] = {}
45+
self._states: dict[Any, Any] = {}
4446

4547
@abstractmethod
4648
def sample(
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# Copyright 2025 The TransferQueue Team
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import Any
17+
18+
from transfer_queue.sampler import BaseSampler
19+
20+
21+
class RankAwareSampler(BaseSampler):
22+
"""Rank-aware sampler for distributed training with TransferQueue.
23+
24+
This sampler is designed for distributed data parallel training scenarios
25+
where each ranks independently retrieve data by themselves.
26+
27+
Each rank independently calls the sampler, passing its own rank information,
28+
and the sampler guarantees that all ranks within the same DP group receive
29+
the same sample indices.
30+
31+
The sampler maintains per-DP-group state to coordinate sampling across ranks:
32+
33+
- First rank in a DP group to call :meth:`sample` performs actual sampling from
34+
``ready_indexes`` and caches the result
35+
- Subsequent ranks in the same DP group retrieve the cached indices
36+
- Once all ranks in the DP group have fetched their samples, the indices are
37+
marked as consumed
38+
39+
40+
Please refer to our roadmap for more details:
41+
[Roadmap] StreamingDataLoader for task-separated RL post-training
42+
https://github.com/Ascend/TransferQueue/issues/1
43+
"""
44+
45+
def __init__(self):
46+
"""Initialize the RankAwareSampler.
47+
48+
The sampler maintains internal state to coordinate sampling across ranks
49+
within the same DP group. This state tracks which samples have been sampled
50+
and how many times they have been fetched.
51+
"""
52+
super().__init__()
53+
54+
def sample(
55+
self,
56+
ready_indexes: list[int],
57+
batch_size: int,
58+
dp_group: int,
59+
dp_world_size: int,
60+
world_size: int,
61+
*args: Any,
62+
**kwargs: Any,
63+
) -> tuple[list[int], list[int]]:
64+
"""Sample indices for the current rank, coordinating with other DP ranks.
65+
66+
This method implements coordinated sampling for distributed training.
67+
The first rank in each DP group to call this method performs actual sampling
68+
from ``ready_indexes`` and caches the result. Subsequent ranks in the same
69+
DP group receive the cached indices directly.
70+
71+
Args:
72+
ready_indexes: List of global indices for which all required fields of the
73+
corresponding samples have been produced, and the samples are not labeled
74+
as consumed in the corresponding task.
75+
batch_size: batch_size: Number of samples to select. If larger than available
76+
ready samples, all available samples will be returned.
77+
dp_group: The group id of current data parallel group. Used to
78+
identify which DP group this rank belongs to.
79+
dp_world_size: Number of ranks in the data parallel group. Used to
80+
determine when all ranks have fetched their samples.
81+
world_size: Total number of ranks across all parallelism dimensions.
82+
Used to determine when all ranks have fetched their samples.
83+
*args: Additional positional arguments (ignored).
84+
**kwargs: Additional keyword arguments (ignored).
85+
86+
Returns:
87+
List of sampled global indices of length batch_size
88+
89+
List of global indices of length batch_size that should be labeled as consumed
90+
(will never be retrieved in the future)
91+
92+
Raises:
93+
RuntimeError: If the fetch count exceeds the expected number of
94+
fetches per DP group.
95+
96+
Note:
97+
The ``world_size // dp_world_size`` calculation determines how many
98+
times each batch should be fetched (once per TP/PP/... rank group).
99+
"""
100+
101+
# Check if this DP group already has sampled data cached
102+
data_for_dp_group = self._states.get(dp_group, None)
103+
104+
# Calculate how many times this batch should be fetched across all ranks
105+
fetches_per_batch = world_size // dp_world_size
106+
107+
if data_for_dp_group is None:
108+
# Initialize state for this DP group
109+
self._states[dp_group] = {}
110+
111+
# Select first batch_size indices from ready_indexes
112+
sampled_indexes = ready_indexes[:batch_size]
113+
consumed_indexes = sampled_indexes
114+
115+
# Cache the sampled indices for other ranks in this DP group
116+
self._states[dp_group]["index"] = sampled_indexes
117+
self._states[dp_group]["fetch_count"] = 1
118+
119+
else:
120+
# Return the cached indices (identical to what first rank received)
121+
sampled_indexes = self._states[dp_group]["index"]
122+
consumed_indexes = self._states[dp_group]["index"]
123+
124+
# Increment fetch count to track progress
125+
self._states[dp_group]["fetch_count"] += 1
126+
127+
# Check if this was the last rank in the DP group to fetch
128+
if self._states[dp_group]["fetch_count"] >= fetches_per_batch:
129+
del self._states[dp_group]
130+
131+
return sampled_indexes, consumed_indexes

0 commit comments

Comments
 (0)