From a92a942ef92323a31e30befd0adaeba71da9dbbc Mon Sep 17 00:00:00 2001 From: FightingZhen <295632982@qq.com> Date: Tue, 23 Sep 2025 19:56:49 +0800 Subject: [PATCH 01/16] Support storage unit in TransferQueue --- verl/experimental/transfer_queue/__init__.py | 13 + verl/experimental/transfer_queue/storage.py | 515 +++++++++++++++++++ 2 files changed, 528 insertions(+) create mode 100644 verl/experimental/transfer_queue/__init__.py create mode 100644 verl/experimental/transfer_queue/storage.py diff --git a/verl/experimental/transfer_queue/__init__.py b/verl/experimental/transfer_queue/__init__.py new file mode 100644 index 00000000000..1ce90c5eb35 --- /dev/null +++ b/verl/experimental/transfer_queue/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/experimental/transfer_queue/storage.py b/verl/experimental/transfer_queue/storage.py new file mode 100644 index 00000000000..0c6e0c08538 --- /dev/null +++ b/verl/experimental/transfer_queue/storage.py @@ -0,0 +1,515 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import time +from operator import itemgetter +from threading import Thread +from uuid import uuid4 + +import ray +import torch +import zmq +from ray.util import get_node_ip_address +from tensordict import NonTensorStack, TensorDict + +from transfer_queue.utils.utils import TransferQueueRole +from transfer_queue.utils.zmq_utils import ( + ZMQMessage, + ZMQRequestType, + ZMQServerInfo, + create_zmq_socket, + get_free_port, +) + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) + +TQ_STORAGE_POLLER_TIMEOUT = os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 1000) +TQ_STORAGE_HANDSHAKE_TIMEOUT = int(os.environ.get("TQ_STORAGE_HANDSHAKE_TIMEOUT", 30)) +TQ_DATA_UPDATE_RESPONSE_TIMEOUT = int(os.environ.get("TQ_DATA_UPDATE_RESPONSE_TIMEOUT", 600)) + + +class StorageUnitData: + """ + Class used for storing several elements, each element is composed of several fields and corresponding data, like: + ##################################################### + # local_index | field_name1 | field_name2 | ... # + # 0 | item1 | item2 | ... # + # 1 | item3 | item4 | ... # + # 2 | item5 | item6 | ... # + ##################################################### + """ + + def __init__(self, storage_size: int): + # Dict containing field names and corresponding data in the field, e.g. {"field_name1": [data1, data2, ...]} + self.field_data: dict[str, list] = {} + + # Maximum number of elements stored in storage unit + self.storage_size = storage_size + + def get_data(self, fields: list[str], local_indexes: list[int]) -> TensorDict[str, list]: + """ + Get data from storage unit according to given fields and local_indexes. + + param: + fields: Field names used for getting data. + local_indexes: Local indexes used for getting data. + return: + TensorDict with field names as keys, corresponding data list as values. + """ + result: dict[str, list] = {} + + for field in fields: + # Validate field name + if field not in self.field_data: + raise ValueError( + f"StorageUnitData get_data operation receive invalid field: {field} beyond {self.field_data.keys()}" + ) + + if len(local_indexes) == 1: + # The unsqueeze op make the shape from n to (1, n) + gathered_item = self.field_data[field][local_indexes[0]] + if not isinstance(gathered_item, torch.Tensor): + result[field] = NonTensorStack(gathered_item).unsqueeze(0) + else: + result[field] = gathered_item.unsqueeze(0) + else: + gathered_items = list(itemgetter(*local_indexes)(self.field_data[field])) + + if gathered_items: + all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items) + if all_tensors: + result[field] = torch.nested.as_nested_tensor(gathered_items) + else: + result[field] = NonTensorStack(*gathered_items) + + return TensorDict(result) + + def put_data(self, field_data: TensorDict[str, list], local_indexes: list[int]) -> None: + """ + Put or update data into storage unit according to given field_data and local_indexes. + + param: + field_data: Dict with field names as keys, corresponding data in the field as values. + local_indexes: Local indexes used for putting data. + """ + for f in field_data.keys(): + for i, idx in enumerate(local_indexes): + # Validate local_indexes + if idx < 0 or idx >= self.storage_size: + raise ValueError( + f"StorageUnitData put_data operation receive invalid local_index: {idx} beyond " + f"storage_size: {self.storage_size}" + ) + + if f not in self.field_data: + # Initialize new field value list with None + self.field_data[f] = [None] * self.storage_size + + self.field_data[f][idx] = field_data[f][i] + + def clear(self, local_indexes: list[int]) -> None: + """ + Clear data at specified local_indexes by setting all related fields to None. + + param: + local_indexes: local_indexes to clear. + """ + # Validate local_indexes + for idx in local_indexes: + if idx < 0 or idx >= self.storage_size: + raise ValueError( + f"StorageUnitData clear operation receive invalid local_index: {idx} beyond " + f"storage_size: {self.storage_size}" + ) + + # Clear data at specified local_indexes + for f in self.field_data: + for idx in local_indexes: + self.field_data[f][idx] = None + + +@ray.remote(num_cpus=1) +class TransferQueueStorageSimpleUnit: + def __init__(self, storage_size: int): + super().__init__() + self.storage_unit_id = f"TQ_STORAGE_UNIT_{uuid4()}" + self.storage_size = storage_size + self.controller_infos: dict[str, ZMQServerInfo] = {} + + self.experience_data = StorageUnitData(self.storage_size) + + self.zmq_server_info = ZMQServerInfo.create( + role=TransferQueueRole.STORAGE, + id=str(self.storage_unit_id), + ip=get_node_ip_address(), + ports={"put_get_socket": get_free_port()}, + ) + self._init_zmq_socket() + + def _init_zmq_socket(self) -> None: + """ + Initialize ZMQ socket connections between storage unit and controllers/clients: + - controller_handshake_sockets: + Handshake between storage unit and controllers. + - data_status_update_sockets: + Broadcast data update status from storage unit to controllers when handling put operation. + - put_get_socket: + Handle put/get requests from clients. + """ + self.zmq_context = zmq.Context() + + self.controller_handshake_sockets: dict[str, zmq.Socket] = {} + self.data_status_update_sockets: dict[str, zmq.Socket] = {} + + self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER) + self.put_get_socket.bind(self.zmq_server_info.to_addr("put_get_socket")) + + def register_controller_info(self, controller_infos: dict[str, ZMQServerInfo]) -> None: + """ + Build connections between storage unit and controllers, start put/get process. + + param: + controller_infos: Dict with controller infos. + """ + self.controller_infos = controller_infos + + self._init_zmq_sockets_with_controller_infos() + self._connect_to_controller() + self._start_process_put_get() + + def _init_zmq_sockets_with_controller_infos(self) -> None: + """Initialize ZMQ sockets between storage unit and controllers for handshake.""" + for controller_id in self.controller_infos.keys(): + self.controller_handshake_sockets[controller_id] = create_zmq_socket( + self.zmq_context, + zmq.DEALER, + identity=f"{self.storage_unit_id}-controller_handshake_sockets-{uuid4()}".encode(), + ) + self.data_status_update_sockets[controller_id] = create_zmq_socket( + self.zmq_context, + zmq.DEALER, + identity=f"{self.storage_unit_id}-data_status_update_sockets-{uuid4()}".encode(), + ) + + def _connect_to_controller(self) -> None: + """Connect storage unit to all controllers.""" + connected_controllers: set[str] = set() + + # Create zmq poller for handshake confirmation between controller and storage unit + poller = zmq.Poller() + + for controller_id, controller_info in self.controller_infos.items(): + self.controller_handshake_sockets[controller_id].connect(controller_info.to_addr("handshake_socket")) + logger.debug( + f"[{self.zmq_server_info.id}]: Handshake connection from storage unit id #{self.zmq_server_info.id} " + f"to controller id #{controller_id} establish successfully." + ) + + # Send handshake request to controllers + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.HANDSHAKE, + sender_id=self.zmq_server_info.id, + body={ + "storage_unit_id": self.storage_unit_id, + "storage_size": self.storage_size, + }, + ).serialize() + + self.controller_handshake_sockets[controller_id].send(request_msg) + logger.debug( + f"[{self.zmq_server_info.id}]: Send handshake request from storage unit id #{self.zmq_server_info.id} " + f"to controller id #{controller_id} successfully." + ) + + poller.register(self.controller_handshake_sockets[controller_id], zmq.POLLIN) + + start_time = time.time() + while ( + len(connected_controllers) < len(self.controller_infos) + and time.time() - start_time < TQ_STORAGE_HANDSHAKE_TIMEOUT + ): + socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT)) + + for controller_handshake_socket in self.controller_handshake_sockets.values(): + if controller_handshake_socket in socks: + response_msg = ZMQMessage.deserialize(controller_handshake_socket.recv()) + + if response_msg.request_type == ZMQRequestType.HANDSHAKE_ACK: + connected_controllers.add(response_msg.sender_id) + logger.debug( + f"[{self.zmq_server_info.id}]: Get handshake ACK response from " + f"controller id #{str(response_msg.sender_id)} to storage unit id " + f"#{self.zmq_server_info.id} successfully." + ) + + if len(connected_controllers) < len(self.controller_infos): + logger.warning( + f"[{self.zmq_server_info.id}]: Only get {len(connected_controllers)} / {len(self.controller_infos)} " + f"successful handshake connections to controllers from storage unit id #{self.zmq_server_info.id}" + ) + + def _start_process_put_get(self) -> None: + """Create a daemon thread and start put/get process.""" + self.process_put_get_thread = Thread( + target=self._process_put_get, name=f"StorageUnitProcessPutGetThread-{self.zmq_server_info.id}", daemon=True + ) + self.process_put_get_thread.start() + + def _process_put_get(self) -> None: + """Process put_get_socket request.""" + poller = zmq.Poller() + poller.register(self.put_get_socket, zmq.POLLIN) + + while True: + socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT)) + + if self.put_get_socket in socks: + identity, serialized_msg = self.put_get_socket.recv_multipart() + + try: + request_msg = ZMQMessage.deserialize(serialized_msg) + operation = request_msg.request_type + logger.debug(f"[{self.zmq_server_info.id}]: receive operation: {operation}, message: {request_msg}") + + if operation == ZMQRequestType.PUT_DATA: + response_msg = self._handle_put(request_msg) + elif operation == ZMQRequestType.GET_DATA: + response_msg = self._handle_get(request_msg) + elif operation == ZMQRequestType.CLEAR_DATA: + response_msg = self._handle_clear(request_msg) + else: + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR, + sender_id=self.zmq_server_info.id, + body={ + "message": f"Storage unit id #{self.zmq_server_info.id} " + f"receive invalid operation: {operation}." + }, + ) + except Exception as e: + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.PUT_GET_ERROR, + sender_id=self.zmq_server_info.id, + body={ + "message": f"Storage unit id #{self.zmq_server_info.id} occur error in processing " + f"put/get/clear request, detail error message: {str(e)}." + }, + ) + + self.put_get_socket.send_multipart([identity, response_msg.serialize()]) + + def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: + """ + Handle put request, add or update data into storage unit. + + param: + data_parts: ZMQMessage from client. + return: + Put data success response ZMQMessage. + """ + try: + global_indexes = data_parts.body["global_indexes"] + local_indexes = data_parts.body["local_indexes"] + field_data = data_parts.body["field_data"] # field_data should be in {field_name: [real data]} format. + + self.experience_data.put_data(field_data, local_indexes) + + # After put operation finish, send a message to the client + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.PUT_DATA_RESPONSE, sender_id=self.zmq_server_info.id, body={} + ) + + # Gather per-tensor dtype and shape information for each field + # global_indexes, local_indexes, and field_data correspond one-to-one + per_tensor_dtypes: dict[int, torch.dtype] = {} + per_tensor_shapes: dict[int, torch.Size] = {} + + # Initialize the data structure for each global index + for global_idx in global_indexes: + per_tensor_dtypes[global_idx] = {} + per_tensor_shapes[global_idx] = {} + + # For each field, extract dtype and shape for each sample + for field in field_data.keys(): + for i, data_item in enumerate(field_data[field]): + global_idx = global_indexes[i] + per_tensor_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None + per_tensor_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None + + # Broadcast data update message to all controllers with per-tensor dtype/shape information + self._notify_data_update(list(field_data.keys()), global_indexes, per_tensor_dtypes, per_tensor_shapes) + return response_msg + except Exception as e: + return ZMQMessage.create( + request_type=ZMQRequestType.PUT_ERROR, + sender_id=self.zmq_server_info.id, + body={ + "message": f"Failed to put data into storage unit id " + f"#{self.zmq_server_info.id}, detail error message: {str(e)}" + }, + ) + + def _notify_data_update(self, fields, global_indexes, dtypes, shapes) -> None: + """ + Broadcast data status update to all controllers. + + param: + fields: data update related fields. + global_indexes: data update related global_indexes. + dtypes: per-tensor dtypes for each field, in {global_index: {field: dtype}} format. + shapes: per-tensor shapes for each field, in {global_index: {field: shape}} format. + """ + # Create zmq poller for notifying data update information + poller = zmq.Poller() + + # Connect data status update socket to all controllers + for controller_id, controller_info in self.controller_infos.items(): + data_status_update_socket = self.data_status_update_sockets[controller_id] + data_status_update_socket.connect(controller_info.to_addr("data_status_update_socket")) + logger.debug( + f"[{self.zmq_server_info.id}]: Data status update connection from " + f"storage unit id #{self.zmq_server_info.id} to " + f"controller id #{controller_id} establish successfully." + ) + + try: + poller.register(data_status_update_socket, zmq.POLLIN) + + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, + sender_id=self.zmq_server_info.id, + body={ + "fields": fields, + "global_indexes": global_indexes, + "dtypes": dtypes, + "shapes": shapes, + }, + ).serialize() + + data_status_update_socket.send(request_msg) + logger.debug( + f"[{self.zmq_server_info.id}]: Send data status update request " + f"from storage unit id #{self.zmq_server_info.id} " + f"to controller id #{controller_id} successfully." + ) + except Exception as e: + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, + sender_id=self.zmq_server_info.id, + body={ + "message": f"Failed to notify data status update information from " + f"storage unit id #{self.zmq_server_info.id}, " + f"detail error message: {str(e)}" + }, + ).serialize() + + data_status_update_socket.send(request_msg) + + # Make sure all controllers successfully receive data status update information. + response_controllers: set[str] = set() + start_time = time.time() + + while ( + len(response_controllers) < len(self.controller_infos) + and time.time() - start_time < TQ_DATA_UPDATE_RESPONSE_TIMEOUT + ): + socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT)) + + for data_status_update_socket in self.data_status_update_sockets.values(): + if data_status_update_socket in socks: + response_msg = ZMQMessage.deserialize(data_status_update_socket.recv()) + + if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: + response_controllers.add(response_msg.sender_id) + logger.debug( + f"[{self.zmq_server_info.id}]: Get data status update ACK response " + f"from controller id #{response_msg.sender_id} " + f"to storage unit id #{self.zmq_server_info.id} successfully." + ) + + if len(response_controllers) < len(self.controller_infos): + logger.warning( + f"[{self.zmq_server_info.id}]: Storage unit id #{self.zmq_server_info.id} " + f"only get {len(response_controllers)} / {len(self.controller_infos)} " + f"data status update ACK responses from controllers." + ) + + def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage: + """ + Handle get request, return data from storage unit. + + param: + data_parts: ZMQMessage from client. + return: + Get data success response ZMQMessage, containing target data. + """ + try: + fields = data_parts.body["fields"] + local_indexes = data_parts.body["local_indexes"] + + result_data = self.experience_data.get_data(fields, local_indexes) + + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.GET_DATA_RESPONSE, + sender_id=self.zmq_server_info.id, + body={ + "data": result_data, + }, + ) + except Exception as e: + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.GET_ERROR, + sender_id=self.zmq_server_info.id, + body={ + "message": f"Failed to get data from storage unit id #{self.zmq_server_info.id}, " + f"detail error message: {str(e)}" + }, + ) + return response_msg + + def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage: + """ + Handle clear request, clear data in storage unit according to given local_indexes. + + param: + data_parts: ZMQMessage from client, including target local_indexes. + return: + Clear data success response ZMQMessage. + """ + try: + local_indexes = data_parts.body["local_indexes"] + + self.experience_data.clear(local_indexes) + + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.CLEAR_DATA_RESPONSE, + sender_id=self.zmq_server_info.id, + body={"message": f"Clear data in storage unit id #{self.zmq_server_info.id} successfully."}, + ) + except Exception as e: + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.CLEAR_DATA_ERROR, + sender_id=self.zmq_server_info.id, + body={ + "message": f"Failed to clear data in storage unit id #{self.zmq_server_info.id}, " + f"detail error message: {str(e)}" + }, + ) + return response_msg + + def get_zmq_server_info(self) -> ZMQServerInfo: + return self.zmq_server_info From bae27bb758ff33db7d1a9a283c8e2ba03c84e658 Mon Sep 17 00:00:00 2001 From: FightingZhen <295632982@qq.com> Date: Tue, 23 Sep 2025 20:23:34 +0800 Subject: [PATCH 02/16] Fix importance error --- verl/experimental/transfer_queue/storage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/verl/experimental/transfer_queue/storage.py b/verl/experimental/transfer_queue/storage.py index 0c6e0c08538..11e05887785 100644 --- a/verl/experimental/transfer_queue/storage.py +++ b/verl/experimental/transfer_queue/storage.py @@ -25,8 +25,8 @@ from ray.util import get_node_ip_address from tensordict import NonTensorStack, TensorDict -from transfer_queue.utils.utils import TransferQueueRole -from transfer_queue.utils.zmq_utils import ( +from verl.experimental.transfer_queue.utils.utils import TransferQueueRole +from verl.experimental.transfer_queue.utils.zmq_utils import ( ZMQMessage, ZMQRequestType, ZMQServerInfo, From 501e1e23e2b535953de79aa7b33d70a65ba1a303 Mon Sep 17 00:00:00 2001 From: LLLLxmmm <130739718+LLLLxmmm@users.noreply.github.com> Date: Wed, 24 Sep 2025 10:17:41 +0800 Subject: [PATCH 03/16] Support controller in TransferQueue (#2) * Support controller in TransferQueue * Fix import * Fix comments --------- Co-authored-by: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com> --- .../experimental/transfer_queue/controller.py | 756 ++++++++++++++++++ 1 file changed, 756 insertions(+) create mode 100644 verl/experimental/transfer_queue/controller.py diff --git a/verl/experimental/transfer_queue/controller.py b/verl/experimental/transfer_queue/controller.py new file mode 100644 index 00000000000..607fe857322 --- /dev/null +++ b/verl/experimental/transfer_queue/controller.py @@ -0,0 +1,756 @@ +import logging +import math +import os +import threading +import time +from threading import Thread +from typing import Any, Optional +from uuid import uuid4 + +import numpy as np +import ray +import torch +import zmq +from ray.util import get_node_ip_address + +from verl.experimental.transfer_queue.metadata import ( + BatchMeta, + FieldMeta, + SampleMeta, +) +from verl.experimental.transfer_queue.utils.utils import ( + ProductionStatus, + TransferQueueRole, + random_sampler, +) +from verl.experimental.transfer_queue.utils.zmq_utils import ( + ZMQMessage, + ZMQRequestType, + ZMQServerInfo, + create_zmq_socket, + get_free_port, +) + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) + +TQ_CONTROLLER_GET_METADATA_TIMEOUT = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_TIMEOUT", 300)) +TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL", 1)) +TQ_INIT_FIELD_NUM = int(os.environ.get("TQ_INIT_FIELD_NUM", 10)) + + +@ray.remote(num_cpus=1) +class TransferQueueController: + def __init__( + self, + num_storage_units: int, + global_batch_size: int, + num_global_batch: int = 1, + num_n_samples: int = 1, + ) -> None: + """Initialize the TransferQueueController. + + Args: + num_storage_units: Number of storage units in the system + global_batch_size: Size of each global batch + num_global_batch: Number of global batches to maintain in storage + num_n_samples: For each prompt, sample n responses + """ + self.controller_id = f"TQ_CONTROLLER_{uuid4()}" + + self._init_zmq_socket() # Initialize ZMQ sockets for data communication + + self.num_storage_units = num_storage_units + self.global_batch_size = ( + global_batch_size # Used as offset for global index to identify corresponding global step + ) + self.num_global_batch = num_global_batch + self.num_n_samples = num_n_samples + self.total_storage_size = self.global_batch_size * self.num_global_batch * self.num_n_samples + + self.data_production_status = torch.zeros( + self.total_storage_size, TQ_INIT_FIELD_NUM, dtype=torch.int8 + ) # Initialize with default number of fields, dynamically extensible + # task_name -> consumption_status mapping + self.data_consumption_status: dict[str, torch.Tensor] = {} + self.field_name_mapping: dict[ + str, int + ] = {} # Mapping table from field_name to the column indices in self.data_production_status tables + # Per-sample dtype and shape storage: {global_index: {field_name: {'dtype': dtype, 'shape': shape}}} + self.per_tensor_dtype_mapping: dict[int, dict[str, torch.dtype]] = {} + self.per_tensor_shape_mapping: dict[int, dict[str, torch.Size]] = {} + + self._build_index_storage_mapping() + + self._start_process_handshake() + self._start_process_update_data_status() + self._start_process_request() + + def _get_consumption_status(self, task_name: str) -> torch.Tensor: + """ + Get or create the consumption status tensor for a specific task. + The consumption status is a binary, 1D tensor that records whether the corresponding sample has been consumed + by the task. + + Args: + task_name: Name of the consumer task + + Returns: + Consumption status tensor for the specified task + """ + # Retrieve or create the consumption state tensor for a specified consumer + if task_name not in self.data_consumption_status: + # Initialize state for a new consumer + self.data_consumption_status[task_name] = torch.zeros(self.total_storage_size, dtype=torch.int8) + return self.data_consumption_status[task_name] + + def _get_per_tensor_dtype(self, global_index: int, field_name: str) -> Optional[torch.dtype]: + """Get dtype for a specific sample and field. + + Args: + global_index: Global index of the sample + field_name: Name of the field + + Returns: + dtype of the specified field for the sample, or None if not found + """ + return self.per_tensor_dtype_mapping.get(global_index, {}).get(field_name) + + def _get_per_tensor_shape(self, global_index: int, field_name: str) -> Optional[torch.Size]: + """Get shape for a specific sample and field. + + Args: + global_index: Global index of the sample + field_name: Name of the field + + Returns: + Shape of the specified field for the sample, or None if not found + """ + return self.per_tensor_shape_mapping.get(global_index, {}).get(field_name) + + def _step_to_global_index_range(self, global_step: int) -> tuple[int, int]: + """Convert global step to corresponding global index range. + + Args: + global_step: The global step to convert + + Returns: + Tuple of (start_index, end_index) for the given global step + """ + start_idx = (global_step % self.num_global_batch) * self.global_batch_size * self.num_n_samples + end_idx = start_idx + self.global_batch_size * self.num_n_samples + + return start_idx, end_idx + + def generate_data_status_mask( + self, data_fields: list[str], global_step: int, task_name: str + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Generate mask matrix for filtering data based on field availability and consumption status. + + This function is called within _get_meta and generates a mask matrix based on + user-specified fields and the current step. The mask matrix selects the required + rows and columns from self.data_production_status while inversely selecting from + self.data_consumption_status to support automated vectorization. + + Args: + data_fields: List of field names to include in the mask + global_step: Current global step for row selection + task_name: Name of the consumer task for consumption status + + Returns: + Tuple of (row_mask, col_mask) tensors for filtering data status matrices + """ + + # Check if all requested fields are registered + for col in data_fields: + if col not in self.field_name_mapping: + # Return empty mask indicating no available data for unregistered columns + empty_row_mask = torch.zeros(self.data_production_status.shape[0], dtype=torch.bool) + empty_col_mask = torch.zeros(self.data_production_status.shape[1], dtype=torch.bool) + return empty_row_mask, empty_col_mask + + # Map steps to global indices + start_idx, end_idx = self._step_to_global_index_range(global_step) + row_mask = torch.zeros(self.data_production_status.shape[0], dtype=torch.bool) + row_mask[start_idx:end_idx] = True + + # Invert selection based on consumption status + consumer_status = self._get_consumption_status(task_name) + unconsumed_mask = consumer_status == 0 + row_mask &= unconsumed_mask + + # Select the specified fields + col_mask = torch.zeros(self.data_production_status.shape[1], dtype=torch.bool) + valid_fields = [self.field_name_mapping[col] for col in data_fields] + if valid_fields: + col_mask[valid_fields] = True + + return row_mask, col_mask + + def _build_index_storage_mapping(self): + """ + Build mappings between global indices and storage locations. + + Distributes samples across storage units based on total storage space and + maintains mappings between global index and local index within each storage. + """ + # Assign each sample to a storage node. Here we scatter the samples in each GBS to different storage nodes + # Samples are arranged sequentially, similar to generate_data_status_mask + real_global_batch_size = self.global_batch_size * self.num_n_samples + global_batch_per_storage_unit = math.ceil(real_global_batch_size / self.num_storage_units) + + # Build mapping between global index and storage unit for locating each data sample + batch_storage_indices = np.repeat(np.arange(self.num_storage_units), global_batch_per_storage_unit)[ + :real_global_batch_size + ] + self._global_index_storage_rank_mapping = np.tile(batch_storage_indices, self.num_global_batch) + + # Build mapping between global index and local index within each storage unit + indices = np.arange(self.total_storage_size) + pos_in_batch = indices % real_global_batch_size + g = indices // real_global_batch_size + pos_in_block = pos_in_batch % global_batch_per_storage_unit + self.global_index_local_index_mapping = g * global_batch_per_storage_unit + pos_in_block + + def get_data_production_status(self) -> torch.Tensor: + """ + Get the current data production status matrix. The data production status is a 2D matrix that records whether + the corresponding data is ready for each field of each sample. + + Returns: + Tensor representing production status of all data fields + """ + return self.data_production_status + + def get_field_name_mapping(self) -> dict[str, Any]: + """Get the field name to column index mapping. + + Returns: + Dictionary mapping field names to their column indices + """ + return self.field_name_mapping + + def get_data_consumption_status(self) -> dict[str, torch.Tensor]: + """Get consumption status for all tasks. + + Returns: + Dictionary mapping task names to their consumption status tensors + """ + return self.data_consumption_status + + def get_global_index_mapping(self): + """Get global index to storage mapping information. + + Returns: + Tuple containing storage rank mapping and local index mapping + """ + return self._global_index_storage_rank_mapping, self.global_index_local_index_mapping + + def _get_metadata( + self, + data_fields: list[str], + batch_size: int, + global_step: int, + mode: str = "fetch", + task_name: str | None = None, + get_n_samples=False, + *args, + **kwargs, + ) -> BatchMeta: + """ + Retrieve metadata with support for three modes. + + Args: + data_fields: List of field names to include in metadata + batch_size: Number of samples to retrieve + global_step: Global step for which to retrieve metadata + mode: Operation mode - 'insert', 'fetch', or 'force_fetch' + - mode="insert": Insert metadata for new rows (without checking data status) + - mode="fetch": Retrieve metadata for ready data (check data status and sample) + - mode="force_fetch": Directly return metadata (without checking data status) + task_name: Name of the consumer task (required for fetch modes) + get_n_samples: Whether to retrieve n_samples as groups + *args: Additional positional arguments + **kwargs: Additional keyword arguments + + Returns: + BatchMeta object containing the requested metadata + + Raises: + TimeoutError: If waiting for sufficient data times out in fetch mode + """ + if mode == "insert": + # TODO: Currently only supports putting entire GBS data, need to extend to support multiple puts to same + # step + assert batch_size == self.global_batch_size, ( + f"batch_size {batch_size} must equal global_batch_size {self.global_batch_size}" + ) + start_idx, end_idx = self._step_to_global_index_range(global_step) + batch_global_indexes = list(range(start_idx, end_idx)) + return self._generate_batch_meta(global_step, batch_global_indexes, data_fields, mode) + + assert task_name is not None + if mode == "fetch": + # Find consumable samples within current batch and package into BatchMeta when reading + + start_time = time.time() + while True: + ready_for_consume_idx = self._scan_data_status(data_fields, global_step, task_name, get_n_samples) + + if len(ready_for_consume_idx) >= batch_size: + break + + if time.time() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT: + raise TimeoutError( + f"Timeout while waiting for sufficient data. " + f"Required: {batch_size}, Available: {len(ready_for_consume_idx)}" + ) + + logger.warning( + f"Insufficient data available. Required: {batch_size}, " + f"Available: {len(ready_for_consume_idx)}. Retrying in " + f"{TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..." + ) + time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL) + logger.debug(f"ready for consume idx: {ready_for_consume_idx}") + + batch_global_indexes = random_sampler(ready_for_consume_idx, batch_size, get_n_samples, self.num_n_samples) + elif mode == "force_fetch": + start_idx, end_idx = self._step_to_global_index_range(global_step) + consumer_status = self._get_consumption_status(task_name) + not_consumed_idx = [i for i in range(start_idx, end_idx) if consumer_status[i] == 0] + batch_global_indexes = random_sampler(not_consumed_idx, batch_size, get_n_samples, self.num_n_samples) + + # Mark this batch of data as consumed + consumer_status = self._get_consumption_status(task_name) + consumer_status[batch_global_indexes] = 1 + # Package into metadata + metadata = self._generate_batch_meta(global_step, batch_global_indexes, data_fields, mode) + logger.debug(f"_get_metadata: {metadata}") + + return metadata + + def _scan_data_status( + self, data_fields: list[str], global_step: int, task_name: str, get_n_samples: bool + ) -> list[int]: + """ + Scan data status to find samples ready for consumption. + + Args: + data_fields: List of field names to check + global_step: Global step to scan + task_name: Name of the consumer task + get_n_samples: Whether to return n_samples as groups + + Returns: + List of global indices that are ready for consumption + """ + # Get row and column masks + row_mask, col_mask = self.generate_data_status_mask(data_fields, global_step, task_name) + logger.debug(f"row_mask, col_mask: {row_mask, col_mask}") + + if not row_mask.any() or not col_mask.any(): + return [] + + # Extract subset of data status for relevant fields + logger.debug(f"self.data_production_status: {self.data_production_status}") + data_status_of_interest = self.data_production_status[:, col_mask] + logger.debug(f"data_status_of_interest: {data_status_of_interest}") + + # Use torch.all for vectorized check instead of sum comparison + all_fields_ready = torch.all(data_status_of_interest, dim=1) + + # Filter samples that meet criteria combined with row mask + ready_mask = all_fields_ready & row_mask + + if get_n_samples and self.num_n_samples > 1: + # Reshape to group view and check group completeness + group_all_ready = torch.all(ready_mask.view(-1, self.num_n_samples), dim=1) + + # Get indices of fully ready groups + ready_group_indices = group_all_ready.nonzero(as_tuple=False).flatten() + + # Calculate all sample indices + sample_offset = torch.arange(self.num_n_samples) + ready_for_consume_idx = ( + (ready_group_indices.unsqueeze(1) * self.num_n_samples + sample_offset).flatten().tolist() + ) + + return ready_for_consume_idx + else: + ready_for_consume_idx = torch.nonzero(ready_mask, as_tuple=False).flatten().tolist() + logger.debug(f"ready_for_consume_idx: {ready_for_consume_idx}") + + return ready_for_consume_idx + + def _generate_batch_meta( + self, global_step: int, global_indexes: list[int], data_fields: list[str], mode: str + ) -> BatchMeta: + """ + Generate BatchMeta by resolving storage locations for given global indexes. + + For each global index, looks up the corresponding storage node address using: + - global_index_local_index_mapping: Maps to local index within storage + - _global_index_storage_id_mapping: Maps to storage node identifier + + Args: + global_step: Current global step + global_indexes: List of global indexes to process + data_fields: List of data field names + mode: Operation mode ('fetch', 'insert', or 'force_fetch') + + Returns: + BatchMeta object containing sample metadata with resolved storage locations + """ + global_arr = np.array(global_indexes) + storage_ids = self.global_index_storage_id_mapping[global_arr] + local_indexes = self.global_index_local_index_mapping[global_arr] + + samples = [] + + # Create samples from the flattened BatchMeta data + # TODO: Optimize this + for i, global_index in enumerate(global_indexes): + local_index = local_indexes[i] + storage_id = storage_ids[i] + + # Create FieldMeta objects for each field + fields = [] + for field_name in data_fields: + if mode == "fetch": + production_status = ProductionStatus.READY_FOR_CONSUME # Since we filtered by ready status + # Get per-tensor dtype and shape for this specific global_index and field + dtype = self._get_per_tensor_dtype(global_index, field_name) + shape = self._get_per_tensor_shape(global_index, field_name) + elif mode == "insert": + production_status = ProductionStatus.NOT_PRODUCED # FIXME: not real-time + dtype = None + shape = None + elif mode == "force_fetch": + col_index = self.field_name_mapping.get(field_name) + if col_index is not None and self.data_production_status[global_index, col_index] == 1: + production_status = ProductionStatus.READY_FOR_CONSUME + dtype = self._get_per_tensor_dtype(global_index, field_name) + shape = self._get_per_tensor_shape(global_index, field_name) + else: + production_status = ProductionStatus.NOT_PRODUCED + dtype = None + shape = None + field_meta = FieldMeta( + name=field_name, + dtype=dtype, + shape=shape, + production_status=production_status, + ) + fields.append(field_meta) + + sample = SampleMeta( + global_step=global_step, + global_index=global_index, + storage_id=storage_id, + local_index=local_index, + fields={field.name: field for field in fields}, + ) + samples.append(sample) + + return BatchMeta(samples=samples) + + def _update_production_status(self, indexes: list[int], fields: list[str]) -> None: + """ + Update production status for specified indexes and fields. + + Args: + indexes: List of global indexes to update + fields: List of field names to update + """ + # TODO: Replace self.data_production_status == 0 or ==1 operations with ProductionStatus enum + # Update data production status matrix + new_fields = [field for field in fields if field not in self.field_name_mapping] + if new_fields: + needed_fields = len(new_fields) + current_fields = self.data_production_status.shape[1] + # Expand data status matrix if needed + if len(self.field_name_mapping) + needed_fields > current_fields: + add_fields = max(TQ_INIT_FIELD_NUM, needed_fields + 1) + new_matrix = torch.zeros((self.total_storage_size, add_fields), dtype=torch.int8) + self.data_production_status = torch.cat([self.data_production_status, new_matrix], dim=1) + + for field in fields: + if field not in self.field_name_mapping.keys(): + self.field_name_mapping[field] = len(self.field_name_mapping) + self.data_production_status[ + torch.tensor(indexes)[:, None], torch.tensor([self.field_name_mapping.get(field) for field in fields]) + ] = 1 + + def _update_field_info( + self, + fields: list[str], + per_tensor_dtypes: dict[int, dict[str, Any]], + per_tensor_shapes: dict[int, dict[str, Any]], + global_indexes: list[int], + ) -> None: + """ + Store per-tensor dtype and shape information. + + Args: + fields: List of field names + per_tensor_dtypes: Dict mapping global_index to field dtypes {global_index: {field: dtype}} + per_tensor_shapes: Dict mapping global_index to field shapes {global_index: {field: shape}} + global_indexes: List of global indexes corresponding to the samples + """ + for global_idx in global_indexes: + if global_idx not in self.per_tensor_dtype_mapping: + self.per_tensor_dtype_mapping[global_idx] = {} + if global_idx not in self.per_tensor_shape_mapping: + self.per_tensor_shape_mapping[global_idx] = {} + + for field in fields: + if global_idx in per_tensor_dtypes and field in per_tensor_dtypes[global_idx]: + self.per_tensor_dtype_mapping[global_idx][field] = per_tensor_dtypes[global_idx][field] + if global_idx in per_tensor_shapes and field in per_tensor_shapes[global_idx]: + self.per_tensor_shape_mapping[global_idx][field] = per_tensor_shapes[global_idx][field] + + def _init_zmq_socket(self): + """ + Initialize ZMQ sockets for communication. + + Sets up three ZMQ service ports for: + 1. Receiving handshake requests from storage + 2. Handling client data read/write requests + 3. Receiving status update signals from storage + """ + self.zmq_context = zmq.Context() + + self._node_ip = get_node_ip_address() + self._handshake_socket_port = get_free_port() + self._request_handle_socket_port = get_free_port() + self._data_status_update_socket_port = get_free_port() + + self.handshake_socket = create_zmq_socket( + ctx=self.zmq_context, + socket_type=zmq.ROUTER, + ) + self.handshake_socket.bind(f"tcp://{self._node_ip}:{self._handshake_socket_port}") + + self.request_handle_socket = create_zmq_socket( + ctx=self.zmq_context, + socket_type=zmq.ROUTER, + ) + self.request_handle_socket.bind(f"tcp://{self._node_ip}:{self._request_handle_socket_port}") + + self.data_status_update_socket = create_zmq_socket( + ctx=self.zmq_context, + socket_type=zmq.ROUTER, + ) + self.data_status_update_socket.bind(f"tcp://{self._node_ip}:{self._data_status_update_socket_port}") + + self.zmq_server_info = ZMQServerInfo.create( + role=TransferQueueRole.CONTROLLER, + id=self.controller_id, + ip=self._node_ip, + ports={ + "handshake_socket": self._handshake_socket_port, + "request_handle_socket": self._request_handle_socket_port, + "data_status_update_socket": self._data_status_update_socket_port, + }, + ) + + def _wait_connection(self): + """Wait for all storage instances to complete handshake. + + Clients don't need handshake to support dynamic scaling. Continuously + listens for handshake messages until all expected storage units connect. + """ + # TODO(zjj): Consider if retransmission is needed (assuming cases where Storage doesn't receive ACK) + connected_storage_units = set() + while len(connected_storage_units) < self.num_storage_units: + identity, serialized_msg = self.handshake_socket.recv_multipart() + request_msg = ZMQMessage.deserialize(serialized_msg) + if request_msg.request_type == ZMQRequestType.HANDSHAKE: + connected_storage_units.add(request_msg.sender_id) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.HANDSHAKE_ACK, + sender_id=self.controller_id, + body={}, + ).serialize() + self.handshake_socket.send_multipart([identity, response_msg]) + logger.info("Controller sent handshake ack successfully!") + self.global_index_storage_id_mapping = np.array(sorted(list(connected_storage_units)))[ + self._global_index_storage_rank_mapping + ] + self.handshake_done.set() + + def _start_process_handshake(self): + """Start the handshake process thread.""" + self.handshake_done = threading.Event() + self.wait_connection_thread = Thread( + target=self._wait_connection, name="TransferQueueControllerWaitConnectionThread", daemon=True + ) + self.wait_connection_thread.start() + + def _start_process_update_data_status(self): + """Start the data status update processing thread.""" + self.process_update_data_status_thread = Thread( + target=self._update_data_status, name="TransferQueueControllerProcessUpdateDataStatusThread", daemon=True + ) + self.process_update_data_status_thread.start() + + def _start_process_request(self): + """Start the request processing thread.""" + self.process_request_thread = Thread( + target=self._process_request, name="TransferQueueControllerProcessRequestThread", daemon=True + ) + self.process_request_thread.start() + + def _process_request(self): + """Main request processing loop. + + Handles various request types including metadata retrieval, + consumption status checks, and clear operations. + """ + self.handshake_done.wait() + while True: + # ROUTER socket receives multi-part messages + identity, serialized_msg = self.request_handle_socket.recv_multipart() + request_msg = ZMQMessage.deserialize(serialized_msg) + + if request_msg.request_type == ZMQRequestType.GET_META: + params = request_msg.body + logger.info("Controller preparing to get metadata...") + metadata = self._get_metadata( + data_fields=params["data_fields"], + batch_size=params["batch_size"], + global_step=params["global_step"], + mode=params.get("mode", "fetch"), + task_name=params.get("task_name", None), + get_n_samples=params.get("get_n_samples", False), + ) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.GET_META_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={"metadata": metadata}, + ) + elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META: + params = request_msg.body + metadata = self._get_metadata( + data_fields=[], + batch_size=self.global_batch_size, + global_step=params["global_step"], + mode="insert", + ) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.GET_CLEAR_META_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={"metadata": metadata}, + ) + elif request_msg.request_type == ZMQRequestType.CLEAR_META: + params = request_msg.body + self.clear(global_step=params["global_step"]) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.CLEAR_META_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={"message": f"Clear operation completed by controller {self.controller_id}"}, + ) + elif request_msg.request_type == ZMQRequestType.CHECK_CONSUMPTION: + # Check consumption status + params = request_msg.body + global_step = params["global_step"] + + consumer_status = self._get_consumption_status(params["task_name"]) + start_idx, end_idx = self._step_to_global_index_range(global_step) + batch_status = consumer_status[start_idx:end_idx] + consumed = torch.all(batch_status == 1).item() + + # Build response message + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.CONSUMPTION_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={ + "global_step": global_step, + "consumed": consumed, + }, + ) + self.request_handle_socket.send_multipart([identity, response_msg.serialize()]) + logger.debug("Controller request_handle_socket sent multipart successfully!") + + def _update_data_status(self): + """Process data status update messages from storage units. + + Continuously listens for data update notifications and updates + internal production status and field information accordingly. + """ + # Receive data status update information from storage + while True: + logger.debug("Preparing _update_data_status...") + identity, serialized_msg = self.data_status_update_socket.recv_multipart() + logger.debug("Controller received update_data_status request!") + request_msg = ZMQMessage.deserialize(serialized_msg) + logger.debug(f"[{self.controller_id}]: Controller received update_data_status request_msg: {request_msg}") + + if request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE: + message_data = request_msg.body + + fields = message_data.get("fields", []) + global_indexes = message_data.get("global_indexes", []) + per_tensor_dtypes = message_data.get("dtypes", {}) # Now a dict of lists + per_tensor_shapes = message_data.get("shapes", {}) # Now a dict of lists + # Update data production status + logger.debug(f"global_indexes, fields: {global_indexes, fields}") + self._update_production_status(global_indexes, fields) + self._update_field_info(fields, per_tensor_dtypes, per_tensor_shapes, global_indexes) + logger.info("Controller updated production status successfully!") + + # Send acknowledgment response + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK, + sender_id=self.controller_id, + body={ + "controller_id": self.controller_id, + "message": f"Data update acknowledged from controller {self.controller_id}", + }, + ) + self.data_status_update_socket.send_multipart([identity, response_msg.serialize()]) + logger.info("Controller sent DATA_UPDATE_ACK successfully!") + elif request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR: + # Handle data update errors + error_msg = request_msg.body.get("message", "Unknown error") + logger.error(f"Data update error from storage: {error_msg}") + + # Send error acknowledgment response + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK, + sender_id=self.controller_id, + body={ + "controller_id": self.controller_id, + "message": f"Error notification acknowledged from controller {self.controller_id}", + }, + ) + self.data_status_update_socket.send_multipart([identity, response_msg.serialize()]) + + def get_zmq_server_info(self) -> ZMQServerInfo: + """Get ZMQ server connection information. + + Returns: + ZMQServerInfo object containing connection details + """ + return self.zmq_server_info + + def clear(self, global_step: int): + """Clear data for a specific global batch. + + Resets production and consumption status for all data in the specified + global step. Currently only supports clearing single GBS at a time. + + Args: + global_step: The global step to clear data for + """ + start_idx, end_idx = self._step_to_global_index_range(global_step) + + self.data_production_status[start_idx:end_idx, :] = 0 + for task_name in self.data_consumption_status: + self.data_consumption_status[task_name][start_idx:end_idx] = 0 From 8aa4bb23554c014bb74a9e39adf1f14adf5ce2c1 Mon Sep 17 00:00:00 2001 From: Huazhong Date: Wed, 24 Sep 2025 11:51:04 +0800 Subject: [PATCH 04/16] expose TransferQueueClient (#3) --- verl/experimental/transfer_queue/__init__.py | 3 +- verl/experimental/transfer_queue/client.py | 586 +++++++++++++++++++ verl/experimental/transfer_queue/storage.py | 3 +- 3 files changed, 590 insertions(+), 2 deletions(-) create mode 100644 verl/experimental/transfer_queue/client.py diff --git a/verl/experimental/transfer_queue/__init__.py b/verl/experimental/transfer_queue/__init__.py index 1ce90c5eb35..2df3b7f876f 100644 --- a/verl/experimental/transfer_queue/__init__.py +++ b/verl/experimental/transfer_queue/__init__.py @@ -1,4 +1,5 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Huawei Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/verl/experimental/transfer_queue/client.py b/verl/experimental/transfer_queue/client.py new file mode 100644 index 00000000000..fd43be11b22 --- /dev/null +++ b/verl/experimental/transfer_queue/client.py @@ -0,0 +1,586 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Huawei Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import os +from functools import wraps +from typing import Any, Callable, Optional, Union +from uuid import uuid4 + +import ray +import torch +import zmq +import zmq.asyncio +from tensordict import NonTensorStack, TensorDict + +from verl.experimental.transfer_queue.controller import TransferQueueController +from verl.experimental.transfer_queue.metadata import ( + BatchMeta, + StorageMetaGroup, +) +from verl.experimental.transfer_queue.storage import TransferQueueStorageSimpleUnit +from verl.experimental.transfer_queue.utils.utils import ( + TransferQueueRole, +) +from verl.experimental.transfer_queue.utils.zmq_utils import ( + ZMQMessage, + ZMQRequestType, + ZMQServerInfo, + create_zmq_socket, +) + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class AsyncTransferQueueClient: + def __init__( + self, + client_id: str, + controller_infos: ZMQServerInfo | dict[Any, ZMQServerInfo], + storage_infos: ZMQServerInfo | dict[Any, ZMQServerInfo], + ): + self.client_id = client_id + + self._controllers: dict[str, ZMQServerInfo] = {} + self._storages: dict[str, ZMQServerInfo] = {} + self._register_servers(TransferQueueRole.CONTROLLER, controller_infos) + self._register_servers(TransferQueueRole.STORAGE, storage_infos) + + def _register_servers( + self, + role: TransferQueueRole, + server_infos: ZMQServerInfo | dict[Any, ZMQServerInfo], + ): + mapping = self._controllers if role == TransferQueueRole.CONTROLLER else self._storages + + if not isinstance(server_infos, dict): + server_infos = {server_infos.id: server_infos} + + for info in server_infos.values(): + if not isinstance(info, ZMQServerInfo): + raise ValueError(f"Invalid server info for {role} {info.id}") + + if info.id not in mapping: + mapping[info.id] = info + logger.info(f"[{self.client_id}]: Registered {role} server {info.id} at {info.ip}") + else: + logger.warning(f"[{self.client_id}]: Server {info.id} already registered, skipping") + + @staticmethod + def dynamic_socket(target_role: TransferQueueRole, socket_name: str): + """Decorator to auto-manage ZMQ sockets for Controller/Storage servers (create -> connect -> inject -> close). + + Args: + target_role (TransferQueueRole): Server type to connect to. Must be one of: + - `TransferQueueRole.CONTROLLER` + - `TransferQueueRole.STORAGE` + socket_name (str): Port name (from server config) to use for ZMQ connection (e.g., "data_req_port"). + + Decorated Function Rules: + 1. Must be an async class method (needs `self`). + 2. `self` requires: + - `_controllers`/`_storages`: Server registries (match `target_role`). + - `client_id`: Unique client ID (for socket identity). + 3. Specify target server via: + - `target_controller` (for Controller) or `target_storage` (for Storage) arg. + - Controller role: Uses first registered server if no ID is given. + 4. Receives ZMQ socket via `socket` keyword arg (injected by decorator). + """ + + def decorator(func: Callable): + @wraps(func) + async def wrapper(self, *args, **kwargs): + if target_role == TransferQueueRole.CONTROLLER: + servers = self._controllers + target = "target_controller" + elif target_role == TransferQueueRole.STORAGE: + servers = self._storages + target = "target_storage" + else: + raise ValueError("Invalid target_role, must be CONTROLLER or STORAGE") + + server_key = kwargs.get(target) + if server_key is None: + for arg in args: + if isinstance(arg, str) and arg in servers.keys(): + server_key = arg + break + if server_key is None and target == "target_controller": + server_key = next(iter(servers.keys())) + + server_info = servers.get(server_key) + if not server_info: + raise RuntimeError(f"Server {server_key} not found in registered {target_role} servers") + + context = zmq.asyncio.Context() + address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}" + identity = f"{self.client_id}_to_{server_info.id}_{uuid4()}".encode() + sock = create_zmq_socket(context, zmq.DEALER, identity=identity) + + try: + sock.connect(address) + logger.info( + f"[{self.client_id}]: Connected to {target_role} {server_info.id} at {address} " + f"with identity {identity.decode()}" + ) + + kwargs["socket"] = sock + return await func(self, *args, **kwargs) + except Exception as e: + logger.error( + f"[{self.client_id}]: Error in socket operation with {target_role} {server_info.id}: {e}" + ) + raise + finally: + try: + if not sock.closed: + sock.setsockopt(zmq.LINGER, -1) + sock.close() + sock.close(linger=0) + except Exception as e: + logger.warning( + f"[{self.client_id}]: Error closing socket to {target_role} {server_info.id}: {e}" + ) + + context.term() + + return wrapper + + return decorator + + @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket") + async def async_get_meta( + self, + data_fields: list[str], + batch_size: int, + global_step: int, + mode: str = "fetch", + get_n_samples: bool = False, + task_name: Optional[str] = None, + target_controller: Optional[str] = None, + socket: Optional[zmq.asyncio.Socket] = None, + ) -> BatchMeta: + """Asynchronously fetches data metadata via ZMQ from the target controller. + + Args: + data_fields (list[str]): List of fields to retrieve metadata for + batch_size (int): Processing batch size + global_step (int): Current training/processing step + mode (str): Data fetch mode (TODO(hz): more details to be added) + get_n_samples (bool): TODO(hz): more details to be added + task_name (str): Optional task name associated with the request + target_controller (str): ID of the target controller to send the request to + socket (zmq.asyncio.Socket): ZMQ async socket for message transmission + + Returns: + BatchMeta: Metadata object containing data structure, sample info, etc. + """ + assert socket is not None + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.GET_META, + sender_id=self.client_id, + receiver_id=target_controller, + body={ + "data_fields": data_fields, + "batch_size": batch_size, + "global_step": global_step, + "mode": mode, + "get_n_samples": get_n_samples, + "task_name": task_name, + }, + ) + + try: + await socket.send(request_msg.serialize()) + response = await socket.recv() + response_msg = ZMQMessage.deserialize(response) + logger.debug( + f"[{self.client_id}]: Client get datameta response: {response_msg} from controller {target_controller}" + ) + + if response_msg.request_type == ZMQRequestType.GET_META_RESPONSE: + metadata = response_msg.body["metadata"] + return metadata + else: + raise RuntimeError( + f"[{self.client_id}]: Failed to get metadata from controller {target_controller}: " + f"{response_msg.body.get('message', 'Unknown error')}" + ) + except Exception as e: + raise RuntimeError(f"[{self.client_id}]: Error in get_meta: {str(e)}") from e + + async def async_put( + self, + data: TensorDict, + metadata: Optional[BatchMeta] = None, + global_step: Optional[int] = None, + ): + """Asynchronously writes data to appropriate Storage Units based on metadata. + + If metadata isn't provided, it will be created automatically using the insert mode + with the provided data_columns and global_step. + + Args: + data (torch.Tensor | tensordict.TensorDict): Data to write, either a Tensor or TensorDict + metadata (BatchMeta, optional): Optional metadata containing index and storage unit information + global_step (int, optional): Current step (required if no metadata is provided) + + """ + if metadata is None: + assert global_step is not None, "global_steps must be provided if metadata is not given" + + metadata = await self.async_get_meta( + data_fields=list(data.keys()), + batch_size=data.batch_size[0], + global_step=global_step, + mode="insert", + ) + + if not metadata or metadata.size == 0: + raise ValueError("metadata cannot be none or empty") + logger.debug(f"[{self.client_id}]: Put data with data: {data}") + tasks = [ + self._put_to_storage(get_transfer_info(meta_group, data), target_storage=storage_id) + for storage_id, meta_group in metadata.storage_meta_groups.items() + ] + await asyncio.gather(*tasks) + + logger.info( + f"[{self.client_id}]: step {global_step} put {metadata.size} samples to storage units successfully." + ) + + @dynamic_socket(target_role=TransferQueueRole.STORAGE, socket_name="put_get_socket") + async def _put_to_storage(self, storage_unit_data, target_storage=None, socket=None): + """ + Send data to a specific storage unit. + """ + global_indexes = storage_unit_data["global_indexes"] + local_indexes = storage_unit_data["local_indexes"] + field_data = TensorDict( + { + field: ( + torch.nested.as_nested_tensor(storage_unit_data["field_data"][field]) + if storage_unit_data["field_data"][field] + and all(isinstance(x, torch.Tensor) for x in storage_unit_data["field_data"][field]) + else NonTensorStack(*storage_unit_data["field_data"][field]) + ) + for field in storage_unit_data["field_data"] + } + ) + + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.PUT_DATA, + sender_id=self.client_id, + receiver_id=target_storage, + body={"global_indexes": global_indexes, "local_indexes": local_indexes, "field_data": field_data}, + ) + try: + await socket.send(request_msg.serialize()) + serialized = await socket.recv() + response_msg = ZMQMessage.deserialize(serialized) + + if response_msg.request_type != ZMQRequestType.PUT_DATA_RESPONSE: + raise RuntimeError( + f"Failed to put data to storage unit {target_storage}: " + f"{response_msg.body.get('message', 'Unknown error')}" + ) + except Exception as e: + raise RuntimeError(f"Error in put to storage unit {target_storage}: {str(e)}") from e + + @dynamic_socket(target_role=TransferQueueRole.STORAGE, socket_name="put_get_socket") + async def _get_from_storage(self, index_data, target_storage=None, socket=None): + global_indexes = index_data["global_indexes"] + local_indexes = index_data["local_indexes"] + fields = index_data["fields"] + + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.GET_DATA, + sender_id=self.client_id, + receiver_id=target_storage, + body={"local_indexes": local_indexes, "fields": fields}, + ) + + try: + await socket.send(request_msg.serialize()) + serialized = await socket.recv() + response_msg = ZMQMessage.deserialize(serialized) + logger.info(f"[{self.client_id}]: get data response from storage unit {target_storage}: {response_msg}") + + if response_msg.request_type == ZMQRequestType.GET_DATA_RESPONSE: + # Return data and index information from this storage unit + storage_unit_data = response_msg.body["data"] + return global_indexes, fields, storage_unit_data + else: + raise RuntimeError( + f"Failed to get data from storage unit {target_storage}: " + f"{response_msg.body.get('message', 'Unknown error')}" + ) + except Exception as e: + raise RuntimeError(f"Error getting data from storage unit {target_storage}: {str(e)}") from e + + async def async_get_data(self, metadata: BatchMeta) -> TensorDict: + """Asynchronously fetches data via Storage Units and organizes it into a TensorDict. + + Args: + metadata (BatchMeta): Object containing: + - Data location info (which Storage Units hold the data) + - `global_indexes` to determine the ordering of merged results + + Returns: + tensordict.TensorDict with: + - Requested data fields (e.g., "prompt_token_ids", "response_token_ids"). + - "global_indexes" key: Maps each sample to its original global index. + + Example: + >>> returned_td = await async_get_data(metadata) + >>> returned_td.keys() + dict_keys(['prompt_token_ids', 'response_token_ids', 'global_indexes']) + >>> returned_td["prompt_token_ids"].shape # Batch size 4, seq length 128 + torch.Size([4, 128]) + >>> returned_td["global_indexes"] # Preserves original global order + tensor([7, 4, 6, 5]) + + Note: + Why track `global_indexes`? + - Batches may be rearranged during task processing. `global_indexes` retains the original + mapping to Storage Units, enabling correct data writing back to Storage Units later. + + """ + if not metadata or metadata.size == 0: + return TensorDict({}, batch_size=0) + + # Use optimized retrieval with direct storage group access + tasks = [ + self._get_from_storage(meta_group.get_transfer_info(), target_storage=storage_id) + for storage_id, meta_group in metadata.storage_meta_groups.items() + ] + + results = await asyncio.gather(*tasks) + + # global_index: {field1: value, field2: value, ...} + storage_data: dict[int, dict[str, torch.Tensor]] = {} + for global_indexes, fields, storage_unit_data in results: + for idx, global_idx in enumerate(global_indexes): + if global_idx not in storage_data: + storage_data[global_idx] = {} + for field in fields: + storage_data[global_idx][field] = storage_unit_data[field][idx] + + ordered_data: dict[str, torch.Tensor] = {field: [] for field in metadata.fields} + for global_idx in metadata.global_indexes: + for field in metadata.fields: + ordered_data[field].append(storage_data[global_idx][field]) + + tensor_data = { + field: ( + torch.stack(torch.nested.as_nested_tensor(v).unbind()) + if v + and all(isinstance(item, torch.Tensor) for item in v) + and all(item.shape == v[0].shape for item in v) + else ( + torch.nested.as_nested_tensor(v) + if v and all(isinstance(item, torch.Tensor) for item in v) + else NonTensorStack(*v) + ) + ) + for field, v in ordered_data.items() + } + tensor_data["global_indexes"] = torch.tensor(metadata.global_indexes) + + return TensorDict(tensor_data, batch_size=len(storage_data)) + + async def async_clear(self, global_step: int): + """Asynchronously clears data from all storage units and controller metadata. + + Args: + global_step (int): The training step associated with the clear operation + """ + try: + target_controller = next(iter(self._controllers.keys())) + metadata = await self._get_clear_meta(global_step, target_controller) + + tasks = [] + + for target_controller in self._controllers.keys(): + tasks.append(self._clear_controller(global_step, target_controller)) + + # Group samples by storage unit for clearing + for target_storage, group in metadata.storage_meta_groups.items(): + group_info = group.get_transfer_info() + if target_storage not in self._storages: + logger.warning( + f"[{self.client_id}]: Storage unit {target_storage} not registered, skipping clear operation." + ) + continue + tasks.append( + self._clear_storage_unit( + group_info["local_indexes"], + target_storage, + ) + ) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"[{self.client_id}]: Error in clear operation task {i}: {result}") + + logger.info(f"[{self.client_id}]: Clear operation for global_step {global_step} completed.") + except Exception as e: + raise RuntimeError(f"Error in clear operation: {str(e)}") from e + + @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket") + async def _get_clear_meta(self, global_step: int, target_controller=None, socket=None): + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.GET_CLEAR_META, + sender_id=self.client_id, + receiver_id=target_controller, + body={"global_step": global_step}, + ) + + await socket.send(request_msg.serialize()) + serialized = await socket.recv() + response_msg = ZMQMessage.deserialize(serialized) + + if response_msg.request_type != ZMQRequestType.GET_CLEAR_META_RESPONSE: + raise RuntimeError( + f"Failed to get metadata for clear operation: {response_msg.body.get('message', 'Unknown error')}" + ) + + return response_msg.body["metadata"] + + @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket") + async def _clear_controller(self, global_step, target_controller=None, socket=None): + try: + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.CLEAR_META, + sender_id=self.client_id, + receiver_id=target_controller, + body={"global_step": global_step}, + ) + + await socket.send(request_msg.serialize()) + serialized_msg = await socket.recv() + response_msg = ZMQMessage.deserialize(serialized_msg) + + if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE: + raise RuntimeError( + f"Failed to clear controller {target_controller}: " + f"{response_msg.body.get('message', 'Unknown error')}" + ) + + logger.info( + f"[{self.client_id}]: Successfully clear controller {target_controller} for global_step {global_step}" + ) + except Exception as e: + logger.error(f"[{self.client_id}]: Error clearing controller {target_controller}: {str(e)}") + raise + + @dynamic_socket(target_role=TransferQueueRole.STORAGE, socket_name="put_get_socket") + async def _clear_storage_unit(self, local_indexes, target_storage=None, socket=None): + try: + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.CLEAR_DATA, + sender_id=self.client_id, + receiver_id=target_storage, + body={"local_indexes": local_indexes}, + ) + + await socket.send(request_msg.serialize()) + serialized_msg = await socket.recv() + response_msg = ZMQMessage.deserialize(serialized_msg) + + if response_msg.request_type != ZMQRequestType.CLEAR_DATA_RESPONSE: + raise RuntimeError( + f"Failed to clear storage {target_storage}: {response_msg.body.get('message', 'Unknown error')}" + ) + + logger.info(f"[{self.client_id}]: Successfully clear storage unit {target_storage}") + except Exception as e: + logger.error(f"[{self.client_id}]: Error clearing storage unit {target_storage}: {str(e)}") + raise + + +class TransferQueueClient(AsyncTransferQueueClient): + def __init__( + self, + client_id: str, + controller_infos: ZMQServerInfo | dict[Any, ZMQServerInfo], + storage_infos: ZMQServerInfo | dict[Any, ZMQServerInfo], + ): + super().__init__( + client_id, + controller_infos, + storage_infos, + ) + + def put(self, data: TensorDict, metadata: Optional[BatchMeta] = None, global_step: Optional[int] = None): + return asyncio.run(self.async_put(data, metadata, global_step)) + + def get_meta( + self, + data_fields: list[str], + batch_size: int, + global_step: int, + get_n_samples: bool = False, + task_name: Optional[str] = None, + ) -> BatchMeta: + return asyncio.run( + self.async_get_meta( + data_fields=data_fields, + batch_size=batch_size, + global_step=global_step, + get_n_samples=get_n_samples, + task_name=task_name, + ) + ) + + def get_data(self, metadata: BatchMeta) -> TensorDict: + return asyncio.run(self.async_get_data(metadata)) + + def clear(self, global_step: int): + return asyncio.run(self.async_clear(global_step)) + + +def _add_field_data( + transfer_dict: dict[str, Any], storage_meta_group: StorageMetaGroup, data: TensorDict +) -> dict[str, Any]: + """Helper function to add field data to the transfer dictionary""" + field_names = transfer_dict["fields"] + for fname in field_names: + if fname in data.keys(): + transfer_dict["field_data"][fname] = [] + for sample_meta in storage_meta_group.sample_metas: + transfer_dict["field_data"][fname].append(data[fname][sample_meta.batch_index]) + return transfer_dict + + +def get_transfer_info( + storage_meta_group: StorageMetaGroup, + data: TensorDict, +) -> dict[str, Any]: + """Convert to dictionary format with field data for put operations""" + result = storage_meta_group.get_transfer_info(field_names=data.keys()) + result = _add_field_data(result, storage_meta_group, data) + return result + + +def process_zmq_server_info(handlers: dict[Any, Union[TransferQueueController, TransferQueueStorageSimpleUnit]]): # noqa: UP007 + server_info = {} + for name, handler in handlers.items(): + server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[attr-defined] + return server_info diff --git a/verl/experimental/transfer_queue/storage.py b/verl/experimental/transfer_queue/storage.py index 11e05887785..a4bac9b60ba 100644 --- a/verl/experimental/transfer_queue/storage.py +++ b/verl/experimental/transfer_queue/storage.py @@ -1,4 +1,5 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Huawei Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From f6638aa0446ff6f73a7a2502f90f8558ffa9b7e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Han=20Zhenyu=20=E9=9F=A9=E6=8C=AF=E5=AE=87?= Date: Wed, 24 Sep 2025 14:09:14 +0800 Subject: [PATCH 05/16] Add copyright and license information Added copyright and licensing information to the controller.py file. --- verl/experimental/transfer_queue/controller.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/verl/experimental/transfer_queue/controller.py b/verl/experimental/transfer_queue/controller.py index 607fe857322..999cd6ab931 100644 --- a/verl/experimental/transfer_queue/controller.py +++ b/verl/experimental/transfer_queue/controller.py @@ -1,3 +1,18 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Huawei Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging import math import os From a884056e1eaf2cda178058e9b01c17da3c7e0203 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Han=20Zhenyu=20=E9=9F=A9=E6=8C=AF=E5=AE=87?= Date: Thu, 25 Sep 2025 11:33:49 +0800 Subject: [PATCH 06/16] update client docstring (#5) Signed-off-by: 0oshowero0 --- verl/experimental/transfer_queue/client.py | 88 +++++++++++++++++++--- 1 file changed, 79 insertions(+), 9 deletions(-) diff --git a/verl/experimental/transfer_queue/client.py b/verl/experimental/transfer_queue/client.py index fd43be11b22..28a648ca57e 100644 --- a/verl/experimental/transfer_queue/client.py +++ b/verl/experimental/transfer_queue/client.py @@ -180,12 +180,40 @@ async def async_get_meta( data_fields (list[str]): List of fields to retrieve metadata for batch_size (int): Processing batch size global_step (int): Current training/processing step - mode (str): Data fetch mode (TODO(hz): more details to be added) - get_n_samples (bool): TODO(hz): more details to be added + mode (str): Data fetch mode. 'fetch' to get ready data, 'force_fetch' to get data regardless of readiness. + 'insert' IS AN INTERNAL USAGE THAT SHOULD NOT BE USED BY USERS. + get_n_samples (bool): If True, we arrange the samples of the same prompt in contiguous order. In 'fetch' + mode, only the samples of the same prompt that are all ready will be returned. task_name (str): Optional task name associated with the request target_controller (str): ID of the target controller to send the request to socket (zmq.asyncio.Socket): ZMQ async socket for message transmission + Example: + >>> batch_size = 4 + >>> current_step = 0 + >>> # Example 1: "fetch" a batch of metadata that has been produced + >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["input_ids", "attention_mask"], + >>> batch_size=batch_size, + >>> global_step=current_step, + >>> mode="fetch", + >>> get_n_samples=False, + >>> task_name="generate_sequences", + >>> )) + >>> print(batch_meta.is_ready) # you should get a batch_meta with is_ready=True + >>> print([sample_meta.is_ready for sample_meta in batch_meta.samples]) # [True, True, True, True] + >>> + >>> # Example 2: "force_fetch" a batch of metadata, ignoring their production status (but we still make + >>> # sure the corresponding data has not been consumed) + >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["input_ids", "attention_mask"], + >>> batch_size=batch_size, + >>> global_step=current_step, + >>> mode="force_fetch", + >>> get_n_samples=False, + >>> task_name="generate_sequences", + >>> )) + >>> print(batch_meta.is_ready) # you may get a batch_meta with is_ready=False + >>> print([sample_meta.is_ready for sample_meta in batch_meta.samples]) # [True, False, False, True] + Returns: BatchMeta: Metadata object containing data structure, sample info, etc. """ @@ -239,6 +267,30 @@ async def async_put( metadata (BatchMeta, optional): Optional metadata containing index and storage unit information global_step (int, optional): Current step (required if no metadata is provided) + Example: + >>> batch_size = 4 + >>> seq_len = 16 + >>> current_step = 0 + >>> # Example 1: normal usage + >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["prompts", "attention_mask"], + >>> batch_size=batch_size, + >>> global_step=current_step, + >>> mode="fetch", + >>> get_n_samples=False, + >>> task_name="generate_sequences", + >>> )) + >>> batch = asyncio.run(client.async_get_data(batch_meta)) + >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) + >>> asyncio.run(client.async_put(data=output, metadata=batch_meta)) + >>> + >>> # Example 2: put the initial data into the system without pre-existing metadata + >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given global_step! + >>> # So make sure the global_step is empty. + >>> prompts = (torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111]])) + >>> prompt_batch = TensorDict({"prompts": prompts}) + >>> # This will create metadata in "insert" mode internally. + >>> asyncio.run(client.async_put(data=prompt_batch, global_step=current_step)) + """ if metadata is None: assert global_step is not None, "global_steps must be provided if metadata is not given" @@ -346,13 +398,20 @@ async def async_get_data(self, metadata: BatchMeta) -> TensorDict: - "global_indexes" key: Maps each sample to its original global index. Example: - >>> returned_td = await async_get_data(metadata) - >>> returned_td.keys() - dict_keys(['prompt_token_ids', 'response_token_ids', 'global_indexes']) - >>> returned_td["prompt_token_ids"].shape # Batch size 4, seq length 128 - torch.Size([4, 128]) - >>> returned_td["global_indexes"] # Preserves original global order - tensor([7, 4, 6, 5]) + >>> batch_size = 4 + >>> seq_len = 16 + >>> current_step = 0 + >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["prompts", "attention_mask"], + >>> batch_size=batch_size, + >>> global_step=current_step, + >>> mode="fetch", + >>> get_n_samples=False, + >>> task_name="generate_sequences", + >>> )) + >>> batch = asyncio.run(client.async_get_data(batch_meta)) + >>> print(batch) + >>> # this is a TensorDict with fields "prompts" and "attention_mask". + >>> # The order of samples in the TensorDict matches the order of global_indexes in batch_meta Note: Why track `global_indexes`? @@ -408,6 +467,7 @@ async def async_clear(self, global_step: int): Args: global_step (int): The training step associated with the clear operation + """ try: target_controller = next(iter(self._controllers.keys())) @@ -514,6 +574,16 @@ async def _clear_storage_unit(self, local_indexes, target_storage=None, socket=N logger.error(f"[{self.client_id}]: Error clearing storage unit {target_storage}: {str(e)}") raise + @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket") + def check_current_step_consumption(self, task_name: str, global_step: int): + # TODO: Implement this method to check if all samples for the current step has been consumed + pass + + @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket") + def check_current_step_production(self, data_fields: list[str], global_step: int): + # TODO: Implement this method to check if all samples for the current step is ready for consumption + pass + class TransferQueueClient(AsyncTransferQueueClient): def __init__( From 7bf946ad2c5512e985b9a26b44f0c34216329865 Mon Sep 17 00:00:00 2001 From: zhabuye <74179177+zhabuye@users.noreply.github.com> Date: Thu, 25 Sep 2025 14:46:26 +0800 Subject: [PATCH 07/16] merge TransferQueue utils (#4) --- .../transfer_queue/utils/__init__.py | 14 ++ .../transfer_queue/utils/utils.py | 111 +++++++++++ .../transfer_queue/utils/zmq_utils.py | 176 ++++++++++++++++++ 3 files changed, 301 insertions(+) create mode 100644 verl/experimental/transfer_queue/utils/__init__.py create mode 100644 verl/experimental/transfer_queue/utils/utils.py create mode 100644 verl/experimental/transfer_queue/utils/zmq_utils.py diff --git a/verl/experimental/transfer_queue/utils/__init__.py b/verl/experimental/transfer_queue/utils/__init__.py new file mode 100644 index 00000000000..2df3b7f876f --- /dev/null +++ b/verl/experimental/transfer_queue/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Huawei Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/experimental/transfer_queue/utils/utils.py b/verl/experimental/transfer_queue/utils/utils.py new file mode 100644 index 00000000000..2fceb3f14ce --- /dev/null +++ b/verl/experimental/transfer_queue/utils/utils.py @@ -0,0 +1,111 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Huawei Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + +import ray +import torch +from tensordict import TensorDict + + +class ExplicitEnum(str, Enum): + """ + Enum with more explicit error message for missing values. + """ + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" + ) + + +class TransferQueueRole(ExplicitEnum): + CONTROLLER = "TransferQueueController" + STORAGE = "TransferQueueStorage" + CLIENT = "TransferQueueClient" + + +# production_status enum: 0: not produced, 1: ready for consume, 2: consumed +class ProductionStatus(ExplicitEnum): + NOT_PRODUCED = 0 + READY_FOR_CONSUME = 1 + CONSUMED = 2 + + +def get_placement_group(num_ray_actors: int, num_cpus_per_actor: int = 1): + """ + Create a placement group with SPREAD strategy for Ray actors. + + Args: + num_ray_actors (int): Number of Ray actors to create. + num_cpus_per_actor (int): Number of CPUs to allocate per actor. + + Returns: + placement_group: The created placement group. + """ + bundle = {"CPU": num_cpus_per_actor} + placement_group = ray.util.placement_group([bundle for _ in range(num_ray_actors)], strategy="SPREAD") + ray.get(placement_group.ready()) + return placement_group + + +def random_sampler( + ready_for_consume_idx: list[int], + batch_size: int, + get_n_samples: bool, + n_samples_per_prompt: int, +) -> list[int]: + """ + random sampling batch_size samples from global indexes ready_for_consume_idx + input example: + if get_n_samples: (group_num=3, group_size=4) + ready_for_consume_idx could look like: [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19] + else: + ready_for_consume_idx could look like: [2, 5, 6] + """ + if get_n_samples: + assert len(ready_for_consume_idx) % n_samples_per_prompt == 0 + assert batch_size % n_samples_per_prompt == 0 + batch_size_n_samples = batch_size // n_samples_per_prompt + + group_ready_for_consume_idx = torch.tensor(ready_for_consume_idx, dtype=torch.int).view( + -1, n_samples_per_prompt + ) + + weights = torch.ones(group_ready_for_consume_idx.size(0)) + sampled_indexes_idx = torch.multinomial(weights, batch_size_n_samples, replacement=False).tolist() + sampled_indexes = group_ready_for_consume_idx[sampled_indexes_idx].flatten().tolist() + else: + weights = torch.ones(len(ready_for_consume_idx)) + sampled_indexes_idx = torch.multinomial(weights, batch_size, replacement=False).tolist() + sampled_indexes = [int(ready_for_consume_idx[i]) for i in sampled_indexes_idx] + return sampled_indexes + + +def extract_field_info(tensor_dict: TensorDict) -> dict: + """ + Extract field names, dtypes, and shapes from a TensorDict. + Assumes all tensors in the same field have the same dtype and shape (excluding batch dimension). + Returns a dictionary with keys: 'names', 'dtypes', 'shapes'. + """ + field_info: dict[str, list] = {"names": [], "dtypes": [], "shapes": []} + for key, value in tensor_dict.items(): + field_info["names"].append(key) + + # TODO: support nested tensors & non tensors + # field_info["dtypes"].append(value.dtype) + # field_info["shapes"].append(value.shape[1:]) # exclude batch dimension + return field_info diff --git a/verl/experimental/transfer_queue/utils/zmq_utils.py b/verl/experimental/transfer_queue/utils/zmq_utils.py new file mode 100644 index 00000000000..947b48407ef --- /dev/null +++ b/verl/experimental/transfer_queue/utils/zmq_utils.py @@ -0,0 +1,176 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Huawei Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle +import socket +import time +import uuid +from dataclasses import dataclass +from typing import Any, Optional + +import psutil +import zmq +from typing_extensions import Self + +from verl.experimental.transfer_queue.utils.utils import ( + ExplicitEnum, + TransferQueueRole, +) + + +class ZMQRequestType(ExplicitEnum): + # HANDSHAKE + HANDSHAKE = "HANDSHAKE" # TransferQueueStorageUnit -> TransferQueueController + HANDSHAKE_ACK = "HANDSHAKE_ACK" # TransferQueueController -> TransferQueueStorageUnit + + # DATA_OPERATION + GET_DATA = "GET" + PUT_DATA = "PUT" + GET_DATA_RESPONSE = "GET_DATA_RESPONSE" + PUT_DATA_RESPONSE = "PUT_DATA_RESPONSE" + CLEAR_DATA = "CLEAR_DATA" + CLEAR_DATA_RESPONSE = "CLEAR_DATA_RESPONSE" + + PUT_GET_OPERATION_ERROR = "PUT_GET_OPERATION_ERROR" + PUT_GET_ERROR = "PUT_GET_ERROR" + PUT_ERROR = "PUT_ERROR" + GET_ERROR = "GET_ERROR" + CLEAR_DATA_ERROR = "CLEAR_DATA_ERROR" + + # META_OPERATION + GET_META = "GET_META" + GET_META_RESPONSE = "GET_META_RESPONSE" + GET_CLEAR_META = "GET_CLEAR_META" + GET_CLEAR_META_RESPONSE = "GET_CLEAR_META_RESPONSE" + CLEAR_META = "CLEAR_META" + CLEAR_META_RESPONSE = "CLEAR_META_RESPONSE" + + # CHECK_CONSUMPTION + CHECK_CONSUMPTION = "CHECK_CONSUMPTION" + CONSUMPTION_RESPONSE = "CONSUMPTION_RESPONSE" + + # NOTIFY_DATA_UPDATE + NOTIFY_DATA_UPDATE = "NOTIFY_DATA_UPDATE" + NOTIFY_DATA_UPDATE_ACK = "NOTIFY_DATA_UPDATE_ACK" + NOTIFY_DATA_UPDATE_ERROR = "NOTIFY_DATA_UPDATE_ERROR" + + +@dataclass +class ZMQServerInfo: + role: TransferQueueRole + id: str + ip: str + ports: dict[str, str] + + @classmethod + def create(cls, role: TransferQueueRole, id: str, ip: str, ports: dict[str, str]) -> Self: + return cls(role=role, id=id, ip=ip, ports=ports) + + def to_addr(self, port_name: str) -> str: + return f"tcp://{self.ip}:{self.ports[port_name]}" + + def to_dict(self): + return { + "role": self.role, + "id": self.id, + "ip": self.ip, + "ports": self.ports, + } + + def __str__(self) -> str: + return f"ZMQSocketInfo(role={self.role}, id={self.id}, ip={self.ip}, ports={self.ports})" + + +@dataclass +class ZMQMessage: + request_type: ZMQRequestType + sender_id: str + receiver_id: str | None + body: dict[str, Any] + request_id: str + timestamp: float + + @classmethod + def create( + cls, + request_type: ZMQRequestType, + sender_id: str, + body: dict[str, Any], + receiver_id: Optional[str] = None, + ) -> "ZMQMessage": + return cls( + request_type=request_type, + sender_id=sender_id, + receiver_id=receiver_id, + body=body, + request_id=str(uuid.uuid4()), + timestamp=time.time(), + ) + + def serialize(self) -> bytes: + """Using pickle to serialize ZMQMessage objects""" + return pickle.dumps(self) + + @classmethod + def deserialize(cls, data: bytes | list[bytes]): + """Using pickle to deserialize ZMQMessage objects""" + if isinstance(data, list): + # Process multiple byte streams by deserializing each in sequence + result = [] + for d in data: + result.append(pickle.loads(d)) + return result + else: + # Single byte stream case + return pickle.loads(data) + + +def get_free_port() -> str: + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + +def create_zmq_socket( + ctx: zmq.Context, + socket_type: Any, + identity: Optional[bytes] = None, +) -> zmq.Socket: + mem = psutil.virtual_memory() + socket = ctx.socket(socket_type) + + # Calculate buffer size based on system memory + total_mem = mem.total / 1024**3 + available_mem = mem.available / 1024**3 + # For systems with substantial memory (>32GB total, >16GB available): + # - Set a large 0.5GB buffer to improve throughput + # For systems with less memory: + # - Use system default (-1) to avoid excessive memory consumption + if total_mem > 32 and available_mem > 16: + buf_size = int(0.5 * 1024**3) # 0.5GB in bytes + else: + buf_size = -1 # Use system default buffer size + + if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER): + socket.setsockopt(zmq.RCVHWM, 0) + socket.setsockopt(zmq.RCVBUF, buf_size) + + if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER): + socket.setsockopt(zmq.SNDHWM, 0) + socket.setsockopt(zmq.SNDBUF, buf_size) + + if identity is not None: + socket.setsockopt(zmq.IDENTITY, identity) + return socket From e94019f2d954b88ec409579d7144cc5ae9f5a206 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Han=20Zhenyu=20=E9=9F=A9=E6=8C=AF=E5=AE=87?= Date: Thu, 25 Sep 2025 15:44:01 +0800 Subject: [PATCH 08/16] [fix] Fix n_sample related problems (#8) * update client docstring Signed-off-by: 0oshowero0 * fix n_sample related problems Signed-off-by: 0oshowero0 --------- Signed-off-by: 0oshowero0 --- verl/experimental/transfer_queue/client.py | 14 ++++++++++---- verl/experimental/transfer_queue/controller.py | 10 +++++----- verl/experimental/transfer_queue/storage.py | 2 +- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/verl/experimental/transfer_queue/client.py b/verl/experimental/transfer_queue/client.py index 28a648ca57e..3bf3eec950e 100644 --- a/verl/experimental/transfer_queue/client.py +++ b/verl/experimental/transfer_queue/client.py @@ -285,11 +285,16 @@ async def async_put( >>> >>> # Example 2: put the initial data into the system without pre-existing metadata >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given global_step! - >>> # So make sure the global_step is empty. - >>> prompts = (torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111]])) - >>> prompt_batch = TensorDict({"prompts": prompts}) + >>> # Please make sure the corresponding global_step is empty before calling the async_put() + >>> # without metadata. + >>> # Now we only support put all the data of the corresponding global step in once. You should repeat with + >>> # interleave the initial data if n_sample > 1 before calling the async_put(). + >>> original_prompts = torch.randn(batch_size, seq_len) + >>> n_samples = 4 + >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) + >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) >>> # This will create metadata in "insert" mode internally. - >>> asyncio.run(client.async_put(data=prompt_batch, global_step=current_step)) + >>> asyncio.run(client.async_put(data=prompts_repeated_batch, global_step=current_step)) """ if metadata is None: @@ -299,6 +304,7 @@ async def async_put( data_fields=list(data.keys()), batch_size=data.batch_size[0], global_step=global_step, + get_n_samples=True, mode="insert", ) diff --git a/verl/experimental/transfer_queue/controller.py b/verl/experimental/transfer_queue/controller.py index 999cd6ab931..73c1e6d4e9a 100644 --- a/verl/experimental/transfer_queue/controller.py +++ b/verl/experimental/transfer_queue/controller.py @@ -296,10 +296,10 @@ def _get_metadata( TimeoutError: If waiting for sufficient data times out in fetch mode """ if mode == "insert": - # TODO: Currently only supports putting entire GBS data, need to extend to support multiple puts to same - # step - assert batch_size == self.global_batch_size, ( - f"batch_size {batch_size} must equal global_batch_size {self.global_batch_size}" + # TODO: Currently we only supports put the entire GBS data in one time + assert batch_size == self.global_batch_size * self.num_n_samples, ( + f"batch_size {batch_size} must equal " + f"global_batch_size * num_n_samples {self.global_batch_size * self.num_n_samples}" ) start_idx, end_idx = self._step_to_global_index_range(global_step) batch_global_indexes = list(range(start_idx, end_idx)) @@ -651,7 +651,7 @@ def _process_request(self): params = request_msg.body metadata = self._get_metadata( data_fields=[], - batch_size=self.global_batch_size, + batch_size=self.global_batch_size * self.num_n_samples, global_step=params["global_step"], mode="insert", ) diff --git a/verl/experimental/transfer_queue/storage.py b/verl/experimental/transfer_queue/storage.py index a4bac9b60ba..e71ac9c8b38 100644 --- a/verl/experimental/transfer_queue/storage.py +++ b/verl/experimental/transfer_queue/storage.py @@ -84,7 +84,7 @@ def get_data(self, fields: list[str], local_indexes: list[int]) -> TensorDict[st # The unsqueeze op make the shape from n to (1, n) gathered_item = self.field_data[field][local_indexes[0]] if not isinstance(gathered_item, torch.Tensor): - result[field] = NonTensorStack(gathered_item).unsqueeze(0) + result[field] = NonTensorStack(gathered_item) else: result[field] = gathered_item.unsqueeze(0) else: From 77c6e7e1f6e1b864ae1d140ac83f3568fb124116 Mon Sep 17 00:00:00 2001 From: zhabuye <74179177+zhabuye@users.noreply.github.com> Date: Thu, 25 Sep 2025 20:42:31 +0800 Subject: [PATCH 09/16] expose TransferQueue client/controller UT (#6) --- .../transfer_queue/test_client.py | 385 ++++++++++++++++++ .../transfer_queue/test_controller.py | 263 ++++++++++++ 2 files changed, 648 insertions(+) create mode 100644 tests/experimental/transfer_queue/test_client.py create mode 100644 tests/experimental/transfer_queue/test_controller.py diff --git a/tests/experimental/transfer_queue/test_client.py b/tests/experimental/transfer_queue/test_client.py new file mode 100644 index 00000000000..f1b4efd191b --- /dev/null +++ b/tests/experimental/transfer_queue/test_client.py @@ -0,0 +1,385 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Huawei Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from threading import Thread + +import pytest +import torch +import zmq +from tensordict import NonTensorStack, TensorDict + +from verl.experimental.transfer_queue import TransferQueueClient # noqa: E402 +from verl.experimental.transfer_queue.metadata import ( # noqa: E402 + BatchMeta, + FieldMeta, + SampleMeta, +) +from verl.experimental.transfer_queue.utils.zmq_utils import ( # noqa: E402 + ZMQMessage, + ZMQRequestType, + ZMQServerInfo, +) + +TEST_DATA = TensorDict( + { + "log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])], + "variable_length_sequences": torch.nested.as_nested_tensor( + [ + torch.tensor([-0.5, -1.2, -0.8]), + torch.tensor([-0.3, -1.5, -2.1, -0.9]), + torch.tensor([-1.1, -0.7]), + ] + ), + "prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"], + }, + batch_size=[3], +) + + +# Mock Controller for Client Unit Testing +class MockController: + def __init__(self, controller_id="controller_0"): + self.controller_id = controller_id + self.context = zmq.Context() + + # Socket for data requests + self.request_socket = self.context.socket(zmq.ROUTER) + self.request_port = self._bind_to_random_port(self.request_socket) + + self.zmq_server_info = ZMQServerInfo.create( + role="TransferQueueController", + id=controller_id, + ip="127.0.0.1", + ports={ + "request_handle_socket": self.request_port, + }, + ) + + self.running = True + self.request_thread = Thread(target=self._handle_requests, daemon=True) + self.request_thread.start() + + def _bind_to_random_port(self, socket): + port = socket.bind_to_random_port("tcp://127.0.0.1") + return port + + def _handle_requests(self): + poller = zmq.Poller() + poller.register(self.request_socket, zmq.POLLIN) + + while self.running: + try: + socks = dict(poller.poll(100)) # 100ms timeout + if self.request_socket in socks: + identity, serialized_msg = self.request_socket.recv_multipart() + request_msg = ZMQMessage.deserialize(serialized_msg) + + # Determine response based on request type + if request_msg.request_type == ZMQRequestType.GET_META: + response_body = self._mock_batch_meta(request_msg.body) + response_type = ZMQRequestType.GET_META_RESPONSE + elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META: + response_body = self._mock_batch_meta(request_msg.body) + response_type = ZMQRequestType.GET_CLEAR_META_RESPONSE + elif request_msg.request_type == ZMQRequestType.CLEAR_META: + response_body = {"message": "clear ok"} + response_type = ZMQRequestType.CLEAR_META_RESPONSE + + # Send response + response_msg = ZMQMessage.create( + request_type=response_type, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body=response_body, + ) + self.request_socket.send_multipart([identity, response_msg.serialize()]) + except zmq.Again: + continue + except Exception as e: + if self.is_running: + print(f"MockController running exception: {e}") + else: + print(f"MockController ERROR: {e}") + raise + + def _mock_batch_meta(self, request_body): + batch_size = request_body.get("batch_size", 1) + data_fields = request_body.get("data_fields", []) + + samples = [] + for i in range(batch_size): + fields = [] + for field_name in data_fields: + field_meta = FieldMeta( + name=field_name, + dtype=None, + shape=None, + production_status=0, + ) + fields.append(field_meta) + sample = SampleMeta( + global_step=0, + global_index=i, + storage_id="storage_0", + local_index=i, + fields={field.name: field for field in fields}, + ) + samples.append(sample) + metadata = BatchMeta(samples=samples) + + return {"metadata": metadata} + + def stop(self): + self.running = False + time.sleep(0.2) # Give thread time to stop + self.request_socket.close() + self.context.term() + + +# Mock Storage for Client Unit Testing +class MockStorage: + def __init__(self, storage_id="storage_0"): + self.storage_id = storage_id + self.context = zmq.Context() + + # Socket for data operations + self.data_socket = self.context.socket(zmq.ROUTER) + self.data_port = self._bind_to_random_port(self.data_socket) + + self.zmq_server_info = ZMQServerInfo.create( + role="TransferQueueStorage", + id=storage_id, + ip="127.0.0.1", + ports={ + "put_get_socket": self.data_port, + }, + ) + + self.running = True + self.data_thread = Thread(target=self._handle_data_requests, daemon=True) + self.data_thread.start() + + def _bind_to_random_port(self, socket): + port = socket.bind_to_random_port("tcp://127.0.0.1") + return port + + def _handle_data_requests(self): + poller = zmq.Poller() + poller.register(self.data_socket, zmq.POLLIN) + + while self.running: + try: + socks = dict(poller.poll(100)) # 100ms timeout + if self.data_socket in socks: + identity, msg_bytes = self.data_socket.recv_multipart() + msg = ZMQMessage.deserialize(msg_bytes) + + # Handle different request types + if msg.request_type == ZMQRequestType.PUT_DATA: + response_body = {"message": "Data stored successfully"} + response_type = ZMQRequestType.PUT_DATA_RESPONSE + elif msg.request_type == ZMQRequestType.GET_DATA: + response_body = self._handle_get_data(msg.body) + response_type = ZMQRequestType.GET_DATA_RESPONSE + elif msg.request_type == ZMQRequestType.CLEAR_DATA: + response_body = {"message": "Data cleared successfully"} + response_type = ZMQRequestType.CLEAR_DATA_RESPONSE + + # Send response + response_msg = ZMQMessage.create( + request_type=response_type, + sender_id=self.storage_id, + receiver_id=msg.sender_id, + body=response_body, + ) + self.data_socket.send_multipart([identity, response_msg.serialize()]) + except zmq.Again: + continue + except Exception as e: + if self.is_running: + print(f"MockStorage running exception: {e}") + else: + print(f"MockStorage ERROR: {e}") + raise + + def _handle_get_data(self, request_body): + """Handle GET_DATA request by retrieving stored data""" + local_indexes = request_body.get("local_indexes", []) + fields = request_body.get("fields", []) + + result: dict[str, list] = {} + for field in fields: + gathered_items = [TEST_DATA[field][i] for i in local_indexes] + + if gathered_items: + all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items) + if all_tensors: + result[field] = torch.nested.as_nested_tensor(gathered_items) + else: + result[field] = NonTensorStack(*gathered_items) + + return {"data": TensorDict(result)} + + def stop(self): + self.running = False + time.sleep(0.2) # Give thread time to stop + self.data_socket.close() + self.context.term() + + +# Test Fixtures +@pytest.fixture +def mock_controller(): + controller = MockController() + yield controller + controller.stop() + + +@pytest.fixture +def mock_storage(): + storage = MockStorage() + yield storage + storage.stop() + + +@pytest.fixture +def client_setup(mock_controller, mock_storage): + # Create client with mock controller and storage + client_id = "client_0" + + client = TransferQueueClient( + client_id=client_id, + controller_infos={mock_controller.controller_id: mock_controller.zmq_server_info}, + storage_infos={mock_storage.storage_id: mock_storage.zmq_server_info}, + ) + + # Give some time for connections to establish + time.sleep(0.5) + + yield client, mock_controller, mock_storage + + +# Test basic functionality +def test_client_initialization(client_setup): + """Test client initialization and connection setup""" + client, mock_controller, mock_storage = client_setup + + assert client.client_id is not None + assert mock_controller.controller_id in client._controllers + assert mock_storage.storage_id in client._storages + + +def test_put_and_get_data(client_setup): + """Test basic put and get operations""" + client, _, _ = client_setup + + # Test put operation + client.put(data=TEST_DATA, global_step=0) + + # Get metadata for retrieving data + metadata = client.get_meta( + data_fields=["log_probs", "variable_length_sequences", "prompt_text"], batch_size=2, global_step=0 + ) + + # Test get operation + result = client.get_data(metadata) + + # Verify result structure + assert "log_probs" in result + assert "variable_length_sequences" in result + assert "prompt_text" in result + + torch.testing.assert_close(result["log_probs"][0], torch.tensor([1.0, 2.0, 3.0])) + torch.testing.assert_close(result["log_probs"][1], torch.tensor([4.0, 5.0, 6.0])) + torch.testing.assert_close(result["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8])) + torch.testing.assert_close(result["variable_length_sequences"][1], torch.tensor([-0.3, -1.5, -2.1, -0.9])) + assert result["prompt_text"][0] == "Hello world!" + assert result["prompt_text"][1] == "This is a longer sentence for testing" + + +def test_get_meta(client_setup): + """Test metadata retrieval""" + client, _, _ = client_setup + + # Test get_meta operation + metadata = client.get_meta(data_fields=["tokens", "labels"], batch_size=10, global_step=0) + + # Verify metadata structure + assert hasattr(metadata, "storage_meta_groups") + assert hasattr(metadata, "global_indexes") + assert hasattr(metadata, "fields") + assert hasattr(metadata, "size") + assert len(metadata.global_indexes) == 10 + + +def test_clear_operation(client_setup): + """Test clear operation""" + client, _, _ = client_setup + + # Test clear operation + client.clear(global_step=0) + + +# Test with multiple controllers and storage units +def test_multiple_servers(): + """Test client with multiple controllers and storage units""" + # Create multiple mock servers + controllers = [MockController(f"controller_{i}") for i in range(2)] + storages = [MockStorage(f"storage_{i}") for i in range(3)] + + try: + # Create client with multiple servers + client_id = "client_test_multiple_servers" + + controller_infos = {c.controller_id: c.zmq_server_info for c in controllers} + storage_infos = {s.storage_id: s.zmq_server_info for s in storages} + + client = TransferQueueClient( + client_id=client_id, controller_infos=controller_infos, storage_infos=storage_infos + ) + + # Give time for connections + time.sleep(1.0) + + # Verify connections + assert len(client._controllers) == 2 + assert len(client._storages) == 3 + + # Test basic operation + test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5) + + # Test put operation + client.put(data=test_data, global_step=0) + + finally: + # Clean up + for c in controllers: + c.stop() + for s in storages: + s.stop() + + +# Test error handling +def test_put_without_required_params(client_setup): + """Test put operation without required parameters""" + client, _, _ = client_setup + + # Create test data + test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5) + + # Test put without global_step (should fail) + with pytest.raises(AssertionError): + client.put(data=test_data) diff --git a/tests/experimental/transfer_queue/test_controller.py b/tests/experimental/transfer_queue/test_controller.py new file mode 100644 index 00000000000..6577cd9e163 --- /dev/null +++ b/tests/experimental/transfer_queue/test_controller.py @@ -0,0 +1,263 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Huawei Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math + +import numpy as np +import pytest +import ray +import torch + +from verl.experimental.transfer_queue.controller import TQ_INIT_FIELD_NUM, TransferQueueController +from verl.experimental.transfer_queue.storage import TransferQueueStorageSimpleUnit + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="function") +def ray_setup(): + if ray.is_initialized(): + ray.shutdown() + ray.init( + ignore_reinit_error=True, + runtime_env={"env_vars": {"RAY_DEBUG": "1", "RAY_DEDUP_LOGS": "0"}}, + log_to_driver=True, + ) + yield + if ray.is_initialized(): + ray.shutdown() + logger.info("Ray has been shut down completely after test") + + +@pytest.fixture(scope="function") +def setup_teardown_transfer_queue_controller(ray_setup): + # Used as the offset for the global index to distinguish which global step the data corresponds to + global_batch_size = 8 + num_global_batch = 2 + num_n_samples = 2 + num_data_storage_units = 2 + + tq_controller = TransferQueueController.remote( + num_storage_units=num_data_storage_units, + global_batch_size=global_batch_size, + num_global_batch=num_global_batch, + num_n_samples=num_n_samples, + ) + yield tq_controller, global_batch_size, num_global_batch, num_n_samples + ray.get(tq_controller.clear.remote(0)) + + +@pytest.fixture(scope="function") +def setup_teardown_register_controller_info(setup_teardown_transfer_queue_controller): + tq_controller, global_batch_size, num_global_batch, num_n_samples = setup_teardown_transfer_queue_controller + total_storage_size = global_batch_size * num_global_batch * num_n_samples + num_data_storage_units = 2 + + data_system_storage_units = {} + for storage_unit_rank in range(num_data_storage_units): + storage_node = TransferQueueStorageSimpleUnit.remote( + storage_size=math.ceil(total_storage_size / num_data_storage_units) + ) + data_system_storage_units[storage_unit_rank] = storage_node + logger.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.") + + # Register controller info + zmq_server_info = ray.get(tq_controller.get_zmq_server_info.remote()) + controller_infos = {zmq_server_info.id: zmq_server_info} + + ray.get( + [ + storage_unit.register_controller_info.remote(controller_infos) + for storage_unit in data_system_storage_units.values() + ] + ) + + yield tq_controller, global_batch_size, num_n_samples, data_system_storage_units + + +class TestTransferQueueController: + @pytest.mark.parametrize("num_n_samples", [1, 2]) + @pytest.mark.parametrize("num_global_batch", [1, 2]) + def test_build_index_storage_mapping(self, num_n_samples, num_global_batch, ray_setup): + # Used as the offset for the global index to distinguish which global step the data corresponds to + global_batch_size = 8 + num_data_storage_units = 2 + + self.tq_controller = TransferQueueController.remote( + num_storage_units=num_data_storage_units, + global_batch_size=global_batch_size, + num_global_batch=num_global_batch, + num_n_samples=num_n_samples, + ) + + global_index_storage_mapping, global_index_local_index_mapping = ray.get( + self.tq_controller.get_global_index_mapping.remote() + ) + + if num_global_batch == 1 and num_n_samples == 1: + assert np.array_equal(global_index_storage_mapping, np.array([0, 0, 0, 0, 1, 1, 1, 1])) + assert np.array_equal(global_index_local_index_mapping, np.array([0, 1, 2, 3, 0, 1, 2, 3])) + # The data of a single GBS will be distributed across different storage units + elif num_global_batch == 2 and num_n_samples == 1: + assert np.array_equal( + global_index_storage_mapping, np.array([0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1]) + ) + assert np.array_equal( + global_index_local_index_mapping, np.array([0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 4, 5, 6, 7]) + ) + # When num_n_samples is larger than 1 + elif num_global_batch == 1 and num_n_samples == 2: + assert np.array_equal( + global_index_storage_mapping, np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]) + ) + assert np.array_equal( + global_index_local_index_mapping, np.array([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]) + ) + elif num_global_batch == 2 and num_n_samples == 2: + assert np.array_equal( + global_index_storage_mapping, + np.array( + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] + ), + ) + assert np.array_equal( + global_index_local_index_mapping, + np.array( + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + ] + ), + ) + + def test_update_production_status(self, setup_teardown_transfer_queue_controller): + tq_controller, global_batch_size, num_global_batch, num_n_samples = setup_teardown_transfer_queue_controller + + total_storage_size = global_batch_size * num_global_batch * num_n_samples + # Initialize get_data_production_status and filed_name_mapping + init_update_production_status = torch.zeros(total_storage_size, TQ_INIT_FIELD_NUM, dtype=torch.int8) + assert torch.equal(ray.get(tq_controller.get_data_production_status.remote()), init_update_production_status) + assert ray.get(tq_controller.get_field_name_mapping.remote()) == {} + + columns_list = ["test_prompts"] + global_indexes = list(range(global_batch_size * num_n_samples)) + + # update production status + tq_controller._update_production_status.remote(global_indexes, columns_list) + new_field_name_mapping = ray.get(tq_controller.get_field_name_mapping.remote()) + assert new_field_name_mapping["test_prompts"] == 0 + + new_data_production_status = ray.get(tq_controller.get_data_production_status.remote()) + assert new_data_production_status[:, 0][: len(global_indexes)].sum() == len(global_indexes) + + def test_data_consumption_status(self, setup_teardown_transfer_queue_controller): + tq_controller, global_batch_size, num_global_batch, num_n_samples = setup_teardown_transfer_queue_controller + total_storage_size = global_batch_size * num_global_batch * num_n_samples + + init_data_consumption_status = {} + assert ray.get(tq_controller.get_data_consumption_status.remote()) == init_data_consumption_status + + task_name = "test_task1" + ray.get(tq_controller._get_consumption_status.remote(task_name)) + new_data_consumption_status = ray.get(tq_controller.get_data_consumption_status.remote()) + assert torch.equal(new_data_consumption_status[task_name], torch.zeros(total_storage_size, dtype=torch.int8)) + + def test_get_prompt_metadata(self, setup_teardown_register_controller_info): + tq_controller, global_batch_size, n_samples, _ = setup_teardown_register_controller_info + + data_fields = ["test_prompts"] + global_step = 5 + + metadata = ray.get( + tq_controller._get_metadata.remote( + data_fields=data_fields, + batch_size=global_batch_size * n_samples, + global_step=global_step, + mode="insert", + ) + ) + assert metadata.global_indexes == [ + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + ] + assert metadata.local_indexes == [ + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + ] + storage_ids = metadata.storage_ids + assert len(set(storage_ids[: len(storage_ids) // 2])) == 1 + + # TODO: Test case where multiple clients concurrently read datameta from a single controller, + # and each client receives the correct response From 01bef2a59ee3ddc26bfd57b80540f7805824f61c Mon Sep 17 00:00:00 2001 From: Jianjun Zhong <87791082+jianjunzhong@users.noreply.github.com> Date: Fri, 26 Sep 2025 09:46:55 +0800 Subject: [PATCH 10/16] Add metadata.py and test_simple_storage_unit.py (#9) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add metadata.py and test_simple_storage_unit.py * Add copyright and license information to test_simple_storage_unit.py * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Han Zhenyu ้ŸฉๆŒฏๅฎ‡ Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../test_simple_storage_unit.py | 479 +++++++++++++++ verl/experimental/transfer_queue/client.py | 4 +- .../experimental/transfer_queue/controller.py | 22 +- verl/experimental/transfer_queue/metadata.py | 567 ++++++++++++++++++ verl/experimental/transfer_queue/storage.py | 22 +- 5 files changed, 1070 insertions(+), 24 deletions(-) create mode 100644 tests/experimental/transfer_queue/test_simple_storage_unit.py create mode 100644 verl/experimental/transfer_queue/metadata.py diff --git a/tests/experimental/transfer_queue/test_simple_storage_unit.py b/tests/experimental/transfer_queue/test_simple_storage_unit.py new file mode 100644 index 00000000000..7949c9cb971 --- /dev/null +++ b/tests/experimental/transfer_queue/test_simple_storage_unit.py @@ -0,0 +1,479 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Huawei Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import time +import uuid +from pathlib import Path +from threading import Thread +from unittest.mock import MagicMock + +import pytest +import ray +import tensordict +import torch +import zmq +from tensordict import TensorDict + +# Import your classes here +parent_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(parent_dir)) + +try: + from verl.experimental.transfer_queue.storage import TransferQueueStorageSimpleUnit + from verl.experimental.transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo +except ImportError: + # For testing purposes if imports are not available + TransferQueueStorageSimpleUnit = MagicMock() + ZMQServerInfo = MagicMock() + ZMQRequestType = MagicMock() + ZMQMessage = MagicMock() + + +# Mock ZMQ utilities if not available in test environment +def create_zmq_socket(context, socket_type, identity=None): + sock = context.socket(socket_type) + if identity: + sock.setsockopt(zmq.IDENTITY, identity) + return sock + + +# Mock Controller to handle handshake and data updates +class MockController: + def __init__(self, controller_id="controller_001"): + self.controller_id = controller_id + self.context = zmq.Context() + + # Socket for handshake + self.handshake_socket = self.context.socket(zmq.ROUTER) + self.handshake_port = self._bind_to_random_port(self.handshake_socket) + + # Socket for data status updates + self.data_update_socket = self.context.socket(zmq.ROUTER) + self.data_update_port = self._bind_to_random_port(self.data_update_socket) + + self.zmq_server_info = ZMQServerInfo.create( + role="CONTROLLER", + id=controller_id, + ip="127.0.0.1", + ports={"handshake_socket": self.handshake_port, "data_status_update_socket": self.data_update_port}, + ) + + self.running = True + self.handshake_thread = Thread(target=self._handle_handshake, daemon=True) + self.data_update_thread = Thread(target=self._handle_data_updates, daemon=True) + self.handshake_thread.start() + self.data_update_thread.start() + + def _bind_to_random_port(self, socket): + port = socket.bind_to_random_port("tcp://127.0.0.1") + return port + + def _handle_handshake(self): + poller = zmq.Poller() + poller.register(self.handshake_socket, zmq.POLLIN) + + while self.running: + try: + socks = dict(poller.poll(100)) # 100ms timeout + if self.handshake_socket in socks: + identity, msg_bytes = self.handshake_socket.recv_multipart() + ZMQMessage.deserialize(msg_bytes) + + # Send handshake ack + ack_msg = ZMQMessage.create( + request_type=ZMQRequestType.HANDSHAKE_ACK, + sender_id=self.controller_id, + body={"message": "Handshake successful"}, + ) + self.handshake_socket.send_multipart([identity, ack_msg.serialize()]) + except zmq.Again: + continue + except Exception: + if self.running: + pass + + def _handle_data_updates(self): + poller = zmq.Poller() + poller.register(self.data_update_socket, zmq.POLLIN) + + while self.running: + try: + socks = dict(poller.poll(100)) # 100ms timeout + if self.data_update_socket in socks: + identity, msg_bytes = self.data_update_socket.recv_multipart() + ZMQMessage.deserialize(msg_bytes) + + # Send data update ack + ack_msg = ZMQMessage.create( + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK, + sender_id=self.controller_id, + body={"message": "Data update received"}, + ) + self.data_update_socket.send_multipart([identity, ack_msg.serialize()]) + except zmq.Again: + continue + except Exception: + if self.running: + pass + + def stop(self): + self.running = False + time.sleep(0.1) # Give threads time to stop + self.handshake_socket.close() + self.data_update_socket.close() + + +# Mock client to send PUT/GET requests +class MockClient: + def __init__(self, storage_put_get_address): + self.context = zmq.Context() + self.socket = self.context.socket(zmq.DEALER) + self.socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout + self.socket.connect(storage_put_get_address) + + def send_put(self, client_id, global_indexes, local_indexes, field_data): + msg = ZMQMessage.create( + request_type=ZMQRequestType.PUT_DATA, + sender_id=f"mock_client_{client_id}", + body={"global_indexes": global_indexes, "local_indexes": local_indexes, "field_data": field_data}, + ) + self.socket.send(msg.serialize()) + return ZMQMessage.deserialize(self.socket.recv()) + + def send_get(self, client_id, local_indexes, fields): + msg = ZMQMessage.create( + request_type=ZMQRequestType.GET_DATA, + sender_id=f"mock_client_{client_id}", + body={"local_indexes": local_indexes, "fields": fields}, + ) + self.socket.send(msg.serialize()) + return ZMQMessage.deserialize(self.socket.recv()) + + def close(self): + self.socket.close() + self.context.term() + + +@pytest.fixture(scope="session") +def ray_setup(): + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + +@pytest.fixture +def storage_setup(ray_setup): + storage_size = 10000 + tensordict.set_list_to_stack(True).set() + + # Start mock controller + mock_controller = MockController(f"controller_{uuid.uuid4()}") + time.sleep(0.5) # Wait for controller sockets to be ready + + # Start Ray actor + storage_actor = TransferQueueStorageSimpleUnit.options(max_concurrency=50, num_cpus=1).remote(storage_size) + + # Register controller info + controller_infos = {mock_controller.controller_id: mock_controller.zmq_server_info} + ray.get(storage_actor.register_controller_info.remote(controller_infos)) + + # Get ZMQ address to connect client + zmq_info = ray.get(storage_actor.get_zmq_server_info.remote()) + put_get_address = zmq_info.to_addr("put_get_socket") + time.sleep(1) # Wait for socket to be ready + + yield storage_actor, put_get_address, mock_controller + + # Cleanup + mock_controller.stop() + + +def test_put_get_single_client(storage_setup): + """Test basic put and get operations with a single client using TensorDict and torch tensors.""" + _, put_get_address, _ = storage_setup + + client = MockClient(put_get_address) + + # PUT data + global_indexes = [0, 1, 2] + local_indexes = [0, 1, 2] + field_data = TensorDict( + { + "log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])], + "rewards": [torch.tensor([10.0]), torch.tensor([20.0]), torch.tensor([30.0])], + }, + batch_size=[], + ) + + response = client.send_put(0, global_indexes, local_indexes, field_data) + assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE + + # GET data + response = client.send_get(0, [0, 1], ["log_probs", "rewards"]) + assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE + + retrieved_data = response.body["data"] + assert "log_probs" in retrieved_data + assert "rewards" in retrieved_data + assert retrieved_data["log_probs"].size(0) == 2 + assert retrieved_data["rewards"].size(0) == 2 + + # Verify data correctness + torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([1.0, 2.0, 3.0])) + torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([4.0, 5.0, 6.0])) + torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([10.0])) + torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([20.0])) + + client.close() + + +def test_put_get_multiple_clients(storage_setup): + """Test put and get operations with multiple clients including overlapping local indexes""" + _, put_get_address, _ = storage_setup + + num_clients = 5 + clients = [MockClient(put_get_address) for _ in range(num_clients)] + + # Each client puts unique data using different local_indexes + for i, client in enumerate(clients): + global_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2] + local_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2] + field_data = TensorDict( + { + "log_probs": [ + torch.tensor([i, i + 1, i + 2]), + torch.tensor([i + 3, i + 4, i + 5]), + torch.tensor([i + 6, i + 7, i + 8]), + ], + "rewards": [torch.tensor([i * 10]), torch.tensor([i * 10 + 10]), torch.tensor([i * 10 + 20])], + } + ) + + response = client.send_put(i, global_indexes, local_indexes, field_data) + assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE + + # Now simulate a third client that writes to overlapping local_indexes (e.g., index 0) + overlapping_client = MockClient(put_get_address) + overlap_local_indexes = [0] # Overlaps with first client's index 0 + overlap_field_data = TensorDict({"log_probs": [torch.tensor([999, 999, 999])], "rewards": [torch.tensor([999])]}) + response = overlapping_client.send_put( + client_id=99, global_indexes=[0], local_indexes=overlap_local_indexes, field_data=overlap_field_data + ) + assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE + + # Each original client gets its own data (except for index 0 which was overwritten) + for i, client in enumerate(clients): + response = client.send_get(i, [i * 10 + 0, i * 10 + 1], ["log_probs", "rewards"]) + assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE + + retrieved_data = response.body["data"] + assert retrieved_data["log_probs"].size(0) == 2 + assert retrieved_data["rewards"].size(0) == 2 + + # For index 0, expect data from overlapping_client; others from original client + if i == 0: + # Index 0 was overwritten + torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([999, 999, 999])) + torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([999])) + # Index 1 remains original + torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([3, 4, 5])) + torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([10])) + else: + # All data remains original + torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([i, i + 1, i + 2])) + torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([i + 3, i + 4, i + 5])) + torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([i * 10])) + torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([i * 10 + 10])) + + # Cleanup + for client in clients: + client.close() + overlapping_client.close() + + +def test_performance_basic(storage_setup): + """Basic performance test with larger data volume and proper index handling""" + _, put_get_address, _ = storage_setup + + client = MockClient(put_get_address) + + # PUT performance test + put_latencies = [] + num_puts = 50 + batch_size = 128 + + for i in range(num_puts): + start = time.time() + + # Use larger batch size and more complex index mapping + global_indexes = list(range(i * batch_size, (i + 1) * batch_size)) + local_indexes = list(range(i * batch_size, (i + 1) * batch_size)) + + # Create larger tensor data to increase data volume + log_probs_data = [] + rewards_data = [] + + for j in range(batch_size): + # Each sample contains larger tensors to increase data transfer volume + log_probs_tensor = torch.randn(32768) + rewards_tensor = torch.randn(32768) + log_probs_data.append(log_probs_tensor) + rewards_data.append(rewards_tensor) + + field_data = TensorDict({"log_probs": log_probs_data, "rewards": rewards_data}, batch_size=[batch_size]) + + response = client.send_put(0, global_indexes, local_indexes, field_data) + latency = time.time() - start + put_latencies.append(latency) + assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE + + # GET performance test + get_latencies = [] + num_gets = 50 + + for i in range(num_gets): + start = time.time() + # Retrieve larger batch of data + indices = list(range(i * batch_size, (i + 1) * batch_size)) # Retrieve batch_size indices of data each time + response = client.send_get(0, indices, ["log_probs", "rewards"]) + latency = time.time() - start + get_latencies.append(latency) + assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE + + avg_put_latency = sum(put_latencies) / len(put_latencies) * 1000 # ms + avg_get_latency = sum(get_latencies) / len(get_latencies) * 1000 # ms + + # Adjust performance thresholds to accommodate larger data volume + assert avg_put_latency < 5000, f"Avg PUT latency {avg_put_latency}ms exceeds threshold" + assert avg_get_latency < 5000, f"Avg GET latency {avg_get_latency}ms exceeds threshold" + + client.close() + + +def test_put_get_nested_tensor_single_client(storage_setup): + """Test basic put and get operations with a single client using TensorDict and nested tensors.""" + _, put_get_address, _ = storage_setup + + client = MockClient(put_get_address) + + # PUT data + global_indexes = [0, 1, 2] + local_indexes = [0, 1, 2] + + field_data = TensorDict( + { + "variable_length_sequences": [ + torch.tensor([-0.5, -1.2, -0.8]), + torch.tensor([-0.3, -1.5, -2.1, -0.9]), + torch.tensor([-1.1, -0.7]), + ], + "attention_mask": [torch.tensor([1, 1, 1]), torch.tensor([1, 1, 1, 1]), torch.tensor([1, 1])], + }, + batch_size=[], + ) + + response = client.send_put(0, global_indexes, local_indexes, field_data) + assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE + + # GET data + response = client.send_get(0, [0, 2], ["variable_length_sequences", "attention_mask"]) + assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE + + retrieved_data = response.body["data"] + assert "variable_length_sequences" in retrieved_data + assert "attention_mask" in retrieved_data + assert retrieved_data["variable_length_sequences"].size(0) == 2 + assert retrieved_data["attention_mask"].size(0) == 2 + + # Verify data correctness + torch.testing.assert_close(retrieved_data["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8])) + torch.testing.assert_close(retrieved_data["variable_length_sequences"][1], torch.tensor([-1.1, -0.7])) + torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1])) + torch.testing.assert_close(retrieved_data["attention_mask"][1], torch.tensor([1, 1])) + + client.close() + + +def test_put_get_nested_nontensor_single_client(storage_setup): + """Test basic put and get operations with a single client using non-tensor data (strings).""" + _, put_get_address, _ = storage_setup + + client = MockClient(put_get_address) + + # PUT data + global_indexes = [0, 1, 2] + local_indexes = [0, 1, 2] + field_data = TensorDict( + { + "prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"], + "response_text": ["Hi there!", "This is the response to the longer sentence", "Test response"], + }, + batch_size=[], + ) + + response = client.send_put(0, global_indexes, local_indexes, field_data) + assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE + + # GET data + response = client.send_get(0, [0, 1, 2], ["prompt_text", "response_text"]) + assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE + + retrieved_data = response.body["data"] + assert "prompt_text" in retrieved_data + assert "response_text" in retrieved_data + + # Verify data correctness + assert isinstance(retrieved_data["prompt_text"][0], str) + assert isinstance(retrieved_data["response_text"][0], str) + + assert retrieved_data["prompt_text"][0] == "Hello world!" + assert retrieved_data["prompt_text"][1] == "This is a longer sentence for testing" + assert retrieved_data["prompt_text"][2] == "Test case" + assert retrieved_data["response_text"][0] == "Hi there!" + assert retrieved_data["response_text"][1] == "This is the response to the longer sentence" + assert retrieved_data["response_text"][2] == "Test response" + + client.close() + + +def test_put_get_single_item_single_client(storage_setup): + """Test put and get operations for a single item with a single client.""" + _, put_get_address, _ = storage_setup + + client = MockClient(put_get_address) + + # PUT data + field_data = TensorDict( + { + "prompt_text": ["Hello world!"], + "attention_mask": [torch.tensor([1, 1, 1])], + }, + batch_size=[], + ) + + response = client.send_put(0, [0], [0], field_data) + assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE + + # GET data + response = client.send_get(0, [0], ["prompt_text", "attention_mask"]) + assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE + + retrieved_data = response.body["data"] + assert "prompt_text" in retrieved_data + assert "attention_mask" in retrieved_data + + assert retrieved_data["prompt_text"][0] == "Hello world!" + assert retrieved_data["attention_mask"].shape == (1, 3) + torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1])) diff --git a/verl/experimental/transfer_queue/client.py b/verl/experimental/transfer_queue/client.py index 3bf3eec950e..8005558b0b1 100644 --- a/verl/experimental/transfer_queue/client.py +++ b/verl/experimental/transfer_queue/client.py @@ -445,9 +445,9 @@ async def async_get_data(self, metadata: BatchMeta) -> TensorDict: for field in fields: storage_data[global_idx][field] = storage_unit_data[field][idx] - ordered_data: dict[str, torch.Tensor] = {field: [] for field in metadata.fields} + ordered_data: dict[str, torch.Tensor] = {field: [] for field in metadata.field_names} for global_idx in metadata.global_indexes: - for field in metadata.fields: + for field in metadata.field_names: ordered_data[field].append(storage_data[global_idx][field]) tensor_data = { diff --git a/verl/experimental/transfer_queue/controller.py b/verl/experimental/transfer_queue/controller.py index 73c1e6d4e9a..08ab6cfe9f4 100644 --- a/verl/experimental/transfer_queue/controller.py +++ b/verl/experimental/transfer_queue/controller.py @@ -91,9 +91,9 @@ def __init__( self.field_name_mapping: dict[ str, int ] = {} # Mapping table from field_name to the column indices in self.data_production_status tables - # Per-sample dtype and shape storage: {global_index: {field_name: {'dtype': dtype, 'shape': shape}}} - self.per_tensor_dtype_mapping: dict[int, dict[str, torch.dtype]] = {} - self.per_tensor_shape_mapping: dict[int, dict[str, torch.Size]] = {} + # Per-field dtype and shape storage: {global_index: {field_name: {'dtype': dtype, 'shape': shape}}} + self.per_tensor_dtype_mapping: dict[int, dict[str, Any]] = {} + self.per_tensor_shape_mapping: dict[int, dict[str, Any]] = {} self._build_index_storage_mapping() @@ -119,7 +119,7 @@ def _get_consumption_status(self, task_name: str) -> torch.Tensor: self.data_consumption_status[task_name] = torch.zeros(self.total_storage_size, dtype=torch.int8) return self.data_consumption_status[task_name] - def _get_per_tensor_dtype(self, global_index: int, field_name: str) -> Optional[torch.dtype]: + def _get_per_field_dtype(self, global_index: int, field_name: str) -> Optional[torch.dtype]: """Get dtype for a specific sample and field. Args: @@ -131,7 +131,7 @@ def _get_per_tensor_dtype(self, global_index: int, field_name: str) -> Optional[ """ return self.per_tensor_dtype_mapping.get(global_index, {}).get(field_name) - def _get_per_tensor_shape(self, global_index: int, field_name: str) -> Optional[torch.Size]: + def _get_per_field_shape(self, global_index: int, field_name: str) -> Optional[torch.Size]: """Get shape for a specific sample and field. Args: @@ -435,9 +435,9 @@ def _generate_batch_meta( for field_name in data_fields: if mode == "fetch": production_status = ProductionStatus.READY_FOR_CONSUME # Since we filtered by ready status - # Get per-tensor dtype and shape for this specific global_index and field - dtype = self._get_per_tensor_dtype(global_index, field_name) - shape = self._get_per_tensor_shape(global_index, field_name) + # Get per-field dtype and shape for this specific global_index and field + dtype = self._get_per_field_dtype(global_index, field_name) + shape = self._get_per_field_shape(global_index, field_name) elif mode == "insert": production_status = ProductionStatus.NOT_PRODUCED # FIXME: not real-time dtype = None @@ -446,8 +446,8 @@ def _generate_batch_meta( col_index = self.field_name_mapping.get(field_name) if col_index is not None and self.data_production_status[global_index, col_index] == 1: production_status = ProductionStatus.READY_FOR_CONSUME - dtype = self._get_per_tensor_dtype(global_index, field_name) - shape = self._get_per_tensor_shape(global_index, field_name) + dtype = self._get_per_field_dtype(global_index, field_name) + shape = self._get_per_field_shape(global_index, field_name) else: production_status = ProductionStatus.NOT_PRODUCED dtype = None @@ -506,7 +506,7 @@ def _update_field_info( global_indexes: list[int], ) -> None: """ - Store per-tensor dtype and shape information. + Store per-field dtype and shape information. Args: fields: List of field names diff --git a/verl/experimental/transfer_queue/metadata.py b/verl/experimental/transfer_queue/metadata.py new file mode 100644 index 00000000000..7346c292116 --- /dev/null +++ b/verl/experimental/transfer_queue/metadata.py @@ -0,0 +1,567 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Huawei Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from dataclasses import dataclass +from typing import Any, Optional + +import numpy as np +from tensordict import TensorDict + +from verl.experimental.transfer_queue.utils.utils import ProductionStatus + + +@dataclass +class FieldMeta: + """ + Records the metadata of a single data field. (name, dtype, shape, etc.) + """ + + # field name (e.g., 'prompt', 'response', etc.) + name: str + + # data schema info + dtype: Optional[Any] + shape: Optional[Any] + + # data status info + production_status: ProductionStatus = ProductionStatus.NOT_PRODUCED + + def __str__(self) -> str: + return ( + f"FieldMeta(name='{self.name}', dtype={self.dtype}, " + f"shape={self.shape}, production_status={self.production_status})" + ) + + @property + def is_ready(self) -> bool: + """Check if this field is ready for consumption""" + return self.production_status == ProductionStatus.READY_FOR_CONSUME + + +@dataclass +class SampleMeta: + """ + Records the metadata of a single data sample (stored as a row in the data system). + """ + + # algorithm related info + global_step: int # global step, used for data versioning + + # data retrival info + global_index: int # global row index, uniquely identifies a data sample + storage_id: str # storage unit id + local_index: int # local row index in the storage unit + + # data fields info + # this fields may not contain all the fields of the sample, but only fields-of-interest + fields: dict[str, FieldMeta] + + def __post_init__(self): + """Initialize is_ready property based on field readiness""" + # Check if all fields are ready and update is_ready property + object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values())) + + def __str__(self) -> str: + return ( + f"SampleMeta(global_step={self.global_step}, " + f"global_index={self.global_index}, storage_id='{self.storage_id}', " + f"local_index={self.local_index}, fields={self.fields})" + ) + + @property + def field_names(self) -> list[str]: + """Get list of field names for this sample""" + return list(self.fields.keys()) + + @property + def batch_index(self) -> int: + """Get the batch index of this sample (to be set by BatchMeta)""" + return getattr(self, "_batch_index", -1) + + def get_field_by_name(self, name: str) -> Optional[FieldMeta]: + """Get FieldMeta by field name""" + return self.fields.get(name) + + def has_field(self, name: str) -> bool: + """Check if this sample has a specific field""" + return name in self.fields + + def is_field_ready(self, field_name: str) -> bool: + """Check if a specific field is ready for consumption""" + field = self.fields.get(field_name) + return field.is_ready if field else False + + def add_fields(self, fields: dict[str, FieldMeta]) -> "SampleMeta": + """ + Add new fields to this sample. New fields will be initialized with given dtype, shape + and production_status (if provided). If not provided, default values (None, None, READY_FOR_CONSUME) + will be used. + This modifies the sample in-place to include the new fields. + """ + self.fields = _union_fields(self.fields, fields) + # Update is_ready property + object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values())) + return self + + def union(self, other: "SampleMeta", validate: bool = True) -> "SampleMeta": + """ + Create a union of this sample's fields with another sample's fields. + Assume both samples have the same global index. If fields overlap, the + fields in this sample will be replaced by the other sample's fields. + + Args: + other: Another SampleMeta to union with + validate: Whether to validate union conditions + + Returns: + New SampleMeta with unioned fields (None if validation fails) + """ + if validate: + if self.global_index != other.global_index: + raise ValueError( + f"Error: Global indexes ({self.global_index} and {other.global_index}) do not match for union." + ) + + # Merge fields + self.fields = _union_fields(self.fields, other.fields) + + # Update is_ready property + object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values())) + return self + + @property + def is_ready(self) -> bool: + """Check if all fields in this sample are ready for consumption""" + return getattr(self, "_is_ready", False) + + @property + def production_status(self) -> dict[str, ProductionStatus]: + """Get production status for all fields (backward compatibility)""" + return {name: field.production_status for name, field in self.fields.items()} + + +@dataclass +class StorageMetaGroup: + """ + Represents a group of samples stored in the same storage unit. + Used to organize samples by their storage_id for efficient client operations. + """ + + storage_id: str + sample_metas: list[SampleMeta] = dataclasses.field(default_factory=list) + + def add_sample_meta(self, sample_meta: SampleMeta) -> None: + """Add a SampleMeta object to this storage group""" + self.sample_metas.append(sample_meta) + + def get_batch_indexes(self) -> list[int]: + """Get all internal indexes from stored SampleMeta objects""" + return [meta.batch_index for meta in self.sample_metas] + + def get_global_indexes(self) -> list[int]: + """Get all global indexes from stored SampleMeta objects""" + return [meta.global_index for meta in self.sample_metas] + + def get_local_indexes(self) -> list[int]: + """Get all local indexes from stored SampleMeta objects""" + return [meta.local_index for meta in self.sample_metas] + + def get_field_names(self) -> list[str]: + """Get all unique field names from stored SampleMeta objects""" + all_fields: set[str] = set() + for meta in self.sample_metas: + all_fields.update(meta.fields.keys()) + return list(all_fields) + + def get_transfer_info(self, field_names: Optional[list[str]] = None) -> dict[str, list | dict]: + """Convert to dictionary format for backward compatibility""" + if field_names is None: + field_names = self.get_field_names() + return { + "batch_indexes": self.get_batch_indexes(), + "global_indexes": self.get_global_indexes(), + "local_indexes": self.get_local_indexes(), + "fields": field_names, + "field_data": {}, # Placeholder for field data to be filled later + } + + @property + def size(self) -> int: + """Number of samples in this storage meta group""" + return len(self.sample_metas) + + @property + def is_empty(self) -> bool: + """Check if this storage meta group is empty""" + return len(self.sample_metas) == 0 + + def __len__(self) -> int: + """Number of samples in this storage meta group""" + return self.size + + def __bool__(self) -> bool: + """Truthiness based on whether group has samples""" + return not self.is_empty + + def __str__(self) -> str: + return f"StorageMetaGroup(storage_id='{self.storage_id}', size={self.size})" + + +@dataclass +class BatchMeta: + """ + Records the metadata of a batch of data samples. + """ + + samples: list[SampleMeta] + extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + """Initialize all computed properties during initialization""" + # Basic properties + object.__setattr__(self, "_size", len(self.samples)) + object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples)) + + # Pre-compute all list properties for better performance + if self.samples: + for idx, sample in enumerate(self.samples): + object.__setattr__(sample, "_batch_index", idx) # Ensure batch_index is set correctly + + object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples]) + object.__setattr__(self, "_local_indexes", [sample.local_index for sample in self.samples]) + object.__setattr__(self, "_storage_ids", [sample.storage_id for sample in self.samples]) + + # assume all samples have the same fields. + object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names)) + + # Initialize storage groups for efficient client operations + storage_meta_groups = self._build_storage_meta_groups() + object.__setattr__(self, "_storage_meta_groups", storage_meta_groups) + else: + object.__setattr__(self, "_global_indexes", []) + object.__setattr__(self, "_local_indexes", []) + object.__setattr__(self, "_storage_ids", []) + object.__setattr__(self, "_field_names", []) + object.__setattr__(self, "_storage_meta_groups", {}) + + @property + def size(self) -> int: + """Return the number of samples in this batch""" + return getattr(self, "_size", 0) + + @property + def global_indexes(self) -> list[int]: + """Get all global indexes in this batch""" + return getattr(self, "_global_indexes", []) + + @property + def field_names(self) -> list[str]: + """Get all unique field names in this batch""" + return getattr(self, "_field_names", []) + + @property + def local_indexes(self) -> list[int]: + """Get all local indexes in this batch""" + return getattr(self, "_local_indexes", []) + + @property + def storage_ids(self) -> list[str]: + """Get all storage unit IDs in this batch""" + return getattr(self, "_storage_ids", []) + + @property + def is_ready(self) -> bool: + """Check if all samples in this batch are ready for consumption""" + # TODO: get ready status from controller realtime + return getattr(self, "_is_ready", False) + + def _build_storage_meta_groups(self) -> dict[str, StorageMetaGroup]: + """Build storage groups from samples during initialization""" + storage_meta_groups: dict[str, StorageMetaGroup] = {} + + for sample in self.samples: + storage_id = sample.storage_id + if storage_id not in storage_meta_groups: + storage_meta_groups[storage_id] = StorageMetaGroup(storage_id=storage_id) + + # Use add_sample_meta to store SampleMeta references directly + storage_meta_groups[storage_id].add_sample_meta(sample) + + return storage_meta_groups + + @property + def storage_meta_groups(self) -> dict[str, StorageMetaGroup]: + """Get storage groups organized by storage_id""" + return getattr(self, "_storage_meta_groups", {}) + + @property + def storage_unit_ids(self) -> list[str]: + """Get list of all storage unit IDs""" + return list(self.storage_meta_groups.keys()) + + def get_storage_meta_groups(self, storage_id: str) -> Optional[StorageMetaGroup]: + """Get storage group by storage ID""" + return self.storage_meta_groups.get(storage_id) + + # Extra info interface methods + def get_extra_info(self, key: str, default: Any = None) -> Any: + """Get extra info by key""" + return self.extra_info.get(key, default) + + def set_extra_info(self, key: str, value: Any) -> None: + """Set extra info by key""" + self.extra_info[key] = value + + def update_extra_info(self, info_dict: dict[str, Any]) -> None: + """Update extra info with multiple key-value pairs""" + self.extra_info.update(info_dict) + + def remove_extra_info(self, key: str) -> Any: + """Remove extra info by key and return its value""" + return self.extra_info.pop(key, None) + + def clear_extra_info(self) -> None: + """Clear all extra info""" + self.extra_info.clear() + + def has_extra_info(self, key: str) -> bool: + """Check if extra info contains a specific key""" + return key in self.extra_info + + def get_all_extra_info(self) -> dict[str, Any]: + """Get all extra info as a dictionary""" + return self.extra_info.copy() + + def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "BatchMeta": + """ + Add new fields from a TensorDict to all samples in this batch. + This modifies each sample in-place to include the new fields. + + Args: + tensor_dict (TensorDict): The input TensorDict containing new fields. + set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME. Default is True. + """ + fields = _extract_field_metas(tensor_dict, set_all_ready) + for idx, sample in enumerate(self.samples): + sample.add_fields(fields=fields[idx]) + + # Update batch-level fields cache + object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names)) + object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples)) + return self + + def __len__(self) -> int: + """Return the number of samples in this batch.""" + return len(self.samples) + + def __getitem__(self, item): + if isinstance(item, int | np.integer): + sample_meta = self.samples[item] if self.samples else [] + return BatchMeta(samples=[sample_meta], extra_info=self.extra_info) + else: + raise TypeError(f"Indexing with {type(item)} is not supported now!") + + def chunk(self, chunks: int) -> list["BatchMeta"]: + """ + Split this batch into smaller chunks. + + Args: + chunks: number of chunks + + Return: + List of smaller BatchMeta chunks + """ + chunk_list = [] + n = len(self.samples) + + # Calculate the base size and remainder of each chunk + base_size = n // chunks + remainder = n % chunks + + start = 0 + for i in range(chunks): + # Calculate the size of the current chunk(the first remainder chunk is 1 more than the base size) + current_chunk_size = base_size + 1 if i < remainder else base_size + end = start + current_chunk_size + chunk_samples = self.samples[start:end] + chunk = BatchMeta(samples=chunk_samples, extra_info=self.extra_info.copy()) + chunk_list.append(chunk) + start = end + return chunk_list + + @classmethod + def concat(cls, data: list["BatchMeta"], validate: bool = True) -> Optional["BatchMeta"]: + """ + Concatenate multiple BatchMeta chunks into one large batch. + + Args: + data: List of BatchMeta chunks to concatenate + validate: Whether to validate concatenation conditions + + Returns: + Concatenated BatchMeta + + Raises: + ValueError: If validation fails (e.g., field names do not match) + """ + if not data: + return None + + if validate: + base_fields = data[0].field_names + + for chunk in data: + if chunk.field_names != base_fields: + raise ValueError("Error: Field names do not match for concatenation.") + + # Combine all samples + all_samples = [] + for chunk in data: + all_samples.extend(chunk.samples) + # Merge all extra_info dictionaries from the chunks + merged_extra_info = {} + for chunk in data: + merged_extra_info.update(chunk.extra_info) + return BatchMeta(samples=all_samples, extra_info=merged_extra_info) + + def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMeta"]: + """ + Create a union of this batch's fields with another batch's fields. + Assume both batches have the same global indices. If fields overlap, the + fields in this batch will be replaced by the other batch's fields. + + Args: + other: Another BatchMeta to union with + validate: Whether to validate union conditions + + Returns: + New BatchMeta with unioned fields + + Raises: + ValueError: If validation fails (e.g., batch sizes or global indexes do not match) + """ + if validate: + if self.size != other.size: + raise ValueError("Error: Batch sizes do not match for union.") + + self_global_indexes = sorted(self.global_indexes) + other_global_indexes = sorted(other.global_indexes) + if self_global_indexes != other_global_indexes: + raise ValueError("Error: Global indexes do not match for union.") + + # Create a mapping from global_index to SampleMeta in the other batch + other_sample_map = {sample.global_index: sample for sample in other.samples} + + # Merge samples + merged_samples = [] + for sample in self.samples: + if sample.global_index in other_sample_map: + other_sample = other_sample_map[sample.global_index] + merged_sample = sample.union(other_sample, validate=validate) + merged_samples.append(merged_sample) + else: + merged_samples.append(sample) + + # Merge extra info dictionaries + merged_extra_info = {**self.extra_info, **other.extra_info} + + return BatchMeta(samples=merged_samples, extra_info=merged_extra_info) + + @classmethod + def from_samples( + cls, samples: SampleMeta | list[SampleMeta], extra_info: Optional[dict[str, Any]] = None + ) -> "BatchMeta": + """ + Create a BatchMeta from a single SampleMeta or a list of SampleMeta objects. + + Args: + samples: A single SampleMeta or a list of SampleMeta objects + extra_info: Optional additional information to store with the batch + + Returns: + BatchMeta instance containing the provided sample(s) + + Example: + >>> sample_meta = SampleMeta(...) + >>> batch_meta = BatchMeta.from_samples(sample_meta) + + >>> sample_metas = [sample1, sample2, sample3] + >>> batch_meta = BatchMeta.from_samples(sample_metas, extra_info={"source": "training"}) + """ + if extra_info is None: + extra_info = {} + + if isinstance(samples, SampleMeta): + samples = [samples] + + return cls(samples=samples, extra_info=extra_info) + + @classmethod + def empty(cls, extra_info: Optional[dict[str, Any]] = None) -> "BatchMeta": + """ + Create an empty BatchMeta with no samples. + + Args: + extra_info: Optional additional information to store with the batch + + Returns: + Empty BatchMeta instance + + Example: + >>> empty_batch = BatchMeta.empty() + """ + if extra_info is None: + extra_info = {} + return cls(samples=[], extra_info=extra_info) + + +def _union_fields(fields1: dict[str, FieldMeta], fields2: dict[str, FieldMeta]) -> dict[str, FieldMeta]: + """Union two sample's fields. If fields overlap, the fields in fields1 will be replaced by fields2.""" + for name in fields2.keys(): + fields1[name] = fields2[name] + return fields1 + + +def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) -> list[dict[str, FieldMeta]]: + """ + Extract field metas from a TensorDict. If data in tensor_dict does not have dtype or shape attribute, + the corresponding dtype or shape will be set to None. + + Args: + tensor_dict (TensorDict): The input TensorDict. + set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME. + Otherwise, set to NOT_PRODUCED. Default is True. + + Returns: + all_fields (list[dict[FieldMeta]]): A list of dictionaries containing field metadata. + """ + all_fields = [] + batch_size = tensor_dict.batch_size[0] + for idx in range(batch_size): + fields = {} + sample = tensor_dict[idx] + for name, value in sample.items(): + fields[name] = FieldMeta( + name=name, + dtype=value.dtype if hasattr(value, "dtype") else None, + shape=value.shape if hasattr(value, "shape") else None, + production_status=ProductionStatus.READY_FOR_CONSUME + if set_all_ready + else ProductionStatus.NOT_PRODUCED, + ) + all_fields.append(fields) + + return all_fields diff --git a/verl/experimental/transfer_queue/storage.py b/verl/experimental/transfer_queue/storage.py index e71ac9c8b38..c8f908ee8d8 100644 --- a/verl/experimental/transfer_queue/storage.py +++ b/verl/experimental/transfer_queue/storage.py @@ -334,25 +334,25 @@ def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: request_type=ZMQRequestType.PUT_DATA_RESPONSE, sender_id=self.zmq_server_info.id, body={} ) - # Gather per-tensor dtype and shape information for each field + # Gather per-field dtype and shape information for each field # global_indexes, local_indexes, and field_data correspond one-to-one - per_tensor_dtypes: dict[int, torch.dtype] = {} - per_tensor_shapes: dict[int, torch.Size] = {} + per_field_dtypes = {} + per_field_shapes = {} # Initialize the data structure for each global index for global_idx in global_indexes: - per_tensor_dtypes[global_idx] = {} - per_tensor_shapes[global_idx] = {} + per_field_dtypes[global_idx] = {} + per_field_shapes[global_idx] = {} # For each field, extract dtype and shape for each sample for field in field_data.keys(): for i, data_item in enumerate(field_data[field]): global_idx = global_indexes[i] - per_tensor_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None - per_tensor_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None + per_field_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None + per_field_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None - # Broadcast data update message to all controllers with per-tensor dtype/shape information - self._notify_data_update(list(field_data.keys()), global_indexes, per_tensor_dtypes, per_tensor_shapes) + # Broadcast data update message to all controllers with per-field dtype/shape information + self._notify_data_update(list(field_data.keys()), global_indexes, per_field_dtypes, per_field_shapes) return response_msg except Exception as e: return ZMQMessage.create( @@ -371,8 +371,8 @@ def _notify_data_update(self, fields, global_indexes, dtypes, shapes) -> None: param: fields: data update related fields. global_indexes: data update related global_indexes. - dtypes: per-tensor dtypes for each field, in {global_index: {field: dtype}} format. - shapes: per-tensor shapes for each field, in {global_index: {field: shape}} format. + dtypes: per-field dtypes for each field, in {global_index: {field: dtype}} format. + shapes: per-field shapes for each field, in {global_index: {field: shape}} format. """ # Create zmq poller for notifying data update information poller = zmq.Poller() From a8342b41e6665e29d376e1b1d72931bfb3a273fe Mon Sep 17 00:00:00 2001 From: LLLLxmmm <130739718+LLLLxmmm@users.noreply.github.com> Date: Sun, 28 Sep 2025 10:33:18 +0800 Subject: [PATCH 11/16] Add reorder function to BatchMeta (#13) Co-authored-by: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com> --- .../transfer_queue/test_controller.py | 57 ++++++++++--------- verl/experimental/transfer_queue/metadata.py | 35 ++++++++++++ 2 files changed, 64 insertions(+), 28 deletions(-) diff --git a/tests/experimental/transfer_queue/test_controller.py b/tests/experimental/transfer_queue/test_controller.py index 6577cd9e163..3b45da2a561 100644 --- a/tests/experimental/transfer_queue/test_controller.py +++ b/tests/experimental/transfer_queue/test_controller.py @@ -220,41 +220,42 @@ def test_get_prompt_metadata(self, setup_teardown_register_controller_info): mode="insert", ) ) + metadata.reorder([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]) assert metadata.global_indexes == [ - 16, - 17, - 18, - 19, - 20, - 21, - 22, - 23, - 24, - 25, - 26, - 27, - 28, - 29, - 30, 31, + 30, + 29, + 28, + 27, + 26, + 25, + 24, + 23, + 22, + 21, + 20, + 19, + 18, + 17, + 16, ] assert metadata.local_indexes == [ - 8, - 9, - 10, - 11, - 12, - 13, - 14, 15, - 8, - 9, - 10, - 11, - 12, - 13, 14, + 13, + 12, + 11, + 10, + 9, + 8, 15, + 14, + 13, + 12, + 11, + 10, + 9, + 8, ] storage_ids = metadata.storage_ids assert len(set(storage_ids[: len(storage_ids) // 2])) == 1 diff --git a/verl/experimental/transfer_queue/metadata.py b/verl/experimental/transfer_queue/metadata.py index 7346c292116..6d81e7f2ca3 100644 --- a/verl/experimental/transfer_queue/metadata.py +++ b/verl/experimental/transfer_queue/metadata.py @@ -480,6 +480,41 @@ def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMet return BatchMeta(samples=merged_samples, extra_info=merged_extra_info) + def reorder(self, indices: list[int]): + """ + Reorder the SampleMeta in the BatchMeta according to the given indices. + + The operation is performed in-place, modifying the current BatchMeta's SampleMeta order. + + Args: + indices : list[int] + A list of integers specifying the new order of SampleMeta. Each integer + represents the current index of the SampleMeta in the BatchMeta. + """ + # Reorder the samples + reordered_samples = [self.samples[i] for i in indices] + object.__setattr__(self, "samples", reordered_samples) + + # Update necessary attributes + self._update_after_reorder() + + def _update_after_reorder(self) -> None: + """Update related attributes specifically for the reorder operation""" + # Update batch_index for each sample + for idx, sample in enumerate(self.samples): + object.__setattr__(sample, "_batch_index", idx) + + # Update cached index lists + object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples]) + object.__setattr__(self, "_local_indexes", [sample.local_index for sample in self.samples]) + object.__setattr__(self, "_storage_ids", [sample.storage_id for sample in self.samples]) + + # Rebuild storage groups + storage_meta_groups = self._build_storage_meta_groups() + object.__setattr__(self, "_storage_meta_groups", storage_meta_groups) + + # Note: No need to update _size, _field_names, _is_ready, etc., as these remain unchanged after reorder + @classmethod def from_samples( cls, samples: SampleMeta | list[SampleMeta], extra_info: Optional[dict[str, Any]] = None From b5be2adf9934d71b9fd1d98fcf9eba4007863328 Mon Sep 17 00:00:00 2001 From: LLLLxmmm <130739718+LLLLxmmm@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:14:07 +0800 Subject: [PATCH 12/16] [recipe, data] feat: TransferQueue - Support managing multiple data partitions for Train/Val/Test in controller (#45) Co-authored-by: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com> --- recipe/transfer_queue/agent_loop.py | 7 +- recipe/transfer_queue/ray_trainer.py | 234 +++++++++++--------------- requirements_transferqueue.txt | 2 +- verl/single_controller/base/worker.py | 8 +- verl/utils/transferqueue_utils.py | 26 +-- 5 files changed, 111 insertions(+), 166 deletions(-) diff --git a/recipe/transfer_queue/agent_loop.py b/recipe/transfer_queue/agent_loop.py index 871ae8025c0..7f936e6730e 100644 --- a/recipe/transfer_queue/agent_loop.py +++ b/recipe/transfer_queue/agent_loop.py @@ -67,10 +67,7 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Data return timing - def create_transferqueue_client(self, controller_infos, storage_infos, role): + def create_transferqueue_client(self, controller_info, config): ray.get( - [ - worker.create_transferqueue_client.remote(controller_infos, storage_infos, role) - for worker in self.agent_loop_workers - ] + [worker.create_transferqueue_client.remote(controller_info, config) for worker in self.agent_loop_workers] ) diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index d6adbddb676..9874fc7e0dc 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -41,8 +41,8 @@ from tqdm import tqdm from transfer_queue import ( BatchMeta, + SimpleStorageUnit, TransferQueueController, - TransferQueueStorageSimpleUnit, get_placement_group, process_zmq_server_info, ) @@ -81,6 +81,7 @@ from verl.utils.metric import reduce_metrics from verl.utils.rollout_skip import RolloutSkip from verl.utils.seqlen_balancing import ( + calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance, ) @@ -89,7 +90,6 @@ from verl.utils.transferqueue_utils import ( create_transferqueue_client, get_transferqueue_client, - get_val_transferqueue_client, tqbridge, ) @@ -412,109 +412,66 @@ def __init__( self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) - self.data_system_client = self._initialize_train_data_system( - self.config.data.train_batch_size, self.config.actor_rollout_ref.rollout.n + self.data_system_client = self._initialize_data_system() + + def _initialize_data_system(self): + # 1. initialize TransferQueueStorage + train_data_size = ( + self.config.data.train_batch_size + * self.config.trainer.num_global_batch + * self.config.actor_rollout_ref.rollout.n ) - self.val_data_system_client = self._initialize_val_data_system( - self.val_batch_size, self.config.actor_rollout_ref.rollout.val_kwargs.n + val_data_size = ( + self.val_batch_size + * self.config.trainer.num_global_batch + * self.config.actor_rollout_ref.rollout.val_kwargs.n ) - def _initialize_train_data_system(self, global_batch_size, num_n_samples, role="train"): - # 1. initialize TransferQueueStorage - total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples + total_storage_size = train_data_size + val_data_size self.data_system_storage_units = {} storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1) for storage_unit_rank in range(self.config.trainer.num_data_storage_units): - storage_node = TransferQueueStorageSimpleUnit.options( + storage_node = SimpleStorageUnit.options( placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank - ).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units)) + ).remote(storage_unit_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units)) self.data_system_storage_units[storage_unit_rank] = storage_node - logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.") - - # 2. initialize TransferQueueController - # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly - # one controller for a single WorkerGroup. - self.data_system_controllers = {} - controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1) - for controller_rank in range(self.config.trainer.num_data_controllers): - self.data_system_controllers[controller_rank] = TransferQueueController.options( - placement_group=controller_placement_group, placement_group_bundle_index=controller_rank - ).remote( - num_storage_units=self.config.trainer.num_data_storage_units, - global_batch_size=global_batch_size, - num_global_batch=self.config.trainer.num_global_batch, - num_n_samples=num_n_samples, - ) - logging.info(f"TransferQueueController #{controller_rank} has been created.") + logging.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.") - # 3. register controller & storage - self.data_system_controller_infos = process_zmq_server_info(self.data_system_controllers) - self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units) + # 2. Initialize TransferQueueController (single controller only) - ray.get( - [ - storage_unit.register_controller_info.remote(self.data_system_controller_infos) - for storage_unit in self.data_system_storage_units.values() - ] - ) + # Sampler usage instructions: + # For GRPO grouped sampling, you can initialize the controller with GRPOGroupNSampler: + # Option 1: Pass sampler class (will be instantiated automatically) + # self.data_system_controller = TransferQueueController.remote(sampler=GRPOGroupNSampler) - # 4. create client - # each client should be allocated to exactly one controller - create_transferqueue_client( - client_id="Trainer-" + role, - controller_infos=self.data_system_controller_infos, - storage_infos=self.data_system_storage_unit_infos, - ) - data_system_client = get_transferqueue_client() - return data_system_client + # Option 2: Pass sampler instance (if you need custom configuration) + # grpo_sampler = GRPOGroupNSampler() + # self.data_system_controller = TransferQueueController.remote(sampler=grpo_sampler) - def _initialize_val_data_system(self, global_batch_size, num_n_samples, role="val"): - # 1. initialize TransferQueueStorage - total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples - self.val_data_system_storage_units = {} - storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1) - for storage_unit_rank in range(self.config.trainer.num_data_storage_units): - storage_node = TransferQueueStorageSimpleUnit.options( - placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank - ).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units)) - self.val_data_system_storage_units[storage_unit_rank] = storage_node - logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.") - - # 2. initialize TransferQueueController - # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly - # one controller for a single WorkerGroup. - self.val_data_system_controllers = {} - controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1) - for controller_rank in range(self.config.trainer.num_data_controllers): - self.val_data_system_controllers[controller_rank] = TransferQueueController.options( - placement_group=controller_placement_group, placement_group_bundle_index=controller_rank - ).remote( - num_storage_units=self.config.trainer.num_data_storage_units, - global_batch_size=global_batch_size, - num_global_batch=self.config.trainer.num_global_batch, - num_n_samples=num_n_samples, - ) - logging.info(f"TransferQueueController #{controller_rank} has been created.") + # Then use sampling_config in get_meta calls: + # sampling_config={"n_samples_per_prompt": 4} + self.data_system_controller = TransferQueueController.remote() + logging.info("TransferQueueController has been created.") - # 3. register controller & storage - self.val_data_system_controller_infos = process_zmq_server_info(self.val_data_system_controllers) - self.val_data_system_storage_unit_infos = process_zmq_server_info(self.val_data_system_storage_units) + # 3. register controller & storage and prepare necessary information + self.data_system_controller_info = process_zmq_server_info(self.data_system_controller) + self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units) - ray.get( - [ - storage_unit.register_controller_info.remote(self.val_data_system_controller_infos) - for storage_unit in self.val_data_system_storage_units.values() - ] - ) + # Note: Need to generate a new DictConfig with allow_objects=True to preserve ZMQServerInfo instances + # (which contain socket connection details). Without this flag, OmegaConf would flatten these objects to dicts, + # breaking the transfer queue client initialization. + tq_config = OmegaConf.create({}, flags={"allow_objects": True}) + tq_config.controller_info = self.data_system_controller_info + tq_config.storage_unit_infos = self.data_system_storage_unit_infos + self.config = OmegaConf.merge(tq_config, self.config) # 4. create client - # each client should be allocated to exactly one controller create_transferqueue_client( - client_id="Trainer-" + role, - controller_infos=self.val_data_system_controller_infos, - storage_infos=self.val_data_system_storage_unit_infos, + client_id="Trainer", + controller_info=self.data_system_controller_info, + config=self.config, ) - data_system_client = get_val_transferqueue_client() + data_system_client = get_transferqueue_client() return data_system_client def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): @@ -726,19 +683,18 @@ def _validate(self): if self.config.reward_model.enable and test_batch[0]["reward_model"]["style"] == "model": return {} - asyncio.run(self.val_data_system_client.async_put(data=test_batch, global_step=self.global_steps - 1)) + asyncio.run(self.data_system_client.async_put(data=test_batch, partition_id=f"val_{self.global_steps - 1}")) # Store original inputs batch_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=["input_ids", "uid", "reward_model"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", task_name="get_data", ) ) - data = asyncio.run(self.val_data_system_client.async_get_data(batch_meta)) + data = asyncio.run(self.data_system_client.async_get_data(batch_meta)) input_ids = data["input_ids"] # TODO: Can we keep special tokens except for padding tokens? input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] @@ -749,11 +705,10 @@ def _validate(self): sample_gts.extend(ground_truths) test_gen_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=list(test_batch.keys()), # TODO: (TQ) Get metadata by specified fields batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 task_name="generate_sequences", ) ) @@ -779,15 +734,14 @@ def _validate(self): # Store generated outputs test_response_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=["responses"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 task_name="get_response", ) ) - data = asyncio.run(self.val_data_system_client.async_get_data(test_response_meta)) + data = asyncio.run(self.data_system_client.async_get_data(test_response_meta)) output_ids = data["responses"] output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] sample_outputs.extend(output_texts) @@ -808,11 +762,10 @@ def _validate(self): if "rm_scores" in batch_meta.field_names: compute_reward_fields = ["rm_scores"] val_reward_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=compute_reward_fields, batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", task_name="compute_reward", ) ) @@ -832,29 +785,27 @@ def _validate(self): # collect num_turns of each prompt if "__num_turns__" in test_batch_meta.field_names: num_turns_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=["__num_turns__"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 task_name="get_num_turns", ) ) - data = asyncio.run(self.val_data_system_client.async_get_data(num_turns_meta)) + data = asyncio.run(self.data_system_client.async_get_data(num_turns_meta)) sample_turns.append(data["__num_turns__"]) data_source = ["unknown"] * reward_tensor.shape[0] if "data_source" in test_batch_meta.field_names: data_source_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=["data_source"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 task_name="get_data_source", ) ) - data = asyncio.run(self.val_data_system_client.async_get_data(data_source_meta)) + data = asyncio.run(self.data_system_client.async_get_data(data_source_meta)) data_source = data["data_source"] data_source_lst.append(data_source) @@ -902,7 +853,7 @@ def _validate(self): metric_dict["val-aux/num_turns/max"] = sample_turns.max() metric_dict["val-aux/num_turns/mean"] = sample_turns.mean() - asyncio.run(self.val_data_system_client.async_clear(self.global_steps - 1)) + asyncio.run(self.data_system_client.async_clear(partition_id=f"val_{self.global_steps - 1}")) return metric_dict def init_workers(self): @@ -1003,12 +954,7 @@ def init_workers(self): # set transferqueue server info for each worker for _, wg in all_wg.items(): - wg.create_transferqueue_client( - self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train" - ) - wg.create_transferqueue_client( - self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val" - ) + wg.create_transferqueue_client(self.data_system_controller_info, self.config) # create async rollout manager and request scheduler self.async_rollout_mode = False @@ -1020,12 +966,7 @@ def init_workers(self): config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg ) - self.async_rollout_manager.create_transferqueue_client( - self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train" - ) - self.async_rollout_manager.create_transferqueue_client( - self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val" - ) + self.async_rollout_manager.create_transferqueue_client(self.data_system_controller_info, self.config) def _save_checkpoint(self): from verl.utils.fs import local_mkdir_safe @@ -1164,17 +1105,39 @@ def _stop_profiling(self, do_profile: bool) -> None: if self.use_rm: self.rm_wg.stop_profile() - def _balance_batch(self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen"): + def _balance_batch( + self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen", keep_minibatch=False + ): """Reorder the batchmeta on single controller such that each dp rank gets similar total tokens""" data = asyncio.run(data_system_client.async_get_data(batch)) attention_mask = data["attention_mask"] batch_size = attention_mask.shape[0] - global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,) + global_seqlen_lst = calculate_workload(global_seqlen_lst) world_size = self.actor_rollout_wg.world_size - global_partition_lst = get_seqlen_balanced_partitions( - global_seqlen_lst, k_partitions=world_size, equal_size=True - ) + if keep_minibatch: + # Decouple the DP balancing and mini-batching. + minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size") + minibatch_num = len(global_seqlen_lst) // minibatch_size + global_partition_lst = [[] for _ in range(world_size)] + for i in range(minibatch_num): + rearrange_minibatch_lst = get_seqlen_balanced_partitions( + global_seqlen_lst[i * minibatch_size : (i + 1) * minibatch_size], + k_partitions=world_size, + equal_size=True, + ) + for j, part in enumerate(rearrange_minibatch_lst): + global_partition_lst[j].extend([x + minibatch_size * i for x in part]) + else: + global_partition_lst = get_seqlen_balanced_partitions( + global_seqlen_lst, k_partitions=world_size, equal_size=True + ) + # Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel. + for idx, partition in enumerate(global_partition_lst): + partition.sort(key=lambda x: (global_seqlen_lst[x], x)) + ordered_partition = partition[::2] + partition[1::2][::-1] + global_partition_lst[idx] = ordered_partition # reorder based on index. The data will be automatically equally partitioned by dispatch function global_idx = [j for partition in global_partition_lst for j in partition] global_balance_stats = log_seqlen_unbalance( @@ -1313,8 +1276,7 @@ def fit(self): timing_raw = {} base_get_meta_kwargs = dict( batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, - global_step=self.global_steps - 1, # self.global_steps starts from 1 - get_n_samples=False, + partition_id=f"train_{self.global_steps - 1}", # self.global_steps starts from 1 ) with marked_timer("start_profile", timing_raw): @@ -1333,7 +1295,9 @@ def fit(self): batch_dict, repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True ) batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict) - asyncio.run(self.data_system_client.async_put(data=batch, global_step=self.global_steps - 1)) + asyncio.run( + self.data_system_client.async_put(data=batch, partition_id=f"train_{self.global_steps - 1}") + ) gen_meta = asyncio.run( self.data_system_client.async_get_meta( @@ -1709,8 +1673,7 @@ def fit(self): ], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, - global_step=self.global_steps - 1, - get_n_samples=False, + partition_id=f"train_{self.global_steps - 1}", task_name="update_actor", ) ) @@ -1735,8 +1698,7 @@ def fit(self): self.data_system_client.async_get_meta( data_fields=data_fields, batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, - global_step=self.global_steps - 1, - get_n_samples=False, + partition_id=f"train_{self.global_steps - 1}", task_name="log_rollout", ) ) @@ -1857,7 +1819,7 @@ def fit(self): # TODO: (TQ) support transfer queue self.train_dataloader.sampler.update(batch=batch) - asyncio.run(self.data_system_client.async_clear(self.global_steps - 1)) + asyncio.run(self.data_system_client.async_clear(partition_id=f"train_{self.global_steps - 1}")) # TODO: make a canonical logger that supports various backend logger.log(data=metrics, step=self.global_steps) diff --git a/requirements_transferqueue.txt b/requirements_transferqueue.txt index 8479d27bb21..621682abbf7 100644 --- a/requirements_transferqueue.txt +++ b/requirements_transferqueue.txt @@ -1,2 +1,2 @@ # requirements.txt records the full set of dependencies for development -git+https://github.com/TransferQueue/TransferQueue.git@68c04e7 +transferqueue==0.1.1.dev2 diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index 2513c57f99c..399ac75a063 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -131,13 +131,13 @@ def _query_collect_info(self, mesh_name: str): return self.__collect_dp_rank[mesh_name] @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True) - def create_transferqueue_client(self, controller_infos, storage_infos, role="train"): + def create_transferqueue_client(self, controller_info, config): from verl.utils.transferqueue_utils import create_transferqueue_client create_transferqueue_client( - client_id=f"{role}_worker_{self.rank}", - controller_infos=controller_infos, - storage_infos=storage_infos, + client_id=f"worker_{self.rank}", + controller_info=controller_info, + config=config, ) @classmethod diff --git a/verl/utils/transferqueue_utils.py b/verl/utils/transferqueue_utils.py index 27160571ef3..c692578e3a0 100644 --- a/verl/utils/transferqueue_utils.py +++ b/verl/utils/transferqueue_utils.py @@ -38,32 +38,24 @@ class BatchMeta: from verl.protocol import DataProto _TRANSFER_QUEUE_CLIENT = None -_VAL_TRANSFER_QUEUE_CLIENT = None is_transferqueue_enabled = os.environ.get("TRANSFER_QUEUE_ENABLE", False) def create_transferqueue_client( client_id: str, - controller_infos: dict[Any, "ZMQServerInfo"], - storage_infos: dict[Any, "ZMQServerInfo"], + controller_info: dict[Any, "ZMQServerInfo"], + config, ) -> None: global _TRANSFER_QUEUE_CLIENT - global _VAL_TRANSFER_QUEUE_CLIENT - if "val" in client_id: - _VAL_TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_infos, storage_infos) - else: - _TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_infos, storage_infos) + _TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_info) + _TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) def get_transferqueue_client() -> "AsyncTransferQueueClient": return _TRANSFER_QUEUE_CLIENT -def get_val_transferqueue_client() -> "AsyncTransferQueueClient": - return _VAL_TRANSFER_QUEUE_CLIENT - - def _run_async_in_temp_loop(async_func: Callable[..., Any], *args, **kwargs) -> Any: # Use a temporary event loop in a new thread because event # loop may already exist in server mode @@ -109,10 +101,7 @@ async def _async_batchmeta_to_dataproto(batchmeta: "BatchMeta") -> DataProto: meta_info=batchmeta.extra_info.copy(), ) - if batchmeta.extra_info.get("validate", False): - tensordict = await _VAL_TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta) - else: - tensordict = await _TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta) + tensordict = await _TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta) return DataProto.from_tensordict(tensordict, meta_info=batchmeta.extra_info.copy()) @@ -130,10 +119,7 @@ async def _async_update_batchmeta_with_output(output: DataProto, batchmeta: "Bat for key in output.meta_info.keys(): tensordict.pop(key) batchmeta.add_fields(tensordict) - if batchmeta.extra_info.get("validate", False): - await _VAL_TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) - else: - await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) + await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta") -> None: From ffb08eabc80011b377425ecb398e639e4fc32e76 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 17 Nov 2025 14:36:57 +0800 Subject: [PATCH 13/16] delete TQ source codes Signed-off-by: 0oshowero0 --- .../transfer_queue/test_client.py | 385 --------- .../transfer_queue/test_controller.py | 264 ------ .../test_simple_storage_unit.py | 479 ----------- verl/experimental/transfer_queue/__init__.py | 14 - verl/experimental/transfer_queue/client.py | 662 --------------- .../experimental/transfer_queue/controller.py | 771 ------------------ verl/experimental/transfer_queue/metadata.py | 602 -------------- verl/experimental/transfer_queue/storage.py | 516 ------------ .../transfer_queue/utils/__init__.py | 14 - .../transfer_queue/utils/utils.py | 111 --- .../transfer_queue/utils/zmq_utils.py | 176 ---- 11 files changed, 3994 deletions(-) delete mode 100644 tests/experimental/transfer_queue/test_client.py delete mode 100644 tests/experimental/transfer_queue/test_controller.py delete mode 100644 tests/experimental/transfer_queue/test_simple_storage_unit.py delete mode 100644 verl/experimental/transfer_queue/__init__.py delete mode 100644 verl/experimental/transfer_queue/client.py delete mode 100644 verl/experimental/transfer_queue/controller.py delete mode 100644 verl/experimental/transfer_queue/metadata.py delete mode 100644 verl/experimental/transfer_queue/storage.py delete mode 100644 verl/experimental/transfer_queue/utils/__init__.py delete mode 100644 verl/experimental/transfer_queue/utils/utils.py delete mode 100644 verl/experimental/transfer_queue/utils/zmq_utils.py diff --git a/tests/experimental/transfer_queue/test_client.py b/tests/experimental/transfer_queue/test_client.py deleted file mode 100644 index f1b4efd191b..00000000000 --- a/tests/experimental/transfer_queue/test_client.py +++ /dev/null @@ -1,385 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -from threading import Thread - -import pytest -import torch -import zmq -from tensordict import NonTensorStack, TensorDict - -from verl.experimental.transfer_queue import TransferQueueClient # noqa: E402 -from verl.experimental.transfer_queue.metadata import ( # noqa: E402 - BatchMeta, - FieldMeta, - SampleMeta, -) -from verl.experimental.transfer_queue.utils.zmq_utils import ( # noqa: E402 - ZMQMessage, - ZMQRequestType, - ZMQServerInfo, -) - -TEST_DATA = TensorDict( - { - "log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])], - "variable_length_sequences": torch.nested.as_nested_tensor( - [ - torch.tensor([-0.5, -1.2, -0.8]), - torch.tensor([-0.3, -1.5, -2.1, -0.9]), - torch.tensor([-1.1, -0.7]), - ] - ), - "prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"], - }, - batch_size=[3], -) - - -# Mock Controller for Client Unit Testing -class MockController: - def __init__(self, controller_id="controller_0"): - self.controller_id = controller_id - self.context = zmq.Context() - - # Socket for data requests - self.request_socket = self.context.socket(zmq.ROUTER) - self.request_port = self._bind_to_random_port(self.request_socket) - - self.zmq_server_info = ZMQServerInfo.create( - role="TransferQueueController", - id=controller_id, - ip="127.0.0.1", - ports={ - "request_handle_socket": self.request_port, - }, - ) - - self.running = True - self.request_thread = Thread(target=self._handle_requests, daemon=True) - self.request_thread.start() - - def _bind_to_random_port(self, socket): - port = socket.bind_to_random_port("tcp://127.0.0.1") - return port - - def _handle_requests(self): - poller = zmq.Poller() - poller.register(self.request_socket, zmq.POLLIN) - - while self.running: - try: - socks = dict(poller.poll(100)) # 100ms timeout - if self.request_socket in socks: - identity, serialized_msg = self.request_socket.recv_multipart() - request_msg = ZMQMessage.deserialize(serialized_msg) - - # Determine response based on request type - if request_msg.request_type == ZMQRequestType.GET_META: - response_body = self._mock_batch_meta(request_msg.body) - response_type = ZMQRequestType.GET_META_RESPONSE - elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META: - response_body = self._mock_batch_meta(request_msg.body) - response_type = ZMQRequestType.GET_CLEAR_META_RESPONSE - elif request_msg.request_type == ZMQRequestType.CLEAR_META: - response_body = {"message": "clear ok"} - response_type = ZMQRequestType.CLEAR_META_RESPONSE - - # Send response - response_msg = ZMQMessage.create( - request_type=response_type, - sender_id=self.controller_id, - receiver_id=request_msg.sender_id, - body=response_body, - ) - self.request_socket.send_multipart([identity, response_msg.serialize()]) - except zmq.Again: - continue - except Exception as e: - if self.is_running: - print(f"MockController running exception: {e}") - else: - print(f"MockController ERROR: {e}") - raise - - def _mock_batch_meta(self, request_body): - batch_size = request_body.get("batch_size", 1) - data_fields = request_body.get("data_fields", []) - - samples = [] - for i in range(batch_size): - fields = [] - for field_name in data_fields: - field_meta = FieldMeta( - name=field_name, - dtype=None, - shape=None, - production_status=0, - ) - fields.append(field_meta) - sample = SampleMeta( - global_step=0, - global_index=i, - storage_id="storage_0", - local_index=i, - fields={field.name: field for field in fields}, - ) - samples.append(sample) - metadata = BatchMeta(samples=samples) - - return {"metadata": metadata} - - def stop(self): - self.running = False - time.sleep(0.2) # Give thread time to stop - self.request_socket.close() - self.context.term() - - -# Mock Storage for Client Unit Testing -class MockStorage: - def __init__(self, storage_id="storage_0"): - self.storage_id = storage_id - self.context = zmq.Context() - - # Socket for data operations - self.data_socket = self.context.socket(zmq.ROUTER) - self.data_port = self._bind_to_random_port(self.data_socket) - - self.zmq_server_info = ZMQServerInfo.create( - role="TransferQueueStorage", - id=storage_id, - ip="127.0.0.1", - ports={ - "put_get_socket": self.data_port, - }, - ) - - self.running = True - self.data_thread = Thread(target=self._handle_data_requests, daemon=True) - self.data_thread.start() - - def _bind_to_random_port(self, socket): - port = socket.bind_to_random_port("tcp://127.0.0.1") - return port - - def _handle_data_requests(self): - poller = zmq.Poller() - poller.register(self.data_socket, zmq.POLLIN) - - while self.running: - try: - socks = dict(poller.poll(100)) # 100ms timeout - if self.data_socket in socks: - identity, msg_bytes = self.data_socket.recv_multipart() - msg = ZMQMessage.deserialize(msg_bytes) - - # Handle different request types - if msg.request_type == ZMQRequestType.PUT_DATA: - response_body = {"message": "Data stored successfully"} - response_type = ZMQRequestType.PUT_DATA_RESPONSE - elif msg.request_type == ZMQRequestType.GET_DATA: - response_body = self._handle_get_data(msg.body) - response_type = ZMQRequestType.GET_DATA_RESPONSE - elif msg.request_type == ZMQRequestType.CLEAR_DATA: - response_body = {"message": "Data cleared successfully"} - response_type = ZMQRequestType.CLEAR_DATA_RESPONSE - - # Send response - response_msg = ZMQMessage.create( - request_type=response_type, - sender_id=self.storage_id, - receiver_id=msg.sender_id, - body=response_body, - ) - self.data_socket.send_multipart([identity, response_msg.serialize()]) - except zmq.Again: - continue - except Exception as e: - if self.is_running: - print(f"MockStorage running exception: {e}") - else: - print(f"MockStorage ERROR: {e}") - raise - - def _handle_get_data(self, request_body): - """Handle GET_DATA request by retrieving stored data""" - local_indexes = request_body.get("local_indexes", []) - fields = request_body.get("fields", []) - - result: dict[str, list] = {} - for field in fields: - gathered_items = [TEST_DATA[field][i] for i in local_indexes] - - if gathered_items: - all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items) - if all_tensors: - result[field] = torch.nested.as_nested_tensor(gathered_items) - else: - result[field] = NonTensorStack(*gathered_items) - - return {"data": TensorDict(result)} - - def stop(self): - self.running = False - time.sleep(0.2) # Give thread time to stop - self.data_socket.close() - self.context.term() - - -# Test Fixtures -@pytest.fixture -def mock_controller(): - controller = MockController() - yield controller - controller.stop() - - -@pytest.fixture -def mock_storage(): - storage = MockStorage() - yield storage - storage.stop() - - -@pytest.fixture -def client_setup(mock_controller, mock_storage): - # Create client with mock controller and storage - client_id = "client_0" - - client = TransferQueueClient( - client_id=client_id, - controller_infos={mock_controller.controller_id: mock_controller.zmq_server_info}, - storage_infos={mock_storage.storage_id: mock_storage.zmq_server_info}, - ) - - # Give some time for connections to establish - time.sleep(0.5) - - yield client, mock_controller, mock_storage - - -# Test basic functionality -def test_client_initialization(client_setup): - """Test client initialization and connection setup""" - client, mock_controller, mock_storage = client_setup - - assert client.client_id is not None - assert mock_controller.controller_id in client._controllers - assert mock_storage.storage_id in client._storages - - -def test_put_and_get_data(client_setup): - """Test basic put and get operations""" - client, _, _ = client_setup - - # Test put operation - client.put(data=TEST_DATA, global_step=0) - - # Get metadata for retrieving data - metadata = client.get_meta( - data_fields=["log_probs", "variable_length_sequences", "prompt_text"], batch_size=2, global_step=0 - ) - - # Test get operation - result = client.get_data(metadata) - - # Verify result structure - assert "log_probs" in result - assert "variable_length_sequences" in result - assert "prompt_text" in result - - torch.testing.assert_close(result["log_probs"][0], torch.tensor([1.0, 2.0, 3.0])) - torch.testing.assert_close(result["log_probs"][1], torch.tensor([4.0, 5.0, 6.0])) - torch.testing.assert_close(result["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8])) - torch.testing.assert_close(result["variable_length_sequences"][1], torch.tensor([-0.3, -1.5, -2.1, -0.9])) - assert result["prompt_text"][0] == "Hello world!" - assert result["prompt_text"][1] == "This is a longer sentence for testing" - - -def test_get_meta(client_setup): - """Test metadata retrieval""" - client, _, _ = client_setup - - # Test get_meta operation - metadata = client.get_meta(data_fields=["tokens", "labels"], batch_size=10, global_step=0) - - # Verify metadata structure - assert hasattr(metadata, "storage_meta_groups") - assert hasattr(metadata, "global_indexes") - assert hasattr(metadata, "fields") - assert hasattr(metadata, "size") - assert len(metadata.global_indexes) == 10 - - -def test_clear_operation(client_setup): - """Test clear operation""" - client, _, _ = client_setup - - # Test clear operation - client.clear(global_step=0) - - -# Test with multiple controllers and storage units -def test_multiple_servers(): - """Test client with multiple controllers and storage units""" - # Create multiple mock servers - controllers = [MockController(f"controller_{i}") for i in range(2)] - storages = [MockStorage(f"storage_{i}") for i in range(3)] - - try: - # Create client with multiple servers - client_id = "client_test_multiple_servers" - - controller_infos = {c.controller_id: c.zmq_server_info for c in controllers} - storage_infos = {s.storage_id: s.zmq_server_info for s in storages} - - client = TransferQueueClient( - client_id=client_id, controller_infos=controller_infos, storage_infos=storage_infos - ) - - # Give time for connections - time.sleep(1.0) - - # Verify connections - assert len(client._controllers) == 2 - assert len(client._storages) == 3 - - # Test basic operation - test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5) - - # Test put operation - client.put(data=test_data, global_step=0) - - finally: - # Clean up - for c in controllers: - c.stop() - for s in storages: - s.stop() - - -# Test error handling -def test_put_without_required_params(client_setup): - """Test put operation without required parameters""" - client, _, _ = client_setup - - # Create test data - test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5) - - # Test put without global_step (should fail) - with pytest.raises(AssertionError): - client.put(data=test_data) diff --git a/tests/experimental/transfer_queue/test_controller.py b/tests/experimental/transfer_queue/test_controller.py deleted file mode 100644 index 3b45da2a561..00000000000 --- a/tests/experimental/transfer_queue/test_controller.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import math - -import numpy as np -import pytest -import ray -import torch - -from verl.experimental.transfer_queue.controller import TQ_INIT_FIELD_NUM, TransferQueueController -from verl.experimental.transfer_queue.storage import TransferQueueStorageSimpleUnit - -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") -logger = logging.getLogger(__name__) - - -@pytest.fixture(scope="function") -def ray_setup(): - if ray.is_initialized(): - ray.shutdown() - ray.init( - ignore_reinit_error=True, - runtime_env={"env_vars": {"RAY_DEBUG": "1", "RAY_DEDUP_LOGS": "0"}}, - log_to_driver=True, - ) - yield - if ray.is_initialized(): - ray.shutdown() - logger.info("Ray has been shut down completely after test") - - -@pytest.fixture(scope="function") -def setup_teardown_transfer_queue_controller(ray_setup): - # Used as the offset for the global index to distinguish which global step the data corresponds to - global_batch_size = 8 - num_global_batch = 2 - num_n_samples = 2 - num_data_storage_units = 2 - - tq_controller = TransferQueueController.remote( - num_storage_units=num_data_storage_units, - global_batch_size=global_batch_size, - num_global_batch=num_global_batch, - num_n_samples=num_n_samples, - ) - yield tq_controller, global_batch_size, num_global_batch, num_n_samples - ray.get(tq_controller.clear.remote(0)) - - -@pytest.fixture(scope="function") -def setup_teardown_register_controller_info(setup_teardown_transfer_queue_controller): - tq_controller, global_batch_size, num_global_batch, num_n_samples = setup_teardown_transfer_queue_controller - total_storage_size = global_batch_size * num_global_batch * num_n_samples - num_data_storage_units = 2 - - data_system_storage_units = {} - for storage_unit_rank in range(num_data_storage_units): - storage_node = TransferQueueStorageSimpleUnit.remote( - storage_size=math.ceil(total_storage_size / num_data_storage_units) - ) - data_system_storage_units[storage_unit_rank] = storage_node - logger.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.") - - # Register controller info - zmq_server_info = ray.get(tq_controller.get_zmq_server_info.remote()) - controller_infos = {zmq_server_info.id: zmq_server_info} - - ray.get( - [ - storage_unit.register_controller_info.remote(controller_infos) - for storage_unit in data_system_storage_units.values() - ] - ) - - yield tq_controller, global_batch_size, num_n_samples, data_system_storage_units - - -class TestTransferQueueController: - @pytest.mark.parametrize("num_n_samples", [1, 2]) - @pytest.mark.parametrize("num_global_batch", [1, 2]) - def test_build_index_storage_mapping(self, num_n_samples, num_global_batch, ray_setup): - # Used as the offset for the global index to distinguish which global step the data corresponds to - global_batch_size = 8 - num_data_storage_units = 2 - - self.tq_controller = TransferQueueController.remote( - num_storage_units=num_data_storage_units, - global_batch_size=global_batch_size, - num_global_batch=num_global_batch, - num_n_samples=num_n_samples, - ) - - global_index_storage_mapping, global_index_local_index_mapping = ray.get( - self.tq_controller.get_global_index_mapping.remote() - ) - - if num_global_batch == 1 and num_n_samples == 1: - assert np.array_equal(global_index_storage_mapping, np.array([0, 0, 0, 0, 1, 1, 1, 1])) - assert np.array_equal(global_index_local_index_mapping, np.array([0, 1, 2, 3, 0, 1, 2, 3])) - # The data of a single GBS will be distributed across different storage units - elif num_global_batch == 2 and num_n_samples == 1: - assert np.array_equal( - global_index_storage_mapping, np.array([0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1]) - ) - assert np.array_equal( - global_index_local_index_mapping, np.array([0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 4, 5, 6, 7]) - ) - # When num_n_samples is larger than 1 - elif num_global_batch == 1 and num_n_samples == 2: - assert np.array_equal( - global_index_storage_mapping, np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]) - ) - assert np.array_equal( - global_index_local_index_mapping, np.array([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]) - ) - elif num_global_batch == 2 and num_n_samples == 2: - assert np.array_equal( - global_index_storage_mapping, - np.array( - [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] - ), - ) - assert np.array_equal( - global_index_local_index_mapping, - np.array( - [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - ] - ), - ) - - def test_update_production_status(self, setup_teardown_transfer_queue_controller): - tq_controller, global_batch_size, num_global_batch, num_n_samples = setup_teardown_transfer_queue_controller - - total_storage_size = global_batch_size * num_global_batch * num_n_samples - # Initialize get_data_production_status and filed_name_mapping - init_update_production_status = torch.zeros(total_storage_size, TQ_INIT_FIELD_NUM, dtype=torch.int8) - assert torch.equal(ray.get(tq_controller.get_data_production_status.remote()), init_update_production_status) - assert ray.get(tq_controller.get_field_name_mapping.remote()) == {} - - columns_list = ["test_prompts"] - global_indexes = list(range(global_batch_size * num_n_samples)) - - # update production status - tq_controller._update_production_status.remote(global_indexes, columns_list) - new_field_name_mapping = ray.get(tq_controller.get_field_name_mapping.remote()) - assert new_field_name_mapping["test_prompts"] == 0 - - new_data_production_status = ray.get(tq_controller.get_data_production_status.remote()) - assert new_data_production_status[:, 0][: len(global_indexes)].sum() == len(global_indexes) - - def test_data_consumption_status(self, setup_teardown_transfer_queue_controller): - tq_controller, global_batch_size, num_global_batch, num_n_samples = setup_teardown_transfer_queue_controller - total_storage_size = global_batch_size * num_global_batch * num_n_samples - - init_data_consumption_status = {} - assert ray.get(tq_controller.get_data_consumption_status.remote()) == init_data_consumption_status - - task_name = "test_task1" - ray.get(tq_controller._get_consumption_status.remote(task_name)) - new_data_consumption_status = ray.get(tq_controller.get_data_consumption_status.remote()) - assert torch.equal(new_data_consumption_status[task_name], torch.zeros(total_storage_size, dtype=torch.int8)) - - def test_get_prompt_metadata(self, setup_teardown_register_controller_info): - tq_controller, global_batch_size, n_samples, _ = setup_teardown_register_controller_info - - data_fields = ["test_prompts"] - global_step = 5 - - metadata = ray.get( - tq_controller._get_metadata.remote( - data_fields=data_fields, - batch_size=global_batch_size * n_samples, - global_step=global_step, - mode="insert", - ) - ) - metadata.reorder([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]) - assert metadata.global_indexes == [ - 31, - 30, - 29, - 28, - 27, - 26, - 25, - 24, - 23, - 22, - 21, - 20, - 19, - 18, - 17, - 16, - ] - assert metadata.local_indexes == [ - 15, - 14, - 13, - 12, - 11, - 10, - 9, - 8, - 15, - 14, - 13, - 12, - 11, - 10, - 9, - 8, - ] - storage_ids = metadata.storage_ids - assert len(set(storage_ids[: len(storage_ids) // 2])) == 1 - - # TODO: Test case where multiple clients concurrently read datameta from a single controller, - # and each client receives the correct response diff --git a/tests/experimental/transfer_queue/test_simple_storage_unit.py b/tests/experimental/transfer_queue/test_simple_storage_unit.py deleted file mode 100644 index 7949c9cb971..00000000000 --- a/tests/experimental/transfer_queue/test_simple_storage_unit.py +++ /dev/null @@ -1,479 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import sys -import time -import uuid -from pathlib import Path -from threading import Thread -from unittest.mock import MagicMock - -import pytest -import ray -import tensordict -import torch -import zmq -from tensordict import TensorDict - -# Import your classes here -parent_dir = Path(__file__).resolve().parent.parent -sys.path.append(str(parent_dir)) - -try: - from verl.experimental.transfer_queue.storage import TransferQueueStorageSimpleUnit - from verl.experimental.transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo -except ImportError: - # For testing purposes if imports are not available - TransferQueueStorageSimpleUnit = MagicMock() - ZMQServerInfo = MagicMock() - ZMQRequestType = MagicMock() - ZMQMessage = MagicMock() - - -# Mock ZMQ utilities if not available in test environment -def create_zmq_socket(context, socket_type, identity=None): - sock = context.socket(socket_type) - if identity: - sock.setsockopt(zmq.IDENTITY, identity) - return sock - - -# Mock Controller to handle handshake and data updates -class MockController: - def __init__(self, controller_id="controller_001"): - self.controller_id = controller_id - self.context = zmq.Context() - - # Socket for handshake - self.handshake_socket = self.context.socket(zmq.ROUTER) - self.handshake_port = self._bind_to_random_port(self.handshake_socket) - - # Socket for data status updates - self.data_update_socket = self.context.socket(zmq.ROUTER) - self.data_update_port = self._bind_to_random_port(self.data_update_socket) - - self.zmq_server_info = ZMQServerInfo.create( - role="CONTROLLER", - id=controller_id, - ip="127.0.0.1", - ports={"handshake_socket": self.handshake_port, "data_status_update_socket": self.data_update_port}, - ) - - self.running = True - self.handshake_thread = Thread(target=self._handle_handshake, daemon=True) - self.data_update_thread = Thread(target=self._handle_data_updates, daemon=True) - self.handshake_thread.start() - self.data_update_thread.start() - - def _bind_to_random_port(self, socket): - port = socket.bind_to_random_port("tcp://127.0.0.1") - return port - - def _handle_handshake(self): - poller = zmq.Poller() - poller.register(self.handshake_socket, zmq.POLLIN) - - while self.running: - try: - socks = dict(poller.poll(100)) # 100ms timeout - if self.handshake_socket in socks: - identity, msg_bytes = self.handshake_socket.recv_multipart() - ZMQMessage.deserialize(msg_bytes) - - # Send handshake ack - ack_msg = ZMQMessage.create( - request_type=ZMQRequestType.HANDSHAKE_ACK, - sender_id=self.controller_id, - body={"message": "Handshake successful"}, - ) - self.handshake_socket.send_multipart([identity, ack_msg.serialize()]) - except zmq.Again: - continue - except Exception: - if self.running: - pass - - def _handle_data_updates(self): - poller = zmq.Poller() - poller.register(self.data_update_socket, zmq.POLLIN) - - while self.running: - try: - socks = dict(poller.poll(100)) # 100ms timeout - if self.data_update_socket in socks: - identity, msg_bytes = self.data_update_socket.recv_multipart() - ZMQMessage.deserialize(msg_bytes) - - # Send data update ack - ack_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK, - sender_id=self.controller_id, - body={"message": "Data update received"}, - ) - self.data_update_socket.send_multipart([identity, ack_msg.serialize()]) - except zmq.Again: - continue - except Exception: - if self.running: - pass - - def stop(self): - self.running = False - time.sleep(0.1) # Give threads time to stop - self.handshake_socket.close() - self.data_update_socket.close() - - -# Mock client to send PUT/GET requests -class MockClient: - def __init__(self, storage_put_get_address): - self.context = zmq.Context() - self.socket = self.context.socket(zmq.DEALER) - self.socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout - self.socket.connect(storage_put_get_address) - - def send_put(self, client_id, global_indexes, local_indexes, field_data): - msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_DATA, - sender_id=f"mock_client_{client_id}", - body={"global_indexes": global_indexes, "local_indexes": local_indexes, "field_data": field_data}, - ) - self.socket.send(msg.serialize()) - return ZMQMessage.deserialize(self.socket.recv()) - - def send_get(self, client_id, local_indexes, fields): - msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_DATA, - sender_id=f"mock_client_{client_id}", - body={"local_indexes": local_indexes, "fields": fields}, - ) - self.socket.send(msg.serialize()) - return ZMQMessage.deserialize(self.socket.recv()) - - def close(self): - self.socket.close() - self.context.term() - - -@pytest.fixture(scope="session") -def ray_setup(): - ray.init(ignore_reinit_error=True) - yield - ray.shutdown() - - -@pytest.fixture -def storage_setup(ray_setup): - storage_size = 10000 - tensordict.set_list_to_stack(True).set() - - # Start mock controller - mock_controller = MockController(f"controller_{uuid.uuid4()}") - time.sleep(0.5) # Wait for controller sockets to be ready - - # Start Ray actor - storage_actor = TransferQueueStorageSimpleUnit.options(max_concurrency=50, num_cpus=1).remote(storage_size) - - # Register controller info - controller_infos = {mock_controller.controller_id: mock_controller.zmq_server_info} - ray.get(storage_actor.register_controller_info.remote(controller_infos)) - - # Get ZMQ address to connect client - zmq_info = ray.get(storage_actor.get_zmq_server_info.remote()) - put_get_address = zmq_info.to_addr("put_get_socket") - time.sleep(1) # Wait for socket to be ready - - yield storage_actor, put_get_address, mock_controller - - # Cleanup - mock_controller.stop() - - -def test_put_get_single_client(storage_setup): - """Test basic put and get operations with a single client using TensorDict and torch tensors.""" - _, put_get_address, _ = storage_setup - - client = MockClient(put_get_address) - - # PUT data - global_indexes = [0, 1, 2] - local_indexes = [0, 1, 2] - field_data = TensorDict( - { - "log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])], - "rewards": [torch.tensor([10.0]), torch.tensor([20.0]), torch.tensor([30.0])], - }, - batch_size=[], - ) - - response = client.send_put(0, global_indexes, local_indexes, field_data) - assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE - - # GET data - response = client.send_get(0, [0, 1], ["log_probs", "rewards"]) - assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE - - retrieved_data = response.body["data"] - assert "log_probs" in retrieved_data - assert "rewards" in retrieved_data - assert retrieved_data["log_probs"].size(0) == 2 - assert retrieved_data["rewards"].size(0) == 2 - - # Verify data correctness - torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([1.0, 2.0, 3.0])) - torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([4.0, 5.0, 6.0])) - torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([10.0])) - torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([20.0])) - - client.close() - - -def test_put_get_multiple_clients(storage_setup): - """Test put and get operations with multiple clients including overlapping local indexes""" - _, put_get_address, _ = storage_setup - - num_clients = 5 - clients = [MockClient(put_get_address) for _ in range(num_clients)] - - # Each client puts unique data using different local_indexes - for i, client in enumerate(clients): - global_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2] - local_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2] - field_data = TensorDict( - { - "log_probs": [ - torch.tensor([i, i + 1, i + 2]), - torch.tensor([i + 3, i + 4, i + 5]), - torch.tensor([i + 6, i + 7, i + 8]), - ], - "rewards": [torch.tensor([i * 10]), torch.tensor([i * 10 + 10]), torch.tensor([i * 10 + 20])], - } - ) - - response = client.send_put(i, global_indexes, local_indexes, field_data) - assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE - - # Now simulate a third client that writes to overlapping local_indexes (e.g., index 0) - overlapping_client = MockClient(put_get_address) - overlap_local_indexes = [0] # Overlaps with first client's index 0 - overlap_field_data = TensorDict({"log_probs": [torch.tensor([999, 999, 999])], "rewards": [torch.tensor([999])]}) - response = overlapping_client.send_put( - client_id=99, global_indexes=[0], local_indexes=overlap_local_indexes, field_data=overlap_field_data - ) - assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE - - # Each original client gets its own data (except for index 0 which was overwritten) - for i, client in enumerate(clients): - response = client.send_get(i, [i * 10 + 0, i * 10 + 1], ["log_probs", "rewards"]) - assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE - - retrieved_data = response.body["data"] - assert retrieved_data["log_probs"].size(0) == 2 - assert retrieved_data["rewards"].size(0) == 2 - - # For index 0, expect data from overlapping_client; others from original client - if i == 0: - # Index 0 was overwritten - torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([999, 999, 999])) - torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([999])) - # Index 1 remains original - torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([3, 4, 5])) - torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([10])) - else: - # All data remains original - torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([i, i + 1, i + 2])) - torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([i + 3, i + 4, i + 5])) - torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([i * 10])) - torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([i * 10 + 10])) - - # Cleanup - for client in clients: - client.close() - overlapping_client.close() - - -def test_performance_basic(storage_setup): - """Basic performance test with larger data volume and proper index handling""" - _, put_get_address, _ = storage_setup - - client = MockClient(put_get_address) - - # PUT performance test - put_latencies = [] - num_puts = 50 - batch_size = 128 - - for i in range(num_puts): - start = time.time() - - # Use larger batch size and more complex index mapping - global_indexes = list(range(i * batch_size, (i + 1) * batch_size)) - local_indexes = list(range(i * batch_size, (i + 1) * batch_size)) - - # Create larger tensor data to increase data volume - log_probs_data = [] - rewards_data = [] - - for j in range(batch_size): - # Each sample contains larger tensors to increase data transfer volume - log_probs_tensor = torch.randn(32768) - rewards_tensor = torch.randn(32768) - log_probs_data.append(log_probs_tensor) - rewards_data.append(rewards_tensor) - - field_data = TensorDict({"log_probs": log_probs_data, "rewards": rewards_data}, batch_size=[batch_size]) - - response = client.send_put(0, global_indexes, local_indexes, field_data) - latency = time.time() - start - put_latencies.append(latency) - assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE - - # GET performance test - get_latencies = [] - num_gets = 50 - - for i in range(num_gets): - start = time.time() - # Retrieve larger batch of data - indices = list(range(i * batch_size, (i + 1) * batch_size)) # Retrieve batch_size indices of data each time - response = client.send_get(0, indices, ["log_probs", "rewards"]) - latency = time.time() - start - get_latencies.append(latency) - assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE - - avg_put_latency = sum(put_latencies) / len(put_latencies) * 1000 # ms - avg_get_latency = sum(get_latencies) / len(get_latencies) * 1000 # ms - - # Adjust performance thresholds to accommodate larger data volume - assert avg_put_latency < 5000, f"Avg PUT latency {avg_put_latency}ms exceeds threshold" - assert avg_get_latency < 5000, f"Avg GET latency {avg_get_latency}ms exceeds threshold" - - client.close() - - -def test_put_get_nested_tensor_single_client(storage_setup): - """Test basic put and get operations with a single client using TensorDict and nested tensors.""" - _, put_get_address, _ = storage_setup - - client = MockClient(put_get_address) - - # PUT data - global_indexes = [0, 1, 2] - local_indexes = [0, 1, 2] - - field_data = TensorDict( - { - "variable_length_sequences": [ - torch.tensor([-0.5, -1.2, -0.8]), - torch.tensor([-0.3, -1.5, -2.1, -0.9]), - torch.tensor([-1.1, -0.7]), - ], - "attention_mask": [torch.tensor([1, 1, 1]), torch.tensor([1, 1, 1, 1]), torch.tensor([1, 1])], - }, - batch_size=[], - ) - - response = client.send_put(0, global_indexes, local_indexes, field_data) - assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE - - # GET data - response = client.send_get(0, [0, 2], ["variable_length_sequences", "attention_mask"]) - assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE - - retrieved_data = response.body["data"] - assert "variable_length_sequences" in retrieved_data - assert "attention_mask" in retrieved_data - assert retrieved_data["variable_length_sequences"].size(0) == 2 - assert retrieved_data["attention_mask"].size(0) == 2 - - # Verify data correctness - torch.testing.assert_close(retrieved_data["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8])) - torch.testing.assert_close(retrieved_data["variable_length_sequences"][1], torch.tensor([-1.1, -0.7])) - torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1])) - torch.testing.assert_close(retrieved_data["attention_mask"][1], torch.tensor([1, 1])) - - client.close() - - -def test_put_get_nested_nontensor_single_client(storage_setup): - """Test basic put and get operations with a single client using non-tensor data (strings).""" - _, put_get_address, _ = storage_setup - - client = MockClient(put_get_address) - - # PUT data - global_indexes = [0, 1, 2] - local_indexes = [0, 1, 2] - field_data = TensorDict( - { - "prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"], - "response_text": ["Hi there!", "This is the response to the longer sentence", "Test response"], - }, - batch_size=[], - ) - - response = client.send_put(0, global_indexes, local_indexes, field_data) - assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE - - # GET data - response = client.send_get(0, [0, 1, 2], ["prompt_text", "response_text"]) - assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE - - retrieved_data = response.body["data"] - assert "prompt_text" in retrieved_data - assert "response_text" in retrieved_data - - # Verify data correctness - assert isinstance(retrieved_data["prompt_text"][0], str) - assert isinstance(retrieved_data["response_text"][0], str) - - assert retrieved_data["prompt_text"][0] == "Hello world!" - assert retrieved_data["prompt_text"][1] == "This is a longer sentence for testing" - assert retrieved_data["prompt_text"][2] == "Test case" - assert retrieved_data["response_text"][0] == "Hi there!" - assert retrieved_data["response_text"][1] == "This is the response to the longer sentence" - assert retrieved_data["response_text"][2] == "Test response" - - client.close() - - -def test_put_get_single_item_single_client(storage_setup): - """Test put and get operations for a single item with a single client.""" - _, put_get_address, _ = storage_setup - - client = MockClient(put_get_address) - - # PUT data - field_data = TensorDict( - { - "prompt_text": ["Hello world!"], - "attention_mask": [torch.tensor([1, 1, 1])], - }, - batch_size=[], - ) - - response = client.send_put(0, [0], [0], field_data) - assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE - - # GET data - response = client.send_get(0, [0], ["prompt_text", "attention_mask"]) - assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE - - retrieved_data = response.body["data"] - assert "prompt_text" in retrieved_data - assert "attention_mask" in retrieved_data - - assert retrieved_data["prompt_text"][0] == "Hello world!" - assert retrieved_data["attention_mask"].shape == (1, 3) - torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1])) diff --git a/verl/experimental/transfer_queue/__init__.py b/verl/experimental/transfer_queue/__init__.py deleted file mode 100644 index 2df3b7f876f..00000000000 --- a/verl/experimental/transfer_queue/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/experimental/transfer_queue/client.py b/verl/experimental/transfer_queue/client.py deleted file mode 100644 index 8005558b0b1..00000000000 --- a/verl/experimental/transfer_queue/client.py +++ /dev/null @@ -1,662 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import logging -import os -from functools import wraps -from typing import Any, Callable, Optional, Union -from uuid import uuid4 - -import ray -import torch -import zmq -import zmq.asyncio -from tensordict import NonTensorStack, TensorDict - -from verl.experimental.transfer_queue.controller import TransferQueueController -from verl.experimental.transfer_queue.metadata import ( - BatchMeta, - StorageMetaGroup, -) -from verl.experimental.transfer_queue.storage import TransferQueueStorageSimpleUnit -from verl.experimental.transfer_queue.utils.utils import ( - TransferQueueRole, -) -from verl.experimental.transfer_queue.utils.zmq_utils import ( - ZMQMessage, - ZMQRequestType, - ZMQServerInfo, - create_zmq_socket, -) - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class AsyncTransferQueueClient: - def __init__( - self, - client_id: str, - controller_infos: ZMQServerInfo | dict[Any, ZMQServerInfo], - storage_infos: ZMQServerInfo | dict[Any, ZMQServerInfo], - ): - self.client_id = client_id - - self._controllers: dict[str, ZMQServerInfo] = {} - self._storages: dict[str, ZMQServerInfo] = {} - self._register_servers(TransferQueueRole.CONTROLLER, controller_infos) - self._register_servers(TransferQueueRole.STORAGE, storage_infos) - - def _register_servers( - self, - role: TransferQueueRole, - server_infos: ZMQServerInfo | dict[Any, ZMQServerInfo], - ): - mapping = self._controllers if role == TransferQueueRole.CONTROLLER else self._storages - - if not isinstance(server_infos, dict): - server_infos = {server_infos.id: server_infos} - - for info in server_infos.values(): - if not isinstance(info, ZMQServerInfo): - raise ValueError(f"Invalid server info for {role} {info.id}") - - if info.id not in mapping: - mapping[info.id] = info - logger.info(f"[{self.client_id}]: Registered {role} server {info.id} at {info.ip}") - else: - logger.warning(f"[{self.client_id}]: Server {info.id} already registered, skipping") - - @staticmethod - def dynamic_socket(target_role: TransferQueueRole, socket_name: str): - """Decorator to auto-manage ZMQ sockets for Controller/Storage servers (create -> connect -> inject -> close). - - Args: - target_role (TransferQueueRole): Server type to connect to. Must be one of: - - `TransferQueueRole.CONTROLLER` - - `TransferQueueRole.STORAGE` - socket_name (str): Port name (from server config) to use for ZMQ connection (e.g., "data_req_port"). - - Decorated Function Rules: - 1. Must be an async class method (needs `self`). - 2. `self` requires: - - `_controllers`/`_storages`: Server registries (match `target_role`). - - `client_id`: Unique client ID (for socket identity). - 3. Specify target server via: - - `target_controller` (for Controller) or `target_storage` (for Storage) arg. - - Controller role: Uses first registered server if no ID is given. - 4. Receives ZMQ socket via `socket` keyword arg (injected by decorator). - """ - - def decorator(func: Callable): - @wraps(func) - async def wrapper(self, *args, **kwargs): - if target_role == TransferQueueRole.CONTROLLER: - servers = self._controllers - target = "target_controller" - elif target_role == TransferQueueRole.STORAGE: - servers = self._storages - target = "target_storage" - else: - raise ValueError("Invalid target_role, must be CONTROLLER or STORAGE") - - server_key = kwargs.get(target) - if server_key is None: - for arg in args: - if isinstance(arg, str) and arg in servers.keys(): - server_key = arg - break - if server_key is None and target == "target_controller": - server_key = next(iter(servers.keys())) - - server_info = servers.get(server_key) - if not server_info: - raise RuntimeError(f"Server {server_key} not found in registered {target_role} servers") - - context = zmq.asyncio.Context() - address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}" - identity = f"{self.client_id}_to_{server_info.id}_{uuid4()}".encode() - sock = create_zmq_socket(context, zmq.DEALER, identity=identity) - - try: - sock.connect(address) - logger.info( - f"[{self.client_id}]: Connected to {target_role} {server_info.id} at {address} " - f"with identity {identity.decode()}" - ) - - kwargs["socket"] = sock - return await func(self, *args, **kwargs) - except Exception as e: - logger.error( - f"[{self.client_id}]: Error in socket operation with {target_role} {server_info.id}: {e}" - ) - raise - finally: - try: - if not sock.closed: - sock.setsockopt(zmq.LINGER, -1) - sock.close() - sock.close(linger=0) - except Exception as e: - logger.warning( - f"[{self.client_id}]: Error closing socket to {target_role} {server_info.id}: {e}" - ) - - context.term() - - return wrapper - - return decorator - - @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket") - async def async_get_meta( - self, - data_fields: list[str], - batch_size: int, - global_step: int, - mode: str = "fetch", - get_n_samples: bool = False, - task_name: Optional[str] = None, - target_controller: Optional[str] = None, - socket: Optional[zmq.asyncio.Socket] = None, - ) -> BatchMeta: - """Asynchronously fetches data metadata via ZMQ from the target controller. - - Args: - data_fields (list[str]): List of fields to retrieve metadata for - batch_size (int): Processing batch size - global_step (int): Current training/processing step - mode (str): Data fetch mode. 'fetch' to get ready data, 'force_fetch' to get data regardless of readiness. - 'insert' IS AN INTERNAL USAGE THAT SHOULD NOT BE USED BY USERS. - get_n_samples (bool): If True, we arrange the samples of the same prompt in contiguous order. In 'fetch' - mode, only the samples of the same prompt that are all ready will be returned. - task_name (str): Optional task name associated with the request - target_controller (str): ID of the target controller to send the request to - socket (zmq.asyncio.Socket): ZMQ async socket for message transmission - - Example: - >>> batch_size = 4 - >>> current_step = 0 - >>> # Example 1: "fetch" a batch of metadata that has been produced - >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["input_ids", "attention_mask"], - >>> batch_size=batch_size, - >>> global_step=current_step, - >>> mode="fetch", - >>> get_n_samples=False, - >>> task_name="generate_sequences", - >>> )) - >>> print(batch_meta.is_ready) # you should get a batch_meta with is_ready=True - >>> print([sample_meta.is_ready for sample_meta in batch_meta.samples]) # [True, True, True, True] - >>> - >>> # Example 2: "force_fetch" a batch of metadata, ignoring their production status (but we still make - >>> # sure the corresponding data has not been consumed) - >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["input_ids", "attention_mask"], - >>> batch_size=batch_size, - >>> global_step=current_step, - >>> mode="force_fetch", - >>> get_n_samples=False, - >>> task_name="generate_sequences", - >>> )) - >>> print(batch_meta.is_ready) # you may get a batch_meta with is_ready=False - >>> print([sample_meta.is_ready for sample_meta in batch_meta.samples]) # [True, False, False, True] - - Returns: - BatchMeta: Metadata object containing data structure, sample info, etc. - """ - assert socket is not None - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_META, - sender_id=self.client_id, - receiver_id=target_controller, - body={ - "data_fields": data_fields, - "batch_size": batch_size, - "global_step": global_step, - "mode": mode, - "get_n_samples": get_n_samples, - "task_name": task_name, - }, - ) - - try: - await socket.send(request_msg.serialize()) - response = await socket.recv() - response_msg = ZMQMessage.deserialize(response) - logger.debug( - f"[{self.client_id}]: Client get datameta response: {response_msg} from controller {target_controller}" - ) - - if response_msg.request_type == ZMQRequestType.GET_META_RESPONSE: - metadata = response_msg.body["metadata"] - return metadata - else: - raise RuntimeError( - f"[{self.client_id}]: Failed to get metadata from controller {target_controller}: " - f"{response_msg.body.get('message', 'Unknown error')}" - ) - except Exception as e: - raise RuntimeError(f"[{self.client_id}]: Error in get_meta: {str(e)}") from e - - async def async_put( - self, - data: TensorDict, - metadata: Optional[BatchMeta] = None, - global_step: Optional[int] = None, - ): - """Asynchronously writes data to appropriate Storage Units based on metadata. - - If metadata isn't provided, it will be created automatically using the insert mode - with the provided data_columns and global_step. - - Args: - data (torch.Tensor | tensordict.TensorDict): Data to write, either a Tensor or TensorDict - metadata (BatchMeta, optional): Optional metadata containing index and storage unit information - global_step (int, optional): Current step (required if no metadata is provided) - - Example: - >>> batch_size = 4 - >>> seq_len = 16 - >>> current_step = 0 - >>> # Example 1: normal usage - >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["prompts", "attention_mask"], - >>> batch_size=batch_size, - >>> global_step=current_step, - >>> mode="fetch", - >>> get_n_samples=False, - >>> task_name="generate_sequences", - >>> )) - >>> batch = asyncio.run(client.async_get_data(batch_meta)) - >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) - >>> asyncio.run(client.async_put(data=output, metadata=batch_meta)) - >>> - >>> # Example 2: put the initial data into the system without pre-existing metadata - >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given global_step! - >>> # Please make sure the corresponding global_step is empty before calling the async_put() - >>> # without metadata. - >>> # Now we only support put all the data of the corresponding global step in once. You should repeat with - >>> # interleave the initial data if n_sample > 1 before calling the async_put(). - >>> original_prompts = torch.randn(batch_size, seq_len) - >>> n_samples = 4 - >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) - >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) - >>> # This will create metadata in "insert" mode internally. - >>> asyncio.run(client.async_put(data=prompts_repeated_batch, global_step=current_step)) - - """ - if metadata is None: - assert global_step is not None, "global_steps must be provided if metadata is not given" - - metadata = await self.async_get_meta( - data_fields=list(data.keys()), - batch_size=data.batch_size[0], - global_step=global_step, - get_n_samples=True, - mode="insert", - ) - - if not metadata or metadata.size == 0: - raise ValueError("metadata cannot be none or empty") - logger.debug(f"[{self.client_id}]: Put data with data: {data}") - tasks = [ - self._put_to_storage(get_transfer_info(meta_group, data), target_storage=storage_id) - for storage_id, meta_group in metadata.storage_meta_groups.items() - ] - await asyncio.gather(*tasks) - - logger.info( - f"[{self.client_id}]: step {global_step} put {metadata.size} samples to storage units successfully." - ) - - @dynamic_socket(target_role=TransferQueueRole.STORAGE, socket_name="put_get_socket") - async def _put_to_storage(self, storage_unit_data, target_storage=None, socket=None): - """ - Send data to a specific storage unit. - """ - global_indexes = storage_unit_data["global_indexes"] - local_indexes = storage_unit_data["local_indexes"] - field_data = TensorDict( - { - field: ( - torch.nested.as_nested_tensor(storage_unit_data["field_data"][field]) - if storage_unit_data["field_data"][field] - and all(isinstance(x, torch.Tensor) for x in storage_unit_data["field_data"][field]) - else NonTensorStack(*storage_unit_data["field_data"][field]) - ) - for field in storage_unit_data["field_data"] - } - ) - - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_DATA, - sender_id=self.client_id, - receiver_id=target_storage, - body={"global_indexes": global_indexes, "local_indexes": local_indexes, "field_data": field_data}, - ) - try: - await socket.send(request_msg.serialize()) - serialized = await socket.recv() - response_msg = ZMQMessage.deserialize(serialized) - - if response_msg.request_type != ZMQRequestType.PUT_DATA_RESPONSE: - raise RuntimeError( - f"Failed to put data to storage unit {target_storage}: " - f"{response_msg.body.get('message', 'Unknown error')}" - ) - except Exception as e: - raise RuntimeError(f"Error in put to storage unit {target_storage}: {str(e)}") from e - - @dynamic_socket(target_role=TransferQueueRole.STORAGE, socket_name="put_get_socket") - async def _get_from_storage(self, index_data, target_storage=None, socket=None): - global_indexes = index_data["global_indexes"] - local_indexes = index_data["local_indexes"] - fields = index_data["fields"] - - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_DATA, - sender_id=self.client_id, - receiver_id=target_storage, - body={"local_indexes": local_indexes, "fields": fields}, - ) - - try: - await socket.send(request_msg.serialize()) - serialized = await socket.recv() - response_msg = ZMQMessage.deserialize(serialized) - logger.info(f"[{self.client_id}]: get data response from storage unit {target_storage}: {response_msg}") - - if response_msg.request_type == ZMQRequestType.GET_DATA_RESPONSE: - # Return data and index information from this storage unit - storage_unit_data = response_msg.body["data"] - return global_indexes, fields, storage_unit_data - else: - raise RuntimeError( - f"Failed to get data from storage unit {target_storage}: " - f"{response_msg.body.get('message', 'Unknown error')}" - ) - except Exception as e: - raise RuntimeError(f"Error getting data from storage unit {target_storage}: {str(e)}") from e - - async def async_get_data(self, metadata: BatchMeta) -> TensorDict: - """Asynchronously fetches data via Storage Units and organizes it into a TensorDict. - - Args: - metadata (BatchMeta): Object containing: - - Data location info (which Storage Units hold the data) - - `global_indexes` to determine the ordering of merged results - - Returns: - tensordict.TensorDict with: - - Requested data fields (e.g., "prompt_token_ids", "response_token_ids"). - - "global_indexes" key: Maps each sample to its original global index. - - Example: - >>> batch_size = 4 - >>> seq_len = 16 - >>> current_step = 0 - >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["prompts", "attention_mask"], - >>> batch_size=batch_size, - >>> global_step=current_step, - >>> mode="fetch", - >>> get_n_samples=False, - >>> task_name="generate_sequences", - >>> )) - >>> batch = asyncio.run(client.async_get_data(batch_meta)) - >>> print(batch) - >>> # this is a TensorDict with fields "prompts" and "attention_mask". - >>> # The order of samples in the TensorDict matches the order of global_indexes in batch_meta - - Note: - Why track `global_indexes`? - - Batches may be rearranged during task processing. `global_indexes` retains the original - mapping to Storage Units, enabling correct data writing back to Storage Units later. - - """ - if not metadata or metadata.size == 0: - return TensorDict({}, batch_size=0) - - # Use optimized retrieval with direct storage group access - tasks = [ - self._get_from_storage(meta_group.get_transfer_info(), target_storage=storage_id) - for storage_id, meta_group in metadata.storage_meta_groups.items() - ] - - results = await asyncio.gather(*tasks) - - # global_index: {field1: value, field2: value, ...} - storage_data: dict[int, dict[str, torch.Tensor]] = {} - for global_indexes, fields, storage_unit_data in results: - for idx, global_idx in enumerate(global_indexes): - if global_idx not in storage_data: - storage_data[global_idx] = {} - for field in fields: - storage_data[global_idx][field] = storage_unit_data[field][idx] - - ordered_data: dict[str, torch.Tensor] = {field: [] for field in metadata.field_names} - for global_idx in metadata.global_indexes: - for field in metadata.field_names: - ordered_data[field].append(storage_data[global_idx][field]) - - tensor_data = { - field: ( - torch.stack(torch.nested.as_nested_tensor(v).unbind()) - if v - and all(isinstance(item, torch.Tensor) for item in v) - and all(item.shape == v[0].shape for item in v) - else ( - torch.nested.as_nested_tensor(v) - if v and all(isinstance(item, torch.Tensor) for item in v) - else NonTensorStack(*v) - ) - ) - for field, v in ordered_data.items() - } - tensor_data["global_indexes"] = torch.tensor(metadata.global_indexes) - - return TensorDict(tensor_data, batch_size=len(storage_data)) - - async def async_clear(self, global_step: int): - """Asynchronously clears data from all storage units and controller metadata. - - Args: - global_step (int): The training step associated with the clear operation - - """ - try: - target_controller = next(iter(self._controllers.keys())) - metadata = await self._get_clear_meta(global_step, target_controller) - - tasks = [] - - for target_controller in self._controllers.keys(): - tasks.append(self._clear_controller(global_step, target_controller)) - - # Group samples by storage unit for clearing - for target_storage, group in metadata.storage_meta_groups.items(): - group_info = group.get_transfer_info() - if target_storage not in self._storages: - logger.warning( - f"[{self.client_id}]: Storage unit {target_storage} not registered, skipping clear operation." - ) - continue - tasks.append( - self._clear_storage_unit( - group_info["local_indexes"], - target_storage, - ) - ) - - results = await asyncio.gather(*tasks, return_exceptions=True) - - for i, result in enumerate(results): - if isinstance(result, Exception): - logger.error(f"[{self.client_id}]: Error in clear operation task {i}: {result}") - - logger.info(f"[{self.client_id}]: Clear operation for global_step {global_step} completed.") - except Exception as e: - raise RuntimeError(f"Error in clear operation: {str(e)}") from e - - @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket") - async def _get_clear_meta(self, global_step: int, target_controller=None, socket=None): - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_CLEAR_META, - sender_id=self.client_id, - receiver_id=target_controller, - body={"global_step": global_step}, - ) - - await socket.send(request_msg.serialize()) - serialized = await socket.recv() - response_msg = ZMQMessage.deserialize(serialized) - - if response_msg.request_type != ZMQRequestType.GET_CLEAR_META_RESPONSE: - raise RuntimeError( - f"Failed to get metadata for clear operation: {response_msg.body.get('message', 'Unknown error')}" - ) - - return response_msg.body["metadata"] - - @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket") - async def _clear_controller(self, global_step, target_controller=None, socket=None): - try: - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_META, - sender_id=self.client_id, - receiver_id=target_controller, - body={"global_step": global_step}, - ) - - await socket.send(request_msg.serialize()) - serialized_msg = await socket.recv() - response_msg = ZMQMessage.deserialize(serialized_msg) - - if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE: - raise RuntimeError( - f"Failed to clear controller {target_controller}: " - f"{response_msg.body.get('message', 'Unknown error')}" - ) - - logger.info( - f"[{self.client_id}]: Successfully clear controller {target_controller} for global_step {global_step}" - ) - except Exception as e: - logger.error(f"[{self.client_id}]: Error clearing controller {target_controller}: {str(e)}") - raise - - @dynamic_socket(target_role=TransferQueueRole.STORAGE, socket_name="put_get_socket") - async def _clear_storage_unit(self, local_indexes, target_storage=None, socket=None): - try: - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_DATA, - sender_id=self.client_id, - receiver_id=target_storage, - body={"local_indexes": local_indexes}, - ) - - await socket.send(request_msg.serialize()) - serialized_msg = await socket.recv() - response_msg = ZMQMessage.deserialize(serialized_msg) - - if response_msg.request_type != ZMQRequestType.CLEAR_DATA_RESPONSE: - raise RuntimeError( - f"Failed to clear storage {target_storage}: {response_msg.body.get('message', 'Unknown error')}" - ) - - logger.info(f"[{self.client_id}]: Successfully clear storage unit {target_storage}") - except Exception as e: - logger.error(f"[{self.client_id}]: Error clearing storage unit {target_storage}: {str(e)}") - raise - - @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket") - def check_current_step_consumption(self, task_name: str, global_step: int): - # TODO: Implement this method to check if all samples for the current step has been consumed - pass - - @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket") - def check_current_step_production(self, data_fields: list[str], global_step: int): - # TODO: Implement this method to check if all samples for the current step is ready for consumption - pass - - -class TransferQueueClient(AsyncTransferQueueClient): - def __init__( - self, - client_id: str, - controller_infos: ZMQServerInfo | dict[Any, ZMQServerInfo], - storage_infos: ZMQServerInfo | dict[Any, ZMQServerInfo], - ): - super().__init__( - client_id, - controller_infos, - storage_infos, - ) - - def put(self, data: TensorDict, metadata: Optional[BatchMeta] = None, global_step: Optional[int] = None): - return asyncio.run(self.async_put(data, metadata, global_step)) - - def get_meta( - self, - data_fields: list[str], - batch_size: int, - global_step: int, - get_n_samples: bool = False, - task_name: Optional[str] = None, - ) -> BatchMeta: - return asyncio.run( - self.async_get_meta( - data_fields=data_fields, - batch_size=batch_size, - global_step=global_step, - get_n_samples=get_n_samples, - task_name=task_name, - ) - ) - - def get_data(self, metadata: BatchMeta) -> TensorDict: - return asyncio.run(self.async_get_data(metadata)) - - def clear(self, global_step: int): - return asyncio.run(self.async_clear(global_step)) - - -def _add_field_data( - transfer_dict: dict[str, Any], storage_meta_group: StorageMetaGroup, data: TensorDict -) -> dict[str, Any]: - """Helper function to add field data to the transfer dictionary""" - field_names = transfer_dict["fields"] - for fname in field_names: - if fname in data.keys(): - transfer_dict["field_data"][fname] = [] - for sample_meta in storage_meta_group.sample_metas: - transfer_dict["field_data"][fname].append(data[fname][sample_meta.batch_index]) - return transfer_dict - - -def get_transfer_info( - storage_meta_group: StorageMetaGroup, - data: TensorDict, -) -> dict[str, Any]: - """Convert to dictionary format with field data for put operations""" - result = storage_meta_group.get_transfer_info(field_names=data.keys()) - result = _add_field_data(result, storage_meta_group, data) - return result - - -def process_zmq_server_info(handlers: dict[Any, Union[TransferQueueController, TransferQueueStorageSimpleUnit]]): # noqa: UP007 - server_info = {} - for name, handler in handlers.items(): - server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[attr-defined] - return server_info diff --git a/verl/experimental/transfer_queue/controller.py b/verl/experimental/transfer_queue/controller.py deleted file mode 100644 index 08ab6cfe9f4..00000000000 --- a/verl/experimental/transfer_queue/controller.py +++ /dev/null @@ -1,771 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import math -import os -import threading -import time -from threading import Thread -from typing import Any, Optional -from uuid import uuid4 - -import numpy as np -import ray -import torch -import zmq -from ray.util import get_node_ip_address - -from verl.experimental.transfer_queue.metadata import ( - BatchMeta, - FieldMeta, - SampleMeta, -) -from verl.experimental.transfer_queue.utils.utils import ( - ProductionStatus, - TransferQueueRole, - random_sampler, -) -from verl.experimental.transfer_queue.utils.zmq_utils import ( - ZMQMessage, - ZMQRequestType, - ZMQServerInfo, - create_zmq_socket, - get_free_port, -) - -logger = logging.getLogger(__name__) -logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) - -TQ_CONTROLLER_GET_METADATA_TIMEOUT = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_TIMEOUT", 300)) -TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL", 1)) -TQ_INIT_FIELD_NUM = int(os.environ.get("TQ_INIT_FIELD_NUM", 10)) - - -@ray.remote(num_cpus=1) -class TransferQueueController: - def __init__( - self, - num_storage_units: int, - global_batch_size: int, - num_global_batch: int = 1, - num_n_samples: int = 1, - ) -> None: - """Initialize the TransferQueueController. - - Args: - num_storage_units: Number of storage units in the system - global_batch_size: Size of each global batch - num_global_batch: Number of global batches to maintain in storage - num_n_samples: For each prompt, sample n responses - """ - self.controller_id = f"TQ_CONTROLLER_{uuid4()}" - - self._init_zmq_socket() # Initialize ZMQ sockets for data communication - - self.num_storage_units = num_storage_units - self.global_batch_size = ( - global_batch_size # Used as offset for global index to identify corresponding global step - ) - self.num_global_batch = num_global_batch - self.num_n_samples = num_n_samples - self.total_storage_size = self.global_batch_size * self.num_global_batch * self.num_n_samples - - self.data_production_status = torch.zeros( - self.total_storage_size, TQ_INIT_FIELD_NUM, dtype=torch.int8 - ) # Initialize with default number of fields, dynamically extensible - # task_name -> consumption_status mapping - self.data_consumption_status: dict[str, torch.Tensor] = {} - self.field_name_mapping: dict[ - str, int - ] = {} # Mapping table from field_name to the column indices in self.data_production_status tables - # Per-field dtype and shape storage: {global_index: {field_name: {'dtype': dtype, 'shape': shape}}} - self.per_tensor_dtype_mapping: dict[int, dict[str, Any]] = {} - self.per_tensor_shape_mapping: dict[int, dict[str, Any]] = {} - - self._build_index_storage_mapping() - - self._start_process_handshake() - self._start_process_update_data_status() - self._start_process_request() - - def _get_consumption_status(self, task_name: str) -> torch.Tensor: - """ - Get or create the consumption status tensor for a specific task. - The consumption status is a binary, 1D tensor that records whether the corresponding sample has been consumed - by the task. - - Args: - task_name: Name of the consumer task - - Returns: - Consumption status tensor for the specified task - """ - # Retrieve or create the consumption state tensor for a specified consumer - if task_name not in self.data_consumption_status: - # Initialize state for a new consumer - self.data_consumption_status[task_name] = torch.zeros(self.total_storage_size, dtype=torch.int8) - return self.data_consumption_status[task_name] - - def _get_per_field_dtype(self, global_index: int, field_name: str) -> Optional[torch.dtype]: - """Get dtype for a specific sample and field. - - Args: - global_index: Global index of the sample - field_name: Name of the field - - Returns: - dtype of the specified field for the sample, or None if not found - """ - return self.per_tensor_dtype_mapping.get(global_index, {}).get(field_name) - - def _get_per_field_shape(self, global_index: int, field_name: str) -> Optional[torch.Size]: - """Get shape for a specific sample and field. - - Args: - global_index: Global index of the sample - field_name: Name of the field - - Returns: - Shape of the specified field for the sample, or None if not found - """ - return self.per_tensor_shape_mapping.get(global_index, {}).get(field_name) - - def _step_to_global_index_range(self, global_step: int) -> tuple[int, int]: - """Convert global step to corresponding global index range. - - Args: - global_step: The global step to convert - - Returns: - Tuple of (start_index, end_index) for the given global step - """ - start_idx = (global_step % self.num_global_batch) * self.global_batch_size * self.num_n_samples - end_idx = start_idx + self.global_batch_size * self.num_n_samples - - return start_idx, end_idx - - def generate_data_status_mask( - self, data_fields: list[str], global_step: int, task_name: str - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Generate mask matrix for filtering data based on field availability and consumption status. - - This function is called within _get_meta and generates a mask matrix based on - user-specified fields and the current step. The mask matrix selects the required - rows and columns from self.data_production_status while inversely selecting from - self.data_consumption_status to support automated vectorization. - - Args: - data_fields: List of field names to include in the mask - global_step: Current global step for row selection - task_name: Name of the consumer task for consumption status - - Returns: - Tuple of (row_mask, col_mask) tensors for filtering data status matrices - """ - - # Check if all requested fields are registered - for col in data_fields: - if col not in self.field_name_mapping: - # Return empty mask indicating no available data for unregistered columns - empty_row_mask = torch.zeros(self.data_production_status.shape[0], dtype=torch.bool) - empty_col_mask = torch.zeros(self.data_production_status.shape[1], dtype=torch.bool) - return empty_row_mask, empty_col_mask - - # Map steps to global indices - start_idx, end_idx = self._step_to_global_index_range(global_step) - row_mask = torch.zeros(self.data_production_status.shape[0], dtype=torch.bool) - row_mask[start_idx:end_idx] = True - - # Invert selection based on consumption status - consumer_status = self._get_consumption_status(task_name) - unconsumed_mask = consumer_status == 0 - row_mask &= unconsumed_mask - - # Select the specified fields - col_mask = torch.zeros(self.data_production_status.shape[1], dtype=torch.bool) - valid_fields = [self.field_name_mapping[col] for col in data_fields] - if valid_fields: - col_mask[valid_fields] = True - - return row_mask, col_mask - - def _build_index_storage_mapping(self): - """ - Build mappings between global indices and storage locations. - - Distributes samples across storage units based on total storage space and - maintains mappings between global index and local index within each storage. - """ - # Assign each sample to a storage node. Here we scatter the samples in each GBS to different storage nodes - # Samples are arranged sequentially, similar to generate_data_status_mask - real_global_batch_size = self.global_batch_size * self.num_n_samples - global_batch_per_storage_unit = math.ceil(real_global_batch_size / self.num_storage_units) - - # Build mapping between global index and storage unit for locating each data sample - batch_storage_indices = np.repeat(np.arange(self.num_storage_units), global_batch_per_storage_unit)[ - :real_global_batch_size - ] - self._global_index_storage_rank_mapping = np.tile(batch_storage_indices, self.num_global_batch) - - # Build mapping between global index and local index within each storage unit - indices = np.arange(self.total_storage_size) - pos_in_batch = indices % real_global_batch_size - g = indices // real_global_batch_size - pos_in_block = pos_in_batch % global_batch_per_storage_unit - self.global_index_local_index_mapping = g * global_batch_per_storage_unit + pos_in_block - - def get_data_production_status(self) -> torch.Tensor: - """ - Get the current data production status matrix. The data production status is a 2D matrix that records whether - the corresponding data is ready for each field of each sample. - - Returns: - Tensor representing production status of all data fields - """ - return self.data_production_status - - def get_field_name_mapping(self) -> dict[str, Any]: - """Get the field name to column index mapping. - - Returns: - Dictionary mapping field names to their column indices - """ - return self.field_name_mapping - - def get_data_consumption_status(self) -> dict[str, torch.Tensor]: - """Get consumption status for all tasks. - - Returns: - Dictionary mapping task names to their consumption status tensors - """ - return self.data_consumption_status - - def get_global_index_mapping(self): - """Get global index to storage mapping information. - - Returns: - Tuple containing storage rank mapping and local index mapping - """ - return self._global_index_storage_rank_mapping, self.global_index_local_index_mapping - - def _get_metadata( - self, - data_fields: list[str], - batch_size: int, - global_step: int, - mode: str = "fetch", - task_name: str | None = None, - get_n_samples=False, - *args, - **kwargs, - ) -> BatchMeta: - """ - Retrieve metadata with support for three modes. - - Args: - data_fields: List of field names to include in metadata - batch_size: Number of samples to retrieve - global_step: Global step for which to retrieve metadata - mode: Operation mode - 'insert', 'fetch', or 'force_fetch' - - mode="insert": Insert metadata for new rows (without checking data status) - - mode="fetch": Retrieve metadata for ready data (check data status and sample) - - mode="force_fetch": Directly return metadata (without checking data status) - task_name: Name of the consumer task (required for fetch modes) - get_n_samples: Whether to retrieve n_samples as groups - *args: Additional positional arguments - **kwargs: Additional keyword arguments - - Returns: - BatchMeta object containing the requested metadata - - Raises: - TimeoutError: If waiting for sufficient data times out in fetch mode - """ - if mode == "insert": - # TODO: Currently we only supports put the entire GBS data in one time - assert batch_size == self.global_batch_size * self.num_n_samples, ( - f"batch_size {batch_size} must equal " - f"global_batch_size * num_n_samples {self.global_batch_size * self.num_n_samples}" - ) - start_idx, end_idx = self._step_to_global_index_range(global_step) - batch_global_indexes = list(range(start_idx, end_idx)) - return self._generate_batch_meta(global_step, batch_global_indexes, data_fields, mode) - - assert task_name is not None - if mode == "fetch": - # Find consumable samples within current batch and package into BatchMeta when reading - - start_time = time.time() - while True: - ready_for_consume_idx = self._scan_data_status(data_fields, global_step, task_name, get_n_samples) - - if len(ready_for_consume_idx) >= batch_size: - break - - if time.time() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT: - raise TimeoutError( - f"Timeout while waiting for sufficient data. " - f"Required: {batch_size}, Available: {len(ready_for_consume_idx)}" - ) - - logger.warning( - f"Insufficient data available. Required: {batch_size}, " - f"Available: {len(ready_for_consume_idx)}. Retrying in " - f"{TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..." - ) - time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL) - logger.debug(f"ready for consume idx: {ready_for_consume_idx}") - - batch_global_indexes = random_sampler(ready_for_consume_idx, batch_size, get_n_samples, self.num_n_samples) - elif mode == "force_fetch": - start_idx, end_idx = self._step_to_global_index_range(global_step) - consumer_status = self._get_consumption_status(task_name) - not_consumed_idx = [i for i in range(start_idx, end_idx) if consumer_status[i] == 0] - batch_global_indexes = random_sampler(not_consumed_idx, batch_size, get_n_samples, self.num_n_samples) - - # Mark this batch of data as consumed - consumer_status = self._get_consumption_status(task_name) - consumer_status[batch_global_indexes] = 1 - # Package into metadata - metadata = self._generate_batch_meta(global_step, batch_global_indexes, data_fields, mode) - logger.debug(f"_get_metadata: {metadata}") - - return metadata - - def _scan_data_status( - self, data_fields: list[str], global_step: int, task_name: str, get_n_samples: bool - ) -> list[int]: - """ - Scan data status to find samples ready for consumption. - - Args: - data_fields: List of field names to check - global_step: Global step to scan - task_name: Name of the consumer task - get_n_samples: Whether to return n_samples as groups - - Returns: - List of global indices that are ready for consumption - """ - # Get row and column masks - row_mask, col_mask = self.generate_data_status_mask(data_fields, global_step, task_name) - logger.debug(f"row_mask, col_mask: {row_mask, col_mask}") - - if not row_mask.any() or not col_mask.any(): - return [] - - # Extract subset of data status for relevant fields - logger.debug(f"self.data_production_status: {self.data_production_status}") - data_status_of_interest = self.data_production_status[:, col_mask] - logger.debug(f"data_status_of_interest: {data_status_of_interest}") - - # Use torch.all for vectorized check instead of sum comparison - all_fields_ready = torch.all(data_status_of_interest, dim=1) - - # Filter samples that meet criteria combined with row mask - ready_mask = all_fields_ready & row_mask - - if get_n_samples and self.num_n_samples > 1: - # Reshape to group view and check group completeness - group_all_ready = torch.all(ready_mask.view(-1, self.num_n_samples), dim=1) - - # Get indices of fully ready groups - ready_group_indices = group_all_ready.nonzero(as_tuple=False).flatten() - - # Calculate all sample indices - sample_offset = torch.arange(self.num_n_samples) - ready_for_consume_idx = ( - (ready_group_indices.unsqueeze(1) * self.num_n_samples + sample_offset).flatten().tolist() - ) - - return ready_for_consume_idx - else: - ready_for_consume_idx = torch.nonzero(ready_mask, as_tuple=False).flatten().tolist() - logger.debug(f"ready_for_consume_idx: {ready_for_consume_idx}") - - return ready_for_consume_idx - - def _generate_batch_meta( - self, global_step: int, global_indexes: list[int], data_fields: list[str], mode: str - ) -> BatchMeta: - """ - Generate BatchMeta by resolving storage locations for given global indexes. - - For each global index, looks up the corresponding storage node address using: - - global_index_local_index_mapping: Maps to local index within storage - - _global_index_storage_id_mapping: Maps to storage node identifier - - Args: - global_step: Current global step - global_indexes: List of global indexes to process - data_fields: List of data field names - mode: Operation mode ('fetch', 'insert', or 'force_fetch') - - Returns: - BatchMeta object containing sample metadata with resolved storage locations - """ - global_arr = np.array(global_indexes) - storage_ids = self.global_index_storage_id_mapping[global_arr] - local_indexes = self.global_index_local_index_mapping[global_arr] - - samples = [] - - # Create samples from the flattened BatchMeta data - # TODO: Optimize this - for i, global_index in enumerate(global_indexes): - local_index = local_indexes[i] - storage_id = storage_ids[i] - - # Create FieldMeta objects for each field - fields = [] - for field_name in data_fields: - if mode == "fetch": - production_status = ProductionStatus.READY_FOR_CONSUME # Since we filtered by ready status - # Get per-field dtype and shape for this specific global_index and field - dtype = self._get_per_field_dtype(global_index, field_name) - shape = self._get_per_field_shape(global_index, field_name) - elif mode == "insert": - production_status = ProductionStatus.NOT_PRODUCED # FIXME: not real-time - dtype = None - shape = None - elif mode == "force_fetch": - col_index = self.field_name_mapping.get(field_name) - if col_index is not None and self.data_production_status[global_index, col_index] == 1: - production_status = ProductionStatus.READY_FOR_CONSUME - dtype = self._get_per_field_dtype(global_index, field_name) - shape = self._get_per_field_shape(global_index, field_name) - else: - production_status = ProductionStatus.NOT_PRODUCED - dtype = None - shape = None - field_meta = FieldMeta( - name=field_name, - dtype=dtype, - shape=shape, - production_status=production_status, - ) - fields.append(field_meta) - - sample = SampleMeta( - global_step=global_step, - global_index=global_index, - storage_id=storage_id, - local_index=local_index, - fields={field.name: field for field in fields}, - ) - samples.append(sample) - - return BatchMeta(samples=samples) - - def _update_production_status(self, indexes: list[int], fields: list[str]) -> None: - """ - Update production status for specified indexes and fields. - - Args: - indexes: List of global indexes to update - fields: List of field names to update - """ - # TODO: Replace self.data_production_status == 0 or ==1 operations with ProductionStatus enum - # Update data production status matrix - new_fields = [field for field in fields if field not in self.field_name_mapping] - if new_fields: - needed_fields = len(new_fields) - current_fields = self.data_production_status.shape[1] - # Expand data status matrix if needed - if len(self.field_name_mapping) + needed_fields > current_fields: - add_fields = max(TQ_INIT_FIELD_NUM, needed_fields + 1) - new_matrix = torch.zeros((self.total_storage_size, add_fields), dtype=torch.int8) - self.data_production_status = torch.cat([self.data_production_status, new_matrix], dim=1) - - for field in fields: - if field not in self.field_name_mapping.keys(): - self.field_name_mapping[field] = len(self.field_name_mapping) - self.data_production_status[ - torch.tensor(indexes)[:, None], torch.tensor([self.field_name_mapping.get(field) for field in fields]) - ] = 1 - - def _update_field_info( - self, - fields: list[str], - per_tensor_dtypes: dict[int, dict[str, Any]], - per_tensor_shapes: dict[int, dict[str, Any]], - global_indexes: list[int], - ) -> None: - """ - Store per-field dtype and shape information. - - Args: - fields: List of field names - per_tensor_dtypes: Dict mapping global_index to field dtypes {global_index: {field: dtype}} - per_tensor_shapes: Dict mapping global_index to field shapes {global_index: {field: shape}} - global_indexes: List of global indexes corresponding to the samples - """ - for global_idx in global_indexes: - if global_idx not in self.per_tensor_dtype_mapping: - self.per_tensor_dtype_mapping[global_idx] = {} - if global_idx not in self.per_tensor_shape_mapping: - self.per_tensor_shape_mapping[global_idx] = {} - - for field in fields: - if global_idx in per_tensor_dtypes and field in per_tensor_dtypes[global_idx]: - self.per_tensor_dtype_mapping[global_idx][field] = per_tensor_dtypes[global_idx][field] - if global_idx in per_tensor_shapes and field in per_tensor_shapes[global_idx]: - self.per_tensor_shape_mapping[global_idx][field] = per_tensor_shapes[global_idx][field] - - def _init_zmq_socket(self): - """ - Initialize ZMQ sockets for communication. - - Sets up three ZMQ service ports for: - 1. Receiving handshake requests from storage - 2. Handling client data read/write requests - 3. Receiving status update signals from storage - """ - self.zmq_context = zmq.Context() - - self._node_ip = get_node_ip_address() - self._handshake_socket_port = get_free_port() - self._request_handle_socket_port = get_free_port() - self._data_status_update_socket_port = get_free_port() - - self.handshake_socket = create_zmq_socket( - ctx=self.zmq_context, - socket_type=zmq.ROUTER, - ) - self.handshake_socket.bind(f"tcp://{self._node_ip}:{self._handshake_socket_port}") - - self.request_handle_socket = create_zmq_socket( - ctx=self.zmq_context, - socket_type=zmq.ROUTER, - ) - self.request_handle_socket.bind(f"tcp://{self._node_ip}:{self._request_handle_socket_port}") - - self.data_status_update_socket = create_zmq_socket( - ctx=self.zmq_context, - socket_type=zmq.ROUTER, - ) - self.data_status_update_socket.bind(f"tcp://{self._node_ip}:{self._data_status_update_socket_port}") - - self.zmq_server_info = ZMQServerInfo.create( - role=TransferQueueRole.CONTROLLER, - id=self.controller_id, - ip=self._node_ip, - ports={ - "handshake_socket": self._handshake_socket_port, - "request_handle_socket": self._request_handle_socket_port, - "data_status_update_socket": self._data_status_update_socket_port, - }, - ) - - def _wait_connection(self): - """Wait for all storage instances to complete handshake. - - Clients don't need handshake to support dynamic scaling. Continuously - listens for handshake messages until all expected storage units connect. - """ - # TODO(zjj): Consider if retransmission is needed (assuming cases where Storage doesn't receive ACK) - connected_storage_units = set() - while len(connected_storage_units) < self.num_storage_units: - identity, serialized_msg = self.handshake_socket.recv_multipart() - request_msg = ZMQMessage.deserialize(serialized_msg) - if request_msg.request_type == ZMQRequestType.HANDSHAKE: - connected_storage_units.add(request_msg.sender_id) - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.HANDSHAKE_ACK, - sender_id=self.controller_id, - body={}, - ).serialize() - self.handshake_socket.send_multipart([identity, response_msg]) - logger.info("Controller sent handshake ack successfully!") - self.global_index_storage_id_mapping = np.array(sorted(list(connected_storage_units)))[ - self._global_index_storage_rank_mapping - ] - self.handshake_done.set() - - def _start_process_handshake(self): - """Start the handshake process thread.""" - self.handshake_done = threading.Event() - self.wait_connection_thread = Thread( - target=self._wait_connection, name="TransferQueueControllerWaitConnectionThread", daemon=True - ) - self.wait_connection_thread.start() - - def _start_process_update_data_status(self): - """Start the data status update processing thread.""" - self.process_update_data_status_thread = Thread( - target=self._update_data_status, name="TransferQueueControllerProcessUpdateDataStatusThread", daemon=True - ) - self.process_update_data_status_thread.start() - - def _start_process_request(self): - """Start the request processing thread.""" - self.process_request_thread = Thread( - target=self._process_request, name="TransferQueueControllerProcessRequestThread", daemon=True - ) - self.process_request_thread.start() - - def _process_request(self): - """Main request processing loop. - - Handles various request types including metadata retrieval, - consumption status checks, and clear operations. - """ - self.handshake_done.wait() - while True: - # ROUTER socket receives multi-part messages - identity, serialized_msg = self.request_handle_socket.recv_multipart() - request_msg = ZMQMessage.deserialize(serialized_msg) - - if request_msg.request_type == ZMQRequestType.GET_META: - params = request_msg.body - logger.info("Controller preparing to get metadata...") - metadata = self._get_metadata( - data_fields=params["data_fields"], - batch_size=params["batch_size"], - global_step=params["global_step"], - mode=params.get("mode", "fetch"), - task_name=params.get("task_name", None), - get_n_samples=params.get("get_n_samples", False), - ) - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_META_RESPONSE, - sender_id=self.controller_id, - receiver_id=request_msg.sender_id, - body={"metadata": metadata}, - ) - elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META: - params = request_msg.body - metadata = self._get_metadata( - data_fields=[], - batch_size=self.global_batch_size * self.num_n_samples, - global_step=params["global_step"], - mode="insert", - ) - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_CLEAR_META_RESPONSE, - sender_id=self.controller_id, - receiver_id=request_msg.sender_id, - body={"metadata": metadata}, - ) - elif request_msg.request_type == ZMQRequestType.CLEAR_META: - params = request_msg.body - self.clear(global_step=params["global_step"]) - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_META_RESPONSE, - sender_id=self.controller_id, - receiver_id=request_msg.sender_id, - body={"message": f"Clear operation completed by controller {self.controller_id}"}, - ) - elif request_msg.request_type == ZMQRequestType.CHECK_CONSUMPTION: - # Check consumption status - params = request_msg.body - global_step = params["global_step"] - - consumer_status = self._get_consumption_status(params["task_name"]) - start_idx, end_idx = self._step_to_global_index_range(global_step) - batch_status = consumer_status[start_idx:end_idx] - consumed = torch.all(batch_status == 1).item() - - # Build response message - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CONSUMPTION_RESPONSE, - sender_id=self.controller_id, - receiver_id=request_msg.sender_id, - body={ - "global_step": global_step, - "consumed": consumed, - }, - ) - self.request_handle_socket.send_multipart([identity, response_msg.serialize()]) - logger.debug("Controller request_handle_socket sent multipart successfully!") - - def _update_data_status(self): - """Process data status update messages from storage units. - - Continuously listens for data update notifications and updates - internal production status and field information accordingly. - """ - # Receive data status update information from storage - while True: - logger.debug("Preparing _update_data_status...") - identity, serialized_msg = self.data_status_update_socket.recv_multipart() - logger.debug("Controller received update_data_status request!") - request_msg = ZMQMessage.deserialize(serialized_msg) - logger.debug(f"[{self.controller_id}]: Controller received update_data_status request_msg: {request_msg}") - - if request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE: - message_data = request_msg.body - - fields = message_data.get("fields", []) - global_indexes = message_data.get("global_indexes", []) - per_tensor_dtypes = message_data.get("dtypes", {}) # Now a dict of lists - per_tensor_shapes = message_data.get("shapes", {}) # Now a dict of lists - # Update data production status - logger.debug(f"global_indexes, fields: {global_indexes, fields}") - self._update_production_status(global_indexes, fields) - self._update_field_info(fields, per_tensor_dtypes, per_tensor_shapes, global_indexes) - logger.info("Controller updated production status successfully!") - - # Send acknowledgment response - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK, - sender_id=self.controller_id, - body={ - "controller_id": self.controller_id, - "message": f"Data update acknowledged from controller {self.controller_id}", - }, - ) - self.data_status_update_socket.send_multipart([identity, response_msg.serialize()]) - logger.info("Controller sent DATA_UPDATE_ACK successfully!") - elif request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR: - # Handle data update errors - error_msg = request_msg.body.get("message", "Unknown error") - logger.error(f"Data update error from storage: {error_msg}") - - # Send error acknowledgment response - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK, - sender_id=self.controller_id, - body={ - "controller_id": self.controller_id, - "message": f"Error notification acknowledged from controller {self.controller_id}", - }, - ) - self.data_status_update_socket.send_multipart([identity, response_msg.serialize()]) - - def get_zmq_server_info(self) -> ZMQServerInfo: - """Get ZMQ server connection information. - - Returns: - ZMQServerInfo object containing connection details - """ - return self.zmq_server_info - - def clear(self, global_step: int): - """Clear data for a specific global batch. - - Resets production and consumption status for all data in the specified - global step. Currently only supports clearing single GBS at a time. - - Args: - global_step: The global step to clear data for - """ - start_idx, end_idx = self._step_to_global_index_range(global_step) - - self.data_production_status[start_idx:end_idx, :] = 0 - for task_name in self.data_consumption_status: - self.data_consumption_status[task_name][start_idx:end_idx] = 0 diff --git a/verl/experimental/transfer_queue/metadata.py b/verl/experimental/transfer_queue/metadata.py deleted file mode 100644 index 6d81e7f2ca3..00000000000 --- a/verl/experimental/transfer_queue/metadata.py +++ /dev/null @@ -1,602 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import dataclasses -from dataclasses import dataclass -from typing import Any, Optional - -import numpy as np -from tensordict import TensorDict - -from verl.experimental.transfer_queue.utils.utils import ProductionStatus - - -@dataclass -class FieldMeta: - """ - Records the metadata of a single data field. (name, dtype, shape, etc.) - """ - - # field name (e.g., 'prompt', 'response', etc.) - name: str - - # data schema info - dtype: Optional[Any] - shape: Optional[Any] - - # data status info - production_status: ProductionStatus = ProductionStatus.NOT_PRODUCED - - def __str__(self) -> str: - return ( - f"FieldMeta(name='{self.name}', dtype={self.dtype}, " - f"shape={self.shape}, production_status={self.production_status})" - ) - - @property - def is_ready(self) -> bool: - """Check if this field is ready for consumption""" - return self.production_status == ProductionStatus.READY_FOR_CONSUME - - -@dataclass -class SampleMeta: - """ - Records the metadata of a single data sample (stored as a row in the data system). - """ - - # algorithm related info - global_step: int # global step, used for data versioning - - # data retrival info - global_index: int # global row index, uniquely identifies a data sample - storage_id: str # storage unit id - local_index: int # local row index in the storage unit - - # data fields info - # this fields may not contain all the fields of the sample, but only fields-of-interest - fields: dict[str, FieldMeta] - - def __post_init__(self): - """Initialize is_ready property based on field readiness""" - # Check if all fields are ready and update is_ready property - object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values())) - - def __str__(self) -> str: - return ( - f"SampleMeta(global_step={self.global_step}, " - f"global_index={self.global_index}, storage_id='{self.storage_id}', " - f"local_index={self.local_index}, fields={self.fields})" - ) - - @property - def field_names(self) -> list[str]: - """Get list of field names for this sample""" - return list(self.fields.keys()) - - @property - def batch_index(self) -> int: - """Get the batch index of this sample (to be set by BatchMeta)""" - return getattr(self, "_batch_index", -1) - - def get_field_by_name(self, name: str) -> Optional[FieldMeta]: - """Get FieldMeta by field name""" - return self.fields.get(name) - - def has_field(self, name: str) -> bool: - """Check if this sample has a specific field""" - return name in self.fields - - def is_field_ready(self, field_name: str) -> bool: - """Check if a specific field is ready for consumption""" - field = self.fields.get(field_name) - return field.is_ready if field else False - - def add_fields(self, fields: dict[str, FieldMeta]) -> "SampleMeta": - """ - Add new fields to this sample. New fields will be initialized with given dtype, shape - and production_status (if provided). If not provided, default values (None, None, READY_FOR_CONSUME) - will be used. - This modifies the sample in-place to include the new fields. - """ - self.fields = _union_fields(self.fields, fields) - # Update is_ready property - object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values())) - return self - - def union(self, other: "SampleMeta", validate: bool = True) -> "SampleMeta": - """ - Create a union of this sample's fields with another sample's fields. - Assume both samples have the same global index. If fields overlap, the - fields in this sample will be replaced by the other sample's fields. - - Args: - other: Another SampleMeta to union with - validate: Whether to validate union conditions - - Returns: - New SampleMeta with unioned fields (None if validation fails) - """ - if validate: - if self.global_index != other.global_index: - raise ValueError( - f"Error: Global indexes ({self.global_index} and {other.global_index}) do not match for union." - ) - - # Merge fields - self.fields = _union_fields(self.fields, other.fields) - - # Update is_ready property - object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values())) - return self - - @property - def is_ready(self) -> bool: - """Check if all fields in this sample are ready for consumption""" - return getattr(self, "_is_ready", False) - - @property - def production_status(self) -> dict[str, ProductionStatus]: - """Get production status for all fields (backward compatibility)""" - return {name: field.production_status for name, field in self.fields.items()} - - -@dataclass -class StorageMetaGroup: - """ - Represents a group of samples stored in the same storage unit. - Used to organize samples by their storage_id for efficient client operations. - """ - - storage_id: str - sample_metas: list[SampleMeta] = dataclasses.field(default_factory=list) - - def add_sample_meta(self, sample_meta: SampleMeta) -> None: - """Add a SampleMeta object to this storage group""" - self.sample_metas.append(sample_meta) - - def get_batch_indexes(self) -> list[int]: - """Get all internal indexes from stored SampleMeta objects""" - return [meta.batch_index for meta in self.sample_metas] - - def get_global_indexes(self) -> list[int]: - """Get all global indexes from stored SampleMeta objects""" - return [meta.global_index for meta in self.sample_metas] - - def get_local_indexes(self) -> list[int]: - """Get all local indexes from stored SampleMeta objects""" - return [meta.local_index for meta in self.sample_metas] - - def get_field_names(self) -> list[str]: - """Get all unique field names from stored SampleMeta objects""" - all_fields: set[str] = set() - for meta in self.sample_metas: - all_fields.update(meta.fields.keys()) - return list(all_fields) - - def get_transfer_info(self, field_names: Optional[list[str]] = None) -> dict[str, list | dict]: - """Convert to dictionary format for backward compatibility""" - if field_names is None: - field_names = self.get_field_names() - return { - "batch_indexes": self.get_batch_indexes(), - "global_indexes": self.get_global_indexes(), - "local_indexes": self.get_local_indexes(), - "fields": field_names, - "field_data": {}, # Placeholder for field data to be filled later - } - - @property - def size(self) -> int: - """Number of samples in this storage meta group""" - return len(self.sample_metas) - - @property - def is_empty(self) -> bool: - """Check if this storage meta group is empty""" - return len(self.sample_metas) == 0 - - def __len__(self) -> int: - """Number of samples in this storage meta group""" - return self.size - - def __bool__(self) -> bool: - """Truthiness based on whether group has samples""" - return not self.is_empty - - def __str__(self) -> str: - return f"StorageMetaGroup(storage_id='{self.storage_id}', size={self.size})" - - -@dataclass -class BatchMeta: - """ - Records the metadata of a batch of data samples. - """ - - samples: list[SampleMeta] - extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) - - def __post_init__(self): - """Initialize all computed properties during initialization""" - # Basic properties - object.__setattr__(self, "_size", len(self.samples)) - object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples)) - - # Pre-compute all list properties for better performance - if self.samples: - for idx, sample in enumerate(self.samples): - object.__setattr__(sample, "_batch_index", idx) # Ensure batch_index is set correctly - - object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples]) - object.__setattr__(self, "_local_indexes", [sample.local_index for sample in self.samples]) - object.__setattr__(self, "_storage_ids", [sample.storage_id for sample in self.samples]) - - # assume all samples have the same fields. - object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names)) - - # Initialize storage groups for efficient client operations - storage_meta_groups = self._build_storage_meta_groups() - object.__setattr__(self, "_storage_meta_groups", storage_meta_groups) - else: - object.__setattr__(self, "_global_indexes", []) - object.__setattr__(self, "_local_indexes", []) - object.__setattr__(self, "_storage_ids", []) - object.__setattr__(self, "_field_names", []) - object.__setattr__(self, "_storage_meta_groups", {}) - - @property - def size(self) -> int: - """Return the number of samples in this batch""" - return getattr(self, "_size", 0) - - @property - def global_indexes(self) -> list[int]: - """Get all global indexes in this batch""" - return getattr(self, "_global_indexes", []) - - @property - def field_names(self) -> list[str]: - """Get all unique field names in this batch""" - return getattr(self, "_field_names", []) - - @property - def local_indexes(self) -> list[int]: - """Get all local indexes in this batch""" - return getattr(self, "_local_indexes", []) - - @property - def storage_ids(self) -> list[str]: - """Get all storage unit IDs in this batch""" - return getattr(self, "_storage_ids", []) - - @property - def is_ready(self) -> bool: - """Check if all samples in this batch are ready for consumption""" - # TODO: get ready status from controller realtime - return getattr(self, "_is_ready", False) - - def _build_storage_meta_groups(self) -> dict[str, StorageMetaGroup]: - """Build storage groups from samples during initialization""" - storage_meta_groups: dict[str, StorageMetaGroup] = {} - - for sample in self.samples: - storage_id = sample.storage_id - if storage_id not in storage_meta_groups: - storage_meta_groups[storage_id] = StorageMetaGroup(storage_id=storage_id) - - # Use add_sample_meta to store SampleMeta references directly - storage_meta_groups[storage_id].add_sample_meta(sample) - - return storage_meta_groups - - @property - def storage_meta_groups(self) -> dict[str, StorageMetaGroup]: - """Get storage groups organized by storage_id""" - return getattr(self, "_storage_meta_groups", {}) - - @property - def storage_unit_ids(self) -> list[str]: - """Get list of all storage unit IDs""" - return list(self.storage_meta_groups.keys()) - - def get_storage_meta_groups(self, storage_id: str) -> Optional[StorageMetaGroup]: - """Get storage group by storage ID""" - return self.storage_meta_groups.get(storage_id) - - # Extra info interface methods - def get_extra_info(self, key: str, default: Any = None) -> Any: - """Get extra info by key""" - return self.extra_info.get(key, default) - - def set_extra_info(self, key: str, value: Any) -> None: - """Set extra info by key""" - self.extra_info[key] = value - - def update_extra_info(self, info_dict: dict[str, Any]) -> None: - """Update extra info with multiple key-value pairs""" - self.extra_info.update(info_dict) - - def remove_extra_info(self, key: str) -> Any: - """Remove extra info by key and return its value""" - return self.extra_info.pop(key, None) - - def clear_extra_info(self) -> None: - """Clear all extra info""" - self.extra_info.clear() - - def has_extra_info(self, key: str) -> bool: - """Check if extra info contains a specific key""" - return key in self.extra_info - - def get_all_extra_info(self) -> dict[str, Any]: - """Get all extra info as a dictionary""" - return self.extra_info.copy() - - def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "BatchMeta": - """ - Add new fields from a TensorDict to all samples in this batch. - This modifies each sample in-place to include the new fields. - - Args: - tensor_dict (TensorDict): The input TensorDict containing new fields. - set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME. Default is True. - """ - fields = _extract_field_metas(tensor_dict, set_all_ready) - for idx, sample in enumerate(self.samples): - sample.add_fields(fields=fields[idx]) - - # Update batch-level fields cache - object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names)) - object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples)) - return self - - def __len__(self) -> int: - """Return the number of samples in this batch.""" - return len(self.samples) - - def __getitem__(self, item): - if isinstance(item, int | np.integer): - sample_meta = self.samples[item] if self.samples else [] - return BatchMeta(samples=[sample_meta], extra_info=self.extra_info) - else: - raise TypeError(f"Indexing with {type(item)} is not supported now!") - - def chunk(self, chunks: int) -> list["BatchMeta"]: - """ - Split this batch into smaller chunks. - - Args: - chunks: number of chunks - - Return: - List of smaller BatchMeta chunks - """ - chunk_list = [] - n = len(self.samples) - - # Calculate the base size and remainder of each chunk - base_size = n // chunks - remainder = n % chunks - - start = 0 - for i in range(chunks): - # Calculate the size of the current chunk(the first remainder chunk is 1 more than the base size) - current_chunk_size = base_size + 1 if i < remainder else base_size - end = start + current_chunk_size - chunk_samples = self.samples[start:end] - chunk = BatchMeta(samples=chunk_samples, extra_info=self.extra_info.copy()) - chunk_list.append(chunk) - start = end - return chunk_list - - @classmethod - def concat(cls, data: list["BatchMeta"], validate: bool = True) -> Optional["BatchMeta"]: - """ - Concatenate multiple BatchMeta chunks into one large batch. - - Args: - data: List of BatchMeta chunks to concatenate - validate: Whether to validate concatenation conditions - - Returns: - Concatenated BatchMeta - - Raises: - ValueError: If validation fails (e.g., field names do not match) - """ - if not data: - return None - - if validate: - base_fields = data[0].field_names - - for chunk in data: - if chunk.field_names != base_fields: - raise ValueError("Error: Field names do not match for concatenation.") - - # Combine all samples - all_samples = [] - for chunk in data: - all_samples.extend(chunk.samples) - # Merge all extra_info dictionaries from the chunks - merged_extra_info = {} - for chunk in data: - merged_extra_info.update(chunk.extra_info) - return BatchMeta(samples=all_samples, extra_info=merged_extra_info) - - def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMeta"]: - """ - Create a union of this batch's fields with another batch's fields. - Assume both batches have the same global indices. If fields overlap, the - fields in this batch will be replaced by the other batch's fields. - - Args: - other: Another BatchMeta to union with - validate: Whether to validate union conditions - - Returns: - New BatchMeta with unioned fields - - Raises: - ValueError: If validation fails (e.g., batch sizes or global indexes do not match) - """ - if validate: - if self.size != other.size: - raise ValueError("Error: Batch sizes do not match for union.") - - self_global_indexes = sorted(self.global_indexes) - other_global_indexes = sorted(other.global_indexes) - if self_global_indexes != other_global_indexes: - raise ValueError("Error: Global indexes do not match for union.") - - # Create a mapping from global_index to SampleMeta in the other batch - other_sample_map = {sample.global_index: sample for sample in other.samples} - - # Merge samples - merged_samples = [] - for sample in self.samples: - if sample.global_index in other_sample_map: - other_sample = other_sample_map[sample.global_index] - merged_sample = sample.union(other_sample, validate=validate) - merged_samples.append(merged_sample) - else: - merged_samples.append(sample) - - # Merge extra info dictionaries - merged_extra_info = {**self.extra_info, **other.extra_info} - - return BatchMeta(samples=merged_samples, extra_info=merged_extra_info) - - def reorder(self, indices: list[int]): - """ - Reorder the SampleMeta in the BatchMeta according to the given indices. - - The operation is performed in-place, modifying the current BatchMeta's SampleMeta order. - - Args: - indices : list[int] - A list of integers specifying the new order of SampleMeta. Each integer - represents the current index of the SampleMeta in the BatchMeta. - """ - # Reorder the samples - reordered_samples = [self.samples[i] for i in indices] - object.__setattr__(self, "samples", reordered_samples) - - # Update necessary attributes - self._update_after_reorder() - - def _update_after_reorder(self) -> None: - """Update related attributes specifically for the reorder operation""" - # Update batch_index for each sample - for idx, sample in enumerate(self.samples): - object.__setattr__(sample, "_batch_index", idx) - - # Update cached index lists - object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples]) - object.__setattr__(self, "_local_indexes", [sample.local_index for sample in self.samples]) - object.__setattr__(self, "_storage_ids", [sample.storage_id for sample in self.samples]) - - # Rebuild storage groups - storage_meta_groups = self._build_storage_meta_groups() - object.__setattr__(self, "_storage_meta_groups", storage_meta_groups) - - # Note: No need to update _size, _field_names, _is_ready, etc., as these remain unchanged after reorder - - @classmethod - def from_samples( - cls, samples: SampleMeta | list[SampleMeta], extra_info: Optional[dict[str, Any]] = None - ) -> "BatchMeta": - """ - Create a BatchMeta from a single SampleMeta or a list of SampleMeta objects. - - Args: - samples: A single SampleMeta or a list of SampleMeta objects - extra_info: Optional additional information to store with the batch - - Returns: - BatchMeta instance containing the provided sample(s) - - Example: - >>> sample_meta = SampleMeta(...) - >>> batch_meta = BatchMeta.from_samples(sample_meta) - - >>> sample_metas = [sample1, sample2, sample3] - >>> batch_meta = BatchMeta.from_samples(sample_metas, extra_info={"source": "training"}) - """ - if extra_info is None: - extra_info = {} - - if isinstance(samples, SampleMeta): - samples = [samples] - - return cls(samples=samples, extra_info=extra_info) - - @classmethod - def empty(cls, extra_info: Optional[dict[str, Any]] = None) -> "BatchMeta": - """ - Create an empty BatchMeta with no samples. - - Args: - extra_info: Optional additional information to store with the batch - - Returns: - Empty BatchMeta instance - - Example: - >>> empty_batch = BatchMeta.empty() - """ - if extra_info is None: - extra_info = {} - return cls(samples=[], extra_info=extra_info) - - -def _union_fields(fields1: dict[str, FieldMeta], fields2: dict[str, FieldMeta]) -> dict[str, FieldMeta]: - """Union two sample's fields. If fields overlap, the fields in fields1 will be replaced by fields2.""" - for name in fields2.keys(): - fields1[name] = fields2[name] - return fields1 - - -def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) -> list[dict[str, FieldMeta]]: - """ - Extract field metas from a TensorDict. If data in tensor_dict does not have dtype or shape attribute, - the corresponding dtype or shape will be set to None. - - Args: - tensor_dict (TensorDict): The input TensorDict. - set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME. - Otherwise, set to NOT_PRODUCED. Default is True. - - Returns: - all_fields (list[dict[FieldMeta]]): A list of dictionaries containing field metadata. - """ - all_fields = [] - batch_size = tensor_dict.batch_size[0] - for idx in range(batch_size): - fields = {} - sample = tensor_dict[idx] - for name, value in sample.items(): - fields[name] = FieldMeta( - name=name, - dtype=value.dtype if hasattr(value, "dtype") else None, - shape=value.shape if hasattr(value, "shape") else None, - production_status=ProductionStatus.READY_FOR_CONSUME - if set_all_ready - else ProductionStatus.NOT_PRODUCED, - ) - all_fields.append(fields) - - return all_fields diff --git a/verl/experimental/transfer_queue/storage.py b/verl/experimental/transfer_queue/storage.py deleted file mode 100644 index c8f908ee8d8..00000000000 --- a/verl/experimental/transfer_queue/storage.py +++ /dev/null @@ -1,516 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import time -from operator import itemgetter -from threading import Thread -from uuid import uuid4 - -import ray -import torch -import zmq -from ray.util import get_node_ip_address -from tensordict import NonTensorStack, TensorDict - -from verl.experimental.transfer_queue.utils.utils import TransferQueueRole -from verl.experimental.transfer_queue.utils.zmq_utils import ( - ZMQMessage, - ZMQRequestType, - ZMQServerInfo, - create_zmq_socket, - get_free_port, -) - -logger = logging.getLogger(__name__) -logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) - -TQ_STORAGE_POLLER_TIMEOUT = os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 1000) -TQ_STORAGE_HANDSHAKE_TIMEOUT = int(os.environ.get("TQ_STORAGE_HANDSHAKE_TIMEOUT", 30)) -TQ_DATA_UPDATE_RESPONSE_TIMEOUT = int(os.environ.get("TQ_DATA_UPDATE_RESPONSE_TIMEOUT", 600)) - - -class StorageUnitData: - """ - Class used for storing several elements, each element is composed of several fields and corresponding data, like: - ##################################################### - # local_index | field_name1 | field_name2 | ... # - # 0 | item1 | item2 | ... # - # 1 | item3 | item4 | ... # - # 2 | item5 | item6 | ... # - ##################################################### - """ - - def __init__(self, storage_size: int): - # Dict containing field names and corresponding data in the field, e.g. {"field_name1": [data1, data2, ...]} - self.field_data: dict[str, list] = {} - - # Maximum number of elements stored in storage unit - self.storage_size = storage_size - - def get_data(self, fields: list[str], local_indexes: list[int]) -> TensorDict[str, list]: - """ - Get data from storage unit according to given fields and local_indexes. - - param: - fields: Field names used for getting data. - local_indexes: Local indexes used for getting data. - return: - TensorDict with field names as keys, corresponding data list as values. - """ - result: dict[str, list] = {} - - for field in fields: - # Validate field name - if field not in self.field_data: - raise ValueError( - f"StorageUnitData get_data operation receive invalid field: {field} beyond {self.field_data.keys()}" - ) - - if len(local_indexes) == 1: - # The unsqueeze op make the shape from n to (1, n) - gathered_item = self.field_data[field][local_indexes[0]] - if not isinstance(gathered_item, torch.Tensor): - result[field] = NonTensorStack(gathered_item) - else: - result[field] = gathered_item.unsqueeze(0) - else: - gathered_items = list(itemgetter(*local_indexes)(self.field_data[field])) - - if gathered_items: - all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items) - if all_tensors: - result[field] = torch.nested.as_nested_tensor(gathered_items) - else: - result[field] = NonTensorStack(*gathered_items) - - return TensorDict(result) - - def put_data(self, field_data: TensorDict[str, list], local_indexes: list[int]) -> None: - """ - Put or update data into storage unit according to given field_data and local_indexes. - - param: - field_data: Dict with field names as keys, corresponding data in the field as values. - local_indexes: Local indexes used for putting data. - """ - for f in field_data.keys(): - for i, idx in enumerate(local_indexes): - # Validate local_indexes - if idx < 0 or idx >= self.storage_size: - raise ValueError( - f"StorageUnitData put_data operation receive invalid local_index: {idx} beyond " - f"storage_size: {self.storage_size}" - ) - - if f not in self.field_data: - # Initialize new field value list with None - self.field_data[f] = [None] * self.storage_size - - self.field_data[f][idx] = field_data[f][i] - - def clear(self, local_indexes: list[int]) -> None: - """ - Clear data at specified local_indexes by setting all related fields to None. - - param: - local_indexes: local_indexes to clear. - """ - # Validate local_indexes - for idx in local_indexes: - if idx < 0 or idx >= self.storage_size: - raise ValueError( - f"StorageUnitData clear operation receive invalid local_index: {idx} beyond " - f"storage_size: {self.storage_size}" - ) - - # Clear data at specified local_indexes - for f in self.field_data: - for idx in local_indexes: - self.field_data[f][idx] = None - - -@ray.remote(num_cpus=1) -class TransferQueueStorageSimpleUnit: - def __init__(self, storage_size: int): - super().__init__() - self.storage_unit_id = f"TQ_STORAGE_UNIT_{uuid4()}" - self.storage_size = storage_size - self.controller_infos: dict[str, ZMQServerInfo] = {} - - self.experience_data = StorageUnitData(self.storage_size) - - self.zmq_server_info = ZMQServerInfo.create( - role=TransferQueueRole.STORAGE, - id=str(self.storage_unit_id), - ip=get_node_ip_address(), - ports={"put_get_socket": get_free_port()}, - ) - self._init_zmq_socket() - - def _init_zmq_socket(self) -> None: - """ - Initialize ZMQ socket connections between storage unit and controllers/clients: - - controller_handshake_sockets: - Handshake between storage unit and controllers. - - data_status_update_sockets: - Broadcast data update status from storage unit to controllers when handling put operation. - - put_get_socket: - Handle put/get requests from clients. - """ - self.zmq_context = zmq.Context() - - self.controller_handshake_sockets: dict[str, zmq.Socket] = {} - self.data_status_update_sockets: dict[str, zmq.Socket] = {} - - self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER) - self.put_get_socket.bind(self.zmq_server_info.to_addr("put_get_socket")) - - def register_controller_info(self, controller_infos: dict[str, ZMQServerInfo]) -> None: - """ - Build connections between storage unit and controllers, start put/get process. - - param: - controller_infos: Dict with controller infos. - """ - self.controller_infos = controller_infos - - self._init_zmq_sockets_with_controller_infos() - self._connect_to_controller() - self._start_process_put_get() - - def _init_zmq_sockets_with_controller_infos(self) -> None: - """Initialize ZMQ sockets between storage unit and controllers for handshake.""" - for controller_id in self.controller_infos.keys(): - self.controller_handshake_sockets[controller_id] = create_zmq_socket( - self.zmq_context, - zmq.DEALER, - identity=f"{self.storage_unit_id}-controller_handshake_sockets-{uuid4()}".encode(), - ) - self.data_status_update_sockets[controller_id] = create_zmq_socket( - self.zmq_context, - zmq.DEALER, - identity=f"{self.storage_unit_id}-data_status_update_sockets-{uuid4()}".encode(), - ) - - def _connect_to_controller(self) -> None: - """Connect storage unit to all controllers.""" - connected_controllers: set[str] = set() - - # Create zmq poller for handshake confirmation between controller and storage unit - poller = zmq.Poller() - - for controller_id, controller_info in self.controller_infos.items(): - self.controller_handshake_sockets[controller_id].connect(controller_info.to_addr("handshake_socket")) - logger.debug( - f"[{self.zmq_server_info.id}]: Handshake connection from storage unit id #{self.zmq_server_info.id} " - f"to controller id #{controller_id} establish successfully." - ) - - # Send handshake request to controllers - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.HANDSHAKE, - sender_id=self.zmq_server_info.id, - body={ - "storage_unit_id": self.storage_unit_id, - "storage_size": self.storage_size, - }, - ).serialize() - - self.controller_handshake_sockets[controller_id].send(request_msg) - logger.debug( - f"[{self.zmq_server_info.id}]: Send handshake request from storage unit id #{self.zmq_server_info.id} " - f"to controller id #{controller_id} successfully." - ) - - poller.register(self.controller_handshake_sockets[controller_id], zmq.POLLIN) - - start_time = time.time() - while ( - len(connected_controllers) < len(self.controller_infos) - and time.time() - start_time < TQ_STORAGE_HANDSHAKE_TIMEOUT - ): - socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT)) - - for controller_handshake_socket in self.controller_handshake_sockets.values(): - if controller_handshake_socket in socks: - response_msg = ZMQMessage.deserialize(controller_handshake_socket.recv()) - - if response_msg.request_type == ZMQRequestType.HANDSHAKE_ACK: - connected_controllers.add(response_msg.sender_id) - logger.debug( - f"[{self.zmq_server_info.id}]: Get handshake ACK response from " - f"controller id #{str(response_msg.sender_id)} to storage unit id " - f"#{self.zmq_server_info.id} successfully." - ) - - if len(connected_controllers) < len(self.controller_infos): - logger.warning( - f"[{self.zmq_server_info.id}]: Only get {len(connected_controllers)} / {len(self.controller_infos)} " - f"successful handshake connections to controllers from storage unit id #{self.zmq_server_info.id}" - ) - - def _start_process_put_get(self) -> None: - """Create a daemon thread and start put/get process.""" - self.process_put_get_thread = Thread( - target=self._process_put_get, name=f"StorageUnitProcessPutGetThread-{self.zmq_server_info.id}", daemon=True - ) - self.process_put_get_thread.start() - - def _process_put_get(self) -> None: - """Process put_get_socket request.""" - poller = zmq.Poller() - poller.register(self.put_get_socket, zmq.POLLIN) - - while True: - socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT)) - - if self.put_get_socket in socks: - identity, serialized_msg = self.put_get_socket.recv_multipart() - - try: - request_msg = ZMQMessage.deserialize(serialized_msg) - operation = request_msg.request_type - logger.debug(f"[{self.zmq_server_info.id}]: receive operation: {operation}, message: {request_msg}") - - if operation == ZMQRequestType.PUT_DATA: - response_msg = self._handle_put(request_msg) - elif operation == ZMQRequestType.GET_DATA: - response_msg = self._handle_get(request_msg) - elif operation == ZMQRequestType.CLEAR_DATA: - response_msg = self._handle_clear(request_msg) - else: - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR, - sender_id=self.zmq_server_info.id, - body={ - "message": f"Storage unit id #{self.zmq_server_info.id} " - f"receive invalid operation: {operation}." - }, - ) - except Exception as e: - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_GET_ERROR, - sender_id=self.zmq_server_info.id, - body={ - "message": f"Storage unit id #{self.zmq_server_info.id} occur error in processing " - f"put/get/clear request, detail error message: {str(e)}." - }, - ) - - self.put_get_socket.send_multipart([identity, response_msg.serialize()]) - - def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: - """ - Handle put request, add or update data into storage unit. - - param: - data_parts: ZMQMessage from client. - return: - Put data success response ZMQMessage. - """ - try: - global_indexes = data_parts.body["global_indexes"] - local_indexes = data_parts.body["local_indexes"] - field_data = data_parts.body["field_data"] # field_data should be in {field_name: [real data]} format. - - self.experience_data.put_data(field_data, local_indexes) - - # After put operation finish, send a message to the client - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_DATA_RESPONSE, sender_id=self.zmq_server_info.id, body={} - ) - - # Gather per-field dtype and shape information for each field - # global_indexes, local_indexes, and field_data correspond one-to-one - per_field_dtypes = {} - per_field_shapes = {} - - # Initialize the data structure for each global index - for global_idx in global_indexes: - per_field_dtypes[global_idx] = {} - per_field_shapes[global_idx] = {} - - # For each field, extract dtype and shape for each sample - for field in field_data.keys(): - for i, data_item in enumerate(field_data[field]): - global_idx = global_indexes[i] - per_field_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None - per_field_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None - - # Broadcast data update message to all controllers with per-field dtype/shape information - self._notify_data_update(list(field_data.keys()), global_indexes, per_field_dtypes, per_field_shapes) - return response_msg - except Exception as e: - return ZMQMessage.create( - request_type=ZMQRequestType.PUT_ERROR, - sender_id=self.zmq_server_info.id, - body={ - "message": f"Failed to put data into storage unit id " - f"#{self.zmq_server_info.id}, detail error message: {str(e)}" - }, - ) - - def _notify_data_update(self, fields, global_indexes, dtypes, shapes) -> None: - """ - Broadcast data status update to all controllers. - - param: - fields: data update related fields. - global_indexes: data update related global_indexes. - dtypes: per-field dtypes for each field, in {global_index: {field: dtype}} format. - shapes: per-field shapes for each field, in {global_index: {field: shape}} format. - """ - # Create zmq poller for notifying data update information - poller = zmq.Poller() - - # Connect data status update socket to all controllers - for controller_id, controller_info in self.controller_infos.items(): - data_status_update_socket = self.data_status_update_sockets[controller_id] - data_status_update_socket.connect(controller_info.to_addr("data_status_update_socket")) - logger.debug( - f"[{self.zmq_server_info.id}]: Data status update connection from " - f"storage unit id #{self.zmq_server_info.id} to " - f"controller id #{controller_id} establish successfully." - ) - - try: - poller.register(data_status_update_socket, zmq.POLLIN) - - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, - sender_id=self.zmq_server_info.id, - body={ - "fields": fields, - "global_indexes": global_indexes, - "dtypes": dtypes, - "shapes": shapes, - }, - ).serialize() - - data_status_update_socket.send(request_msg) - logger.debug( - f"[{self.zmq_server_info.id}]: Send data status update request " - f"from storage unit id #{self.zmq_server_info.id} " - f"to controller id #{controller_id} successfully." - ) - except Exception as e: - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, - sender_id=self.zmq_server_info.id, - body={ - "message": f"Failed to notify data status update information from " - f"storage unit id #{self.zmq_server_info.id}, " - f"detail error message: {str(e)}" - }, - ).serialize() - - data_status_update_socket.send(request_msg) - - # Make sure all controllers successfully receive data status update information. - response_controllers: set[str] = set() - start_time = time.time() - - while ( - len(response_controllers) < len(self.controller_infos) - and time.time() - start_time < TQ_DATA_UPDATE_RESPONSE_TIMEOUT - ): - socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT)) - - for data_status_update_socket in self.data_status_update_sockets.values(): - if data_status_update_socket in socks: - response_msg = ZMQMessage.deserialize(data_status_update_socket.recv()) - - if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: - response_controllers.add(response_msg.sender_id) - logger.debug( - f"[{self.zmq_server_info.id}]: Get data status update ACK response " - f"from controller id #{response_msg.sender_id} " - f"to storage unit id #{self.zmq_server_info.id} successfully." - ) - - if len(response_controllers) < len(self.controller_infos): - logger.warning( - f"[{self.zmq_server_info.id}]: Storage unit id #{self.zmq_server_info.id} " - f"only get {len(response_controllers)} / {len(self.controller_infos)} " - f"data status update ACK responses from controllers." - ) - - def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage: - """ - Handle get request, return data from storage unit. - - param: - data_parts: ZMQMessage from client. - return: - Get data success response ZMQMessage, containing target data. - """ - try: - fields = data_parts.body["fields"] - local_indexes = data_parts.body["local_indexes"] - - result_data = self.experience_data.get_data(fields, local_indexes) - - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_DATA_RESPONSE, - sender_id=self.zmq_server_info.id, - body={ - "data": result_data, - }, - ) - except Exception as e: - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_ERROR, - sender_id=self.zmq_server_info.id, - body={ - "message": f"Failed to get data from storage unit id #{self.zmq_server_info.id}, " - f"detail error message: {str(e)}" - }, - ) - return response_msg - - def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage: - """ - Handle clear request, clear data in storage unit according to given local_indexes. - - param: - data_parts: ZMQMessage from client, including target local_indexes. - return: - Clear data success response ZMQMessage. - """ - try: - local_indexes = data_parts.body["local_indexes"] - - self.experience_data.clear(local_indexes) - - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_DATA_RESPONSE, - sender_id=self.zmq_server_info.id, - body={"message": f"Clear data in storage unit id #{self.zmq_server_info.id} successfully."}, - ) - except Exception as e: - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_DATA_ERROR, - sender_id=self.zmq_server_info.id, - body={ - "message": f"Failed to clear data in storage unit id #{self.zmq_server_info.id}, " - f"detail error message: {str(e)}" - }, - ) - return response_msg - - def get_zmq_server_info(self) -> ZMQServerInfo: - return self.zmq_server_info diff --git a/verl/experimental/transfer_queue/utils/__init__.py b/verl/experimental/transfer_queue/utils/__init__.py deleted file mode 100644 index 2df3b7f876f..00000000000 --- a/verl/experimental/transfer_queue/utils/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/experimental/transfer_queue/utils/utils.py b/verl/experimental/transfer_queue/utils/utils.py deleted file mode 100644 index 2fceb3f14ce..00000000000 --- a/verl/experimental/transfer_queue/utils/utils.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from enum import Enum - -import ray -import torch -from tensordict import TensorDict - - -class ExplicitEnum(str, Enum): - """ - Enum with more explicit error message for missing values. - """ - - @classmethod - def _missing_(cls, value): - raise ValueError( - f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" - ) - - -class TransferQueueRole(ExplicitEnum): - CONTROLLER = "TransferQueueController" - STORAGE = "TransferQueueStorage" - CLIENT = "TransferQueueClient" - - -# production_status enum: 0: not produced, 1: ready for consume, 2: consumed -class ProductionStatus(ExplicitEnum): - NOT_PRODUCED = 0 - READY_FOR_CONSUME = 1 - CONSUMED = 2 - - -def get_placement_group(num_ray_actors: int, num_cpus_per_actor: int = 1): - """ - Create a placement group with SPREAD strategy for Ray actors. - - Args: - num_ray_actors (int): Number of Ray actors to create. - num_cpus_per_actor (int): Number of CPUs to allocate per actor. - - Returns: - placement_group: The created placement group. - """ - bundle = {"CPU": num_cpus_per_actor} - placement_group = ray.util.placement_group([bundle for _ in range(num_ray_actors)], strategy="SPREAD") - ray.get(placement_group.ready()) - return placement_group - - -def random_sampler( - ready_for_consume_idx: list[int], - batch_size: int, - get_n_samples: bool, - n_samples_per_prompt: int, -) -> list[int]: - """ - random sampling batch_size samples from global indexes ready_for_consume_idx - input example: - if get_n_samples: (group_num=3, group_size=4) - ready_for_consume_idx could look like: [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19] - else: - ready_for_consume_idx could look like: [2, 5, 6] - """ - if get_n_samples: - assert len(ready_for_consume_idx) % n_samples_per_prompt == 0 - assert batch_size % n_samples_per_prompt == 0 - batch_size_n_samples = batch_size // n_samples_per_prompt - - group_ready_for_consume_idx = torch.tensor(ready_for_consume_idx, dtype=torch.int).view( - -1, n_samples_per_prompt - ) - - weights = torch.ones(group_ready_for_consume_idx.size(0)) - sampled_indexes_idx = torch.multinomial(weights, batch_size_n_samples, replacement=False).tolist() - sampled_indexes = group_ready_for_consume_idx[sampled_indexes_idx].flatten().tolist() - else: - weights = torch.ones(len(ready_for_consume_idx)) - sampled_indexes_idx = torch.multinomial(weights, batch_size, replacement=False).tolist() - sampled_indexes = [int(ready_for_consume_idx[i]) for i in sampled_indexes_idx] - return sampled_indexes - - -def extract_field_info(tensor_dict: TensorDict) -> dict: - """ - Extract field names, dtypes, and shapes from a TensorDict. - Assumes all tensors in the same field have the same dtype and shape (excluding batch dimension). - Returns a dictionary with keys: 'names', 'dtypes', 'shapes'. - """ - field_info: dict[str, list] = {"names": [], "dtypes": [], "shapes": []} - for key, value in tensor_dict.items(): - field_info["names"].append(key) - - # TODO: support nested tensors & non tensors - # field_info["dtypes"].append(value.dtype) - # field_info["shapes"].append(value.shape[1:]) # exclude batch dimension - return field_info diff --git a/verl/experimental/transfer_queue/utils/zmq_utils.py b/verl/experimental/transfer_queue/utils/zmq_utils.py deleted file mode 100644 index 947b48407ef..00000000000 --- a/verl/experimental/transfer_queue/utils/zmq_utils.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pickle -import socket -import time -import uuid -from dataclasses import dataclass -from typing import Any, Optional - -import psutil -import zmq -from typing_extensions import Self - -from verl.experimental.transfer_queue.utils.utils import ( - ExplicitEnum, - TransferQueueRole, -) - - -class ZMQRequestType(ExplicitEnum): - # HANDSHAKE - HANDSHAKE = "HANDSHAKE" # TransferQueueStorageUnit -> TransferQueueController - HANDSHAKE_ACK = "HANDSHAKE_ACK" # TransferQueueController -> TransferQueueStorageUnit - - # DATA_OPERATION - GET_DATA = "GET" - PUT_DATA = "PUT" - GET_DATA_RESPONSE = "GET_DATA_RESPONSE" - PUT_DATA_RESPONSE = "PUT_DATA_RESPONSE" - CLEAR_DATA = "CLEAR_DATA" - CLEAR_DATA_RESPONSE = "CLEAR_DATA_RESPONSE" - - PUT_GET_OPERATION_ERROR = "PUT_GET_OPERATION_ERROR" - PUT_GET_ERROR = "PUT_GET_ERROR" - PUT_ERROR = "PUT_ERROR" - GET_ERROR = "GET_ERROR" - CLEAR_DATA_ERROR = "CLEAR_DATA_ERROR" - - # META_OPERATION - GET_META = "GET_META" - GET_META_RESPONSE = "GET_META_RESPONSE" - GET_CLEAR_META = "GET_CLEAR_META" - GET_CLEAR_META_RESPONSE = "GET_CLEAR_META_RESPONSE" - CLEAR_META = "CLEAR_META" - CLEAR_META_RESPONSE = "CLEAR_META_RESPONSE" - - # CHECK_CONSUMPTION - CHECK_CONSUMPTION = "CHECK_CONSUMPTION" - CONSUMPTION_RESPONSE = "CONSUMPTION_RESPONSE" - - # NOTIFY_DATA_UPDATE - NOTIFY_DATA_UPDATE = "NOTIFY_DATA_UPDATE" - NOTIFY_DATA_UPDATE_ACK = "NOTIFY_DATA_UPDATE_ACK" - NOTIFY_DATA_UPDATE_ERROR = "NOTIFY_DATA_UPDATE_ERROR" - - -@dataclass -class ZMQServerInfo: - role: TransferQueueRole - id: str - ip: str - ports: dict[str, str] - - @classmethod - def create(cls, role: TransferQueueRole, id: str, ip: str, ports: dict[str, str]) -> Self: - return cls(role=role, id=id, ip=ip, ports=ports) - - def to_addr(self, port_name: str) -> str: - return f"tcp://{self.ip}:{self.ports[port_name]}" - - def to_dict(self): - return { - "role": self.role, - "id": self.id, - "ip": self.ip, - "ports": self.ports, - } - - def __str__(self) -> str: - return f"ZMQSocketInfo(role={self.role}, id={self.id}, ip={self.ip}, ports={self.ports})" - - -@dataclass -class ZMQMessage: - request_type: ZMQRequestType - sender_id: str - receiver_id: str | None - body: dict[str, Any] - request_id: str - timestamp: float - - @classmethod - def create( - cls, - request_type: ZMQRequestType, - sender_id: str, - body: dict[str, Any], - receiver_id: Optional[str] = None, - ) -> "ZMQMessage": - return cls( - request_type=request_type, - sender_id=sender_id, - receiver_id=receiver_id, - body=body, - request_id=str(uuid.uuid4()), - timestamp=time.time(), - ) - - def serialize(self) -> bytes: - """Using pickle to serialize ZMQMessage objects""" - return pickle.dumps(self) - - @classmethod - def deserialize(cls, data: bytes | list[bytes]): - """Using pickle to deserialize ZMQMessage objects""" - if isinstance(data, list): - # Process multiple byte streams by deserializing each in sequence - result = [] - for d in data: - result.append(pickle.loads(d)) - return result - else: - # Single byte stream case - return pickle.loads(data) - - -def get_free_port() -> str: - with socket.socket() as sock: - sock.bind(("", 0)) - return sock.getsockname()[1] - - -def create_zmq_socket( - ctx: zmq.Context, - socket_type: Any, - identity: Optional[bytes] = None, -) -> zmq.Socket: - mem = psutil.virtual_memory() - socket = ctx.socket(socket_type) - - # Calculate buffer size based on system memory - total_mem = mem.total / 1024**3 - available_mem = mem.available / 1024**3 - # For systems with substantial memory (>32GB total, >16GB available): - # - Set a large 0.5GB buffer to improve throughput - # For systems with less memory: - # - Use system default (-1) to avoid excessive memory consumption - if total_mem > 32 and available_mem > 16: - buf_size = int(0.5 * 1024**3) # 0.5GB in bytes - else: - buf_size = -1 # Use system default buffer size - - if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER): - socket.setsockopt(zmq.RCVHWM, 0) - socket.setsockopt(zmq.RCVBUF, buf_size) - - if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER): - socket.setsockopt(zmq.SNDHWM, 0) - socket.setsockopt(zmq.SNDBUF, buf_size) - - if identity is not None: - socket.setsockopt(zmq.IDENTITY, identity) - return socket From 0cef8a0ef629bf76066dca1c64835c88712c6d73 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 17 Nov 2025 15:02:35 +0800 Subject: [PATCH 14/16] update docs Signed-off-by: 0oshowero0 --- docs/data/transfer_queue.md | 203 ++++++++++++++++++++++++++++++------ 1 file changed, 169 insertions(+), 34 deletions(-) diff --git a/docs/data/transfer_queue.md b/docs/data/transfer_queue.md index 4532d42ed56..f15de29fab6 100644 --- a/docs/data/transfer_queue.md +++ b/docs/data/transfer_queue.md @@ -1,52 +1,73 @@ # TransferQueue Data System -Last updated: 09/28/2025. +Last updated: 11/17/2025. This doc introduce [TransferQueue](https://github.com/TransferQueue/TransferQueue), an asynchronous streaming data management system for efficient post-training.

Overview

-TransferQueue is a high-performance data storage and transfer system with panoramic data visibility and streaming scheduling capabilities, optimized for efficient dataflow in post-training workflows. +TransferQueue is a high-performance data storage and transfer module with panoramic data visibility and streaming scheduling capabilities, optimized for efficient dataflow in post-training workflows.

- +

- -TransferQueue offers **fine-grained, sample-level** data management capabilities, serving as a data gateway that decouples explicit data dependencies across computational tasks. This enables a divide-and-conquer approach, significantly simplifying the design of the algorithm controller. - +TransferQueue offers **fine-grained, sample-level** data management and **load-balancing** (on the way) capabilities, serving as a data gateway that decouples explicit data dependencies across computational tasks. This enables a divide-and-conquer approach, significantly simplifies the algorithm controller design.

- +

+

Updates

- + - **Nov 10, 2025**: We disentangle the data retrieval logic from TransferQueueController [PR#101](https://github.com/TransferQueue/TransferQueue/pull/101). Now you can implement your own `Sampler` to control how to consume the data. + - **Nov 5, 2025**: We provide a `KVStorageManager` that simplifies the integration with KV-based storage backends [PR#96](https://github.com/TransferQueue/TransferQueue/pull/96). The first available KV-based backend is [Yuanrong](https://gitee.com/openeuler/yuanrong-datasystem). + - **Nov 4, 2025**: Data partition capability is available in [PR#98](https://github.com/TransferQueue/TransferQueue/pull/98). Now you can define logical data partitions to manage your train/val/test datasets. + - **Oct 25, 2025**: We make storage backends pluggable in [PR#66](https://github.com/TransferQueue/TransferQueue/pull/66). You can try to integrate your own storage backend with TransferQueue now! + - **Oct 21, 2025**: Official integration into verl is ready [verl/pulls/3649](https://github.com/volcengine/verl/pull/3649). Following PRs will optimize the single controller architecture by fully decoupling data & control flows. + - **July 22, 2025**: We present a series of Chinese blogs on Zhihu 1, 2. + - **July 21, 2025**: We started an RFC on verl community [verl/RFC#2662](https://github.com/volcengine/verl/discussions/2662). + - **July 2, 2025**: We publish the paper [AsyncFlow](https://arxiv.org/abs/2507.01663).

Components

+### Control Plane: Panoramic Data Management +In the control plane, `TransferQueueController` tracks the **production status** and **consumption status** of each training sample as metadata. When all the required data fields are ready (i.e., written to the `TransferQueueStorageManager`), we know that this data sample can be consumed by downstream tasks. -### Control Plane: Panoramic Data Management - -In the control plane, `TransferQueueController` tracks the **production status** and **consumption status** of each training sample as metadata. When all the required data fields are ready (i.e., written to the `TransferQueueStorage`), we know that this data sample can be consumed by downstream tasks. - -For consumption status, we record the consumption records for each computational task (e.g., `generate_sequences`, `compute_log_prob`, etc.). Therefore, even different computation tasks require the same data field, they can consume the data independently without interfering with each other. - +For consumption status, we record the consumption records for each computational task (e.g., `generate_sequences`, `compute_log_prob`, etc.). Therefore, even when different computation tasks require the same data field, they can consume the data independently without interfering with each other.

- +

+To make the data retrieval process more customizable, we provide a `Sampler` class that allows users to define their own data retrieval and consumption logic. Refer to the [Customize](#customize) section for details. -> In the future, we plan to support **load-balancing** and **dynamic batching** capabilities in the control plane. Besides, we will support data management for disaggregated frameworks where each rank manages the data retrieval by itself, rather than coordinated by a single controller. +> In the future, we plan to support **load-balancing** and **dynamic batching** capabilities in the control plane. Additionally, we will support data management for disaggregated frameworks where each rank manages the data retrieval by itself, rather than coordinated by a single controller. ### Data Plane: Distributed Data Storage -In the data plane, `TransferQueueStorageSimpleUnit` serves as a naive storage unit based on CPU memory, responsible for the actual storage and retrieval of data. Each storage unit can be deployed on a separate node, allowing for distributed data management. +In the data plane, we provide a pluggable design that enables TransferQueue to integrate with different storage backends according to user requirements. + +Specifically, we provide a `TransferQueueStorageManager` abstraction class that defines the core APIs as follows: -`TransferQueueStorageSimpleUnit` employs a 2D data structure as follows: +- `async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None` +- `async def get_data(self, metadata: BatchMeta) -> TensorDict` +- `async def clear_data(self, metadata: BatchMeta) -> None` + +This class encapsulates the core interaction logic within the TransferQueue system. You only need to write a simple subclass to integrate your own storage backend. Refer to the [Customize](#customize) section for details. + +Currently, we support the following storage backends: + +- SimpleStorageUnit: A basic CPU memory storage with minimal data format constraints and easy usability. +- [MoonCakeStore](https://github.com/kvcache-ai/Mooncake): A high-performance, KV-based hierarchical storage that supports RDMA transport between GPU and DRAM. +- [Yuanrong](https://gitee.com/openeuler/yuanrong-datasystem): An Ascend native data system that provides hierarchical storage interfaces including HBM/DRAM/SSD. +- [Ray Direct Transport](https://docs.ray.io/en/master/ray-core/direct-transport.html): Ray's new feature that allows Ray to store and pass objects directly between Ray actors. + +Among them, `SimpleStorageUnit` serves as our default storage backend, coordinated by the `AsyncSimpleStorageManager` class. Each storage unit can be deployed on a separate node, allowing for distributed data management. + +`SimpleStorageUnit` employs a 2D data structure as follows: - Each row corresponds to a training sample, assigned a unique index within the corresponding global batch. - Each column represents the input/output data fields for computational tasks. @@ -54,29 +75,22 @@ In the data plane, `TransferQueueStorageSimpleUnit` serves as a naive storage un This data structure design is motivated by the computational characteristics of the post-training process, where each training sample is generated in a relayed manner across task pipelines. It provides an accurate addressing capability, which allows fine-grained, concurrent data read/write operations in a streaming manner.

- +

- -> In the future, we plan to implement a **general storage abstraction layer** to support various storage backends. Through this abstraction, we hope to integrate high-performance storage solutions such as [MoonCakeStore](https://github.com/kvcache-ai/Mooncake) to support device-to-device data transfer through RDMA, further enhancing data transfer efficiency for large-scale data. - - ### User Interface: Asynchronous & Synchronous Client - The interaction workflow of TransferQueue system is as follows: 1. A process sends a read request to the `TransferQueueController`. 2. `TransferQueueController` scans the production and consumption metadata for each sample (row), and dynamically assembles a micro-batch metadata according to the load-balancing policy. This mechanism enables sample-level data scheduling. 3. The process retrieves the actual data from distributed storage units using the metadata provided by the controller. -To simplify the usage of TransferQueue, we have encapsulated this process into `AsyncTransferQueueClient` and `TransferQueueClient`. These clients provide both asynchronous and synchronous interfaces for data transfer, allowing users to easily integrate TransferQueue to their framework. - - -> In the future, we will provide a `StreamingDataLoader` interface for disaggregated frameworks as discussed in [RFC#2662](https://github.com/volcengine/verl/discussions/2662). Leveraging this abstraction, each rank can automatically get its own data like `DataLoader` in PyTorch. The TransferQueue system will handle the underlying data scheduling and transfer logic caused by different parallelism strategies, significantly simplifying the design of disaggregated frameworks. +To simplify the usage of TransferQueue, we have encapsulated this process into `AsyncTransferQueueClient` and `TransferQueueClient`. These clients provide both asynchronous and synchronous interfaces for data transfer, allowing users to easily integrate TransferQueue into their framework. +> In the future, we will provide a `StreamingDataLoader` interface for disaggregated frameworks as discussed in [issue#85](https://github.com/TransferQueue/TransferQueue/issues/85) and [verl/RFC#2662](https://github.com/volcengine/verl/discussions/2662). Leveraging this abstraction, each rank can automatically get its own data like `DataLoader` in PyTorch. The TransferQueue system will handle the underlying data scheduling and transfer logic caused by different parallelism strategies, significantly simplifying the design of disaggregated frameworks. -

Show Cases

+

๐Ÿ”ฅ Showcases

### General Usage @@ -89,16 +103,15 @@ Core interfaces: - (async_)put(data:TensorDict, metadata:BatchMeta, global_step) - (async_)clear(global_step: int) - We will soon release a detailed tutorial and API documentation. ### verl Example +The primary motivation for integrating TransferQueue to verl now is to **alleviate the data transfer bottleneck of the single controller `RayPPOTrainer`**. Currently, all `DataProto` objects must be routed through `RayPPOTrainer`, resulting in a single point bottleneck of the whole post-training system. -The primary motivation for integrating TransferQueue to verl now is to **alleviate the data transfer bottleneck of the single controller `RayPPOTrainer`**. Currently, all `DataProto` objects must be routed through `RayPPOTrainer`, resulting in a single point bottleneck of the whole post-training system. +![verl_dataflow_DataProto](https://github.com/TransferQueue/community_doc/blob/main/docs/verl_workflow.jpeg?raw=true) -![verl_dataflow_DataProto](https://cdn.nlark.com/yuque/0/2025/jpeg/23208217/1758704289414-bcc54228-716b-4d4a-ad3b-f9ace6d10fcf.jpeg) Leveraging TransferQueue, we separate experience data transfer from metadata dispatch by @@ -106,12 +119,134 @@ Leveraging TransferQueue, we separate experience data transfer from metadata dis - Preserving verl's original Dispatch/Collect logic via BatchMeta (maintaining single-controller debuggability) - Accelerating data transfer by TransferQueue's distributed storage units -![verl_dataflow_TransferQueue](https://cdn.nlark.com/yuque/0/2025/jpeg/23208217/1758704301666-0807dc06-766c-4a2d-9cde-889a6bb56b34.jpeg) +![verl_dataflow_TransferQueue](https://github.com/TransferQueue/community_doc/blob/main/docs/verl_workflow_with_tq.jpeg?raw=true) + + +You may refer to the [recipe](https://github.com/TransferQueue/TransferQueue/tree/dev/recipe/simple_use_case), where we mimic the verl usage in both async & sync scenarios. Official integration to verl is also available now at [verl/pulls/3649](https://github.com/volcengine/verl/pull/3649) (with subsequent PRs to further optimize the integration). -You may refer to the [recipe](https://github.com/TransferQueue/TransferQueue/tree/dev/recipe/simple_use_case), where we mimic the verl usage in both async & sync scenarios. +### Use Python package +```bash +pip install TransferQueue==0.1.1.dev2 +``` +### Build wheel package from source code + +Follow these steps to build and install: +1. Clone the source code from the GitHub repository + ```bash + git clone https://github.com/TransferQueue/TransferQueue/ + cd TransferQueue + ``` + +2. Install dependencies + ```bash + pip install -r requirements.txt + ``` + +3. Build and install + ```bash + python -m build --wheel + pip install dist/*.whl + ``` + +

๐Ÿ“Š Performance

+ +

+ +

+> Note: The above benchmark for TransferQueue is based on our naive `SimpleStorageUnit` backend. By introducing high-performance storage backends and optimizing serialization/deserialization, we expect to achieve even better performance. Warmly welcome contributions from the community! + +For detailed performance benchmarks, please refer to [this blog](https://www.yuque.com/haomingzi-lfse7/hlx5g0/obi4ovmy9wf08zz3?singleDoc#). + +

๐Ÿ› ๏ธ Customize TransferQueue

+ +### Define your own data retrieval logic +We provide a `BaseSampler` abstraction class, which defines the following interface: + +```python3 +@abstractmethod +def sample( + self, + ready_indexes: list[int], + batch_size: int, + *args: Any, + **kwargs: Any, +) -> tuple[list[int], list[int]]: + """Sample a batch of indices from the ready indices. + + Args: + ready_indexes: List of global indices for which all required fields of the + corresponding samples have been produced, and the samples are not labeled as + consumed in the corresponding task. + batch_size: Number of samples to select + *args: Additional positional arguments for specific sampler implementations + **kwargs: Additional keyword arguments for specific sampler implementations + + Returns: + List of sampled global indices of length batch_size + List of global indices of length batch_size that should be labeled as consumed + (will never be retrieved in the future) + + Raises: + ValueError: If batch_size is invalid or ready_indexes is insufficient + """ + raise NotImplementedError("Subclasses must implement sample") +``` + +In this design, we separate data retrieval and data consumption through the two return values, which enables us to easily control sample replacement. We have implemented two reference designs: `SequentialSampler` and `GRPOGroupNSampler`. + +The `Sampler` class or instance should be passed to the `TransferQueueController` during initialization. During each `get_meta` call, you can provide dynamic sampling parameters to the `Sampler`. + +```python3 +from transfer_queue import TransferQueueController, TransferQueueClient, GRPOGroupNSampler, process_zmq_server_info + +# Option 1: Pass the sampler class to the TransferQueueController +controller = TransferQueueController.remote(GRPOGroupNSampler) + +# Option 2: Pass the sampler instance to the TransferQueueController (if you need custom configuration) +your_own_sampler = YourOwnSampler(config) +controller = TransferQueueController.remote(your_own_sampler) + +# Use the sampler +batch_meta = client.get_meta( + data_fields=["input_ids", "attention_mask"], + batch_size=8, + partition_id="train_0", + task_name="generate_sequences", + sampling_config={"n_samples_per_prompt": 4} # Put the required sampling parameters here +) +``` + +### How to integrate a new storage backend + +The data plane is organized as follows: +```text + transfer_queue/ + โ”œโ”€โ”€ storage/ + โ”‚ โ”œโ”€โ”€ __init__.py + โ”‚ โ”‚โ”€โ”€ simple_backend.py # SimpleStorageUnitใ€StorageUnitDataใ€StorageMetaGroup + โ”‚ โ”œโ”€โ”€ managers/ # Managers are upper level interfaces that encapsulate the interaction logic with TQ system. + โ”‚ โ”‚ โ”œโ”€โ”€ __init__.py + โ”‚ โ”‚ โ”œโ”€โ”€base.py # TransferQueueStorageManager, KVStorageManager + โ”‚ โ”‚ โ”œโ”€โ”€simple_backend_manager.py # AsyncSimpleStorageManager + โ”‚ โ”‚ โ”œโ”€โ”€yuanrong_manager.py # YuanrongStorageManager + โ”‚ โ”‚ โ”œโ”€โ”€mooncake_manager.py # MooncakeStorageManager + โ”‚ โ”‚ โ””โ”€โ”€factory.py # TransferQueueStorageManagerFactory + โ”‚ โ””โ”€โ”€ clients/ # Clients are lower level interfaces that directly manipulate the target storage backend. + โ”‚ โ”‚ โ”œโ”€โ”€ __init__.py + โ”‚ โ”‚ โ”œโ”€โ”€ base.py # TransferQueueStorageKVClient + โ”‚ โ”‚ โ”œโ”€โ”€ yuanrong_client.py # YRStorageClient + โ”‚ โ”‚ โ”œโ”€โ”€ mooncake_client.py # MooncakeStoreClient + โ”‚ โ”‚ โ””โ”€โ”€ factory.py # TransferQueueStorageClientFactory +``` + +To integrate TransferQueue with a custom storage backend, start by implementing a subclass that inherits from `TransferQueueStorageManager`. This subclass acts as an adapter between the TransferQueue system and the target storage backend. For KV-based storage backends, you can simply inherit from `KVStorageManager`, which can serve as the general manager for all KV-based backends. + +Distributed storage backends often come with their own native clients serving as the interface of the storage system. In such cases, a low-level adapter for this client can be written, following the examples provided in the `storage/clients` directory. + +Factory classes are provided for both `StorageManager` and `StorageClient` to facilitate easy integration. Adding necessary descriptions of required parameters in the factory class helps enhance the overall user experience. From 88569095eeac4517424e00bc2664d577e78a7bed Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 17 Nov 2025 21:05:04 +0800 Subject: [PATCH 15/16] update performance Signed-off-by: 0oshowero0 --- docs/data/transfer_queue.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/data/transfer_queue.md b/docs/data/transfer_queue.md index f15de29fab6..877a46b9063 100644 --- a/docs/data/transfer_queue.md +++ b/docs/data/transfer_queue.md @@ -153,12 +153,12 @@ Follow these steps to build and install:

๐Ÿ“Š Performance

- +

> Note: The above benchmark for TransferQueue is based on our naive `SimpleStorageUnit` backend. By introducing high-performance storage backends and optimizing serialization/deserialization, we expect to achieve even better performance. Warmly welcome contributions from the community! -For detailed performance benchmarks, please refer to [this blog](https://www.yuque.com/haomingzi-lfse7/hlx5g0/obi4ovmy9wf08zz3?singleDoc#). +For detailed performance benchmarks, please refer to [this blog](https://www.yuque.com/haomingzi-lfse7/hlx5g0/tml8ke0zkgn6roey?singleDoc#).

๐Ÿ› ๏ธ Customize TransferQueue

From 2ae18a6b9e9f4781a200eb2b788de9fe11a32b49 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 18 Nov 2025 11:14:55 +0800 Subject: [PATCH 16/16] fix Signed-off-by: 0oshowero0 --- recipe/transfer_queue/ray_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index 9874fc7e0dc..daa2f8d95b6 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -1118,7 +1118,9 @@ def _balance_batch( world_size = self.actor_rollout_wg.world_size if keep_minibatch: # Decouple the DP balancing and mini-batching. - minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size") + minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size", None) + if minibatch_size is None: + raise ValueError("'ppo_mini_batch_size' must be set in actor config when 'keep_minibatch' is True.") minibatch_num = len(global_seqlen_lst) // minibatch_size global_partition_lst = [[] for _ in range(world_size)] for i in range(minibatch_num):