diff --git a/docs/data/transfer_queue.md b/docs/data/transfer_queue.md index 4532d42ed56..877a46b9063 100644 --- a/docs/data/transfer_queue.md +++ b/docs/data/transfer_queue.md @@ -1,52 +1,73 @@ # TransferQueue Data System -Last updated: 09/28/2025. +Last updated: 11/17/2025. This doc introduce [TransferQueue](https://github.com/TransferQueue/TransferQueue), an asynchronous streaming data management system for efficient post-training.

Overview

-TransferQueue is a high-performance data storage and transfer system with panoramic data visibility and streaming scheduling capabilities, optimized for efficient dataflow in post-training workflows. +TransferQueue is a high-performance data storage and transfer module with panoramic data visibility and streaming scheduling capabilities, optimized for efficient dataflow in post-training workflows.

- +

- -TransferQueue offers **fine-grained, sample-level** data management capabilities, serving as a data gateway that decouples explicit data dependencies across computational tasks. This enables a divide-and-conquer approach, significantly simplifying the design of the algorithm controller. - +TransferQueue offers **fine-grained, sample-level** data management and **load-balancing** (on the way) capabilities, serving as a data gateway that decouples explicit data dependencies across computational tasks. This enables a divide-and-conquer approach, significantly simplifies the algorithm controller design.

- +

+

Updates

- + - **Nov 10, 2025**: We disentangle the data retrieval logic from TransferQueueController [PR#101](https://github.com/TransferQueue/TransferQueue/pull/101). Now you can implement your own `Sampler` to control how to consume the data. + - **Nov 5, 2025**: We provide a `KVStorageManager` that simplifies the integration with KV-based storage backends [PR#96](https://github.com/TransferQueue/TransferQueue/pull/96). The first available KV-based backend is [Yuanrong](https://gitee.com/openeuler/yuanrong-datasystem). + - **Nov 4, 2025**: Data partition capability is available in [PR#98](https://github.com/TransferQueue/TransferQueue/pull/98). Now you can define logical data partitions to manage your train/val/test datasets. + - **Oct 25, 2025**: We make storage backends pluggable in [PR#66](https://github.com/TransferQueue/TransferQueue/pull/66). You can try to integrate your own storage backend with TransferQueue now! + - **Oct 21, 2025**: Official integration into verl is ready [verl/pulls/3649](https://github.com/volcengine/verl/pull/3649). Following PRs will optimize the single controller architecture by fully decoupling data & control flows. + - **July 22, 2025**: We present a series of Chinese blogs on Zhihu 1, 2. + - **July 21, 2025**: We started an RFC on verl community [verl/RFC#2662](https://github.com/volcengine/verl/discussions/2662). + - **July 2, 2025**: We publish the paper [AsyncFlow](https://arxiv.org/abs/2507.01663).

Components

+### Control Plane: Panoramic Data Management +In the control plane, `TransferQueueController` tracks the **production status** and **consumption status** of each training sample as metadata. When all the required data fields are ready (i.e., written to the `TransferQueueStorageManager`), we know that this data sample can be consumed by downstream tasks. -### Control Plane: Panoramic Data Management - -In the control plane, `TransferQueueController` tracks the **production status** and **consumption status** of each training sample as metadata. When all the required data fields are ready (i.e., written to the `TransferQueueStorage`), we know that this data sample can be consumed by downstream tasks. - -For consumption status, we record the consumption records for each computational task (e.g., `generate_sequences`, `compute_log_prob`, etc.). Therefore, even different computation tasks require the same data field, they can consume the data independently without interfering with each other. - +For consumption status, we record the consumption records for each computational task (e.g., `generate_sequences`, `compute_log_prob`, etc.). Therefore, even when different computation tasks require the same data field, they can consume the data independently without interfering with each other.

- +

+To make the data retrieval process more customizable, we provide a `Sampler` class that allows users to define their own data retrieval and consumption logic. Refer to the [Customize](#customize) section for details. -> In the future, we plan to support **load-balancing** and **dynamic batching** capabilities in the control plane. Besides, we will support data management for disaggregated frameworks where each rank manages the data retrieval by itself, rather than coordinated by a single controller. +> In the future, we plan to support **load-balancing** and **dynamic batching** capabilities in the control plane. Additionally, we will support data management for disaggregated frameworks where each rank manages the data retrieval by itself, rather than coordinated by a single controller. ### Data Plane: Distributed Data Storage -In the data plane, `TransferQueueStorageSimpleUnit` serves as a naive storage unit based on CPU memory, responsible for the actual storage and retrieval of data. Each storage unit can be deployed on a separate node, allowing for distributed data management. +In the data plane, we provide a pluggable design that enables TransferQueue to integrate with different storage backends according to user requirements. + +Specifically, we provide a `TransferQueueStorageManager` abstraction class that defines the core APIs as follows: -`TransferQueueStorageSimpleUnit` employs a 2D data structure as follows: +- `async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None` +- `async def get_data(self, metadata: BatchMeta) -> TensorDict` +- `async def clear_data(self, metadata: BatchMeta) -> None` + +This class encapsulates the core interaction logic within the TransferQueue system. You only need to write a simple subclass to integrate your own storage backend. Refer to the [Customize](#customize) section for details. + +Currently, we support the following storage backends: + +- SimpleStorageUnit: A basic CPU memory storage with minimal data format constraints and easy usability. +- [MoonCakeStore](https://github.com/kvcache-ai/Mooncake): A high-performance, KV-based hierarchical storage that supports RDMA transport between GPU and DRAM. +- [Yuanrong](https://gitee.com/openeuler/yuanrong-datasystem): An Ascend native data system that provides hierarchical storage interfaces including HBM/DRAM/SSD. +- [Ray Direct Transport](https://docs.ray.io/en/master/ray-core/direct-transport.html): Ray's new feature that allows Ray to store and pass objects directly between Ray actors. + +Among them, `SimpleStorageUnit` serves as our default storage backend, coordinated by the `AsyncSimpleStorageManager` class. Each storage unit can be deployed on a separate node, allowing for distributed data management. + +`SimpleStorageUnit` employs a 2D data structure as follows: - Each row corresponds to a training sample, assigned a unique index within the corresponding global batch. - Each column represents the input/output data fields for computational tasks. @@ -54,29 +75,22 @@ In the data plane, `TransferQueueStorageSimpleUnit` serves as a naive storage un This data structure design is motivated by the computational characteristics of the post-training process, where each training sample is generated in a relayed manner across task pipelines. It provides an accurate addressing capability, which allows fine-grained, concurrent data read/write operations in a streaming manner.

- +

- -> In the future, we plan to implement a **general storage abstraction layer** to support various storage backends. Through this abstraction, we hope to integrate high-performance storage solutions such as [MoonCakeStore](https://github.com/kvcache-ai/Mooncake) to support device-to-device data transfer through RDMA, further enhancing data transfer efficiency for large-scale data. - - ### User Interface: Asynchronous & Synchronous Client - The interaction workflow of TransferQueue system is as follows: 1. A process sends a read request to the `TransferQueueController`. 2. `TransferQueueController` scans the production and consumption metadata for each sample (row), and dynamically assembles a micro-batch metadata according to the load-balancing policy. This mechanism enables sample-level data scheduling. 3. The process retrieves the actual data from distributed storage units using the metadata provided by the controller. -To simplify the usage of TransferQueue, we have encapsulated this process into `AsyncTransferQueueClient` and `TransferQueueClient`. These clients provide both asynchronous and synchronous interfaces for data transfer, allowing users to easily integrate TransferQueue to their framework. - - -> In the future, we will provide a `StreamingDataLoader` interface for disaggregated frameworks as discussed in [RFC#2662](https://github.com/volcengine/verl/discussions/2662). Leveraging this abstraction, each rank can automatically get its own data like `DataLoader` in PyTorch. The TransferQueue system will handle the underlying data scheduling and transfer logic caused by different parallelism strategies, significantly simplifying the design of disaggregated frameworks. +To simplify the usage of TransferQueue, we have encapsulated this process into `AsyncTransferQueueClient` and `TransferQueueClient`. These clients provide both asynchronous and synchronous interfaces for data transfer, allowing users to easily integrate TransferQueue into their framework. +> In the future, we will provide a `StreamingDataLoader` interface for disaggregated frameworks as discussed in [issue#85](https://github.com/TransferQueue/TransferQueue/issues/85) and [verl/RFC#2662](https://github.com/volcengine/verl/discussions/2662). Leveraging this abstraction, each rank can automatically get its own data like `DataLoader` in PyTorch. The TransferQueue system will handle the underlying data scheduling and transfer logic caused by different parallelism strategies, significantly simplifying the design of disaggregated frameworks. -

Show Cases

+

🔥 Showcases

### General Usage @@ -89,16 +103,15 @@ Core interfaces: - (async_)put(data:TensorDict, metadata:BatchMeta, global_step) - (async_)clear(global_step: int) - We will soon release a detailed tutorial and API documentation. ### verl Example +The primary motivation for integrating TransferQueue to verl now is to **alleviate the data transfer bottleneck of the single controller `RayPPOTrainer`**. Currently, all `DataProto` objects must be routed through `RayPPOTrainer`, resulting in a single point bottleneck of the whole post-training system. -The primary motivation for integrating TransferQueue to verl now is to **alleviate the data transfer bottleneck of the single controller `RayPPOTrainer`**. Currently, all `DataProto` objects must be routed through `RayPPOTrainer`, resulting in a single point bottleneck of the whole post-training system. +![verl_dataflow_DataProto](https://github.com/TransferQueue/community_doc/blob/main/docs/verl_workflow.jpeg?raw=true) -![verl_dataflow_DataProto](https://cdn.nlark.com/yuque/0/2025/jpeg/23208217/1758704289414-bcc54228-716b-4d4a-ad3b-f9ace6d10fcf.jpeg) Leveraging TransferQueue, we separate experience data transfer from metadata dispatch by @@ -106,12 +119,134 @@ Leveraging TransferQueue, we separate experience data transfer from metadata dis - Preserving verl's original Dispatch/Collect logic via BatchMeta (maintaining single-controller debuggability) - Accelerating data transfer by TransferQueue's distributed storage units -![verl_dataflow_TransferQueue](https://cdn.nlark.com/yuque/0/2025/jpeg/23208217/1758704301666-0807dc06-766c-4a2d-9cde-889a6bb56b34.jpeg) +![verl_dataflow_TransferQueue](https://github.com/TransferQueue/community_doc/blob/main/docs/verl_workflow_with_tq.jpeg?raw=true) + + +You may refer to the [recipe](https://github.com/TransferQueue/TransferQueue/tree/dev/recipe/simple_use_case), where we mimic the verl usage in both async & sync scenarios. Official integration to verl is also available now at [verl/pulls/3649](https://github.com/volcengine/verl/pull/3649) (with subsequent PRs to further optimize the integration). -You may refer to the [recipe](https://github.com/TransferQueue/TransferQueue/tree/dev/recipe/simple_use_case), where we mimic the verl usage in both async & sync scenarios. +### Use Python package +```bash +pip install TransferQueue==0.1.1.dev2 +``` +### Build wheel package from source code + +Follow these steps to build and install: +1. Clone the source code from the GitHub repository + ```bash + git clone https://github.com/TransferQueue/TransferQueue/ + cd TransferQueue + ``` + +2. Install dependencies + ```bash + pip install -r requirements.txt + ``` + +3. Build and install + ```bash + python -m build --wheel + pip install dist/*.whl + ``` + +

📊 Performance

+ +

+ +

+> Note: The above benchmark for TransferQueue is based on our naive `SimpleStorageUnit` backend. By introducing high-performance storage backends and optimizing serialization/deserialization, we expect to achieve even better performance. Warmly welcome contributions from the community! + +For detailed performance benchmarks, please refer to [this blog](https://www.yuque.com/haomingzi-lfse7/hlx5g0/tml8ke0zkgn6roey?singleDoc#). + +

🛠️ Customize TransferQueue

+ +### Define your own data retrieval logic +We provide a `BaseSampler` abstraction class, which defines the following interface: + +```python3 +@abstractmethod +def sample( + self, + ready_indexes: list[int], + batch_size: int, + *args: Any, + **kwargs: Any, +) -> tuple[list[int], list[int]]: + """Sample a batch of indices from the ready indices. + + Args: + ready_indexes: List of global indices for which all required fields of the + corresponding samples have been produced, and the samples are not labeled as + consumed in the corresponding task. + batch_size: Number of samples to select + *args: Additional positional arguments for specific sampler implementations + **kwargs: Additional keyword arguments for specific sampler implementations + + Returns: + List of sampled global indices of length batch_size + List of global indices of length batch_size that should be labeled as consumed + (will never be retrieved in the future) + + Raises: + ValueError: If batch_size is invalid or ready_indexes is insufficient + """ + raise NotImplementedError("Subclasses must implement sample") +``` + +In this design, we separate data retrieval and data consumption through the two return values, which enables us to easily control sample replacement. We have implemented two reference designs: `SequentialSampler` and `GRPOGroupNSampler`. + +The `Sampler` class or instance should be passed to the `TransferQueueController` during initialization. During each `get_meta` call, you can provide dynamic sampling parameters to the `Sampler`. + +```python3 +from transfer_queue import TransferQueueController, TransferQueueClient, GRPOGroupNSampler, process_zmq_server_info + +# Option 1: Pass the sampler class to the TransferQueueController +controller = TransferQueueController.remote(GRPOGroupNSampler) + +# Option 2: Pass the sampler instance to the TransferQueueController (if you need custom configuration) +your_own_sampler = YourOwnSampler(config) +controller = TransferQueueController.remote(your_own_sampler) + +# Use the sampler +batch_meta = client.get_meta( + data_fields=["input_ids", "attention_mask"], + batch_size=8, + partition_id="train_0", + task_name="generate_sequences", + sampling_config={"n_samples_per_prompt": 4} # Put the required sampling parameters here +) +``` + +### How to integrate a new storage backend + +The data plane is organized as follows: +```text + transfer_queue/ + ├── storage/ + │ ├── __init__.py + │ │── simple_backend.py # SimpleStorageUnit、StorageUnitData、StorageMetaGroup + │ ├── managers/ # Managers are upper level interfaces that encapsulate the interaction logic with TQ system. + │ │ ├── __init__.py + │ │ ├──base.py # TransferQueueStorageManager, KVStorageManager + │ │ ├──simple_backend_manager.py # AsyncSimpleStorageManager + │ │ ├──yuanrong_manager.py # YuanrongStorageManager + │ │ ├──mooncake_manager.py # MooncakeStorageManager + │ │ └──factory.py # TransferQueueStorageManagerFactory + │ └── clients/ # Clients are lower level interfaces that directly manipulate the target storage backend. + │ │ ├── __init__.py + │ │ ├── base.py # TransferQueueStorageKVClient + │ │ ├── yuanrong_client.py # YRStorageClient + │ │ ├── mooncake_client.py # MooncakeStoreClient + │ │ └── factory.py # TransferQueueStorageClientFactory +``` + +To integrate TransferQueue with a custom storage backend, start by implementing a subclass that inherits from `TransferQueueStorageManager`. This subclass acts as an adapter between the TransferQueue system and the target storage backend. For KV-based storage backends, you can simply inherit from `KVStorageManager`, which can serve as the general manager for all KV-based backends. + +Distributed storage backends often come with their own native clients serving as the interface of the storage system. In such cases, a low-level adapter for this client can be written, following the examples provided in the `storage/clients` directory. + +Factory classes are provided for both `StorageManager` and `StorageClient` to facilitate easy integration. Adding necessary descriptions of required parameters in the factory class helps enhance the overall user experience. diff --git a/recipe/transfer_queue/agent_loop.py b/recipe/transfer_queue/agent_loop.py index 871ae8025c0..7f936e6730e 100644 --- a/recipe/transfer_queue/agent_loop.py +++ b/recipe/transfer_queue/agent_loop.py @@ -67,10 +67,7 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Data return timing - def create_transferqueue_client(self, controller_infos, storage_infos, role): + def create_transferqueue_client(self, controller_info, config): ray.get( - [ - worker.create_transferqueue_client.remote(controller_infos, storage_infos, role) - for worker in self.agent_loop_workers - ] + [worker.create_transferqueue_client.remote(controller_info, config) for worker in self.agent_loop_workers] ) diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index d6adbddb676..daa2f8d95b6 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -41,8 +41,8 @@ from tqdm import tqdm from transfer_queue import ( BatchMeta, + SimpleStorageUnit, TransferQueueController, - TransferQueueStorageSimpleUnit, get_placement_group, process_zmq_server_info, ) @@ -81,6 +81,7 @@ from verl.utils.metric import reduce_metrics from verl.utils.rollout_skip import RolloutSkip from verl.utils.seqlen_balancing import ( + calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance, ) @@ -89,7 +90,6 @@ from verl.utils.transferqueue_utils import ( create_transferqueue_client, get_transferqueue_client, - get_val_transferqueue_client, tqbridge, ) @@ -412,109 +412,66 @@ def __init__( self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) - self.data_system_client = self._initialize_train_data_system( - self.config.data.train_batch_size, self.config.actor_rollout_ref.rollout.n + self.data_system_client = self._initialize_data_system() + + def _initialize_data_system(self): + # 1. initialize TransferQueueStorage + train_data_size = ( + self.config.data.train_batch_size + * self.config.trainer.num_global_batch + * self.config.actor_rollout_ref.rollout.n ) - self.val_data_system_client = self._initialize_val_data_system( - self.val_batch_size, self.config.actor_rollout_ref.rollout.val_kwargs.n + val_data_size = ( + self.val_batch_size + * self.config.trainer.num_global_batch + * self.config.actor_rollout_ref.rollout.val_kwargs.n ) - def _initialize_train_data_system(self, global_batch_size, num_n_samples, role="train"): - # 1. initialize TransferQueueStorage - total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples + total_storage_size = train_data_size + val_data_size self.data_system_storage_units = {} storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1) for storage_unit_rank in range(self.config.trainer.num_data_storage_units): - storage_node = TransferQueueStorageSimpleUnit.options( + storage_node = SimpleStorageUnit.options( placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank - ).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units)) + ).remote(storage_unit_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units)) self.data_system_storage_units[storage_unit_rank] = storage_node - logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.") - - # 2. initialize TransferQueueController - # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly - # one controller for a single WorkerGroup. - self.data_system_controllers = {} - controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1) - for controller_rank in range(self.config.trainer.num_data_controllers): - self.data_system_controllers[controller_rank] = TransferQueueController.options( - placement_group=controller_placement_group, placement_group_bundle_index=controller_rank - ).remote( - num_storage_units=self.config.trainer.num_data_storage_units, - global_batch_size=global_batch_size, - num_global_batch=self.config.trainer.num_global_batch, - num_n_samples=num_n_samples, - ) - logging.info(f"TransferQueueController #{controller_rank} has been created.") + logging.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.") - # 3. register controller & storage - self.data_system_controller_infos = process_zmq_server_info(self.data_system_controllers) - self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units) + # 2. Initialize TransferQueueController (single controller only) - ray.get( - [ - storage_unit.register_controller_info.remote(self.data_system_controller_infos) - for storage_unit in self.data_system_storage_units.values() - ] - ) + # Sampler usage instructions: + # For GRPO grouped sampling, you can initialize the controller with GRPOGroupNSampler: + # Option 1: Pass sampler class (will be instantiated automatically) + # self.data_system_controller = TransferQueueController.remote(sampler=GRPOGroupNSampler) - # 4. create client - # each client should be allocated to exactly one controller - create_transferqueue_client( - client_id="Trainer-" + role, - controller_infos=self.data_system_controller_infos, - storage_infos=self.data_system_storage_unit_infos, - ) - data_system_client = get_transferqueue_client() - return data_system_client + # Option 2: Pass sampler instance (if you need custom configuration) + # grpo_sampler = GRPOGroupNSampler() + # self.data_system_controller = TransferQueueController.remote(sampler=grpo_sampler) - def _initialize_val_data_system(self, global_batch_size, num_n_samples, role="val"): - # 1. initialize TransferQueueStorage - total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples - self.val_data_system_storage_units = {} - storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1) - for storage_unit_rank in range(self.config.trainer.num_data_storage_units): - storage_node = TransferQueueStorageSimpleUnit.options( - placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank - ).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units)) - self.val_data_system_storage_units[storage_unit_rank] = storage_node - logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.") - - # 2. initialize TransferQueueController - # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly - # one controller for a single WorkerGroup. - self.val_data_system_controllers = {} - controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1) - for controller_rank in range(self.config.trainer.num_data_controllers): - self.val_data_system_controllers[controller_rank] = TransferQueueController.options( - placement_group=controller_placement_group, placement_group_bundle_index=controller_rank - ).remote( - num_storage_units=self.config.trainer.num_data_storage_units, - global_batch_size=global_batch_size, - num_global_batch=self.config.trainer.num_global_batch, - num_n_samples=num_n_samples, - ) - logging.info(f"TransferQueueController #{controller_rank} has been created.") + # Then use sampling_config in get_meta calls: + # sampling_config={"n_samples_per_prompt": 4} + self.data_system_controller = TransferQueueController.remote() + logging.info("TransferQueueController has been created.") - # 3. register controller & storage - self.val_data_system_controller_infos = process_zmq_server_info(self.val_data_system_controllers) - self.val_data_system_storage_unit_infos = process_zmq_server_info(self.val_data_system_storage_units) + # 3. register controller & storage and prepare necessary information + self.data_system_controller_info = process_zmq_server_info(self.data_system_controller) + self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units) - ray.get( - [ - storage_unit.register_controller_info.remote(self.val_data_system_controller_infos) - for storage_unit in self.val_data_system_storage_units.values() - ] - ) + # Note: Need to generate a new DictConfig with allow_objects=True to preserve ZMQServerInfo instances + # (which contain socket connection details). Without this flag, OmegaConf would flatten these objects to dicts, + # breaking the transfer queue client initialization. + tq_config = OmegaConf.create({}, flags={"allow_objects": True}) + tq_config.controller_info = self.data_system_controller_info + tq_config.storage_unit_infos = self.data_system_storage_unit_infos + self.config = OmegaConf.merge(tq_config, self.config) # 4. create client - # each client should be allocated to exactly one controller create_transferqueue_client( - client_id="Trainer-" + role, - controller_infos=self.val_data_system_controller_infos, - storage_infos=self.val_data_system_storage_unit_infos, + client_id="Trainer", + controller_info=self.data_system_controller_info, + config=self.config, ) - data_system_client = get_val_transferqueue_client() + data_system_client = get_transferqueue_client() return data_system_client def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): @@ -726,19 +683,18 @@ def _validate(self): if self.config.reward_model.enable and test_batch[0]["reward_model"]["style"] == "model": return {} - asyncio.run(self.val_data_system_client.async_put(data=test_batch, global_step=self.global_steps - 1)) + asyncio.run(self.data_system_client.async_put(data=test_batch, partition_id=f"val_{self.global_steps - 1}")) # Store original inputs batch_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=["input_ids", "uid", "reward_model"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", task_name="get_data", ) ) - data = asyncio.run(self.val_data_system_client.async_get_data(batch_meta)) + data = asyncio.run(self.data_system_client.async_get_data(batch_meta)) input_ids = data["input_ids"] # TODO: Can we keep special tokens except for padding tokens? input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] @@ -749,11 +705,10 @@ def _validate(self): sample_gts.extend(ground_truths) test_gen_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=list(test_batch.keys()), # TODO: (TQ) Get metadata by specified fields batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 task_name="generate_sequences", ) ) @@ -779,15 +734,14 @@ def _validate(self): # Store generated outputs test_response_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=["responses"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 task_name="get_response", ) ) - data = asyncio.run(self.val_data_system_client.async_get_data(test_response_meta)) + data = asyncio.run(self.data_system_client.async_get_data(test_response_meta)) output_ids = data["responses"] output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] sample_outputs.extend(output_texts) @@ -808,11 +762,10 @@ def _validate(self): if "rm_scores" in batch_meta.field_names: compute_reward_fields = ["rm_scores"] val_reward_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=compute_reward_fields, batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", task_name="compute_reward", ) ) @@ -832,29 +785,27 @@ def _validate(self): # collect num_turns of each prompt if "__num_turns__" in test_batch_meta.field_names: num_turns_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=["__num_turns__"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 task_name="get_num_turns", ) ) - data = asyncio.run(self.val_data_system_client.async_get_data(num_turns_meta)) + data = asyncio.run(self.data_system_client.async_get_data(num_turns_meta)) sample_turns.append(data["__num_turns__"]) data_source = ["unknown"] * reward_tensor.shape[0] if "data_source" in test_batch_meta.field_names: data_source_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=["data_source"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 task_name="get_data_source", ) ) - data = asyncio.run(self.val_data_system_client.async_get_data(data_source_meta)) + data = asyncio.run(self.data_system_client.async_get_data(data_source_meta)) data_source = data["data_source"] data_source_lst.append(data_source) @@ -902,7 +853,7 @@ def _validate(self): metric_dict["val-aux/num_turns/max"] = sample_turns.max() metric_dict["val-aux/num_turns/mean"] = sample_turns.mean() - asyncio.run(self.val_data_system_client.async_clear(self.global_steps - 1)) + asyncio.run(self.data_system_client.async_clear(partition_id=f"val_{self.global_steps - 1}")) return metric_dict def init_workers(self): @@ -1003,12 +954,7 @@ def init_workers(self): # set transferqueue server info for each worker for _, wg in all_wg.items(): - wg.create_transferqueue_client( - self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train" - ) - wg.create_transferqueue_client( - self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val" - ) + wg.create_transferqueue_client(self.data_system_controller_info, self.config) # create async rollout manager and request scheduler self.async_rollout_mode = False @@ -1020,12 +966,7 @@ def init_workers(self): config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg ) - self.async_rollout_manager.create_transferqueue_client( - self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train" - ) - self.async_rollout_manager.create_transferqueue_client( - self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val" - ) + self.async_rollout_manager.create_transferqueue_client(self.data_system_controller_info, self.config) def _save_checkpoint(self): from verl.utils.fs import local_mkdir_safe @@ -1164,17 +1105,41 @@ def _stop_profiling(self, do_profile: bool) -> None: if self.use_rm: self.rm_wg.stop_profile() - def _balance_batch(self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen"): + def _balance_batch( + self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen", keep_minibatch=False + ): """Reorder the batchmeta on single controller such that each dp rank gets similar total tokens""" data = asyncio.run(data_system_client.async_get_data(batch)) attention_mask = data["attention_mask"] batch_size = attention_mask.shape[0] - global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,) + global_seqlen_lst = calculate_workload(global_seqlen_lst) world_size = self.actor_rollout_wg.world_size - global_partition_lst = get_seqlen_balanced_partitions( - global_seqlen_lst, k_partitions=world_size, equal_size=True - ) + if keep_minibatch: + # Decouple the DP balancing and mini-batching. + minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size", None) + if minibatch_size is None: + raise ValueError("'ppo_mini_batch_size' must be set in actor config when 'keep_minibatch' is True.") + minibatch_num = len(global_seqlen_lst) // minibatch_size + global_partition_lst = [[] for _ in range(world_size)] + for i in range(minibatch_num): + rearrange_minibatch_lst = get_seqlen_balanced_partitions( + global_seqlen_lst[i * minibatch_size : (i + 1) * minibatch_size], + k_partitions=world_size, + equal_size=True, + ) + for j, part in enumerate(rearrange_minibatch_lst): + global_partition_lst[j].extend([x + minibatch_size * i for x in part]) + else: + global_partition_lst = get_seqlen_balanced_partitions( + global_seqlen_lst, k_partitions=world_size, equal_size=True + ) + # Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel. + for idx, partition in enumerate(global_partition_lst): + partition.sort(key=lambda x: (global_seqlen_lst[x], x)) + ordered_partition = partition[::2] + partition[1::2][::-1] + global_partition_lst[idx] = ordered_partition # reorder based on index. The data will be automatically equally partitioned by dispatch function global_idx = [j for partition in global_partition_lst for j in partition] global_balance_stats = log_seqlen_unbalance( @@ -1313,8 +1278,7 @@ def fit(self): timing_raw = {} base_get_meta_kwargs = dict( batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, - global_step=self.global_steps - 1, # self.global_steps starts from 1 - get_n_samples=False, + partition_id=f"train_{self.global_steps - 1}", # self.global_steps starts from 1 ) with marked_timer("start_profile", timing_raw): @@ -1333,7 +1297,9 @@ def fit(self): batch_dict, repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True ) batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict) - asyncio.run(self.data_system_client.async_put(data=batch, global_step=self.global_steps - 1)) + asyncio.run( + self.data_system_client.async_put(data=batch, partition_id=f"train_{self.global_steps - 1}") + ) gen_meta = asyncio.run( self.data_system_client.async_get_meta( @@ -1709,8 +1675,7 @@ def fit(self): ], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, - global_step=self.global_steps - 1, - get_n_samples=False, + partition_id=f"train_{self.global_steps - 1}", task_name="update_actor", ) ) @@ -1735,8 +1700,7 @@ def fit(self): self.data_system_client.async_get_meta( data_fields=data_fields, batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, - global_step=self.global_steps - 1, - get_n_samples=False, + partition_id=f"train_{self.global_steps - 1}", task_name="log_rollout", ) ) @@ -1857,7 +1821,7 @@ def fit(self): # TODO: (TQ) support transfer queue self.train_dataloader.sampler.update(batch=batch) - asyncio.run(self.data_system_client.async_clear(self.global_steps - 1)) + asyncio.run(self.data_system_client.async_clear(partition_id=f"train_{self.global_steps - 1}")) # TODO: make a canonical logger that supports various backend logger.log(data=metrics, step=self.global_steps) diff --git a/requirements_transferqueue.txt b/requirements_transferqueue.txt index 8479d27bb21..621682abbf7 100644 --- a/requirements_transferqueue.txt +++ b/requirements_transferqueue.txt @@ -1,2 +1,2 @@ # requirements.txt records the full set of dependencies for development -git+https://github.com/TransferQueue/TransferQueue.git@68c04e7 +transferqueue==0.1.1.dev2 diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index 2513c57f99c..399ac75a063 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -131,13 +131,13 @@ def _query_collect_info(self, mesh_name: str): return self.__collect_dp_rank[mesh_name] @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True) - def create_transferqueue_client(self, controller_infos, storage_infos, role="train"): + def create_transferqueue_client(self, controller_info, config): from verl.utils.transferqueue_utils import create_transferqueue_client create_transferqueue_client( - client_id=f"{role}_worker_{self.rank}", - controller_infos=controller_infos, - storage_infos=storage_infos, + client_id=f"worker_{self.rank}", + controller_info=controller_info, + config=config, ) @classmethod diff --git a/verl/utils/transferqueue_utils.py b/verl/utils/transferqueue_utils.py index 27160571ef3..c692578e3a0 100644 --- a/verl/utils/transferqueue_utils.py +++ b/verl/utils/transferqueue_utils.py @@ -38,32 +38,24 @@ class BatchMeta: from verl.protocol import DataProto _TRANSFER_QUEUE_CLIENT = None -_VAL_TRANSFER_QUEUE_CLIENT = None is_transferqueue_enabled = os.environ.get("TRANSFER_QUEUE_ENABLE", False) def create_transferqueue_client( client_id: str, - controller_infos: dict[Any, "ZMQServerInfo"], - storage_infos: dict[Any, "ZMQServerInfo"], + controller_info: dict[Any, "ZMQServerInfo"], + config, ) -> None: global _TRANSFER_QUEUE_CLIENT - global _VAL_TRANSFER_QUEUE_CLIENT - if "val" in client_id: - _VAL_TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_infos, storage_infos) - else: - _TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_infos, storage_infos) + _TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_info) + _TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) def get_transferqueue_client() -> "AsyncTransferQueueClient": return _TRANSFER_QUEUE_CLIENT -def get_val_transferqueue_client() -> "AsyncTransferQueueClient": - return _VAL_TRANSFER_QUEUE_CLIENT - - def _run_async_in_temp_loop(async_func: Callable[..., Any], *args, **kwargs) -> Any: # Use a temporary event loop in a new thread because event # loop may already exist in server mode @@ -109,10 +101,7 @@ async def _async_batchmeta_to_dataproto(batchmeta: "BatchMeta") -> DataProto: meta_info=batchmeta.extra_info.copy(), ) - if batchmeta.extra_info.get("validate", False): - tensordict = await _VAL_TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta) - else: - tensordict = await _TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta) + tensordict = await _TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta) return DataProto.from_tensordict(tensordict, meta_info=batchmeta.extra_info.copy()) @@ -130,10 +119,7 @@ async def _async_update_batchmeta_with_output(output: DataProto, batchmeta: "Bat for key in output.meta_info.keys(): tensordict.pop(key) batchmeta.add_fields(tensordict) - if batchmeta.extra_info.get("validate", False): - await _VAL_TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) - else: - await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) + await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta") -> None: