Skip to content

Commit f0047b9

Browse files
authored
[feat] Add SeqlenBalancedSampler and enhance StreamingDataset support (#70)
## 🎯 Summary This PR introduces the `SeqlenBalancedSampler` to optimize sequence length distribution across Data Parallel (DP) ranks during GRPO training. It also enhances `StreamingDataset` with proper streaming mode support and refactors the controller's polling mechanism to improve efficiency when data is insufficient. ## ✨ Key Features & Enhancements ### 1. `SeqlenBalancedSampler` (Sequence-Length Balanced Sampling) - **Karmarkar-Karp Algorithm:** Added a new sampler that extends `GRPOGroupNSampler`. It uses the Karmarkar-Karp largest differencing method to balance sequence lengths (`total_lengths`) across DP ranks, ensuring that each rank processes approximately the same total token count. - **Group Integrity:** Guarantees that complete prompt groups remain intact across ranks to fulfill pass@k metrics and GRPO advantage normalization requirements. - **Assignment Caching:** Implements state caching (`_balanced_cache`) so that once global sampling and balancing are computed for a batch, subsequent DP ranks can quickly retrieve their assigned chunks. ### 2. `StreamingDataset` Improvements - **Finite vs. Infinite Stream:** Introduced the `should_check_consumption_status` parameter. - `False` (Default): Operates in an **infinite stream** mode, continuously polling for new data (ideal for online/streaming pipelines). - `True`: Operates in **finite-dataset** mode, terminating iteration only after all samples in the partition are fully consumed. - **Client Initialization Refactor:** Refactored `_create_client()` to use `init()` and `get_client()` from `transfer_queue.interface` instead of manually setting up `TransferQueueClient`. ### 3. Controller Optimizations - **Polling Mode Sampler Cache Lookup:** Updated `get_metadata` to look up cached sampler states when operating in `polling_mode`. If `dp_rank` and `batch_index` are cached, it immediately returns the data instead of failing or entering redundant wait loops when `ready_for_consume_indexes` are insufficient. - **Variable-size Batch Support:** Updated the sampler length validation logic to accommodate variable-size batches returned by samplers like `SeqlenBalancedSampler`. ## 🛠️ Refactoring & Minor Fixes - **Log Level Adjustments:** Downgraded the 1D tensor shape warnings to `logger.info()` in `client.py` and `metadata.py` to reduce unnecessary noise. - **Pre-allocation Scope:** Moved `TQ_PRE_ALLOC_SAMPLE_NUM` environment variable resolution into local method scopes where appropriate. ## 🧪 Testing - Added comprehensive unit tests for `SeqlenBalancedSampler` covering initialization, fallback behavior, balanced partitioning logic with mock custom meta, group level integrity, and caching mechanisms. - Added explicit utility tests for the `karmarkar_karp` and `get_seqlen_balanced_partitions` functions (`TestKarmarkarKarp`). --------- Signed-off-by: 宁本哲 <ningbenzhe@xiaohongshu.com>
1 parent 2f24c68 commit f0047b9

11 files changed

Lines changed: 956 additions & 84 deletions

File tree

tests/test_samplers.py

Lines changed: 466 additions & 2 deletions
Large diffs are not rendered by default.

transfer_queue/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from .sampler import BaseSampler
3939
from .sampler.grpo_group_n_sampler import GRPOGroupNSampler
4040
from .sampler.rank_aware_sampler import RankAwareSampler
41+
from .sampler.seqlen_balanced_sampler import SeqlenBalancedSampler
4142
from .sampler.sequential_sampler import SequentialSampler
4243

4344
__all__ = (
@@ -76,6 +77,7 @@
7677
"GRPOGroupNSampler",
7778
"SequentialSampler",
7879
"RankAwareSampler",
80+
"SeqlenBalancedSampler",
7981
]
8082
)
8183

transfer_queue/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ async def async_put(
389389

390390
for field_name, field_data in data.items():
391391
if isinstance(field_data, torch.Tensor) and field_data.ndim == 1:
392-
logger.warning(
392+
logger.info(
393393
f"[{self.client_id}]: Data field '{field_name}' is a tensor with only one dimension. "
394394
f"You may receive 2D tensors in key-value based backend."
395395
)

transfer_queue/controller.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@
6363
# Sample pre-allocation for StreamingDataLoader compatibility.
6464
# By pre-allocating sample indices (typically global_batch_size), consumers can accurately
6565
# determine consumption status even before producers have generated the samples.
66-
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))
6766

6867

6968
class PartitionIndexManager:
@@ -335,6 +334,7 @@ class DataPartitionStatus:
335334

336335
# Production status tensor - dynamically expandable
337336
# Values: 0 = not produced, 1 = ready for consumption
337+
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))
338338

339339
production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8)
340340

@@ -1050,6 +1050,8 @@ def create_partition(self, partition_id: str) -> bool:
10501050
Returns:
10511051
True if partition was created successfully, False if it already exists
10521052
"""
1053+
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))
1054+
10531055
if partition_id in self.partitions:
10541056
logger.warning(f"Partition {partition_id} already exists")
10551057
return False
@@ -1313,38 +1315,49 @@ def get_metadata(
13131315

13141316
if len(ready_for_consume_indexes) < batch_size:
13151317
if self.polling_mode:
1316-
logger.debug(
1317-
f"[{self.controller_id}]: Not enough data for task {task_name} in partition {partition_id}."
1318-
f" Required: {batch_size}, Available: {len(ready_for_consume_indexes)}."
1319-
f" Returning None due to polling mode."
1318+
# Return cached result if available
1319+
if self.sampler.has_cached_result(partition_id, task_name, sampling_config):
1320+
break
1321+
else:
1322+
logger.debug(
1323+
f"[{self.controller_id}]: Not enough data for task {task_name} in "
1324+
f"partition {partition_id}. Required: {batch_size}, "
1325+
f"Available: {len(ready_for_consume_indexes)}."
1326+
f" Returning None due to polling mode."
1327+
)
1328+
return BatchMeta.empty()
1329+
else:
1330+
logger.warning(
1331+
f"[{self.controller_id}]: Insufficient data for task {task_name}. Required: {batch_size} "
1332+
f"samples with fields {data_fields} in partition {partition_id}, but only have "
1333+
f"{len(ready_for_consume_indexes)} samples meeting the criteria. "
1334+
f"Retrying in {TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..."
13201335
)
1321-
return BatchMeta.empty()
1336+
time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL)
13221337
if time.time() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT:
13231338
raise TimeoutError(
13241339
f"Timeout while waiting for sufficient data for task {task_name}. "
13251340
f"Required: {batch_size}, Available: {len(ready_for_consume_indexes)}"
13261341
)
1327-
logger.warning(
1328-
f"[{self.controller_id}]: Insufficient data for task {task_name}. Required: {batch_size} "
1329-
f"samples with fields {data_fields} in partition {partition_id}, but only have "
1330-
f"{len(ready_for_consume_indexes)} samples meeting the criteria. "
1331-
f"Retrying in {TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..."
1332-
)
1333-
time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL)
13341342
else:
13351343
break
13361344

13371345
batch_global_indexes, consumed_indexes = self.sampler(
13381346
ready_for_consume_indexes,
13391347
batch_size,
1348+
partition=self._get_partition(partition_id),
13401349
**(sampling_config or {}),
13411350
**kwargs,
13421351
)
13431352

1344-
# Check if we got valid results from the sampler
1345-
if len(batch_global_indexes) != batch_size:
1353+
# Check if we got valid results from the sampler.
1354+
# Some samplers (e.g. SeqlenBalancedSampler) may return variable-size
1355+
# batches per DP rank, so we only check for empty results.
1356+
if len(batch_global_indexes) == 0:
1357+
if self.polling_mode:
1358+
return BatchMeta.empty()
13461359
raise RuntimeError(
1347-
f"Sampler returned insufficient samples. Please check the sampler logic. "
1360+
f"Sampler returned no samples. Please check the sampler logic. "
13481361
f"Expected: {batch_size}, before sampling: {len(ready_for_consume_indexes)}, "
13491362
f"after sampling: {len(batch_global_indexes)}"
13501363
)
@@ -1826,7 +1839,7 @@ def _process_request(self):
18261839
partition_id=params["partition_id"],
18271840
mode=params.get("mode", "fetch"),
18281841
task_name=params.get("task_name"),
1829-
sampling_config=params.get("sampling_config"),
1842+
sampling_config=params.get("sampling_config", {}),
18301843
)
18311844

18321845
response_msg = ZMQMessage.create(

transfer_queue/dataloader/streaming_dataset.py

Lines changed: 48 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,14 @@
1717
import os
1818
import time
1919
import uuid
20-
import warnings
2120
from typing import Callable, Iterator
2221

2322
from omegaconf import DictConfig
2423
from tensordict import TensorDict
2524
from torch.utils.data import IterableDataset
2625

27-
from transfer_queue import TransferQueueClient
26+
from transfer_queue.client import TransferQueueClient
2827
from transfer_queue.metadata import BatchMeta
29-
from transfer_queue.utils.zmq_utils import ZMQServerInfo
3028

3129
TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL = float(
3230
os.environ.get("TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL", 1)
@@ -77,6 +75,7 @@ def __init__(
7775
partition_id: str,
7876
task_name: str,
7977
dp_rank: int,
78+
should_check_consumption_status: bool = False,
8079
fetch_batch_fn: Callable | None = None,
8180
process_batch_fn: Callable | None = None,
8281
):
@@ -98,6 +97,14 @@ def __init__(
9897
which samples have been consumed by which task.
9998
dp_rank: The group ID of the current data group. All
10099
ranks with the same dp_rank will receive identical samples.
100+
should_check_consumption_status: Whether to check the consumption status of the
101+
partition to decide when to stop iterating. Defaults to ``False``, which
102+
means the iterator runs as an **infinite stream** — it will continuously
103+
poll for new data and never exit on its own. This is the typical mode for
104+
online/streaming training where producers keep feeding data indefinitely.
105+
Set to ``True`` when the total number of samples is known in advance (i.e.
106+
finite-dataset mode); the iterator will then stop once all samples in the
107+
partition have been consumed.
101108
fetch_batch_fn: Optional custom function to retrieve batch data.
102109
If None, uses default_fetch_batch_fn function.
103110
process_batch_fn: Optional custom function to post-process
@@ -123,6 +130,7 @@ def __init__(
123130
self.partition_id = partition_id
124131
self.task_name = task_name
125132
self.dp_rank = dp_rank
133+
self.should_check_consumption_status = should_check_consumption_status
126134
self.fetch_batch_fn = fetch_batch_fn if fetch_batch_fn else default_fetch_batch_fn
127135
self.process_batch_fn = process_batch_fn if process_batch_fn else chunk_batch_fn
128136

@@ -149,65 +157,45 @@ def __init__(
149157
super().__init__()
150158

151159
def _create_client(self):
152-
"""Create and initialize a TransferQueue client.
153-
154-
This method initializes the TransferQueueClient with the provided configuration
155-
and storage backend, and sets up the storage manager for data retrieval.
156-
157-
Raises:
158-
ValueError: If controller_info or storage_backend is missing or invalid.
160+
"""Create and initialize a TransferQueue client directly from config.
161+
162+
This method creates a ``TransferQueueClient`` using the ZMQ address and
163+
storage backend information already present in ``self.config``. It
164+
intentionally does **not** call ``tq.init()`` because that relies on Ray
165+
internally (``ray.get_actor`` / ``ray.get``), which is **unsafe in
166+
forked subprocesses** spawned by PyTorch DataLoader (``num_workers > 0``).
167+
Creating the client directly via ZMQ avoids this issue.
159168
"""
160-
client_id = uuid.uuid4().hex[:8]
161-
162-
# TODO: DEPRECATE in future
163-
controller_config = self.config.get("controller", None)
164-
if controller_config:
165-
controller_info = controller_config.get("zmq_info", None)
166-
else:
167-
controller_info = self.config.get("controller_info", None)
168-
if controller_info:
169-
warnings.warn(
170-
"Config entry `controller_info` will be deprecated in 0.1.7, please "
171-
"use `controller.zmq_info` instead.",
172-
category=DeprecationWarning,
173-
stacklevel=2,
174-
)
175-
176-
if not controller_info or not isinstance(controller_info, ZMQServerInfo):
177-
raise ValueError("Invalid or missing controller.zmq_info in config")
178-
179-
backend_config = self.config.get("backend", None)
180-
if not backend_config:
181-
storage_backend = self.config.get("storage_backend", None)
182-
backend_config = self.config
183-
if storage_backend:
184-
warnings.warn(
185-
"Config entry `storage_backend` will be deprecated in 0.1.7, please "
186-
"use `backend.storage_backend` instead.",
187-
category=DeprecationWarning,
188-
stacklevel=2,
189-
)
190-
else:
191-
storage_backend = backend_config.get("storage_backend", None)
192-
backend_config = self.config.backend[storage_backend]
193-
194-
if not storage_backend:
195-
raise ValueError("Missing storage_backend in config")
169+
client_id = f"StreamingDataset_{uuid.uuid4().hex[:8]}"
170+
171+
controller_info = self.config.controller.zmq_info
172+
storage_backend = self.config.backend.storage_backend
173+
backend_config = self.config.backend[storage_backend]
196174

197175
self._tq_client = TransferQueueClient(client_id, controller_info)
198176
self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=backend_config)
199177

200178
def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]:
201179
"""Iterate over the dataset, yielding batches of data.
202180
181+
The iteration behaviour depends on ``should_check_consumption_status``:
182+
183+
- **False (default — streaming mode)**: The iterator runs as an
184+
infinite stream, continuously polling TransferQueue for new data.
185+
It will sleep for `TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL` seconds
186+
(default=1) when no data is available and
187+
resume once new batches are produced. This is the standard mode for
188+
online / streaming training pipelines where producers feed data
189+
indefinitely.
190+
- **True (finite-dataset mode)**: The iterator terminates once all
191+
samples in the partition have been consumed (as reported by
192+
``check_consumption_status``), *and* all buffered batches have been
193+
yielded.
194+
203195
Yields:
204196
Tuple[TensorDict, BatchMeta]: A tuple containing:
205197
- TensorDict: Batch of data with the requested fields.
206198
- BatchMeta: Corresponding metadata to interact with TransferQueue.
207-
Note:
208-
This iterator runs indefinitely until the data source is exhausted.
209-
The caller should handle StopIteration when appropriate (e.g., when
210-
all data has been consumed and no more data will be produced).
211199
"""
212200
if self._tq_client is None:
213201
self._create_client()
@@ -218,24 +206,26 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]:
218206
# TQ_PRE_ALLOC_SAMPLE_NUM to the required global_batch_size to make sure consumers can accurately
219207
# determine consumption status even before producers have generated the samples.
220208
while (
221-
not self._tq_client.check_consumption_status(self.task_name, self.partition_id)
209+
not self.should_check_consumption_status
210+
or not self._tq_client.check_consumption_status(self.task_name, self.partition_id)
222211
or self.batch_index <= len(self.buffer) - 1
223212
):
224213
try:
225214
if self.batch_index <= len(self.buffer) - 1:
226215
current_data = self.buffer[self.batch_index]
227216
self.batch_index += 1
217+
logger.debug(f"StreamDataloader current batch index is {self.batch_index}/{len(self.buffer)}")
228218
yield from self.process_batch_fn(*current_data, micro_batch_size=self.micro_batch_size)
229219

230220
else:
231221
batch_data, batch_meta = self.fetch_batch_fn(
232-
self._tq_client,
233-
self.data_fields,
234-
self.batch_size,
235-
self.partition_id,
236-
self.task_name,
237-
self.sampling_config,
238-
self.batch_index,
222+
tq_client=self._tq_client,
223+
data_fields=self.data_fields,
224+
batch_size=self.batch_size,
225+
partition_id=self.partition_id,
226+
task_name=self.task_name,
227+
sampling_config=self.sampling_config,
228+
batch_index=self.batch_index,
239229
)
240230
if batch_data is not None:
241231
self.buffer.append((batch_data, batch_meta))

transfer_queue/interface.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def _init_from_existing() -> bool:
345345

346346

347347
# ==================== Initialization API ====================
348-
def init(conf: Optional[DictConfig] = None) -> None:
348+
def init(conf: Optional[DictConfig] = None) -> Optional[DictConfig]:
349349
"""Initialize the TransferQueue system.
350350
351351
This function sets up the TransferQueue controller, distributed storage, and client.
@@ -360,6 +360,8 @@ def init(conf: Optional[DictConfig] = None) -> None:
360360
the default config from 'config.yaml'. This is only used for first-time
361361
initializing. When connecting to an existing controller, this parameter
362362
is ignored.
363+
Returns:
364+
The merged configuration dictionary.
363365
364366
Raises:
365367
ValueError: If config is not valid or required configuration keys are missing.
@@ -377,7 +379,7 @@ def init(conf: Optional[DictConfig] = None) -> None:
377379
>>> data = tq.get_data(metadata)
378380
"""
379381
if _init_from_existing():
380-
return
382+
return conf
381383

