diff --git a/docs/data/transfer_queue.md b/docs/data/transfer_queue.md
index 4532d42ed56..877a46b9063 100644
--- a/docs/data/transfer_queue.md
+++ b/docs/data/transfer_queue.md
@@ -1,52 +1,73 @@
# TransferQueue Data System
-Last updated: 09/28/2025.
+Last updated: 11/17/2025.
This doc introduce [TransferQueue](https://github.com/TransferQueue/TransferQueue), an asynchronous streaming data management system for efficient post-training.
Overview
-TransferQueue is a high-performance data storage and transfer system with panoramic data visibility and streaming scheduling capabilities, optimized for efficient dataflow in post-training workflows.
+TransferQueue is a high-performance data storage and transfer module with panoramic data visibility and streaming scheduling capabilities, optimized for efficient dataflow in post-training workflows.
-
+
-
-TransferQueue offers **fine-grained, sample-level** data management capabilities, serving as a data gateway that decouples explicit data dependencies across computational tasks. This enables a divide-and-conquer approach, significantly simplifying the design of the algorithm controller.
-
+TransferQueue offers **fine-grained, sample-level** data management and **load-balancing** (on the way) capabilities, serving as a data gateway that decouples explicit data dependencies across computational tasks. This enables a divide-and-conquer approach, significantly simplifies the algorithm controller design.
-
+
+ Updates
-
+ - **Nov 10, 2025**: We disentangle the data retrieval logic from TransferQueueController [PR#101](https://github.com/TransferQueue/TransferQueue/pull/101). Now you can implement your own `Sampler` to control how to consume the data.
+ - **Nov 5, 2025**: We provide a `KVStorageManager` that simplifies the integration with KV-based storage backends [PR#96](https://github.com/TransferQueue/TransferQueue/pull/96). The first available KV-based backend is [Yuanrong](https://gitee.com/openeuler/yuanrong-datasystem).
+ - **Nov 4, 2025**: Data partition capability is available in [PR#98](https://github.com/TransferQueue/TransferQueue/pull/98). Now you can define logical data partitions to manage your train/val/test datasets.
+ - **Oct 25, 2025**: We make storage backends pluggable in [PR#66](https://github.com/TransferQueue/TransferQueue/pull/66). You can try to integrate your own storage backend with TransferQueue now!
+ - **Oct 21, 2025**: Official integration into verl is ready [verl/pulls/3649](https://github.com/volcengine/verl/pull/3649). Following PRs will optimize the single controller architecture by fully decoupling data & control flows.
+ - **July 22, 2025**: We present a series of Chinese blogs on Zhihu 1, 2.
+ - **July 21, 2025**: We started an RFC on verl community [verl/RFC#2662](https://github.com/volcengine/verl/discussions/2662).
+ - **July 2, 2025**: We publish the paper [AsyncFlow](https://arxiv.org/abs/2507.01663).
Components
+### Control Plane: Panoramic Data Management
+In the control plane, `TransferQueueController` tracks the **production status** and **consumption status** of each training sample as metadata. When all the required data fields are ready (i.e., written to the `TransferQueueStorageManager`), we know that this data sample can be consumed by downstream tasks.
-### Control Plane: Panoramic Data Management
-
-In the control plane, `TransferQueueController` tracks the **production status** and **consumption status** of each training sample as metadata. When all the required data fields are ready (i.e., written to the `TransferQueueStorage`), we know that this data sample can be consumed by downstream tasks.
-
-For consumption status, we record the consumption records for each computational task (e.g., `generate_sequences`, `compute_log_prob`, etc.). Therefore, even different computation tasks require the same data field, they can consume the data independently without interfering with each other.
-
+For consumption status, we record the consumption records for each computational task (e.g., `generate_sequences`, `compute_log_prob`, etc.). Therefore, even when different computation tasks require the same data field, they can consume the data independently without interfering with each other.
-
+
+To make the data retrieval process more customizable, we provide a `Sampler` class that allows users to define their own data retrieval and consumption logic. Refer to the [Customize](#customize) section for details.
-> In the future, we plan to support **load-balancing** and **dynamic batching** capabilities in the control plane. Besides, we will support data management for disaggregated frameworks where each rank manages the data retrieval by itself, rather than coordinated by a single controller.
+> In the future, we plan to support **load-balancing** and **dynamic batching** capabilities in the control plane. Additionally, we will support data management for disaggregated frameworks where each rank manages the data retrieval by itself, rather than coordinated by a single controller.
### Data Plane: Distributed Data Storage
-In the data plane, `TransferQueueStorageSimpleUnit` serves as a naive storage unit based on CPU memory, responsible for the actual storage and retrieval of data. Each storage unit can be deployed on a separate node, allowing for distributed data management.
+In the data plane, we provide a pluggable design that enables TransferQueue to integrate with different storage backends according to user requirements.
+
+Specifically, we provide a `TransferQueueStorageManager` abstraction class that defines the core APIs as follows:
-`TransferQueueStorageSimpleUnit` employs a 2D data structure as follows:
+- `async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None`
+- `async def get_data(self, metadata: BatchMeta) -> TensorDict`
+- `async def clear_data(self, metadata: BatchMeta) -> None`
+
+This class encapsulates the core interaction logic within the TransferQueue system. You only need to write a simple subclass to integrate your own storage backend. Refer to the [Customize](#customize) section for details.
+
+Currently, we support the following storage backends:
+
+- SimpleStorageUnit: A basic CPU memory storage with minimal data format constraints and easy usability.
+- [MoonCakeStore](https://github.com/kvcache-ai/Mooncake): A high-performance, KV-based hierarchical storage that supports RDMA transport between GPU and DRAM.
+- [Yuanrong](https://gitee.com/openeuler/yuanrong-datasystem): An Ascend native data system that provides hierarchical storage interfaces including HBM/DRAM/SSD.
+- [Ray Direct Transport](https://docs.ray.io/en/master/ray-core/direct-transport.html): Ray's new feature that allows Ray to store and pass objects directly between Ray actors.
+
+Among them, `SimpleStorageUnit` serves as our default storage backend, coordinated by the `AsyncSimpleStorageManager` class. Each storage unit can be deployed on a separate node, allowing for distributed data management.
+
+`SimpleStorageUnit` employs a 2D data structure as follows:
- Each row corresponds to a training sample, assigned a unique index within the corresponding global batch.
- Each column represents the input/output data fields for computational tasks.
@@ -54,29 +75,22 @@ In the data plane, `TransferQueueStorageSimpleUnit` serves as a naive storage un
This data structure design is motivated by the computational characteristics of the post-training process, where each training sample is generated in a relayed manner across task pipelines. It provides an accurate addressing capability, which allows fine-grained, concurrent data read/write operations in a streaming manner.
-
+
-
-> In the future, we plan to implement a **general storage abstraction layer** to support various storage backends. Through this abstraction, we hope to integrate high-performance storage solutions such as [MoonCakeStore](https://github.com/kvcache-ai/Mooncake) to support device-to-device data transfer through RDMA, further enhancing data transfer efficiency for large-scale data.
-
-
### User Interface: Asynchronous & Synchronous Client
-
The interaction workflow of TransferQueue system is as follows:
1. A process sends a read request to the `TransferQueueController`.
2. `TransferQueueController` scans the production and consumption metadata for each sample (row), and dynamically assembles a micro-batch metadata according to the load-balancing policy. This mechanism enables sample-level data scheduling.
3. The process retrieves the actual data from distributed storage units using the metadata provided by the controller.
-To simplify the usage of TransferQueue, we have encapsulated this process into `AsyncTransferQueueClient` and `TransferQueueClient`. These clients provide both asynchronous and synchronous interfaces for data transfer, allowing users to easily integrate TransferQueue to their framework.
-
-
-> In the future, we will provide a `StreamingDataLoader` interface for disaggregated frameworks as discussed in [RFC#2662](https://github.com/volcengine/verl/discussions/2662). Leveraging this abstraction, each rank can automatically get its own data like `DataLoader` in PyTorch. The TransferQueue system will handle the underlying data scheduling and transfer logic caused by different parallelism strategies, significantly simplifying the design of disaggregated frameworks.
+To simplify the usage of TransferQueue, we have encapsulated this process into `AsyncTransferQueueClient` and `TransferQueueClient`. These clients provide both asynchronous and synchronous interfaces for data transfer, allowing users to easily integrate TransferQueue into their framework.
+> In the future, we will provide a `StreamingDataLoader` interface for disaggregated frameworks as discussed in [issue#85](https://github.com/TransferQueue/TransferQueue/issues/85) and [verl/RFC#2662](https://github.com/volcengine/verl/discussions/2662). Leveraging this abstraction, each rank can automatically get its own data like `DataLoader` in PyTorch. The TransferQueue system will handle the underlying data scheduling and transfer logic caused by different parallelism strategies, significantly simplifying the design of disaggregated frameworks.
- Show Cases
+🔥 Showcases
### General Usage
@@ -89,16 +103,15 @@ Core interfaces:
- (async_)put(data:TensorDict, metadata:BatchMeta, global_step)
- (async_)clear(global_step: int)
-
We will soon release a detailed tutorial and API documentation.
### verl Example
+The primary motivation for integrating TransferQueue to verl now is to **alleviate the data transfer bottleneck of the single controller `RayPPOTrainer`**. Currently, all `DataProto` objects must be routed through `RayPPOTrainer`, resulting in a single point bottleneck of the whole post-training system.
-The primary motivation for integrating TransferQueue to verl now is to **alleviate the data transfer bottleneck of the single controller `RayPPOTrainer`**. Currently, all `DataProto` objects must be routed through `RayPPOTrainer`, resulting in a single point bottleneck of the whole post-training system.
+
-
Leveraging TransferQueue, we separate experience data transfer from metadata dispatch by
@@ -106,12 +119,134 @@ Leveraging TransferQueue, we separate experience data transfer from metadata dis
- Preserving verl's original Dispatch/Collect logic via BatchMeta (maintaining single-controller debuggability)
- Accelerating data transfer by TransferQueue's distributed storage units
-
+
+
+
+You may refer to the [recipe](https://github.com/TransferQueue/TransferQueue/tree/dev/recipe/simple_use_case), where we mimic the verl usage in both async & sync scenarios. Official integration to verl is also available now at [verl/pulls/3649](https://github.com/volcengine/verl/pull/3649) (with subsequent PRs to further optimize the integration).
-You may refer to the [recipe](https://github.com/TransferQueue/TransferQueue/tree/dev/recipe/simple_use_case), where we mimic the verl usage in both async & sync scenarios.
+### Use Python package
+```bash
+pip install TransferQueue==0.1.1.dev2
+```
+### Build wheel package from source code
+
+Follow these steps to build and install:
+1. Clone the source code from the GitHub repository
+ ```bash
+ git clone https://github.com/TransferQueue/TransferQueue/
+ cd TransferQueue
+ ```
+
+2. Install dependencies
+ ```bash
+ pip install -r requirements.txt
+ ```
+
+3. Build and install
+ ```bash
+ python -m build --wheel
+ pip install dist/*.whl
+ ```
+
+
+
+
+
+
+> Note: The above benchmark for TransferQueue is based on our naive `SimpleStorageUnit` backend. By introducing high-performance storage backends and optimizing serialization/deserialization, we expect to achieve even better performance. Warmly welcome contributions from the community!
+
+For detailed performance benchmarks, please refer to [this blog](https://www.yuque.com/haomingzi-lfse7/hlx5g0/tml8ke0zkgn6roey?singleDoc#).
+
+ 🛠️ Customize TransferQueue
+
+### Define your own data retrieval logic
+We provide a `BaseSampler` abstraction class, which defines the following interface:
+
+```python3
+@abstractmethod
+def sample(
+ self,
+ ready_indexes: list[int],
+ batch_size: int,
+ *args: Any,
+ **kwargs: Any,
+) -> tuple[list[int], list[int]]:
+ """Sample a batch of indices from the ready indices.
+
+ Args:
+ ready_indexes: List of global indices for which all required fields of the
+ corresponding samples have been produced, and the samples are not labeled as
+ consumed in the corresponding task.
+ batch_size: Number of samples to select
+ *args: Additional positional arguments for specific sampler implementations
+ **kwargs: Additional keyword arguments for specific sampler implementations
+
+ Returns:
+ List of sampled global indices of length batch_size
+ List of global indices of length batch_size that should be labeled as consumed
+ (will never be retrieved in the future)
+
+ Raises:
+ ValueError: If batch_size is invalid or ready_indexes is insufficient
+ """
+ raise NotImplementedError("Subclasses must implement sample")
+```
+
+In this design, we separate data retrieval and data consumption through the two return values, which enables us to easily control sample replacement. We have implemented two reference designs: `SequentialSampler` and `GRPOGroupNSampler`.
+
+The `Sampler` class or instance should be passed to the `TransferQueueController` during initialization. During each `get_meta` call, you can provide dynamic sampling parameters to the `Sampler`.
+
+```python3
+from transfer_queue import TransferQueueController, TransferQueueClient, GRPOGroupNSampler, process_zmq_server_info
+
+# Option 1: Pass the sampler class to the TransferQueueController
+controller = TransferQueueController.remote(GRPOGroupNSampler)
+
+# Option 2: Pass the sampler instance to the TransferQueueController (if you need custom configuration)
+your_own_sampler = YourOwnSampler(config)
+controller = TransferQueueController.remote(your_own_sampler)
+
+# Use the sampler
+batch_meta = client.get_meta(
+ data_fields=["input_ids", "attention_mask"],
+ batch_size=8,
+ partition_id="train_0",
+ task_name="generate_sequences",
+ sampling_config={"n_samples_per_prompt": 4} # Put the required sampling parameters here
+)
+```
+
+### How to integrate a new storage backend
+
+The data plane is organized as follows:
+```text
+ transfer_queue/
+ ├── storage/
+ │ ├── __init__.py
+ │ │── simple_backend.py # SimpleStorageUnit、StorageUnitData、StorageMetaGroup
+ │ ├── managers/ # Managers are upper level interfaces that encapsulate the interaction logic with TQ system.
+ │ │ ├── __init__.py
+ │ │ ├──base.py # TransferQueueStorageManager, KVStorageManager
+ │ │ ├──simple_backend_manager.py # AsyncSimpleStorageManager
+ │ │ ├──yuanrong_manager.py # YuanrongStorageManager
+ │ │ ├──mooncake_manager.py # MooncakeStorageManager
+ │ │ └──factory.py # TransferQueueStorageManagerFactory
+ │ └── clients/ # Clients are lower level interfaces that directly manipulate the target storage backend.
+ │ │ ├── __init__.py
+ │ │ ├── base.py # TransferQueueStorageKVClient
+ │ │ ├── yuanrong_client.py # YRStorageClient
+ │ │ ├── mooncake_client.py # MooncakeStoreClient
+ │ │ └── factory.py # TransferQueueStorageClientFactory
+```
+
+To integrate TransferQueue with a custom storage backend, start by implementing a subclass that inherits from `TransferQueueStorageManager`. This subclass acts as an adapter between the TransferQueue system and the target storage backend. For KV-based storage backends, you can simply inherit from `KVStorageManager`, which can serve as the general manager for all KV-based backends.
+
+Distributed storage backends often come with their own native clients serving as the interface of the storage system. In such cases, a low-level adapter for this client can be written, following the examples provided in the `storage/clients` directory.
+
+Factory classes are provided for both `StorageManager` and `StorageClient` to facilitate easy integration. Adding necessary descriptions of required parameters in the factory class helps enhance the overall user experience.
diff --git a/recipe/transfer_queue/agent_loop.py b/recipe/transfer_queue/agent_loop.py
index 871ae8025c0..7f936e6730e 100644
--- a/recipe/transfer_queue/agent_loop.py
+++ b/recipe/transfer_queue/agent_loop.py
@@ -67,10 +67,7 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Data
return timing
- def create_transferqueue_client(self, controller_infos, storage_infos, role):
+ def create_transferqueue_client(self, controller_info, config):
ray.get(
- [
- worker.create_transferqueue_client.remote(controller_infos, storage_infos, role)
- for worker in self.agent_loop_workers
- ]
+ [worker.create_transferqueue_client.remote(controller_info, config) for worker in self.agent_loop_workers]
)
diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py
index d6adbddb676..daa2f8d95b6 100644
--- a/recipe/transfer_queue/ray_trainer.py
+++ b/recipe/transfer_queue/ray_trainer.py
@@ -41,8 +41,8 @@
from tqdm import tqdm
from transfer_queue import (
BatchMeta,
+ SimpleStorageUnit,
TransferQueueController,
- TransferQueueStorageSimpleUnit,
get_placement_group,
process_zmq_server_info,
)
@@ -81,6 +81,7 @@
from verl.utils.metric import reduce_metrics
from verl.utils.rollout_skip import RolloutSkip
from verl.utils.seqlen_balancing import (
+ calculate_workload,
get_seqlen_balanced_partitions,
log_seqlen_unbalance,
)
@@ -89,7 +90,6 @@
from verl.utils.transferqueue_utils import (
create_transferqueue_client,
get_transferqueue_client,
- get_val_transferqueue_client,
tqbridge,
)
@@ -412,109 +412,66 @@ def __init__(
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
- self.data_system_client = self._initialize_train_data_system(
- self.config.data.train_batch_size, self.config.actor_rollout_ref.rollout.n
+ self.data_system_client = self._initialize_data_system()
+
+ def _initialize_data_system(self):
+ # 1. initialize TransferQueueStorage
+ train_data_size = (
+ self.config.data.train_batch_size
+ * self.config.trainer.num_global_batch
+ * self.config.actor_rollout_ref.rollout.n
)
- self.val_data_system_client = self._initialize_val_data_system(
- self.val_batch_size, self.config.actor_rollout_ref.rollout.val_kwargs.n
+ val_data_size = (
+ self.val_batch_size
+ * self.config.trainer.num_global_batch
+ * self.config.actor_rollout_ref.rollout.val_kwargs.n
)
- def _initialize_train_data_system(self, global_batch_size, num_n_samples, role="train"):
- # 1. initialize TransferQueueStorage
- total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples
+ total_storage_size = train_data_size + val_data_size
self.data_system_storage_units = {}
storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1)
for storage_unit_rank in range(self.config.trainer.num_data_storage_units):
- storage_node = TransferQueueStorageSimpleUnit.options(
+ storage_node = SimpleStorageUnit.options(
placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
- ).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units))
+ ).remote(storage_unit_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units))
self.data_system_storage_units[storage_unit_rank] = storage_node
- logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")
-
- # 2. initialize TransferQueueController
- # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly
- # one controller for a single WorkerGroup.
- self.data_system_controllers = {}
- controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1)
- for controller_rank in range(self.config.trainer.num_data_controllers):
- self.data_system_controllers[controller_rank] = TransferQueueController.options(
- placement_group=controller_placement_group, placement_group_bundle_index=controller_rank
- ).remote(
- num_storage_units=self.config.trainer.num_data_storage_units,
- global_batch_size=global_batch_size,
- num_global_batch=self.config.trainer.num_global_batch,
- num_n_samples=num_n_samples,
- )
- logging.info(f"TransferQueueController #{controller_rank} has been created.")
+ logging.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.")
- # 3. register controller & storage
- self.data_system_controller_infos = process_zmq_server_info(self.data_system_controllers)
- self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units)
+ # 2. Initialize TransferQueueController (single controller only)
- ray.get(
- [
- storage_unit.register_controller_info.remote(self.data_system_controller_infos)
- for storage_unit in self.data_system_storage_units.values()
- ]
- )
+ # Sampler usage instructions:
+ # For GRPO grouped sampling, you can initialize the controller with GRPOGroupNSampler:
+ # Option 1: Pass sampler class (will be instantiated automatically)
+ # self.data_system_controller = TransferQueueController.remote(sampler=GRPOGroupNSampler)
- # 4. create client
- # each client should be allocated to exactly one controller
- create_transferqueue_client(
- client_id="Trainer-" + role,
- controller_infos=self.data_system_controller_infos,
- storage_infos=self.data_system_storage_unit_infos,
- )
- data_system_client = get_transferqueue_client()
- return data_system_client
+ # Option 2: Pass sampler instance (if you need custom configuration)
+ # grpo_sampler = GRPOGroupNSampler()
+ # self.data_system_controller = TransferQueueController.remote(sampler=grpo_sampler)
- def _initialize_val_data_system(self, global_batch_size, num_n_samples, role="val"):
- # 1. initialize TransferQueueStorage
- total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples
- self.val_data_system_storage_units = {}
- storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1)
- for storage_unit_rank in range(self.config.trainer.num_data_storage_units):
- storage_node = TransferQueueStorageSimpleUnit.options(
- placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
- ).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units))
- self.val_data_system_storage_units[storage_unit_rank] = storage_node
- logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")
-
- # 2. initialize TransferQueueController
- # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly
- # one controller for a single WorkerGroup.
- self.val_data_system_controllers = {}
- controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1)
- for controller_rank in range(self.config.trainer.num_data_controllers):
- self.val_data_system_controllers[controller_rank] = TransferQueueController.options(
- placement_group=controller_placement_group, placement_group_bundle_index=controller_rank
- ).remote(
- num_storage_units=self.config.trainer.num_data_storage_units,
- global_batch_size=global_batch_size,
- num_global_batch=self.config.trainer.num_global_batch,
- num_n_samples=num_n_samples,
- )
- logging.info(f"TransferQueueController #{controller_rank} has been created.")
+ # Then use sampling_config in get_meta calls:
+ # sampling_config={"n_samples_per_prompt": 4}
+ self.data_system_controller = TransferQueueController.remote()
+ logging.info("TransferQueueController has been created.")
- # 3. register controller & storage
- self.val_data_system_controller_infos = process_zmq_server_info(self.val_data_system_controllers)
- self.val_data_system_storage_unit_infos = process_zmq_server_info(self.val_data_system_storage_units)
+ # 3. register controller & storage and prepare necessary information
+ self.data_system_controller_info = process_zmq_server_info(self.data_system_controller)
+ self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units)
- ray.get(
- [
- storage_unit.register_controller_info.remote(self.val_data_system_controller_infos)
- for storage_unit in self.val_data_system_storage_units.values()
- ]
- )
+ # Note: Need to generate a new DictConfig with allow_objects=True to preserve ZMQServerInfo instances
+ # (which contain socket connection details). Without this flag, OmegaConf would flatten these objects to dicts,
+ # breaking the transfer queue client initialization.
+ tq_config = OmegaConf.create({}, flags={"allow_objects": True})
+ tq_config.controller_info = self.data_system_controller_info
+ tq_config.storage_unit_infos = self.data_system_storage_unit_infos
+ self.config = OmegaConf.merge(tq_config, self.config)
# 4. create client
- # each client should be allocated to exactly one controller
create_transferqueue_client(
- client_id="Trainer-" + role,
- controller_infos=self.val_data_system_controller_infos,
- storage_infos=self.val_data_system_storage_unit_infos,
+ client_id="Trainer",
+ controller_info=self.data_system_controller_info,
+ config=self.config,
)
- data_system_client = get_val_transferqueue_client()
+ data_system_client = get_transferqueue_client()
return data_system_client
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
@@ -726,19 +683,18 @@ def _validate(self):
if self.config.reward_model.enable and test_batch[0]["reward_model"]["style"] == "model":
return {}
- asyncio.run(self.val_data_system_client.async_put(data=test_batch, global_step=self.global_steps - 1))
+ asyncio.run(self.data_system_client.async_put(data=test_batch, partition_id=f"val_{self.global_steps - 1}"))
# Store original inputs
batch_meta = asyncio.run(
- self.val_data_system_client.async_get_meta(
+ self.data_system_client.async_get_meta(
data_fields=["input_ids", "uid", "reward_model"],
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
- global_step=self.global_steps - 1,
- get_n_samples=False,
+ partition_id=f"val_{self.global_steps - 1}",
task_name="get_data",
)
)
- data = asyncio.run(self.val_data_system_client.async_get_data(batch_meta))
+ data = asyncio.run(self.data_system_client.async_get_data(batch_meta))
input_ids = data["input_ids"]
# TODO: Can we keep special tokens except for padding tokens?
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
@@ -749,11 +705,10 @@ def _validate(self):
sample_gts.extend(ground_truths)
test_gen_meta = asyncio.run(
- self.val_data_system_client.async_get_meta(
+ self.data_system_client.async_get_meta(
data_fields=list(test_batch.keys()), # TODO: (TQ) Get metadata by specified fields
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
- global_step=self.global_steps - 1, # self.global_steps start from 1
- get_n_samples=False,
+ partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
task_name="generate_sequences",
)
)
@@ -779,15 +734,14 @@ def _validate(self):
# Store generated outputs
test_response_meta = asyncio.run(
- self.val_data_system_client.async_get_meta(
+ self.data_system_client.async_get_meta(
data_fields=["responses"],
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
- global_step=self.global_steps - 1, # self.global_steps start from 1
- get_n_samples=False,
+ partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
task_name="get_response",
)
)
- data = asyncio.run(self.val_data_system_client.async_get_data(test_response_meta))
+ data = asyncio.run(self.data_system_client.async_get_data(test_response_meta))
output_ids = data["responses"]
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
sample_outputs.extend(output_texts)
@@ -808,11 +762,10 @@ def _validate(self):
if "rm_scores" in batch_meta.field_names:
compute_reward_fields = ["rm_scores"]
val_reward_meta = asyncio.run(
- self.val_data_system_client.async_get_meta(
+ self.data_system_client.async_get_meta(
data_fields=compute_reward_fields,
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
- global_step=self.global_steps - 1,
- get_n_samples=False,
+ partition_id=f"val_{self.global_steps - 1}",
task_name="compute_reward",
)
)
@@ -832,29 +785,27 @@ def _validate(self):
# collect num_turns of each prompt
if "__num_turns__" in test_batch_meta.field_names:
num_turns_meta = asyncio.run(
- self.val_data_system_client.async_get_meta(
+ self.data_system_client.async_get_meta(
data_fields=["__num_turns__"],
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
- global_step=self.global_steps - 1, # self.global_steps start from 1
- get_n_samples=False,
+ partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
task_name="get_num_turns",
)
)
- data = asyncio.run(self.val_data_system_client.async_get_data(num_turns_meta))
+ data = asyncio.run(self.data_system_client.async_get_data(num_turns_meta))
sample_turns.append(data["__num_turns__"])
data_source = ["unknown"] * reward_tensor.shape[0]
if "data_source" in test_batch_meta.field_names:
data_source_meta = asyncio.run(
- self.val_data_system_client.async_get_meta(
+ self.data_system_client.async_get_meta(
data_fields=["data_source"],
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
- global_step=self.global_steps - 1, # self.global_steps start from 1
- get_n_samples=False,
+ partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
task_name="get_data_source",
)
)
- data = asyncio.run(self.val_data_system_client.async_get_data(data_source_meta))
+ data = asyncio.run(self.data_system_client.async_get_data(data_source_meta))
data_source = data["data_source"]
data_source_lst.append(data_source)
@@ -902,7 +853,7 @@ def _validate(self):
metric_dict["val-aux/num_turns/max"] = sample_turns.max()
metric_dict["val-aux/num_turns/mean"] = sample_turns.mean()
- asyncio.run(self.val_data_system_client.async_clear(self.global_steps - 1))
+ asyncio.run(self.data_system_client.async_clear(partition_id=f"val_{self.global_steps - 1}"))
return metric_dict
def init_workers(self):
@@ -1003,12 +954,7 @@ def init_workers(self):
# set transferqueue server info for each worker
for _, wg in all_wg.items():
- wg.create_transferqueue_client(
- self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train"
- )
- wg.create_transferqueue_client(
- self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val"
- )
+ wg.create_transferqueue_client(self.data_system_controller_info, self.config)
# create async rollout manager and request scheduler
self.async_rollout_mode = False
@@ -1020,12 +966,7 @@ def init_workers(self):
config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg
)
- self.async_rollout_manager.create_transferqueue_client(
- self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train"
- )
- self.async_rollout_manager.create_transferqueue_client(
- self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val"
- )
+ self.async_rollout_manager.create_transferqueue_client(self.data_system_controller_info, self.config)
def _save_checkpoint(self):
from verl.utils.fs import local_mkdir_safe
@@ -1164,17 +1105,41 @@ def _stop_profiling(self, do_profile: bool) -> None:
if self.use_rm:
self.rm_wg.stop_profile()
- def _balance_batch(self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen"):
+ def _balance_batch(
+ self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen", keep_minibatch=False
+ ):
"""Reorder the batchmeta on single controller such that each dp rank gets similar total tokens"""
data = asyncio.run(data_system_client.async_get_data(batch))
attention_mask = data["attention_mask"]
batch_size = attention_mask.shape[0]
- global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,)
+ global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,)
+ global_seqlen_lst = calculate_workload(global_seqlen_lst)
world_size = self.actor_rollout_wg.world_size
- global_partition_lst = get_seqlen_balanced_partitions(
- global_seqlen_lst, k_partitions=world_size, equal_size=True
- )
+ if keep_minibatch:
+ # Decouple the DP balancing and mini-batching.
+ minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size", None)
+ if minibatch_size is None:
+ raise ValueError("'ppo_mini_batch_size' must be set in actor config when 'keep_minibatch' is True.")
+ minibatch_num = len(global_seqlen_lst) // minibatch_size
+ global_partition_lst = [[] for _ in range(world_size)]
+ for i in range(minibatch_num):
+ rearrange_minibatch_lst = get_seqlen_balanced_partitions(
+ global_seqlen_lst[i * minibatch_size : (i + 1) * minibatch_size],
+ k_partitions=world_size,
+ equal_size=True,
+ )
+ for j, part in enumerate(rearrange_minibatch_lst):
+ global_partition_lst[j].extend([x + minibatch_size * i for x in part])
+ else:
+ global_partition_lst = get_seqlen_balanced_partitions(
+ global_seqlen_lst, k_partitions=world_size, equal_size=True
+ )
+ # Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel.
+ for idx, partition in enumerate(global_partition_lst):
+ partition.sort(key=lambda x: (global_seqlen_lst[x], x))
+ ordered_partition = partition[::2] + partition[1::2][::-1]
+ global_partition_lst[idx] = ordered_partition
# reorder based on index. The data will be automatically equally partitioned by dispatch function
global_idx = [j for partition in global_partition_lst for j in partition]
global_balance_stats = log_seqlen_unbalance(
@@ -1313,8 +1278,7 @@ def fit(self):
timing_raw = {}
base_get_meta_kwargs = dict(
batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n,
- global_step=self.global_steps - 1, # self.global_steps starts from 1
- get_n_samples=False,
+ partition_id=f"train_{self.global_steps - 1}", # self.global_steps starts from 1
)
with marked_timer("start_profile", timing_raw):
@@ -1333,7 +1297,9 @@ def fit(self):
batch_dict, repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
)
batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict)
- asyncio.run(self.data_system_client.async_put(data=batch, global_step=self.global_steps - 1))
+ asyncio.run(
+ self.data_system_client.async_put(data=batch, partition_id=f"train_{self.global_steps - 1}")
+ )
gen_meta = asyncio.run(
self.data_system_client.async_get_meta(
@@ -1709,8 +1675,7 @@ def fit(self):
],
batch_size=self.config.data.train_batch_size
* self.config.actor_rollout_ref.rollout.n,
- global_step=self.global_steps - 1,
- get_n_samples=False,
+ partition_id=f"train_{self.global_steps - 1}",
task_name="update_actor",
)
)
@@ -1735,8 +1700,7 @@ def fit(self):
self.data_system_client.async_get_meta(
data_fields=data_fields,
batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n,
- global_step=self.global_steps - 1,
- get_n_samples=False,
+ partition_id=f"train_{self.global_steps - 1}",
task_name="log_rollout",
)
)
@@ -1857,7 +1821,7 @@ def fit(self):
# TODO: (TQ) support transfer queue
self.train_dataloader.sampler.update(batch=batch)
- asyncio.run(self.data_system_client.async_clear(self.global_steps - 1))
+ asyncio.run(self.data_system_client.async_clear(partition_id=f"train_{self.global_steps - 1}"))
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
diff --git a/requirements_transferqueue.txt b/requirements_transferqueue.txt
index 8479d27bb21..621682abbf7 100644
--- a/requirements_transferqueue.txt
+++ b/requirements_transferqueue.txt
@@ -1,2 +1,2 @@
# requirements.txt records the full set of dependencies for development
-git+https://github.com/TransferQueue/TransferQueue.git@68c04e7
+transferqueue==0.1.1.dev2
diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py
index 2513c57f99c..399ac75a063 100644
--- a/verl/single_controller/base/worker.py
+++ b/verl/single_controller/base/worker.py
@@ -131,13 +131,13 @@ def _query_collect_info(self, mesh_name: str):
return self.__collect_dp_rank[mesh_name]
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
- def create_transferqueue_client(self, controller_infos, storage_infos, role="train"):
+ def create_transferqueue_client(self, controller_info, config):
from verl.utils.transferqueue_utils import create_transferqueue_client
create_transferqueue_client(
- client_id=f"{role}_worker_{self.rank}",
- controller_infos=controller_infos,
- storage_infos=storage_infos,
+ client_id=f"worker_{self.rank}",
+ controller_info=controller_info,
+ config=config,
)
@classmethod
diff --git a/verl/utils/transferqueue_utils.py b/verl/utils/transferqueue_utils.py
index 27160571ef3..c692578e3a0 100644
--- a/verl/utils/transferqueue_utils.py
+++ b/verl/utils/transferqueue_utils.py
@@ -38,32 +38,24 @@ class BatchMeta:
from verl.protocol import DataProto
_TRANSFER_QUEUE_CLIENT = None
-_VAL_TRANSFER_QUEUE_CLIENT = None
is_transferqueue_enabled = os.environ.get("TRANSFER_QUEUE_ENABLE", False)
def create_transferqueue_client(
client_id: str,
- controller_infos: dict[Any, "ZMQServerInfo"],
- storage_infos: dict[Any, "ZMQServerInfo"],
+ controller_info: dict[Any, "ZMQServerInfo"],
+ config,
) -> None:
global _TRANSFER_QUEUE_CLIENT
- global _VAL_TRANSFER_QUEUE_CLIENT
- if "val" in client_id:
- _VAL_TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_infos, storage_infos)
- else:
- _TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_infos, storage_infos)
+ _TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_info)
+ _TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
def get_transferqueue_client() -> "AsyncTransferQueueClient":
return _TRANSFER_QUEUE_CLIENT
-def get_val_transferqueue_client() -> "AsyncTransferQueueClient":
- return _VAL_TRANSFER_QUEUE_CLIENT
-
-
def _run_async_in_temp_loop(async_func: Callable[..., Any], *args, **kwargs) -> Any:
# Use a temporary event loop in a new thread because event
# loop may already exist in server mode
@@ -109,10 +101,7 @@ async def _async_batchmeta_to_dataproto(batchmeta: "BatchMeta") -> DataProto:
meta_info=batchmeta.extra_info.copy(),
)
- if batchmeta.extra_info.get("validate", False):
- tensordict = await _VAL_TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta)
- else:
- tensordict = await _TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta)
+ tensordict = await _TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta)
return DataProto.from_tensordict(tensordict, meta_info=batchmeta.extra_info.copy())
@@ -130,10 +119,7 @@ async def _async_update_batchmeta_with_output(output: DataProto, batchmeta: "Bat
for key in output.meta_info.keys():
tensordict.pop(key)
batchmeta.add_fields(tensordict)
- if batchmeta.extra_info.get("validate", False):
- await _VAL_TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta)
- else:
- await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta)
+ await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta)
def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta") -> None: