You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[feat] Add SeqlenBalancedSampler and enhance StreamingDataset support (#70)
## 🎯 Summary
This PR introduces the `SeqlenBalancedSampler` to optimize sequence
length distribution across Data Parallel (DP) ranks during GRPO
training. It also enhances `StreamingDataset` with proper streaming mode
support and refactors the controller's polling mechanism to improve
efficiency when data is insufficient.
## ✨ Key Features & Enhancements
### 1. `SeqlenBalancedSampler` (Sequence-Length Balanced Sampling)
- **Karmarkar-Karp Algorithm:** Added a new sampler that extends
`GRPOGroupNSampler`. It uses the Karmarkar-Karp largest differencing
method to balance sequence lengths (`total_lengths`) across DP ranks,
ensuring that each rank processes approximately the same total token
count.
- **Group Integrity:** Guarantees that complete prompt groups remain
intact across ranks to fulfill pass@k metrics and GRPO advantage
normalization requirements.
- **Assignment Caching:** Implements state caching (`_balanced_cache`)
so that once global sampling and balancing are computed for a batch,
subsequent DP ranks can quickly retrieve their assigned chunks.
### 2. `StreamingDataset` Improvements
- **Finite vs. Infinite Stream:** Introduced the
`should_check_consumption_status` parameter.
- `False` (Default): Operates in an **infinite stream** mode,
continuously polling for new data (ideal for online/streaming
pipelines).
- `True`: Operates in **finite-dataset** mode, terminating iteration
only after all samples in the partition are fully consumed.
- **Client Initialization Refactor:** Refactored `_create_client()` to
use `init()` and `get_client()` from `transfer_queue.interface` instead
of manually setting up `TransferQueueClient`.
### 3. Controller Optimizations
- **Polling Mode Sampler Cache Lookup:** Updated `get_metadata` to look
up cached sampler states when operating in `polling_mode`. If `dp_rank`
and `batch_index` are cached, it immediately returns the data instead of
failing or entering redundant wait loops when
`ready_for_consume_indexes` are insufficient.
- **Variable-size Batch Support:** Updated the sampler length validation
logic to accommodate variable-size batches returned by samplers like
`SeqlenBalancedSampler`.
## 🛠️ Refactoring & Minor Fixes
- **Log Level Adjustments:** Downgraded the 1D tensor shape warnings to
`logger.info()` in `client.py` and `metadata.py` to reduce unnecessary
noise.
- **Pre-allocation Scope:** Moved `TQ_PRE_ALLOC_SAMPLE_NUM` environment
variable resolution into local method scopes where appropriate.
## 🧪 Testing
- Added comprehensive unit tests for `SeqlenBalancedSampler` covering
initialization, fallback behavior, balanced partitioning logic with mock
custom meta, group level integrity, and caching mechanisms.
- Added explicit utility tests for the `karmarkar_karp` and
`get_seqlen_balanced_partitions` functions (`TestKarmarkarKarp`).
---------
Signed-off-by: 宁本哲 <ningbenzhe@xiaohongshu.com>
0 commit comments