382384
# First-time initialize TransferQueue
383385
logger.info("No TransferQueueController found. Starting first-time initialization...")
@@ -415,7 +417,7 @@ def init(conf: Optional[DictConfig] = None) -> None:
415417
except ValueError:
416418
logger.info("Some other rank has initialized TransferQueueController. Try to connect to existing controller.")
417419
_init_from_existing()
418-
return
420+
return final_conf
419421

420422
controller_zmq_info = process_zmq_server_info(_TRANSFER_QUEUE_CONTROLLER)
421423
final_conf.controller.zmq_info = controller_zmq_info
@@ -429,6 +431,7 @@ def init(conf: Optional[DictConfig] = None) -> None:
429431

430432
# create client
431433
_maybe_create_transferqueue_client(final_conf)
434+
return final_conf
432435

433436

434437
def close():

transfer_queue/metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]:
169169
f"expected batch_size[0]={batch_size}, got value.shape[0]={value.shape[0]}"
170170
)
171171
if len(value.shape) == 1:
172-
logger.warning(f"Receiving 1D tensor for field '{field_name}'. Unsqueeze the last dimension.")
172+
logger.info(f"Receiving 1D tensor for field '{field_name}'. Unsqueeze the last dimension.")
173173
value = value.unsqueeze(-1)
174174
first_item = value[0]
175175
else:

transfer_queue/sampler/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .base import BaseSampler
1717
from .grpo_group_n_sampler import GRPOGroupNSampler
1818
from .rank_aware_sampler import RankAwareSampler
19+
from .seqlen_balanced_sampler import SeqlenBalancedSampler
1920
from .sequential_sampler import SequentialSampler
2021

21-
__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler", "RankAwareSampler"]
22+
__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler", "RankAwareSampler", "SeqlenBalancedSampler"]

0 commit comments

Comments
 (0)