From a13b642cf308bd79e463180097a4e527eafedcf5 Mon Sep 17 00:00:00 2001 From: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com> Date: Mon, 17 Nov 2025 13:05:12 +0800 Subject: [PATCH 1/2] [recipe, data] feat: TransferQueue - Support managing multiple data partitions for Train/Val/Test in controller --- recipe/transfer_queue/agent_loop.py | 7 +- recipe/transfer_queue/ray_trainer.py | 234 +++++++++++--------------- requirements_transferqueue.txt | 2 +- verl/single_controller/base/worker.py | 8 +- verl/utils/transferqueue_utils.py | 26 +-- 5 files changed, 111 insertions(+), 166 deletions(-) diff --git a/recipe/transfer_queue/agent_loop.py b/recipe/transfer_queue/agent_loop.py index 871ae8025c0..7f936e6730e 100644 --- a/recipe/transfer_queue/agent_loop.py +++ b/recipe/transfer_queue/agent_loop.py @@ -67,10 +67,7 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Data return timing - def create_transferqueue_client(self, controller_infos, storage_infos, role): + def create_transferqueue_client(self, controller_info, config): ray.get( - [ - worker.create_transferqueue_client.remote(controller_infos, storage_infos, role) - for worker in self.agent_loop_workers - ] + [worker.create_transferqueue_client.remote(controller_info, config) for worker in self.agent_loop_workers] ) diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index d6adbddb676..9874fc7e0dc 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -41,8 +41,8 @@ from tqdm import tqdm from transfer_queue import ( BatchMeta, + SimpleStorageUnit, TransferQueueController, - TransferQueueStorageSimpleUnit, get_placement_group, process_zmq_server_info, ) @@ -81,6 +81,7 @@ from verl.utils.metric import reduce_metrics from verl.utils.rollout_skip import RolloutSkip from verl.utils.seqlen_balancing import ( + calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance, ) @@ -89,7 +90,6 @@ from verl.utils.transferqueue_utils import ( create_transferqueue_client, get_transferqueue_client, - get_val_transferqueue_client, tqbridge, ) @@ -412,109 +412,66 @@ def __init__( self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) - self.data_system_client = self._initialize_train_data_system( - self.config.data.train_batch_size, self.config.actor_rollout_ref.rollout.n + self.data_system_client = self._initialize_data_system() + + def _initialize_data_system(self): + # 1. initialize TransferQueueStorage + train_data_size = ( + self.config.data.train_batch_size + * self.config.trainer.num_global_batch + * self.config.actor_rollout_ref.rollout.n ) - self.val_data_system_client = self._initialize_val_data_system( - self.val_batch_size, self.config.actor_rollout_ref.rollout.val_kwargs.n + val_data_size = ( + self.val_batch_size + * self.config.trainer.num_global_batch + * self.config.actor_rollout_ref.rollout.val_kwargs.n ) - def _initialize_train_data_system(self, global_batch_size, num_n_samples, role="train"): - # 1. initialize TransferQueueStorage - total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples + total_storage_size = train_data_size + val_data_size self.data_system_storage_units = {} storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1) for storage_unit_rank in range(self.config.trainer.num_data_storage_units): - storage_node = TransferQueueStorageSimpleUnit.options( + storage_node = SimpleStorageUnit.options( placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank - ).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units)) + ).remote(storage_unit_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units)) self.data_system_storage_units[storage_unit_rank] = storage_node - logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.") - - # 2. initialize TransferQueueController - # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly - # one controller for a single WorkerGroup. - self.data_system_controllers = {} - controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1) - for controller_rank in range(self.config.trainer.num_data_controllers): - self.data_system_controllers[controller_rank] = TransferQueueController.options( - placement_group=controller_placement_group, placement_group_bundle_index=controller_rank - ).remote( - num_storage_units=self.config.trainer.num_data_storage_units, - global_batch_size=global_batch_size, - num_global_batch=self.config.trainer.num_global_batch, - num_n_samples=num_n_samples, - ) - logging.info(f"TransferQueueController #{controller_rank} has been created.") + logging.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.") - # 3. register controller & storage - self.data_system_controller_infos = process_zmq_server_info(self.data_system_controllers) - self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units) + # 2. Initialize TransferQueueController (single controller only) - ray.get( - [ - storage_unit.register_controller_info.remote(self.data_system_controller_infos) - for storage_unit in self.data_system_storage_units.values() - ] - ) + # Sampler usage instructions: + # For GRPO grouped sampling, you can initialize the controller with GRPOGroupNSampler: + # Option 1: Pass sampler class (will be instantiated automatically) + # self.data_system_controller = TransferQueueController.remote(sampler=GRPOGroupNSampler) - # 4. create client - # each client should be allocated to exactly one controller - create_transferqueue_client( - client_id="Trainer-" + role, - controller_infos=self.data_system_controller_infos, - storage_infos=self.data_system_storage_unit_infos, - ) - data_system_client = get_transferqueue_client() - return data_system_client + # Option 2: Pass sampler instance (if you need custom configuration) + # grpo_sampler = GRPOGroupNSampler() + # self.data_system_controller = TransferQueueController.remote(sampler=grpo_sampler) - def _initialize_val_data_system(self, global_batch_size, num_n_samples, role="val"): - # 1. initialize TransferQueueStorage - total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples - self.val_data_system_storage_units = {} - storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1) - for storage_unit_rank in range(self.config.trainer.num_data_storage_units): - storage_node = TransferQueueStorageSimpleUnit.options( - placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank - ).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units)) - self.val_data_system_storage_units[storage_unit_rank] = storage_node - logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.") - - # 2. initialize TransferQueueController - # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly - # one controller for a single WorkerGroup. - self.val_data_system_controllers = {} - controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1) - for controller_rank in range(self.config.trainer.num_data_controllers): - self.val_data_system_controllers[controller_rank] = TransferQueueController.options( - placement_group=controller_placement_group, placement_group_bundle_index=controller_rank - ).remote( - num_storage_units=self.config.trainer.num_data_storage_units, - global_batch_size=global_batch_size, - num_global_batch=self.config.trainer.num_global_batch, - num_n_samples=num_n_samples, - ) - logging.info(f"TransferQueueController #{controller_rank} has been created.") + # Then use sampling_config in get_meta calls: + # sampling_config={"n_samples_per_prompt": 4} + self.data_system_controller = TransferQueueController.remote() + logging.info("TransferQueueController has been created.") - # 3. register controller & storage - self.val_data_system_controller_infos = process_zmq_server_info(self.val_data_system_controllers) - self.val_data_system_storage_unit_infos = process_zmq_server_info(self.val_data_system_storage_units) + # 3. register controller & storage and prepare necessary information + self.data_system_controller_info = process_zmq_server_info(self.data_system_controller) + self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units) - ray.get( - [ - storage_unit.register_controller_info.remote(self.val_data_system_controller_infos) - for storage_unit in self.val_data_system_storage_units.values() - ] - ) + # Note: Need to generate a new DictConfig with allow_objects=True to preserve ZMQServerInfo instances + # (which contain socket connection details). Without this flag, OmegaConf would flatten these objects to dicts, + # breaking the transfer queue client initialization. + tq_config = OmegaConf.create({}, flags={"allow_objects": True}) + tq_config.controller_info = self.data_system_controller_info + tq_config.storage_unit_infos = self.data_system_storage_unit_infos + self.config = OmegaConf.merge(tq_config, self.config) # 4. create client - # each client should be allocated to exactly one controller create_transferqueue_client( - client_id="Trainer-" + role, - controller_infos=self.val_data_system_controller_infos, - storage_infos=self.val_data_system_storage_unit_infos, + client_id="Trainer", + controller_info=self.data_system_controller_info, + config=self.config, ) - data_system_client = get_val_transferqueue_client() + data_system_client = get_transferqueue_client() return data_system_client def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): @@ -726,19 +683,18 @@ def _validate(self): if self.config.reward_model.enable and test_batch[0]["reward_model"]["style"] == "model": return {} - asyncio.run(self.val_data_system_client.async_put(data=test_batch, global_step=self.global_steps - 1)) + asyncio.run(self.data_system_client.async_put(data=test_batch, partition_id=f"val_{self.global_steps - 1}")) # Store original inputs batch_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=["input_ids", "uid", "reward_model"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", task_name="get_data", ) ) - data = asyncio.run(self.val_data_system_client.async_get_data(batch_meta)) + data = asyncio.run(self.data_system_client.async_get_data(batch_meta)) input_ids = data["input_ids"] # TODO: Can we keep special tokens except for padding tokens? input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] @@ -749,11 +705,10 @@ def _validate(self): sample_gts.extend(ground_truths) test_gen_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=list(test_batch.keys()), # TODO: (TQ) Get metadata by specified fields batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 task_name="generate_sequences", ) ) @@ -779,15 +734,14 @@ def _validate(self): # Store generated outputs test_response_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=["responses"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 task_name="get_response", ) ) - data = asyncio.run(self.val_data_system_client.async_get_data(test_response_meta)) + data = asyncio.run(self.data_system_client.async_get_data(test_response_meta)) output_ids = data["responses"] output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] sample_outputs.extend(output_texts) @@ -808,11 +762,10 @@ def _validate(self): if "rm_scores" in batch_meta.field_names: compute_reward_fields = ["rm_scores"] val_reward_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=compute_reward_fields, batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", task_name="compute_reward", ) ) @@ -832,29 +785,27 @@ def _validate(self): # collect num_turns of each prompt if "__num_turns__" in test_batch_meta.field_names: num_turns_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=["__num_turns__"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 task_name="get_num_turns", ) ) - data = asyncio.run(self.val_data_system_client.async_get_data(num_turns_meta)) + data = asyncio.run(self.data_system_client.async_get_data(num_turns_meta)) sample_turns.append(data["__num_turns__"]) data_source = ["unknown"] * reward_tensor.shape[0] if "data_source" in test_batch_meta.field_names: data_source_meta = asyncio.run( - self.val_data_system_client.async_get_meta( + self.data_system_client.async_get_meta( data_fields=["data_source"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 - get_n_samples=False, + partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 task_name="get_data_source", ) ) - data = asyncio.run(self.val_data_system_client.async_get_data(data_source_meta)) + data = asyncio.run(self.data_system_client.async_get_data(data_source_meta)) data_source = data["data_source"] data_source_lst.append(data_source) @@ -902,7 +853,7 @@ def _validate(self): metric_dict["val-aux/num_turns/max"] = sample_turns.max() metric_dict["val-aux/num_turns/mean"] = sample_turns.mean() - asyncio.run(self.val_data_system_client.async_clear(self.global_steps - 1)) + asyncio.run(self.data_system_client.async_clear(partition_id=f"val_{self.global_steps - 1}")) return metric_dict def init_workers(self): @@ -1003,12 +954,7 @@ def init_workers(self): # set transferqueue server info for each worker for _, wg in all_wg.items(): - wg.create_transferqueue_client( - self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train" - ) - wg.create_transferqueue_client( - self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val" - ) + wg.create_transferqueue_client(self.data_system_controller_info, self.config) # create async rollout manager and request scheduler self.async_rollout_mode = False @@ -1020,12 +966,7 @@ def init_workers(self): config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg ) - self.async_rollout_manager.create_transferqueue_client( - self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train" - ) - self.async_rollout_manager.create_transferqueue_client( - self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val" - ) + self.async_rollout_manager.create_transferqueue_client(self.data_system_controller_info, self.config) def _save_checkpoint(self): from verl.utils.fs import local_mkdir_safe @@ -1164,17 +1105,39 @@ def _stop_profiling(self, do_profile: bool) -> None: if self.use_rm: self.rm_wg.stop_profile() - def _balance_batch(self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen"): + def _balance_batch( + self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen", keep_minibatch=False + ): """Reorder the batchmeta on single controller such that each dp rank gets similar total tokens""" data = asyncio.run(data_system_client.async_get_data(batch)) attention_mask = data["attention_mask"] batch_size = attention_mask.shape[0] - global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,) + global_seqlen_lst = calculate_workload(global_seqlen_lst) world_size = self.actor_rollout_wg.world_size - global_partition_lst = get_seqlen_balanced_partitions( - global_seqlen_lst, k_partitions=world_size, equal_size=True - ) + if keep_minibatch: + # Decouple the DP balancing and mini-batching. + minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size") + minibatch_num = len(global_seqlen_lst) // minibatch_size + global_partition_lst = [[] for _ in range(world_size)] + for i in range(minibatch_num): + rearrange_minibatch_lst = get_seqlen_balanced_partitions( + global_seqlen_lst[i * minibatch_size : (i + 1) * minibatch_size], + k_partitions=world_size, + equal_size=True, + ) + for j, part in enumerate(rearrange_minibatch_lst): + global_partition_lst[j].extend([x + minibatch_size * i for x in part]) + else: + global_partition_lst = get_seqlen_balanced_partitions( + global_seqlen_lst, k_partitions=world_size, equal_size=True + ) + # Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel. + for idx, partition in enumerate(global_partition_lst): + partition.sort(key=lambda x: (global_seqlen_lst[x], x)) + ordered_partition = partition[::2] + partition[1::2][::-1] + global_partition_lst[idx] = ordered_partition # reorder based on index. The data will be automatically equally partitioned by dispatch function global_idx = [j for partition in global_partition_lst for j in partition] global_balance_stats = log_seqlen_unbalance( @@ -1313,8 +1276,7 @@ def fit(self): timing_raw = {} base_get_meta_kwargs = dict( batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, - global_step=self.global_steps - 1, # self.global_steps starts from 1 - get_n_samples=False, + partition_id=f"train_{self.global_steps - 1}", # self.global_steps starts from 1 ) with marked_timer("start_profile", timing_raw): @@ -1333,7 +1295,9 @@ def fit(self): batch_dict, repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True ) batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict) - asyncio.run(self.data_system_client.async_put(data=batch, global_step=self.global_steps - 1)) + asyncio.run( + self.data_system_client.async_put(data=batch, partition_id=f"train_{self.global_steps - 1}") + ) gen_meta = asyncio.run( self.data_system_client.async_get_meta( @@ -1709,8 +1673,7 @@ def fit(self): ], batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, - global_step=self.global_steps - 1, - get_n_samples=False, + partition_id=f"train_{self.global_steps - 1}", task_name="update_actor", ) ) @@ -1735,8 +1698,7 @@ def fit(self): self.data_system_client.async_get_meta( data_fields=data_fields, batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, - global_step=self.global_steps - 1, - get_n_samples=False, + partition_id=f"train_{self.global_steps - 1}", task_name="log_rollout", ) ) @@ -1857,7 +1819,7 @@ def fit(self): # TODO: (TQ) support transfer queue self.train_dataloader.sampler.update(batch=batch) - asyncio.run(self.data_system_client.async_clear(self.global_steps - 1)) + asyncio.run(self.data_system_client.async_clear(partition_id=f"train_{self.global_steps - 1}")) # TODO: make a canonical logger that supports various backend logger.log(data=metrics, step=self.global_steps) diff --git a/requirements_transferqueue.txt b/requirements_transferqueue.txt index 8479d27bb21..621682abbf7 100644 --- a/requirements_transferqueue.txt +++ b/requirements_transferqueue.txt @@ -1,2 +1,2 @@ # requirements.txt records the full set of dependencies for development -git+https://github.com/TransferQueue/TransferQueue.git@68c04e7 +transferqueue==0.1.1.dev2 diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index 2513c57f99c..399ac75a063 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -131,13 +131,13 @@ def _query_collect_info(self, mesh_name: str): return self.__collect_dp_rank[mesh_name] @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True) - def create_transferqueue_client(self, controller_infos, storage_infos, role="train"): + def create_transferqueue_client(self, controller_info, config): from verl.utils.transferqueue_utils import create_transferqueue_client create_transferqueue_client( - client_id=f"{role}_worker_{self.rank}", - controller_infos=controller_infos, - storage_infos=storage_infos, + client_id=f"worker_{self.rank}", + controller_info=controller_info, + config=config, ) @classmethod diff --git a/verl/utils/transferqueue_utils.py b/verl/utils/transferqueue_utils.py index 27160571ef3..c692578e3a0 100644 --- a/verl/utils/transferqueue_utils.py +++ b/verl/utils/transferqueue_utils.py @@ -38,32 +38,24 @@ class BatchMeta: from verl.protocol import DataProto _TRANSFER_QUEUE_CLIENT = None -_VAL_TRANSFER_QUEUE_CLIENT = None is_transferqueue_enabled = os.environ.get("TRANSFER_QUEUE_ENABLE", False) def create_transferqueue_client( client_id: str, - controller_infos: dict[Any, "ZMQServerInfo"], - storage_infos: dict[Any, "ZMQServerInfo"], + controller_info: dict[Any, "ZMQServerInfo"], + config, ) -> None: global _TRANSFER_QUEUE_CLIENT - global _VAL_TRANSFER_QUEUE_CLIENT - if "val" in client_id: - _VAL_TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_infos, storage_infos) - else: - _TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_infos, storage_infos) + _TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_info) + _TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) def get_transferqueue_client() -> "AsyncTransferQueueClient": return _TRANSFER_QUEUE_CLIENT -def get_val_transferqueue_client() -> "AsyncTransferQueueClient": - return _VAL_TRANSFER_QUEUE_CLIENT - - def _run_async_in_temp_loop(async_func: Callable[..., Any], *args, **kwargs) -> Any: # Use a temporary event loop in a new thread because event # loop may already exist in server mode @@ -109,10 +101,7 @@ async def _async_batchmeta_to_dataproto(batchmeta: "BatchMeta") -> DataProto: meta_info=batchmeta.extra_info.copy(), ) - if batchmeta.extra_info.get("validate", False): - tensordict = await _VAL_TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta) - else: - tensordict = await _TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta) + tensordict = await _TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta) return DataProto.from_tensordict(tensordict, meta_info=batchmeta.extra_info.copy()) @@ -130,10 +119,7 @@ async def _async_update_batchmeta_with_output(output: DataProto, batchmeta: "Bat for key in output.meta_info.keys(): tensordict.pop(key) batchmeta.add_fields(tensordict) - if batchmeta.extra_info.get("validate", False): - await _VAL_TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) - else: - await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) + await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta") -> None: From ef357e0bd3ff75edb7661cee409111dd5badb2d3 Mon Sep 17 00:00:00 2001 From: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com> Date: Mon, 17 Nov 2025 14:32:51 +0800 Subject: [PATCH 2/2] Remove useless code --- .../transfer_queue/test_client.py | 385 --------- .../transfer_queue/test_controller.py | 264 ------ .../test_simple_storage_unit.py | 479 ----------- verl/experimental/transfer_queue/__init__.py | 14 - verl/experimental/transfer_queue/client.py | 662 --------------- .../experimental/transfer_queue/controller.py | 771 ------------------ verl/experimental/transfer_queue/metadata.py | 602 -------------- verl/experimental/transfer_queue/storage.py | 516 ------------ .../transfer_queue/utils/__init__.py | 14 - .../transfer_queue/utils/utils.py | 111 --- .../transfer_queue/utils/zmq_utils.py | 176 ---- 11 files changed, 3994 deletions(-) delete mode 100644 tests/experimental/transfer_queue/test_client.py delete mode 100644 tests/experimental/transfer_queue/test_controller.py delete mode 100644 tests/experimental/transfer_queue/test_simple_storage_unit.py delete mode 100644 verl/experimental/transfer_queue/__init__.py delete mode 100644 verl/experimental/transfer_queue/client.py delete mode 100644 verl/experimental/transfer_queue/controller.py delete mode 100644 verl/experimental/transfer_queue/metadata.py delete mode 100644 verl/experimental/transfer_queue/storage.py delete mode 100644 verl/experimental/transfer_queue/utils/__init__.py delete mode 100644 verl/experimental/transfer_queue/utils/utils.py delete mode 100644 verl/experimental/transfer_queue/utils/zmq_utils.py diff --git a/tests/experimental/transfer_queue/test_client.py b/tests/experimental/transfer_queue/test_client.py deleted file mode 100644 index f1b4efd191b..00000000000 --- a/tests/experimental/transfer_queue/test_client.py +++ /dev/null @@ -1,385 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -from threading import Thread - -import pytest -import torch -import zmq -from tensordict import NonTensorStack, TensorDict - -from verl.experimental.transfer_queue import TransferQueueClient # noqa: E402 -from verl.experimental.transfer_queue.metadata import ( # noqa: E402 - BatchMeta, - FieldMeta, - SampleMeta, -) -from verl.experimental.transfer_queue.utils.zmq_utils import ( # noqa: E402 - ZMQMessage, - ZMQRequestType, - ZMQServerInfo, -) - -TEST_DATA = TensorDict( - { - "log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])], - "variable_length_sequences": torch.nested.as_nested_tensor( - [ - torch.tensor([-0.5, -1.2, -0.8]), - torch.tensor([-0.3, -1.5, -2.1, -0.9]), - torch.tensor([-1.1, -0.7]), - ] - ), - "prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"], - }, - batch_size=[3], -) - - -# Mock Controller for Client Unit Testing -class MockController: - def __init__(self, controller_id="controller_0"): - self.controller_id = controller_id - self.context = zmq.Context() - - # Socket for data requests - self.request_socket = self.context.socket(zmq.ROUTER) - self.request_port = self._bind_to_random_port(self.request_socket) - - self.zmq_server_info = ZMQServerInfo.create( - role="TransferQueueController", - id=controller_id, - ip="127.0.0.1", - ports={ - "request_handle_socket": self.request_port, - }, - ) - - self.running = True - self.request_thread = Thread(target=self._handle_requests, daemon=True) - self.request_thread.start() - - def _bind_to_random_port(self, socket): - port = socket.bind_to_random_port("tcp://127.0.0.1") - return port - - def _handle_requests(self): - poller = zmq.Poller() - poller.register(self.request_socket, zmq.POLLIN) - - while self.running: - try: - socks = dict(poller.poll(100)) # 100ms timeout - if self.request_socket in socks: - identity, serialized_msg = self.request_socket.recv_multipart() - request_msg = ZMQMessage.deserialize(serialized_msg) - - # Determine response based on request type - if request_msg.request_type == ZMQRequestType.GET_META: - response_body = self._mock_batch_meta(request_msg.body) - response_type = ZMQRequestType.GET_META_RESPONSE - elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META: - response_body = self._mock_batch_meta(request_msg.body) - response_type = ZMQRequestType.GET_CLEAR_META_RESPONSE - elif request_msg.request_type == ZMQRequestType.CLEAR_META: - response_body = {"message": "clear ok"} - response_type = ZMQRequestType.CLEAR_META_RESPONSE - - # Send response - response_msg = ZMQMessage.create( - request_type=response_type, - sender_id=self.controller_id, - receiver_id=request_msg.sender_id, - body=response_body, - ) - self.request_socket.send_multipart([identity, response_msg.serialize()]) - except zmq.Again: - continue - except Exception as e: - if self.is_running: - print(f"MockController running exception: {e}") - else: - print(f"MockController ERROR: {e}") - raise - - def _mock_batch_meta(self, request_body): - batch_size = request_body.get("batch_size", 1) - data_fields = request_body.get("data_fields", []) - - samples = [] - for i in range(batch_size): - fields = [] - for field_name in data_fields: - field_meta = FieldMeta( - name=field_name, - dtype=None, - shape=None, - production_status=0, - ) - fields.append(field_meta) - sample = SampleMeta( - global_step=0, - global_index=i, - storage_id="storage_0", - local_index=i, - fields={field.name: field for field in fields}, - ) - samples.append(sample) - metadata = BatchMeta(samples=samples) - - return {"metadata": metadata} - - def stop(self): - self.running = False - time.sleep(0.2) # Give thread time to stop - self.request_socket.close() - self.context.term() - - -# Mock Storage for Client Unit Testing -class MockStorage: - def __init__(self, storage_id="storage_0"): - self.storage_id = storage_id - self.context = zmq.Context() - - # Socket for data operations - self.data_socket = self.context.socket(zmq.ROUTER) - self.data_port = self._bind_to_random_port(self.data_socket) - - self.zmq_server_info = ZMQServerInfo.create( - role="TransferQueueStorage", - id=storage_id, - ip="127.0.0.1", - ports={ - "put_get_socket": self.data_port, - }, - ) - - self.running = True - self.data_thread = Thread(target=self._handle_data_requests, daemon=True) - self.data_thread.start() - - def _bind_to_random_port(self, socket): - port = socket.bind_to_random_port("tcp://127.0.0.1") - return port - - def _handle_data_requests(self): - poller = zmq.Poller() - poller.register(self.data_socket, zmq.POLLIN) - - while self.running: - try: - socks = dict(poller.poll(100)) # 100ms timeout - if self.data_socket in socks: - identity, msg_bytes = self.data_socket.recv_multipart() - msg = ZMQMessage.deserialize(msg_bytes) - - # Handle different request types - if msg.request_type == ZMQRequestType.PUT_DATA: - response_body = {"message": "Data stored successfully"} - response_type = ZMQRequestType.PUT_DATA_RESPONSE - elif msg.request_type == ZMQRequestType.GET_DATA: - response_body = self._handle_get_data(msg.body) - response_type = ZMQRequestType.GET_DATA_RESPONSE - elif msg.request_type == ZMQRequestType.CLEAR_DATA: - response_body = {"message": "Data cleared successfully"} - response_type = ZMQRequestType.CLEAR_DATA_RESPONSE - - # Send response - response_msg = ZMQMessage.create( - request_type=response_type, - sender_id=self.storage_id, - receiver_id=msg.sender_id, - body=response_body, - ) - self.data_socket.send_multipart([identity, response_msg.serialize()]) - except zmq.Again: - continue - except Exception as e: - if self.is_running: - print(f"MockStorage running exception: {e}") - else: - print(f"MockStorage ERROR: {e}") - raise - - def _handle_get_data(self, request_body): - """Handle GET_DATA request by retrieving stored data""" - local_indexes = request_body.get("local_indexes", []) - fields = request_body.get("fields", []) - - result: dict[str, list] = {} - for field in fields: - gathered_items = [TEST_DATA[field][i] for i in local_indexes] - - if gathered_items: - all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items) - if all_tensors: - result[field] = torch.nested.as_nested_tensor(gathered_items) - else: - result[field] = NonTensorStack(*gathered_items) - - return {"data": TensorDict(result)} - - def stop(self): - self.running = False - time.sleep(0.2) # Give thread time to stop - self.data_socket.close() - self.context.term() - - -# Test Fixtures -@pytest.fixture -def mock_controller(): - controller = MockController() - yield controller - controller.stop() - - -@pytest.fixture -def mock_storage(): - storage = MockStorage() - yield storage - storage.stop() - - -@pytest.fixture -def client_setup(mock_controller, mock_storage): - # Create client with mock controller and storage - client_id = "client_0" - - client = TransferQueueClient( - client_id=client_id, - controller_infos={mock_controller.controller_id: mock_controller.zmq_server_info}, - storage_infos={mock_storage.storage_id: mock_storage.zmq_server_info}, - ) - - # Give some time for connections to establish - time.sleep(0.5) - - yield client, mock_controller, mock_storage - - -# Test basic functionality -def test_client_initialization(client_setup): - """Test client initialization and connection setup""" - client, mock_controller, mock_storage = client_setup - - assert client.client_id is not None - assert mock_controller.controller_id in client._controllers - assert mock_storage.storage_id in client._storages - - -def test_put_and_get_data(client_setup): - """Test basic put and get operations""" - client, _, _ = client_setup - - # Test put operation - client.put(data=TEST_DATA, global_step=0) - - # Get metadata for retrieving data - metadata = client.get_meta( - data_fields=["log_probs", "variable_length_sequences", "prompt_text"], batch_size=2, global_step=0 - ) - - # Test get operation - result = client.get_data(metadata) - - # Verify result structure - assert "log_probs" in result - assert "variable_length_sequences" in result - assert "prompt_text" in result - - torch.testing.assert_close(result["log_probs"][0], torch.tensor([1.0, 2.0, 3.0])) - torch.testing.assert_close(result["log_probs"][1], torch.tensor([4.0, 5.0, 6.0])) - torch.testing.assert_close(result["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8])) - torch.testing.assert_close(result["variable_length_sequences"][1], torch.tensor([-0.3, -1.5, -2.1, -0.9])) - assert result["prompt_text"][0] == "Hello world!" - assert result["prompt_text"][1] == "This is a longer sentence for testing" - - -def test_get_meta(client_setup): - """Test metadata retrieval""" - client, _, _ = client_setup - - # Test get_meta operation - metadata = client.get_meta(data_fields=["tokens", "labels"], batch_size=10, global_step=0) - - # Verify metadata structure - assert hasattr(metadata, "storage_meta_groups") - assert hasattr(metadata, "global_indexes") - assert hasattr(metadata, "fields") - assert hasattr(metadata, "size") - assert len(metadata.global_indexes) == 10 - - -def test_clear_operation(client_setup): - """Test clear operation""" - client, _, _ = client_setup - - # Test clear operation - client.clear(global_step=0) - - -# Test with multiple controllers and storage units -def test_multiple_servers(): - """Test client with multiple controllers and storage units""" - # Create multiple mock servers - controllers = [MockController(f"controller_{i}") for i in range(2)] - storages = [MockStorage(f"storage_{i}") for i in range(3)] - - try: - # Create client with multiple servers - client_id = "client_test_multiple_servers" - - controller_infos = {c.controller_id: c.zmq_server_info for c in controllers} - storage_infos = {s.storage_id: s.zmq_server_info for s in storages} - - client = TransferQueueClient( - client_id=client_id, controller_infos=controller_infos, storage_infos=storage_infos - ) - - # Give time for connections - time.sleep(1.0) - - # Verify connections - assert len(client._controllers) == 2 - assert len(client._storages) == 3 - - # Test basic operation - test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5) - - # Test put operation - client.put(data=test_data, global_step=0) - - finally: - # Clean up - for c in controllers: - c.stop() - for s in storages: - s.stop() - - -# Test error handling -def test_put_without_required_params(client_setup): - """Test put operation without required parameters""" - client, _, _ = client_setup - - # Create test data - test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5) - - # Test put without global_step (should fail) - with pytest.raises(AssertionError): - client.put(data=test_data) diff --git a/tests/experimental/transfer_queue/test_controller.py b/tests/experimental/transfer_queue/test_controller.py deleted file mode 100644 index 3b45da2a561..00000000000 --- a/tests/experimental/transfer_queue/test_controller.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import math - -import numpy as np -import pytest -import ray -import torch - -from verl.experimental.transfer_queue.controller import TQ_INIT_FIELD_NUM, TransferQueueController -from verl.experimental.transfer_queue.storage import TransferQueueStorageSimpleUnit - -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") -logger = logging.getLogger(__name__) - - -@pytest.fixture(scope="function") -def ray_setup(): - if ray.is_initialized(): - ray.shutdown() - ray.init( - ignore_reinit_error=True, - runtime_env={"env_vars": {"RAY_DEBUG": "1", "RAY_DEDUP_LOGS": "0"}}, - log_to_driver=True, - ) - yield - if ray.is_initialized(): - ray.shutdown() - logger.info("Ray has been shut down completely after test") - - -@pytest.fixture(scope="function") -def setup_teardown_transfer_queue_controller(ray_setup): - # Used as the offset for the global index to distinguish which global step the data corresponds to - global_batch_size = 8 - num_global_batch = 2 - num_n_samples = 2 - num_data_storage_units = 2 - - tq_controller = TransferQueueController.remote( - num_storage_units=num_data_storage_units, - global_batch_size=global_batch_size, - num_global_batch=num_global_batch, - num_n_samples=num_n_samples, - ) - yield tq_controller, global_batch_size, num_global_batch, num_n_samples - ray.get(tq_controller.clear.remote(0)) - - -@pytest.fixture(scope="function") -def setup_teardown_register_controller_info(setup_teardown_transfer_queue_controller): - tq_controller, global_batch_size, num_global_batch, num_n_samples = setup_teardown_transfer_queue_controller - total_storage_size = global_batch_size * num_global_batch * num_n_samples - num_data_storage_units = 2 - - data_system_storage_units = {} - for storage_unit_rank in range(num_data_storage_units): - storage_node = TransferQueueStorageSimpleUnit.remote( - storage_size=math.ceil(total_storage_size / num_data_storage_units) - ) - data_system_storage_units[storage_unit_rank] = storage_node - logger.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.") - - # Register controller info - zmq_server_info = ray.get(tq_controller.get_zmq_server_info.remote()) - controller_infos = {zmq_server_info.id: zmq_server_info} - - ray.get( - [ - storage_unit.register_controller_info.remote(controller_infos) - for storage_unit in data_system_storage_units.values() - ] - ) - - yield tq_controller, global_batch_size, num_n_samples, data_system_storage_units - - -class TestTransferQueueController: - @pytest.mark.parametrize("num_n_samples", [1, 2]) - @pytest.mark.parametrize("num_global_batch", [1, 2]) - def test_build_index_storage_mapping(self, num_n_samples, num_global_batch, ray_setup): - # Used as the offset for the global index to distinguish which global step the data corresponds to - global_batch_size = 8 - num_data_storage_units = 2 - - self.tq_controller = TransferQueueController.remote( - num_storage_units=num_data_storage_units, - global_batch_size=global_batch_size, - num_global_batch=num_global_batch, - num_n_samples=num_n_samples, - ) - - global_index_storage_mapping, global_index_local_index_mapping = ray.get( - self.tq_controller.get_global_index_mapping.remote() - ) - - if num_global_batch == 1 and num_n_samples == 1: - assert np.array_equal(global_index_storage_mapping, np.array([0, 0, 0, 0, 1, 1, 1, 1])) - assert np.array_equal(global_index_local_index_mapping, np.array([0, 1, 2, 3, 0, 1, 2, 3])) - # The data of a single GBS will be distributed across different storage units - elif num_global_batch == 2 and num_n_samples == 1: - assert np.array_equal( - global_index_storage_mapping, np.array([0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1]) - ) - assert np.array_equal( - global_index_local_index_mapping, np.array([0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 4, 5, 6, 7]) - ) - # When num_n_samples is larger than 1 - elif num_global_batch == 1 and num_n_samples == 2: - assert np.array_equal( - global_index_storage_mapping, np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]) - ) - assert np.array_equal( - global_index_local_index_mapping, np.array([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]) - ) - elif num_global_batch == 2 and num_n_samples == 2: - assert np.array_equal( - global_index_storage_mapping, - np.array( - [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] - ), - ) - assert np.array_equal( - global_index_local_index_mapping, - np.array( - [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - ] - ), - ) - - def test_update_production_status(self, setup_teardown_transfer_queue_controller): - tq_controller, global_batch_size, num_global_batch, num_n_samples = setup_teardown_transfer_queue_controller - - total_storage_size = global_batch_size * num_global_batch * num_n_samples - # Initialize get_data_production_status and filed_name_mapping - init_update_production_status = torch.zeros(total_storage_size, TQ_INIT_FIELD_NUM, dtype=torch.int8) - assert torch.equal(ray.get(tq_controller.get_data_production_status.remote()), init_update_production_status) - assert ray.get(tq_controller.get_field_name_mapping.remote()) == {} - - columns_list = ["test_prompts"] - global_indexes = list(range(global_batch_size * num_n_samples)) - - # update production status - tq_controller._update_production_status.remote(global_indexes, columns_list) - new_field_name_mapping = ray.get(tq_controller.get_field_name_mapping.remote()) - assert new_field_name_mapping["test_prompts"] == 0 - - new_data_production_status = ray.get(tq_controller.get_data_production_status.remote()) - assert new_data_production_status[:, 0][: len(global_indexes)].sum() == len(global_indexes) - - def test_data_consumption_status(self, setup_teardown_transfer_queue_controller): - tq_controller, global_batch_size, num_global_batch, num_n_samples = setup_teardown_transfer_queue_controller - total_storage_size = global_batch_size * num_global_batch * num_n_samples - - init_data_consumption_status = {} - assert ray.get(tq_controller.get_data_consumption_status.remote()) == init_data_consumption_status - - task_name = "test_task1" - ray.get(tq_controller._get_consumption_status.remote(task_name)) - new_data_consumption_status = ray.get(tq_controller.get_data_consumption_status.remote()) - assert torch.equal(new_data_consumption_status[task_name], torch.zeros(total_storage_size, dtype=torch.int8)) - - def test_get_prompt_metadata(self, setup_teardown_register_controller_info): - tq_controller, global_batch_size, n_samples, _ = setup_teardown_register_controller_info - - data_fields = ["test_prompts"] - global_step = 5 - - metadata = ray.get( - tq_controller._get_metadata.remote( - data_fields=data_fields, - batch_size=global_batch_size * n_samples, - global_step=global_step, - mode="insert", - ) - ) - metadata.reorder([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]) - assert metadata.global_indexes == [ - 31, - 30, - 29, - 28, - 27, - 26, - 25, - 24, - 23, - 22, - 21, - 20, - 19, - 18, - 17, - 16, - ] - assert metadata.local_indexes == [ - 15, - 14, - 13, - 12, - 11, - 10, - 9, - 8, - 15, - 14, - 13, - 12, - 11, - 10, - 9, - 8, - ] - storage_ids = metadata.storage_ids - assert len(set(storage_ids[: len(storage_ids) // 2])) == 1 - - # TODO: Test case where multiple clients concurrently read datameta from a single controller, - # and each client receives the correct response diff --git a/tests/experimental/transfer_queue/test_simple_storage_unit.py b/tests/experimental/transfer_queue/test_simple_storage_unit.py deleted file mode 100644 index 7949c9cb971..00000000000 --- a/tests/experimental/transfer_queue/test_simple_storage_unit.py +++ /dev/null @@ -1,479 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import sys -import time -import uuid -from pathlib import Path -from threading import Thread -from unittest.mock import MagicMock - -import pytest -import ray -import tensordict -import torch -import zmq -from tensordict import TensorDict - -# Import your classes here -parent_dir = Path(__file__).resolve().parent.parent -sys.path.append(str(parent_dir)) - -try: - from verl.experimental.transfer_queue.storage import TransferQueueStorageSimpleUnit - from verl.experimental.transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo -except ImportError: - # For testing purposes if imports are not available - TransferQueueStorageSimpleUnit = MagicMock() - ZMQServerInfo = MagicMock() - ZMQRequestType = MagicMock() - ZMQMessage = MagicMock() - - -# Mock ZMQ utilities if not available in test environment -def create_zmq_socket(context, socket_type, identity=None): - sock = context.socket(socket_type) - if identity: - sock.setsockopt(zmq.IDENTITY, identity) - return sock - - -# Mock Controller to handle handshake and data updates -class MockController: - def __init__(self, controller_id="controller_001"): - self.controller_id = controller_id - self.context = zmq.Context() - - # Socket for handshake - self.handshake_socket = self.context.socket(zmq.ROUTER) - self.handshake_port = self._bind_to_random_port(self.handshake_socket) - - # Socket for data status updates - self.data_update_socket = self.context.socket(zmq.ROUTER) - self.data_update_port = self._bind_to_random_port(self.data_update_socket) - - self.zmq_server_info = ZMQServerInfo.create( - role="CONTROLLER", - id=controller_id, - ip="127.0.0.1", - ports={"handshake_socket": self.handshake_port, "data_status_update_socket": self.data_update_port}, - ) - - self.running = True - self.handshake_thread = Thread(target=self._handle_handshake, daemon=True) - self.data_update_thread = Thread(target=self._handle_data_updates, daemon=True) - self.handshake_thread.start() - self.data_update_thread.start() - - def _bind_to_random_port(self, socket): - port = socket.bind_to_random_port("tcp://127.0.0.1") - return port - - def _handle_handshake(self): - poller = zmq.Poller() - poller.register(self.handshake_socket, zmq.POLLIN) - - while self.running: - try: - socks = dict(poller.poll(100)) # 100ms timeout - if self.handshake_socket in socks: - identity, msg_bytes = self.handshake_socket.recv_multipart() - ZMQMessage.deserialize(msg_bytes) - - # Send handshake ack - ack_msg = ZMQMessage.create( - request_type=ZMQRequestType.HANDSHAKE_ACK, - sender_id=self.controller_id, - body={"message": "Handshake successful"}, - ) - self.handshake_socket.send_multipart([identity, ack_msg.serialize()]) - except zmq.Again: - continue - except Exception: - if self.running: - pass - - def _handle_data_updates(self): - poller = zmq.Poller() - poller.register(self.data_update_socket, zmq.POLLIN) - - while self.running: - try: - socks = dict(poller.poll(100)) # 100ms timeout - if self.data_update_socket in socks: - identity, msg_bytes = self.data_update_socket.recv_multipart() - ZMQMessage.deserialize(msg_bytes) - - # Send data update ack - ack_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK, - sender_id=self.controller_id, - body={"message": "Data update received"}, - ) - self.data_update_socket.send_multipart([identity, ack_msg.serialize()]) - except zmq.Again: - continue - except Exception: - if self.running: - pass - - def stop(self): - self.running = False - time.sleep(0.1) # Give threads time to stop - self.handshake_socket.close() - self.data_update_socket.close() - - -# Mock client to send PUT/GET requests -class MockClient: - def __init__(self, storage_put_get_address): - self.context = zmq.Context() - self.socket = self.context.socket(zmq.DEALER) - self.socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout - self.socket.connect(storage_put_get_address) - - def send_put(self, client_id, global_indexes, local_indexes, field_data): - msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_DATA, - sender_id=f"mock_client_{client_id}", - body={"global_indexes": global_indexes, "local_indexes": local_indexes, "field_data": field_data}, - ) - self.socket.send(msg.serialize()) - return ZMQMessage.deserialize(self.socket.recv()) - - def send_get(self, client_id, local_indexes, fields): - msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_DATA, - sender_id=f"mock_client_{client_id}", - body={"local_indexes": local_indexes, "fields": fields}, - ) - self.socket.send(msg.serialize()) - return ZMQMessage.deserialize(self.socket.recv()) - - def close(self): - self.socket.close() - self.context.term() - - -@pytest.fixture(scope="session") -def ray_setup(): - ray.init(ignore_reinit_error=True) - yield - ray.shutdown() - - -@pytest.fixture -def storage_setup(ray_setup): - storage_size = 10000 - tensordict.set_list_to_stack(True).set() - - # Start mock controller - mock_controller = MockController(f"controller_{uuid.uuid4()}") - time.sleep(0.5) # Wait for controller sockets to be ready - - # Start Ray actor - storage_actor = TransferQueueStorageSimpleUnit.options(max_concurrency=50, num_cpus=1).remote(storage_size) - - # Register controller info - controller_infos = {mock_controller.controller_id: mock_controller.zmq_server_info} - ray.get(storage_actor.register_controller_info.remote(controller_infos)) - - # Get ZMQ address to connect client - zmq_info = ray.get(storage_actor.get_zmq_server_info.remote()) - put_get_address = zmq_info.to_addr("put_get_socket") - time.sleep(1) # Wait for socket to be ready - - yield storage_actor, put_get_address, mock_controller - - # Cleanup - mock_controller.stop() - - -def test_put_get_single_client(storage_setup): - """Test basic put and get operations with a single client using TensorDict and torch tensors.""" - _, put_get_address, _ = storage_setup - - client = MockClient(put_get_address) - - # PUT data - global_indexes = [0, 1, 2] - local_indexes = [0, 1, 2] - field_data = TensorDict( - { - "log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])], - "rewards": [torch.tensor([10.0]), torch.tensor([20.0]), torch.tensor([30.0])], - }, - batch_size=[], - ) - - response = client.send_put(0, global_indexes, local_indexes, field_data) - assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE - - # GET data - response = client.send_get(0, [0, 1], ["log_probs", "rewards"]) - assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE - - retrieved_data = response.body["data"] - assert "log_probs" in retrieved_data - assert "rewards" in retrieved_data - assert retrieved_data["log_probs"].size(0) == 2 - assert retrieved_data["rewards"].size(0) == 2 - - # Verify data correctness - torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([1.0, 2.0, 3.0])) - torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([4.0, 5.0, 6.0])) - torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([10.0])) - torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([20.0])) - - client.close() - - -def test_put_get_multiple_clients(storage_setup): - """Test put and get operations with multiple clients including overlapping local indexes""" - _, put_get_address, _ = storage_setup - - num_clients = 5 - clients = [MockClient(put_get_address) for _ in range(num_clients)] - - # Each client puts unique data using different local_indexes - for i, client in enumerate(clients): - global_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2] - local_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2] - field_data = TensorDict( - { - "log_probs": [ - torch.tensor([i, i + 1, i + 2]), - torch.tensor([i + 3, i + 4, i + 5]), - torch.tensor([i + 6, i + 7, i + 8]), - ], - "rewards": [torch.tensor([i * 10]), torch.tensor([i * 10 + 10]), torch.tensor([i * 10 + 20])], - } - ) - - response = client.send_put(i, global_indexes, local_indexes, field_data) - assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE - - # Now simulate a third client that writes to overlapping local_indexes (e.g., index 0) - overlapping_client = MockClient(put_get_address) - overlap_local_indexes = [0] # Overlaps with first client's index 0 - overlap_field_data = TensorDict({"log_probs": [torch.tensor([999, 999, 999])], "rewards": [torch.tensor([999])]}) - response = overlapping_client.send_put( - client_id=99, global_indexes=[0], local_indexes=overlap_local_indexes, field_data=overlap_field_data - ) - assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE - - # Each original client gets its own data (except for index 0 which was overwritten) - for i, client in enumerate(clients): - response = client.send_get(i, [i * 10 + 0, i * 10 + 1], ["log_probs", "rewards"]) - assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE - - retrieved_data = response.body["data"] - assert retrieved_data["log_probs"].size(0) == 2 - assert retrieved_data["rewards"].size(0) == 2 - - # For index 0, expect data from overlapping_client; others from original client - if i == 0: - # Index 0 was overwritten - torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([999, 999, 999])) - torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([999])) - # Index 1 remains original - torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([3, 4, 5])) - torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([10])) - else: - # All data remains original - torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([i, i + 1, i + 2])) - torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([i + 3, i + 4, i + 5])) - torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([i * 10])) - torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([i * 10 + 10])) - - # Cleanup - for client in clients: - client.close() - overlapping_client.close() - - -def test_performance_basic(storage_setup): - """Basic performance test with larger data volume and proper index handling""" - _, put_get_address, _ = storage_setup - - client = MockClient(put_get_address) - - # PUT performance test - put_latencies = [] - num_puts = 50 - batch_size = 128 - - for i in range(num_puts): - start = time.time() - - # Use larger batch size and more complex index mapping - global_indexes = list(range(i * batch_size, (i + 1) * batch_size)) - local_indexes = list(range(i * batch_size, (i + 1) * batch_size)) - - # Create larger tensor data to increase data volume - log_probs_data = [] - rewards_data = [] - - for j in range(batch_size): - # Each sample contains larger tensors to increase data transfer volume - log_probs_tensor = torch.randn(32768) - rewards_tensor = torch.randn(32768) - log_probs_data.append(log_probs_tensor) - rewards_data.append(rewards_tensor) - - field_data = TensorDict({"log_probs": log_probs_data, "rewards": rewards_data}, batch_size=[batch_size]) - - response = client.send_put(0, global_indexes, local_indexes, field_data) - latency = time.time() - start - put_latencies.append(latency) - assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE - - # GET performance test - get_latencies = [] - num_gets = 50 - - for i in range(num_gets): - start = time.time() - # Retrieve larger batch of data - indices = list(range(i * batch_size, (i + 1) * batch_size)) # Retrieve batch_size indices of data each time - response = client.send_get(0, indices, ["log_probs", "rewards"]) - latency = time.time() - start - get_latencies.append(latency) - assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE - - avg_put_latency = sum(put_latencies) / len(put_latencies) * 1000 # ms - avg_get_latency = sum(get_latencies) / len(get_latencies) * 1000 # ms - - # Adjust performance thresholds to accommodate larger data volume - assert avg_put_latency < 5000, f"Avg PUT latency {avg_put_latency}ms exceeds threshold" - assert avg_get_latency < 5000, f"Avg GET latency {avg_get_latency}ms exceeds threshold" - - client.close() - - -def test_put_get_nested_tensor_single_client(storage_setup): - """Test basic put and get operations with a single client using TensorDict and nested tensors.""" - _, put_get_address, _ = storage_setup - - client = MockClient(put_get_address) - - # PUT data - global_indexes = [0, 1, 2] - local_indexes = [0, 1, 2] - - field_data = TensorDict( - { - "variable_length_sequences": [ - torch.tensor([-0.5, -1.2, -0.8]), - torch.tensor([-0.3, -1.5, -2.1, -0.9]), - torch.tensor([-1.1, -0.7]), - ], - "attention_mask": [torch.tensor([1, 1, 1]), torch.tensor([1, 1, 1, 1]), torch.tensor([1, 1])], - }, - batch_size=[], - ) - - response = client.send_put(0, global_indexes, local_indexes, field_data) - assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE - - # GET data - response = client.send_get(0, [0, 2], ["variable_length_sequences", "attention_mask"]) - assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE - - retrieved_data = response.body["data"] - assert "variable_length_sequences" in retrieved_data - assert "attention_mask" in retrieved_data - assert retrieved_data["variable_length_sequences"].size(0) == 2 - assert retrieved_data["attention_mask"].size(0) == 2 - - # Verify data correctness - torch.testing.assert_close(retrieved_data["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8])) - torch.testing.assert_close(retrieved_data["variable_length_sequences"][1], torch.tensor([-1.1, -0.7])) - torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1])) - torch.testing.assert_close(retrieved_data["attention_mask"][1], torch.tensor([1, 1])) - - client.close() - - -def test_put_get_nested_nontensor_single_client(storage_setup): - """Test basic put and get operations with a single client using non-tensor data (strings).""" - _, put_get_address, _ = storage_setup - - client = MockClient(put_get_address) - - # PUT data - global_indexes = [0, 1, 2] - local_indexes = [0, 1, 2] - field_data = TensorDict( - { - "prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"], - "response_text": ["Hi there!", "This is the response to the longer sentence", "Test response"], - }, - batch_size=[], - ) - - response = client.send_put(0, global_indexes, local_indexes, field_data) - assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE - - # GET data - response = client.send_get(0, [0, 1, 2], ["prompt_text", "response_text"]) - assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE - - retrieved_data = response.body["data"] - assert "prompt_text" in retrieved_data - assert "response_text" in retrieved_data - - # Verify data correctness - assert isinstance(retrieved_data["prompt_text"][0], str) - assert isinstance(retrieved_data["response_text"][0], str) - - assert retrieved_data["prompt_text"][0] == "Hello world!" - assert retrieved_data["prompt_text"][1] == "This is a longer sentence for testing" - assert retrieved_data["prompt_text"][2] == "Test case" - assert retrieved_data["response_text"][0] == "Hi there!" - assert retrieved_data["response_text"][1] == "This is the response to the longer sentence" - assert retrieved_data["response_text"][2] == "Test response" - - client.close() - - -def test_put_get_single_item_single_client(storage_setup): - """Test put and get operations for a single item with a single client.""" - _, put_get_address, _ = storage_setup - - client = MockClient(put_get_address) - - # PUT data - field_data = TensorDict( - { - "prompt_text": ["Hello world!"], - "attention_mask": [torch.tensor([1, 1, 1])], - }, - batch_size=[], - ) - - response = client.send_put(0, [0], [0], field_data) - assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE - - # GET data - response = client.send_get(0, [0], ["prompt_text", "attention_mask"]) - assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE - - retrieved_data = response.body["data"] - assert "prompt_text" in retrieved_data - assert "attention_mask" in retrieved_data - - assert retrieved_data["prompt_text"][0] == "Hello world!" - assert retrieved_data["attention_mask"].shape == (1, 3) - torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1])) diff --git a/verl/experimental/transfer_queue/__init__.py b/verl/experimental/transfer_queue/__init__.py deleted file mode 100644 index 2df3b7f876f..00000000000 --- a/verl/experimental/transfer_queue/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/experimental/transfer_queue/client.py b/verl/experimental/transfer_queue/client.py deleted file mode 100644 index 8005558b0b1..00000000000 --- a/verl/experimental/transfer_queue/client.py +++ /dev/null @@ -1,662 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import logging -import os -from functools import wraps -from typing import Any, Callable, Optional, Union -from uuid import uuid4 - -import ray -import torch -import zmq -import zmq.asyncio -from tensordict import NonTensorStack, TensorDict - -from verl.experimental.transfer_queue.controller import TransferQueueController -from verl.experimental.transfer_queue.metadata import ( - BatchMeta, - StorageMetaGroup, -) -from verl.experimental.transfer_queue.storage import TransferQueueStorageSimpleUnit -from verl.experimental.transfer_queue.utils.utils import ( - TransferQueueRole, -) -from verl.experimental.transfer_queue.utils.zmq_utils import ( - ZMQMessage, - ZMQRequestType, - ZMQServerInfo, - create_zmq_socket, -) - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class AsyncTransferQueueClient: - def __init__( - self, - client_id: str, - controller_infos: ZMQServerInfo | dict[Any, ZMQServerInfo], - storage_infos: ZMQServerInfo | dict[Any, ZMQServerInfo], - ): - self.client_id = client_id - - self._controllers: dict[str, ZMQServerInfo] = {} - self._storages: dict[str, ZMQServerInfo] = {} - self._register_servers(TransferQueueRole.CONTROLLER, controller_infos) - self._register_servers(TransferQueueRole.STORAGE, storage_infos) - - def _register_servers( - self, - role: TransferQueueRole, - server_infos: ZMQServerInfo | dict[Any, ZMQServerInfo], - ): - mapping = self._controllers if role == TransferQueueRole.CONTROLLER else self._storages - - if not isinstance(server_infos, dict): - server_infos = {server_infos.id: server_infos} - - for info in server_infos.values(): - if not isinstance(info, ZMQServerInfo): - raise ValueError(f"Invalid server info for {role} {info.id}") - - if info.id not in mapping: - mapping[info.id] = info - logger.info(f"[{self.client_id}]: Registered {role} server {info.id} at {info.ip}") - else: - logger.warning(f"[{self.client_id}]: Server {info.id} already registered, skipping") - - @staticmethod - def dynamic_socket(target_role: TransferQueueRole, socket_name: str): - """Decorator to auto-manage ZMQ sockets for Controller/Storage servers (create -> connect -> inject -> close). - - Args: - target_role (TransferQueueRole): Server type to connect to. Must be one of: - - `TransferQueueRole.CONTROLLER` - - `TransferQueueRole.STORAGE` - socket_name (str): Port name (from server config) to use for ZMQ connection (e.g., "data_req_port"). - - Decorated Function Rules: - 1. Must be an async class method (needs `self`). - 2. `self` requires: - - `_controllers`/`_storages`: Server registries (match `target_role`). - - `client_id`: Unique client ID (for socket identity). - 3. Specify target server via: - - `target_controller` (for Controller) or `target_storage` (for Storage) arg. - - Controller role: Uses first registered server if no ID is given. - 4. Receives ZMQ socket via `socket` keyword arg (injected by decorator). - """ - - def decorator(func: Callable): - @wraps(func) - async def wrapper(self, *args, **kwargs): - if target_role == TransferQueueRole.CONTROLLER: - servers = self._controllers - target = "target_controller" - elif target_role == TransferQueueRole.STORAGE: - servers = self._storages - target = "target_storage" - else: - raise ValueError("Invalid target_role, must be CONTROLLER or STORAGE") - - server_key = kwargs.get(target) - if server_key is None: - for arg in args: - if isinstance(arg, str) and arg in servers.keys(): - server_key = arg - break - if server_key is None and target == "target_controller": - server_key = next(iter(servers.keys())) - - server_info = servers.get(server_key) - if not server_info: - raise RuntimeError(f"Server {server_key} not found in registered {target_role} servers") - - context = zmq.asyncio.Context() - address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}" - identity = f"{self.client_id}_to_{server_info.id}_{uuid4()}".encode() - sock = create_zmq_socket(context, zmq.DEALER, identity=identity) - - try: - sock.connect(address) - logger.info( - f"[{self.client_id}]: Connected to {target_role} {server_info.id} at {address} " - f"with identity {identity.decode()}" - ) - - kwargs["socket"] = sock - return await func(self, *args, **kwargs) - except Exception as e: - logger.error( - f"[{self.client_id}]: Error in socket operation with {target_role} {server_info.id}: {e}" - ) - raise - finally: - try: - if not sock.closed: - sock.setsockopt(zmq.LINGER, -1) - sock.close() - sock.close(linger=0) - except Exception as e: - logger.warning( - f"[{self.client_id}]: Error closing socket to {target_role} {server_info.id}: {e}" - ) - - context.term() - - return wrapper - - return decorator - - @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket") - async def async_get_meta( - self, - data_fields: list[str], - batch_size: int, - global_step: int, - mode: str = "fetch", - get_n_samples: bool = False, - task_name: Optional[str] = None, - target_controller: Optional[str] = None, - socket: Optional[zmq.asyncio.Socket] = None, - ) -> BatchMeta: - """Asynchronously fetches data metadata via ZMQ from the target controller. - - Args: - data_fields (list[str]): List of fields to retrieve metadata for - batch_size (int): Processing batch size - global_step (int): Current training/processing step - mode (str): Data fetch mode. 'fetch' to get ready data, 'force_fetch' to get data regardless of readiness. - 'insert' IS AN INTERNAL USAGE THAT SHOULD NOT BE USED BY USERS. - get_n_samples (bool): If True, we arrange the samples of the same prompt in contiguous order. In 'fetch' - mode, only the samples of the same prompt that are all ready will be returned. - task_name (str): Optional task name associated with the request - target_controller (str): ID of the target controller to send the request to - socket (zmq.asyncio.Socket): ZMQ async socket for message transmission - - Example: - >>> batch_size = 4 - >>> current_step = 0 - >>> # Example 1: "fetch" a batch of metadata that has been produced - >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["input_ids", "attention_mask"], - >>> batch_size=batch_size, - >>> global_step=current_step, - >>> mode="fetch", - >>> get_n_samples=False, - >>> task_name="generate_sequences", - >>> )) - >>> print(batch_meta.is_ready) # you should get a batch_meta with is_ready=True - >>> print([sample_meta.is_ready for sample_meta in batch_meta.samples]) # [True, True, True, True] - >>> - >>> # Example 2: "force_fetch" a batch of metadata, ignoring their production status (but we still make - >>> # sure the corresponding data has not been consumed) - >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["input_ids", "attention_mask"], - >>> batch_size=batch_size, - >>> global_step=current_step, - >>> mode="force_fetch", - >>> get_n_samples=False, - >>> task_name="generate_sequences", - >>> )) - >>> print(batch_meta.is_ready) # you may get a batch_meta with is_ready=False - >>> print([sample_meta.is_ready for sample_meta in batch_meta.samples]) # [True, False, False, True] - - Returns: - BatchMeta: Metadata object containing data structure, sample info, etc. - """ - assert socket is not None - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_META, - sender_id=self.client_id, - receiver_id=target_controller, - body={ - "data_fields": data_fields, - "batch_size": batch_size, - "global_step": global_step, - "mode": mode, - "get_n_samples": get_n_samples, - "task_name": task_name, - }, - ) - - try: - await socket.send(request_msg.serialize()) - response = await socket.recv() - response_msg = ZMQMessage.deserialize(response) - logger.debug( - f"[{self.client_id}]: Client get datameta response: {response_msg} from controller {target_controller}" - ) - - if response_msg.request_type == ZMQRequestType.GET_META_RESPONSE: - metadata = response_msg.body["metadata"] - return metadata - else: - raise RuntimeError( - f"[{self.client_id}]: Failed to get metadata from controller {target_controller}: " - f"{response_msg.body.get('message', 'Unknown error')}" - ) - except Exception as e: - raise RuntimeError(f"[{self.client_id}]: Error in get_meta: {str(e)}") from e - - async def async_put( - self, - data: TensorDict, - metadata: Optional[BatchMeta] = None, - global_step: Optional[int] = None, - ): - """Asynchronously writes data to appropriate Storage Units based on metadata. - - If metadata isn't provided, it will be created automatically using the insert mode - with the provided data_columns and global_step. - - Args: - data (torch.Tensor | tensordict.TensorDict): Data to write, either a Tensor or TensorDict - metadata (BatchMeta, optional): Optional metadata containing index and storage unit information - global_step (int, optional): Current step (required if no metadata is provided) - - Example: - >>> batch_size = 4 - >>> seq_len = 16 - >>> current_step = 0 - >>> # Example 1: normal usage - >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["prompts", "attention_mask"], - >>> batch_size=batch_size, - >>> global_step=current_step, - >>> mode="fetch", - >>> get_n_samples=False, - >>> task_name="generate_sequences", - >>> )) - >>> batch = asyncio.run(client.async_get_data(batch_meta)) - >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)}) - >>> asyncio.run(client.async_put(data=output, metadata=batch_meta)) - >>> - >>> # Example 2: put the initial data into the system without pre-existing metadata - >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given global_step! - >>> # Please make sure the corresponding global_step is empty before calling the async_put() - >>> # without metadata. - >>> # Now we only support put all the data of the corresponding global step in once. You should repeat with - >>> # interleave the initial data if n_sample > 1 before calling the async_put(). - >>> original_prompts = torch.randn(batch_size, seq_len) - >>> n_samples = 4 - >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0) - >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated}) - >>> # This will create metadata in "insert" mode internally. - >>> asyncio.run(client.async_put(data=prompts_repeated_batch, global_step=current_step)) - - """ - if metadata is None: - assert global_step is not None, "global_steps must be provided if metadata is not given" - - metadata = await self.async_get_meta( - data_fields=list(data.keys()), - batch_size=data.batch_size[0], - global_step=global_step, - get_n_samples=True, - mode="insert", - ) - - if not metadata or metadata.size == 0: - raise ValueError("metadata cannot be none or empty") - logger.debug(f"[{self.client_id}]: Put data with data: {data}") - tasks = [ - self._put_to_storage(get_transfer_info(meta_group, data), target_storage=storage_id) - for storage_id, meta_group in metadata.storage_meta_groups.items() - ] - await asyncio.gather(*tasks) - - logger.info( - f"[{self.client_id}]: step {global_step} put {metadata.size} samples to storage units successfully." - ) - - @dynamic_socket(target_role=TransferQueueRole.STORAGE, socket_name="put_get_socket") - async def _put_to_storage(self, storage_unit_data, target_storage=None, socket=None): - """ - Send data to a specific storage unit. - """ - global_indexes = storage_unit_data["global_indexes"] - local_indexes = storage_unit_data["local_indexes"] - field_data = TensorDict( - { - field: ( - torch.nested.as_nested_tensor(storage_unit_data["field_data"][field]) - if storage_unit_data["field_data"][field] - and all(isinstance(x, torch.Tensor) for x in storage_unit_data["field_data"][field]) - else NonTensorStack(*storage_unit_data["field_data"][field]) - ) - for field in storage_unit_data["field_data"] - } - ) - - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_DATA, - sender_id=self.client_id, - receiver_id=target_storage, - body={"global_indexes": global_indexes, "local_indexes": local_indexes, "field_data": field_data}, - ) - try: - await socket.send(request_msg.serialize()) - serialized = await socket.recv() - response_msg = ZMQMessage.deserialize(serialized) - - if response_msg.request_type != ZMQRequestType.PUT_DATA_RESPONSE: - raise RuntimeError( - f"Failed to put data to storage unit {target_storage}: " - f"{response_msg.body.get('message', 'Unknown error')}" - ) - except Exception as e: - raise RuntimeError(f"Error in put to storage unit {target_storage}: {str(e)}") from e - - @dynamic_socket(target_role=TransferQueueRole.STORAGE, socket_name="put_get_socket") - async def _get_from_storage(self, index_data, target_storage=None, socket=None): - global_indexes = index_data["global_indexes"] - local_indexes = index_data["local_indexes"] - fields = index_data["fields"] - - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_DATA, - sender_id=self.client_id, - receiver_id=target_storage, - body={"local_indexes": local_indexes, "fields": fields}, - ) - - try: - await socket.send(request_msg.serialize()) - serialized = await socket.recv() - response_msg = ZMQMessage.deserialize(serialized) - logger.info(f"[{self.client_id}]: get data response from storage unit {target_storage}: {response_msg}") - - if response_msg.request_type == ZMQRequestType.GET_DATA_RESPONSE: - # Return data and index information from this storage unit - storage_unit_data = response_msg.body["data"] - return global_indexes, fields, storage_unit_data - else: - raise RuntimeError( - f"Failed to get data from storage unit {target_storage}: " - f"{response_msg.body.get('message', 'Unknown error')}" - ) - except Exception as e: - raise RuntimeError(f"Error getting data from storage unit {target_storage}: {str(e)}") from e - - async def async_get_data(self, metadata: BatchMeta) -> TensorDict: - """Asynchronously fetches data via Storage Units and organizes it into a TensorDict. - - Args: - metadata (BatchMeta): Object containing: - - Data location info (which Storage Units hold the data) - - `global_indexes` to determine the ordering of merged results - - Returns: - tensordict.TensorDict with: - - Requested data fields (e.g., "prompt_token_ids", "response_token_ids"). - - "global_indexes" key: Maps each sample to its original global index. - - Example: - >>> batch_size = 4 - >>> seq_len = 16 - >>> current_step = 0 - >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["prompts", "attention_mask"], - >>> batch_size=batch_size, - >>> global_step=current_step, - >>> mode="fetch", - >>> get_n_samples=False, - >>> task_name="generate_sequences", - >>> )) - >>> batch = asyncio.run(client.async_get_data(batch_meta)) - >>> print(batch) - >>> # this is a TensorDict with fields "prompts" and "attention_mask". - >>> # The order of samples in the TensorDict matches the order of global_indexes in batch_meta - - Note: - Why track `global_indexes`? - - Batches may be rearranged during task processing. `global_indexes` retains the original - mapping to Storage Units, enabling correct data writing back to Storage Units later. - - """ - if not metadata or metadata.size == 0: - return TensorDict({}, batch_size=0) - - # Use optimized retrieval with direct storage group access - tasks = [ - self._get_from_storage(meta_group.get_transfer_info(), target_storage=storage_id) - for storage_id, meta_group in metadata.storage_meta_groups.items() - ] - - results = await asyncio.gather(*tasks) - - # global_index: {field1: value, field2: value, ...} - storage_data: dict[int, dict[str, torch.Tensor]] = {} - for global_indexes, fields, storage_unit_data in results: - for idx, global_idx in enumerate(global_indexes): - if global_idx not in storage_data: - storage_data[global_idx] = {} - for field in fields: - storage_data[global_idx][field] = storage_unit_data[field][idx] - - ordered_data: dict[str, torch.Tensor] = {field: [] for field in metadata.field_names} - for global_idx in metadata.global_indexes: - for field in metadata.field_names: - ordered_data[field].append(storage_data[global_idx][field]) - - tensor_data = { - field: ( - torch.stack(torch.nested.as_nested_tensor(v).unbind()) - if v - and all(isinstance(item, torch.Tensor) for item in v) - and all(item.shape == v[0].shape for item in v) - else ( - torch.nested.as_nested_tensor(v) - if v and all(isinstance(item, torch.Tensor) for item in v) - else NonTensorStack(*v) - ) - ) - for field, v in ordered_data.items() - } - tensor_data["global_indexes"] = torch.tensor(metadata.global_indexes) - - return TensorDict(tensor_data, batch_size=len(storage_data)) - - async def async_clear(self, global_step: int): - """Asynchronously clears data from all storage units and controller metadata. - - Args: - global_step (int): The training step associated with the clear operation - - """ - try: - target_controller = next(iter(self._controllers.keys())) - metadata = await self._get_clear_meta(global_step, target_controller) - - tasks = [] - - for target_controller in self._controllers.keys(): - tasks.append(self._clear_controller(global_step, target_controller)) - - # Group samples by storage unit for clearing - for target_storage, group in metadata.storage_meta_groups.items(): - group_info = group.get_transfer_info() - if target_storage not in self._storages: - logger.warning( - f"[{self.client_id}]: Storage unit {target_storage} not registered, skipping clear operation." - ) - continue - tasks.append( - self._clear_storage_unit( - group_info["local_indexes"], - target_storage, - ) - ) - - results = await asyncio.gather(*tasks, return_exceptions=True) - - for i, result in enumerate(results): - if isinstance(result, Exception): - logger.error(f"[{self.client_id}]: Error in clear operation task {i}: {result}") - - logger.info(f"[{self.client_id}]: Clear operation for global_step {global_step} completed.") - except Exception as e: - raise RuntimeError(f"Error in clear operation: {str(e)}") from e - - @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket") - async def _get_clear_meta(self, global_step: int, target_controller=None, socket=None): - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_CLEAR_META, - sender_id=self.client_id, - receiver_id=target_controller, - body={"global_step": global_step}, - ) - - await socket.send(request_msg.serialize()) - serialized = await socket.recv() - response_msg = ZMQMessage.deserialize(serialized) - - if response_msg.request_type != ZMQRequestType.GET_CLEAR_META_RESPONSE: - raise RuntimeError( - f"Failed to get metadata for clear operation: {response_msg.body.get('message', 'Unknown error')}" - ) - - return response_msg.body["metadata"] - - @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket") - async def _clear_controller(self, global_step, target_controller=None, socket=None): - try: - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_META, - sender_id=self.client_id, - receiver_id=target_controller, - body={"global_step": global_step}, - ) - - await socket.send(request_msg.serialize()) - serialized_msg = await socket.recv() - response_msg = ZMQMessage.deserialize(serialized_msg) - - if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE: - raise RuntimeError( - f"Failed to clear controller {target_controller}: " - f"{response_msg.body.get('message', 'Unknown error')}" - ) - - logger.info( - f"[{self.client_id}]: Successfully clear controller {target_controller} for global_step {global_step}" - ) - except Exception as e: - logger.error(f"[{self.client_id}]: Error clearing controller {target_controller}: {str(e)}") - raise - - @dynamic_socket(target_role=TransferQueueRole.STORAGE, socket_name="put_get_socket") - async def _clear_storage_unit(self, local_indexes, target_storage=None, socket=None): - try: - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_DATA, - sender_id=self.client_id, - receiver_id=target_storage, - body={"local_indexes": local_indexes}, - ) - - await socket.send(request_msg.serialize()) - serialized_msg = await socket.recv() - response_msg = ZMQMessage.deserialize(serialized_msg) - - if response_msg.request_type != ZMQRequestType.CLEAR_DATA_RESPONSE: - raise RuntimeError( - f"Failed to clear storage {target_storage}: {response_msg.body.get('message', 'Unknown error')}" - ) - - logger.info(f"[{self.client_id}]: Successfully clear storage unit {target_storage}") - except Exception as e: - logger.error(f"[{self.client_id}]: Error clearing storage unit {target_storage}: {str(e)}") - raise - - @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket") - def check_current_step_consumption(self, task_name: str, global_step: int): - # TODO: Implement this method to check if all samples for the current step has been consumed - pass - - @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket") - def check_current_step_production(self, data_fields: list[str], global_step: int): - # TODO: Implement this method to check if all samples for the current step is ready for consumption - pass - - -class TransferQueueClient(AsyncTransferQueueClient): - def __init__( - self, - client_id: str, - controller_infos: ZMQServerInfo | dict[Any, ZMQServerInfo], - storage_infos: ZMQServerInfo | dict[Any, ZMQServerInfo], - ): - super().__init__( - client_id, - controller_infos, - storage_infos, - ) - - def put(self, data: TensorDict, metadata: Optional[BatchMeta] = None, global_step: Optional[int] = None): - return asyncio.run(self.async_put(data, metadata, global_step)) - - def get_meta( - self, - data_fields: list[str], - batch_size: int, - global_step: int, - get_n_samples: bool = False, - task_name: Optional[str] = None, - ) -> BatchMeta: - return asyncio.run( - self.async_get_meta( - data_fields=data_fields, - batch_size=batch_size, - global_step=global_step, - get_n_samples=get_n_samples, - task_name=task_name, - ) - ) - - def get_data(self, metadata: BatchMeta) -> TensorDict: - return asyncio.run(self.async_get_data(metadata)) - - def clear(self, global_step: int): - return asyncio.run(self.async_clear(global_step)) - - -def _add_field_data( - transfer_dict: dict[str, Any], storage_meta_group: StorageMetaGroup, data: TensorDict -) -> dict[str, Any]: - """Helper function to add field data to the transfer dictionary""" - field_names = transfer_dict["fields"] - for fname in field_names: - if fname in data.keys(): - transfer_dict["field_data"][fname] = [] - for sample_meta in storage_meta_group.sample_metas: - transfer_dict["field_data"][fname].append(data[fname][sample_meta.batch_index]) - return transfer_dict - - -def get_transfer_info( - storage_meta_group: StorageMetaGroup, - data: TensorDict, -) -> dict[str, Any]: - """Convert to dictionary format with field data for put operations""" - result = storage_meta_group.get_transfer_info(field_names=data.keys()) - result = _add_field_data(result, storage_meta_group, data) - return result - - -def process_zmq_server_info(handlers: dict[Any, Union[TransferQueueController, TransferQueueStorageSimpleUnit]]): # noqa: UP007 - server_info = {} - for name, handler in handlers.items(): - server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[attr-defined] - return server_info diff --git a/verl/experimental/transfer_queue/controller.py b/verl/experimental/transfer_queue/controller.py deleted file mode 100644 index 08ab6cfe9f4..00000000000 --- a/verl/experimental/transfer_queue/controller.py +++ /dev/null @@ -1,771 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import math -import os -import threading -import time -from threading import Thread -from typing import Any, Optional -from uuid import uuid4 - -import numpy as np -import ray -import torch -import zmq -from ray.util import get_node_ip_address - -from verl.experimental.transfer_queue.metadata import ( - BatchMeta, - FieldMeta, - SampleMeta, -) -from verl.experimental.transfer_queue.utils.utils import ( - ProductionStatus, - TransferQueueRole, - random_sampler, -) -from verl.experimental.transfer_queue.utils.zmq_utils import ( - ZMQMessage, - ZMQRequestType, - ZMQServerInfo, - create_zmq_socket, - get_free_port, -) - -logger = logging.getLogger(__name__) -logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) - -TQ_CONTROLLER_GET_METADATA_TIMEOUT = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_TIMEOUT", 300)) -TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL", 1)) -TQ_INIT_FIELD_NUM = int(os.environ.get("TQ_INIT_FIELD_NUM", 10)) - - -@ray.remote(num_cpus=1) -class TransferQueueController: - def __init__( - self, - num_storage_units: int, - global_batch_size: int, - num_global_batch: int = 1, - num_n_samples: int = 1, - ) -> None: - """Initialize the TransferQueueController. - - Args: - num_storage_units: Number of storage units in the system - global_batch_size: Size of each global batch - num_global_batch: Number of global batches to maintain in storage - num_n_samples: For each prompt, sample n responses - """ - self.controller_id = f"TQ_CONTROLLER_{uuid4()}" - - self._init_zmq_socket() # Initialize ZMQ sockets for data communication - - self.num_storage_units = num_storage_units - self.global_batch_size = ( - global_batch_size # Used as offset for global index to identify corresponding global step - ) - self.num_global_batch = num_global_batch - self.num_n_samples = num_n_samples - self.total_storage_size = self.global_batch_size * self.num_global_batch * self.num_n_samples - - self.data_production_status = torch.zeros( - self.total_storage_size, TQ_INIT_FIELD_NUM, dtype=torch.int8 - ) # Initialize with default number of fields, dynamically extensible - # task_name -> consumption_status mapping - self.data_consumption_status: dict[str, torch.Tensor] = {} - self.field_name_mapping: dict[ - str, int - ] = {} # Mapping table from field_name to the column indices in self.data_production_status tables - # Per-field dtype and shape storage: {global_index: {field_name: {'dtype': dtype, 'shape': shape}}} - self.per_tensor_dtype_mapping: dict[int, dict[str, Any]] = {} - self.per_tensor_shape_mapping: dict[int, dict[str, Any]] = {} - - self._build_index_storage_mapping() - - self._start_process_handshake() - self._start_process_update_data_status() - self._start_process_request() - - def _get_consumption_status(self, task_name: str) -> torch.Tensor: - """ - Get or create the consumption status tensor for a specific task. - The consumption status is a binary, 1D tensor that records whether the corresponding sample has been consumed - by the task. - - Args: - task_name: Name of the consumer task - - Returns: - Consumption status tensor for the specified task - """ - # Retrieve or create the consumption state tensor for a specified consumer - if task_name not in self.data_consumption_status: - # Initialize state for a new consumer - self.data_consumption_status[task_name] = torch.zeros(self.total_storage_size, dtype=torch.int8) - return self.data_consumption_status[task_name] - - def _get_per_field_dtype(self, global_index: int, field_name: str) -> Optional[torch.dtype]: - """Get dtype for a specific sample and field. - - Args: - global_index: Global index of the sample - field_name: Name of the field - - Returns: - dtype of the specified field for the sample, or None if not found - """ - return self.per_tensor_dtype_mapping.get(global_index, {}).get(field_name) - - def _get_per_field_shape(self, global_index: int, field_name: str) -> Optional[torch.Size]: - """Get shape for a specific sample and field. - - Args: - global_index: Global index of the sample - field_name: Name of the field - - Returns: - Shape of the specified field for the sample, or None if not found - """ - return self.per_tensor_shape_mapping.get(global_index, {}).get(field_name) - - def _step_to_global_index_range(self, global_step: int) -> tuple[int, int]: - """Convert global step to corresponding global index range. - - Args: - global_step: The global step to convert - - Returns: - Tuple of (start_index, end_index) for the given global step - """ - start_idx = (global_step % self.num_global_batch) * self.global_batch_size * self.num_n_samples - end_idx = start_idx + self.global_batch_size * self.num_n_samples - - return start_idx, end_idx - - def generate_data_status_mask( - self, data_fields: list[str], global_step: int, task_name: str - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Generate mask matrix for filtering data based on field availability and consumption status. - - This function is called within _get_meta and generates a mask matrix based on - user-specified fields and the current step. The mask matrix selects the required - rows and columns from self.data_production_status while inversely selecting from - self.data_consumption_status to support automated vectorization. - - Args: - data_fields: List of field names to include in the mask - global_step: Current global step for row selection - task_name: Name of the consumer task for consumption status - - Returns: - Tuple of (row_mask, col_mask) tensors for filtering data status matrices - """ - - # Check if all requested fields are registered - for col in data_fields: - if col not in self.field_name_mapping: - # Return empty mask indicating no available data for unregistered columns - empty_row_mask = torch.zeros(self.data_production_status.shape[0], dtype=torch.bool) - empty_col_mask = torch.zeros(self.data_production_status.shape[1], dtype=torch.bool) - return empty_row_mask, empty_col_mask - - # Map steps to global indices - start_idx, end_idx = self._step_to_global_index_range(global_step) - row_mask = torch.zeros(self.data_production_status.shape[0], dtype=torch.bool) - row_mask[start_idx:end_idx] = True - - # Invert selection based on consumption status - consumer_status = self._get_consumption_status(task_name) - unconsumed_mask = consumer_status == 0 - row_mask &= unconsumed_mask - - # Select the specified fields - col_mask = torch.zeros(self.data_production_status.shape[1], dtype=torch.bool) - valid_fields = [self.field_name_mapping[col] for col in data_fields] - if valid_fields: - col_mask[valid_fields] = True - - return row_mask, col_mask - - def _build_index_storage_mapping(self): - """ - Build mappings between global indices and storage locations. - - Distributes samples across storage units based on total storage space and - maintains mappings between global index and local index within each storage. - """ - # Assign each sample to a storage node. Here we scatter the samples in each GBS to different storage nodes - # Samples are arranged sequentially, similar to generate_data_status_mask - real_global_batch_size = self.global_batch_size * self.num_n_samples - global_batch_per_storage_unit = math.ceil(real_global_batch_size / self.num_storage_units) - - # Build mapping between global index and storage unit for locating each data sample - batch_storage_indices = np.repeat(np.arange(self.num_storage_units), global_batch_per_storage_unit)[ - :real_global_batch_size - ] - self._global_index_storage_rank_mapping = np.tile(batch_storage_indices, self.num_global_batch) - - # Build mapping between global index and local index within each storage unit - indices = np.arange(self.total_storage_size) - pos_in_batch = indices % real_global_batch_size - g = indices // real_global_batch_size - pos_in_block = pos_in_batch % global_batch_per_storage_unit - self.global_index_local_index_mapping = g * global_batch_per_storage_unit + pos_in_block - - def get_data_production_status(self) -> torch.Tensor: - """ - Get the current data production status matrix. The data production status is a 2D matrix that records whether - the corresponding data is ready for each field of each sample. - - Returns: - Tensor representing production status of all data fields - """ - return self.data_production_status - - def get_field_name_mapping(self) -> dict[str, Any]: - """Get the field name to column index mapping. - - Returns: - Dictionary mapping field names to their column indices - """ - return self.field_name_mapping - - def get_data_consumption_status(self) -> dict[str, torch.Tensor]: - """Get consumption status for all tasks. - - Returns: - Dictionary mapping task names to their consumption status tensors - """ - return self.data_consumption_status - - def get_global_index_mapping(self): - """Get global index to storage mapping information. - - Returns: - Tuple containing storage rank mapping and local index mapping - """ - return self._global_index_storage_rank_mapping, self.global_index_local_index_mapping - - def _get_metadata( - self, - data_fields: list[str], - batch_size: int, - global_step: int, - mode: str = "fetch", - task_name: str | None = None, - get_n_samples=False, - *args, - **kwargs, - ) -> BatchMeta: - """ - Retrieve metadata with support for three modes. - - Args: - data_fields: List of field names to include in metadata - batch_size: Number of samples to retrieve - global_step: Global step for which to retrieve metadata - mode: Operation mode - 'insert', 'fetch', or 'force_fetch' - - mode="insert": Insert metadata for new rows (without checking data status) - - mode="fetch": Retrieve metadata for ready data (check data status and sample) - - mode="force_fetch": Directly return metadata (without checking data status) - task_name: Name of the consumer task (required for fetch modes) - get_n_samples: Whether to retrieve n_samples as groups - *args: Additional positional arguments - **kwargs: Additional keyword arguments - - Returns: - BatchMeta object containing the requested metadata - - Raises: - TimeoutError: If waiting for sufficient data times out in fetch mode - """ - if mode == "insert": - # TODO: Currently we only supports put the entire GBS data in one time - assert batch_size == self.global_batch_size * self.num_n_samples, ( - f"batch_size {batch_size} must equal " - f"global_batch_size * num_n_samples {self.global_batch_size * self.num_n_samples}" - ) - start_idx, end_idx = self._step_to_global_index_range(global_step) - batch_global_indexes = list(range(start_idx, end_idx)) - return self._generate_batch_meta(global_step, batch_global_indexes, data_fields, mode) - - assert task_name is not None - if mode == "fetch": - # Find consumable samples within current batch and package into BatchMeta when reading - - start_time = time.time() - while True: - ready_for_consume_idx = self._scan_data_status(data_fields, global_step, task_name, get_n_samples) - - if len(ready_for_consume_idx) >= batch_size: - break - - if time.time() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT: - raise TimeoutError( - f"Timeout while waiting for sufficient data. " - f"Required: {batch_size}, Available: {len(ready_for_consume_idx)}" - ) - - logger.warning( - f"Insufficient data available. Required: {batch_size}, " - f"Available: {len(ready_for_consume_idx)}. Retrying in " - f"{TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..." - ) - time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL) - logger.debug(f"ready for consume idx: {ready_for_consume_idx}") - - batch_global_indexes = random_sampler(ready_for_consume_idx, batch_size, get_n_samples, self.num_n_samples) - elif mode == "force_fetch": - start_idx, end_idx = self._step_to_global_index_range(global_step) - consumer_status = self._get_consumption_status(task_name) - not_consumed_idx = [i for i in range(start_idx, end_idx) if consumer_status[i] == 0] - batch_global_indexes = random_sampler(not_consumed_idx, batch_size, get_n_samples, self.num_n_samples) - - # Mark this batch of data as consumed - consumer_status = self._get_consumption_status(task_name) - consumer_status[batch_global_indexes] = 1 - # Package into metadata - metadata = self._generate_batch_meta(global_step, batch_global_indexes, data_fields, mode) - logger.debug(f"_get_metadata: {metadata}") - - return metadata - - def _scan_data_status( - self, data_fields: list[str], global_step: int, task_name: str, get_n_samples: bool - ) -> list[int]: - """ - Scan data status to find samples ready for consumption. - - Args: - data_fields: List of field names to check - global_step: Global step to scan - task_name: Name of the consumer task - get_n_samples: Whether to return n_samples as groups - - Returns: - List of global indices that are ready for consumption - """ - # Get row and column masks - row_mask, col_mask = self.generate_data_status_mask(data_fields, global_step, task_name) - logger.debug(f"row_mask, col_mask: {row_mask, col_mask}") - - if not row_mask.any() or not col_mask.any(): - return [] - - # Extract subset of data status for relevant fields - logger.debug(f"self.data_production_status: {self.data_production_status}") - data_status_of_interest = self.data_production_status[:, col_mask] - logger.debug(f"data_status_of_interest: {data_status_of_interest}") - - # Use torch.all for vectorized check instead of sum comparison - all_fields_ready = torch.all(data_status_of_interest, dim=1) - - # Filter samples that meet criteria combined with row mask - ready_mask = all_fields_ready & row_mask - - if get_n_samples and self.num_n_samples > 1: - # Reshape to group view and check group completeness - group_all_ready = torch.all(ready_mask.view(-1, self.num_n_samples), dim=1) - - # Get indices of fully ready groups - ready_group_indices = group_all_ready.nonzero(as_tuple=False).flatten() - - # Calculate all sample indices - sample_offset = torch.arange(self.num_n_samples) - ready_for_consume_idx = ( - (ready_group_indices.unsqueeze(1) * self.num_n_samples + sample_offset).flatten().tolist() - ) - - return ready_for_consume_idx - else: - ready_for_consume_idx = torch.nonzero(ready_mask, as_tuple=False).flatten().tolist() - logger.debug(f"ready_for_consume_idx: {ready_for_consume_idx}") - - return ready_for_consume_idx - - def _generate_batch_meta( - self, global_step: int, global_indexes: list[int], data_fields: list[str], mode: str - ) -> BatchMeta: - """ - Generate BatchMeta by resolving storage locations for given global indexes. - - For each global index, looks up the corresponding storage node address using: - - global_index_local_index_mapping: Maps to local index within storage - - _global_index_storage_id_mapping: Maps to storage node identifier - - Args: - global_step: Current global step - global_indexes: List of global indexes to process - data_fields: List of data field names - mode: Operation mode ('fetch', 'insert', or 'force_fetch') - - Returns: - BatchMeta object containing sample metadata with resolved storage locations - """ - global_arr = np.array(global_indexes) - storage_ids = self.global_index_storage_id_mapping[global_arr] - local_indexes = self.global_index_local_index_mapping[global_arr] - - samples = [] - - # Create samples from the flattened BatchMeta data - # TODO: Optimize this - for i, global_index in enumerate(global_indexes): - local_index = local_indexes[i] - storage_id = storage_ids[i] - - # Create FieldMeta objects for each field - fields = [] - for field_name in data_fields: - if mode == "fetch": - production_status = ProductionStatus.READY_FOR_CONSUME # Since we filtered by ready status - # Get per-field dtype and shape for this specific global_index and field - dtype = self._get_per_field_dtype(global_index, field_name) - shape = self._get_per_field_shape(global_index, field_name) - elif mode == "insert": - production_status = ProductionStatus.NOT_PRODUCED # FIXME: not real-time - dtype = None - shape = None - elif mode == "force_fetch": - col_index = self.field_name_mapping.get(field_name) - if col_index is not None and self.data_production_status[global_index, col_index] == 1: - production_status = ProductionStatus.READY_FOR_CONSUME - dtype = self._get_per_field_dtype(global_index, field_name) - shape = self._get_per_field_shape(global_index, field_name) - else: - production_status = ProductionStatus.NOT_PRODUCED - dtype = None - shape = None - field_meta = FieldMeta( - name=field_name, - dtype=dtype, - shape=shape, - production_status=production_status, - ) - fields.append(field_meta) - - sample = SampleMeta( - global_step=global_step, - global_index=global_index, - storage_id=storage_id, - local_index=local_index, - fields={field.name: field for field in fields}, - ) - samples.append(sample) - - return BatchMeta(samples=samples) - - def _update_production_status(self, indexes: list[int], fields: list[str]) -> None: - """ - Update production status for specified indexes and fields. - - Args: - indexes: List of global indexes to update - fields: List of field names to update - """ - # TODO: Replace self.data_production_status == 0 or ==1 operations with ProductionStatus enum - # Update data production status matrix - new_fields = [field for field in fields if field not in self.field_name_mapping] - if new_fields: - needed_fields = len(new_fields) - current_fields = self.data_production_status.shape[1] - # Expand data status matrix if needed - if len(self.field_name_mapping) + needed_fields > current_fields: - add_fields = max(TQ_INIT_FIELD_NUM, needed_fields + 1) - new_matrix = torch.zeros((self.total_storage_size, add_fields), dtype=torch.int8) - self.data_production_status = torch.cat([self.data_production_status, new_matrix], dim=1) - - for field in fields: - if field not in self.field_name_mapping.keys(): - self.field_name_mapping[field] = len(self.field_name_mapping) - self.data_production_status[ - torch.tensor(indexes)[:, None], torch.tensor([self.field_name_mapping.get(field) for field in fields]) - ] = 1 - - def _update_field_info( - self, - fields: list[str], - per_tensor_dtypes: dict[int, dict[str, Any]], - per_tensor_shapes: dict[int, dict[str, Any]], - global_indexes: list[int], - ) -> None: - """ - Store per-field dtype and shape information. - - Args: - fields: List of field names - per_tensor_dtypes: Dict mapping global_index to field dtypes {global_index: {field: dtype}} - per_tensor_shapes: Dict mapping global_index to field shapes {global_index: {field: shape}} - global_indexes: List of global indexes corresponding to the samples - """ - for global_idx in global_indexes: - if global_idx not in self.per_tensor_dtype_mapping: - self.per_tensor_dtype_mapping[global_idx] = {} - if global_idx not in self.per_tensor_shape_mapping: - self.per_tensor_shape_mapping[global_idx] = {} - - for field in fields: - if global_idx in per_tensor_dtypes and field in per_tensor_dtypes[global_idx]: - self.per_tensor_dtype_mapping[global_idx][field] = per_tensor_dtypes[global_idx][field] - if global_idx in per_tensor_shapes and field in per_tensor_shapes[global_idx]: - self.per_tensor_shape_mapping[global_idx][field] = per_tensor_shapes[global_idx][field] - - def _init_zmq_socket(self): - """ - Initialize ZMQ sockets for communication. - - Sets up three ZMQ service ports for: - 1. Receiving handshake requests from storage - 2. Handling client data read/write requests - 3. Receiving status update signals from storage - """ - self.zmq_context = zmq.Context() - - self._node_ip = get_node_ip_address() - self._handshake_socket_port = get_free_port() - self._request_handle_socket_port = get_free_port() - self._data_status_update_socket_port = get_free_port() - - self.handshake_socket = create_zmq_socket( - ctx=self.zmq_context, - socket_type=zmq.ROUTER, - ) - self.handshake_socket.bind(f"tcp://{self._node_ip}:{self._handshake_socket_port}") - - self.request_handle_socket = create_zmq_socket( - ctx=self.zmq_context, - socket_type=zmq.ROUTER, - ) - self.request_handle_socket.bind(f"tcp://{self._node_ip}:{self._request_handle_socket_port}") - - self.data_status_update_socket = create_zmq_socket( - ctx=self.zmq_context, - socket_type=zmq.ROUTER, - ) - self.data_status_update_socket.bind(f"tcp://{self._node_ip}:{self._data_status_update_socket_port}") - - self.zmq_server_info = ZMQServerInfo.create( - role=TransferQueueRole.CONTROLLER, - id=self.controller_id, - ip=self._node_ip, - ports={ - "handshake_socket": self._handshake_socket_port, - "request_handle_socket": self._request_handle_socket_port, - "data_status_update_socket": self._data_status_update_socket_port, - }, - ) - - def _wait_connection(self): - """Wait for all storage instances to complete handshake. - - Clients don't need handshake to support dynamic scaling. Continuously - listens for handshake messages until all expected storage units connect. - """ - # TODO(zjj): Consider if retransmission is needed (assuming cases where Storage doesn't receive ACK) - connected_storage_units = set() - while len(connected_storage_units) < self.num_storage_units: - identity, serialized_msg = self.handshake_socket.recv_multipart() - request_msg = ZMQMessage.deserialize(serialized_msg) - if request_msg.request_type == ZMQRequestType.HANDSHAKE: - connected_storage_units.add(request_msg.sender_id) - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.HANDSHAKE_ACK, - sender_id=self.controller_id, - body={}, - ).serialize() - self.handshake_socket.send_multipart([identity, response_msg]) - logger.info("Controller sent handshake ack successfully!") - self.global_index_storage_id_mapping = np.array(sorted(list(connected_storage_units)))[ - self._global_index_storage_rank_mapping - ] - self.handshake_done.set() - - def _start_process_handshake(self): - """Start the handshake process thread.""" - self.handshake_done = threading.Event() - self.wait_connection_thread = Thread( - target=self._wait_connection, name="TransferQueueControllerWaitConnectionThread", daemon=True - ) - self.wait_connection_thread.start() - - def _start_process_update_data_status(self): - """Start the data status update processing thread.""" - self.process_update_data_status_thread = Thread( - target=self._update_data_status, name="TransferQueueControllerProcessUpdateDataStatusThread", daemon=True - ) - self.process_update_data_status_thread.start() - - def _start_process_request(self): - """Start the request processing thread.""" - self.process_request_thread = Thread( - target=self._process_request, name="TransferQueueControllerProcessRequestThread", daemon=True - ) - self.process_request_thread.start() - - def _process_request(self): - """Main request processing loop. - - Handles various request types including metadata retrieval, - consumption status checks, and clear operations. - """ - self.handshake_done.wait() - while True: - # ROUTER socket receives multi-part messages - identity, serialized_msg = self.request_handle_socket.recv_multipart() - request_msg = ZMQMessage.deserialize(serialized_msg) - - if request_msg.request_type == ZMQRequestType.GET_META: - params = request_msg.body - logger.info("Controller preparing to get metadata...") - metadata = self._get_metadata( - data_fields=params["data_fields"], - batch_size=params["batch_size"], - global_step=params["global_step"], - mode=params.get("mode", "fetch"), - task_name=params.get("task_name", None), - get_n_samples=params.get("get_n_samples", False), - ) - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_META_RESPONSE, - sender_id=self.controller_id, - receiver_id=request_msg.sender_id, - body={"metadata": metadata}, - ) - elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META: - params = request_msg.body - metadata = self._get_metadata( - data_fields=[], - batch_size=self.global_batch_size * self.num_n_samples, - global_step=params["global_step"], - mode="insert", - ) - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_CLEAR_META_RESPONSE, - sender_id=self.controller_id, - receiver_id=request_msg.sender_id, - body={"metadata": metadata}, - ) - elif request_msg.request_type == ZMQRequestType.CLEAR_META: - params = request_msg.body - self.clear(global_step=params["global_step"]) - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_META_RESPONSE, - sender_id=self.controller_id, - receiver_id=request_msg.sender_id, - body={"message": f"Clear operation completed by controller {self.controller_id}"}, - ) - elif request_msg.request_type == ZMQRequestType.CHECK_CONSUMPTION: - # Check consumption status - params = request_msg.body - global_step = params["global_step"] - - consumer_status = self._get_consumption_status(params["task_name"]) - start_idx, end_idx = self._step_to_global_index_range(global_step) - batch_status = consumer_status[start_idx:end_idx] - consumed = torch.all(batch_status == 1).item() - - # Build response message - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CONSUMPTION_RESPONSE, - sender_id=self.controller_id, - receiver_id=request_msg.sender_id, - body={ - "global_step": global_step, - "consumed": consumed, - }, - ) - self.request_handle_socket.send_multipart([identity, response_msg.serialize()]) - logger.debug("Controller request_handle_socket sent multipart successfully!") - - def _update_data_status(self): - """Process data status update messages from storage units. - - Continuously listens for data update notifications and updates - internal production status and field information accordingly. - """ - # Receive data status update information from storage - while True: - logger.debug("Preparing _update_data_status...") - identity, serialized_msg = self.data_status_update_socket.recv_multipart() - logger.debug("Controller received update_data_status request!") - request_msg = ZMQMessage.deserialize(serialized_msg) - logger.debug(f"[{self.controller_id}]: Controller received update_data_status request_msg: {request_msg}") - - if request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE: - message_data = request_msg.body - - fields = message_data.get("fields", []) - global_indexes = message_data.get("global_indexes", []) - per_tensor_dtypes = message_data.get("dtypes", {}) # Now a dict of lists - per_tensor_shapes = message_data.get("shapes", {}) # Now a dict of lists - # Update data production status - logger.debug(f"global_indexes, fields: {global_indexes, fields}") - self._update_production_status(global_indexes, fields) - self._update_field_info(fields, per_tensor_dtypes, per_tensor_shapes, global_indexes) - logger.info("Controller updated production status successfully!") - - # Send acknowledgment response - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK, - sender_id=self.controller_id, - body={ - "controller_id": self.controller_id, - "message": f"Data update acknowledged from controller {self.controller_id}", - }, - ) - self.data_status_update_socket.send_multipart([identity, response_msg.serialize()]) - logger.info("Controller sent DATA_UPDATE_ACK successfully!") - elif request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR: - # Handle data update errors - error_msg = request_msg.body.get("message", "Unknown error") - logger.error(f"Data update error from storage: {error_msg}") - - # Send error acknowledgment response - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK, - sender_id=self.controller_id, - body={ - "controller_id": self.controller_id, - "message": f"Error notification acknowledged from controller {self.controller_id}", - }, - ) - self.data_status_update_socket.send_multipart([identity, response_msg.serialize()]) - - def get_zmq_server_info(self) -> ZMQServerInfo: - """Get ZMQ server connection information. - - Returns: - ZMQServerInfo object containing connection details - """ - return self.zmq_server_info - - def clear(self, global_step: int): - """Clear data for a specific global batch. - - Resets production and consumption status for all data in the specified - global step. Currently only supports clearing single GBS at a time. - - Args: - global_step: The global step to clear data for - """ - start_idx, end_idx = self._step_to_global_index_range(global_step) - - self.data_production_status[start_idx:end_idx, :] = 0 - for task_name in self.data_consumption_status: - self.data_consumption_status[task_name][start_idx:end_idx] = 0 diff --git a/verl/experimental/transfer_queue/metadata.py b/verl/experimental/transfer_queue/metadata.py deleted file mode 100644 index 6d81e7f2ca3..00000000000 --- a/verl/experimental/transfer_queue/metadata.py +++ /dev/null @@ -1,602 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import dataclasses -from dataclasses import dataclass -from typing import Any, Optional - -import numpy as np -from tensordict import TensorDict - -from verl.experimental.transfer_queue.utils.utils import ProductionStatus - - -@dataclass -class FieldMeta: - """ - Records the metadata of a single data field. (name, dtype, shape, etc.) - """ - - # field name (e.g., 'prompt', 'response', etc.) - name: str - - # data schema info - dtype: Optional[Any] - shape: Optional[Any] - - # data status info - production_status: ProductionStatus = ProductionStatus.NOT_PRODUCED - - def __str__(self) -> str: - return ( - f"FieldMeta(name='{self.name}', dtype={self.dtype}, " - f"shape={self.shape}, production_status={self.production_status})" - ) - - @property - def is_ready(self) -> bool: - """Check if this field is ready for consumption""" - return self.production_status == ProductionStatus.READY_FOR_CONSUME - - -@dataclass -class SampleMeta: - """ - Records the metadata of a single data sample (stored as a row in the data system). - """ - - # algorithm related info - global_step: int # global step, used for data versioning - - # data retrival info - global_index: int # global row index, uniquely identifies a data sample - storage_id: str # storage unit id - local_index: int # local row index in the storage unit - - # data fields info - # this fields may not contain all the fields of the sample, but only fields-of-interest - fields: dict[str, FieldMeta] - - def __post_init__(self): - """Initialize is_ready property based on field readiness""" - # Check if all fields are ready and update is_ready property - object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values())) - - def __str__(self) -> str: - return ( - f"SampleMeta(global_step={self.global_step}, " - f"global_index={self.global_index}, storage_id='{self.storage_id}', " - f"local_index={self.local_index}, fields={self.fields})" - ) - - @property - def field_names(self) -> list[str]: - """Get list of field names for this sample""" - return list(self.fields.keys()) - - @property - def batch_index(self) -> int: - """Get the batch index of this sample (to be set by BatchMeta)""" - return getattr(self, "_batch_index", -1) - - def get_field_by_name(self, name: str) -> Optional[FieldMeta]: - """Get FieldMeta by field name""" - return self.fields.get(name) - - def has_field(self, name: str) -> bool: - """Check if this sample has a specific field""" - return name in self.fields - - def is_field_ready(self, field_name: str) -> bool: - """Check if a specific field is ready for consumption""" - field = self.fields.get(field_name) - return field.is_ready if field else False - - def add_fields(self, fields: dict[str, FieldMeta]) -> "SampleMeta": - """ - Add new fields to this sample. New fields will be initialized with given dtype, shape - and production_status (if provided). If not provided, default values (None, None, READY_FOR_CONSUME) - will be used. - This modifies the sample in-place to include the new fields. - """ - self.fields = _union_fields(self.fields, fields) - # Update is_ready property - object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values())) - return self - - def union(self, other: "SampleMeta", validate: bool = True) -> "SampleMeta": - """ - Create a union of this sample's fields with another sample's fields. - Assume both samples have the same global index. If fields overlap, the - fields in this sample will be replaced by the other sample's fields. - - Args: - other: Another SampleMeta to union with - validate: Whether to validate union conditions - - Returns: - New SampleMeta with unioned fields (None if validation fails) - """ - if validate: - if self.global_index != other.global_index: - raise ValueError( - f"Error: Global indexes ({self.global_index} and {other.global_index}) do not match for union." - ) - - # Merge fields - self.fields = _union_fields(self.fields, other.fields) - - # Update is_ready property - object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values())) - return self - - @property - def is_ready(self) -> bool: - """Check if all fields in this sample are ready for consumption""" - return getattr(self, "_is_ready", False) - - @property - def production_status(self) -> dict[str, ProductionStatus]: - """Get production status for all fields (backward compatibility)""" - return {name: field.production_status for name, field in self.fields.items()} - - -@dataclass -class StorageMetaGroup: - """ - Represents a group of samples stored in the same storage unit. - Used to organize samples by their storage_id for efficient client operations. - """ - - storage_id: str - sample_metas: list[SampleMeta] = dataclasses.field(default_factory=list) - - def add_sample_meta(self, sample_meta: SampleMeta) -> None: - """Add a SampleMeta object to this storage group""" - self.sample_metas.append(sample_meta) - - def get_batch_indexes(self) -> list[int]: - """Get all internal indexes from stored SampleMeta objects""" - return [meta.batch_index for meta in self.sample_metas] - - def get_global_indexes(self) -> list[int]: - """Get all global indexes from stored SampleMeta objects""" - return [meta.global_index for meta in self.sample_metas] - - def get_local_indexes(self) -> list[int]: - """Get all local indexes from stored SampleMeta objects""" - return [meta.local_index for meta in self.sample_metas] - - def get_field_names(self) -> list[str]: - """Get all unique field names from stored SampleMeta objects""" - all_fields: set[str] = set() - for meta in self.sample_metas: - all_fields.update(meta.fields.keys()) - return list(all_fields) - - def get_transfer_info(self, field_names: Optional[list[str]] = None) -> dict[str, list | dict]: - """Convert to dictionary format for backward compatibility""" - if field_names is None: - field_names = self.get_field_names() - return { - "batch_indexes": self.get_batch_indexes(), - "global_indexes": self.get_global_indexes(), - "local_indexes": self.get_local_indexes(), - "fields": field_names, - "field_data": {}, # Placeholder for field data to be filled later - } - - @property - def size(self) -> int: - """Number of samples in this storage meta group""" - return len(self.sample_metas) - - @property - def is_empty(self) -> bool: - """Check if this storage meta group is empty""" - return len(self.sample_metas) == 0 - - def __len__(self) -> int: - """Number of samples in this storage meta group""" - return self.size - - def __bool__(self) -> bool: - """Truthiness based on whether group has samples""" - return not self.is_empty - - def __str__(self) -> str: - return f"StorageMetaGroup(storage_id='{self.storage_id}', size={self.size})" - - -@dataclass -class BatchMeta: - """ - Records the metadata of a batch of data samples. - """ - - samples: list[SampleMeta] - extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) - - def __post_init__(self): - """Initialize all computed properties during initialization""" - # Basic properties - object.__setattr__(self, "_size", len(self.samples)) - object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples)) - - # Pre-compute all list properties for better performance - if self.samples: - for idx, sample in enumerate(self.samples): - object.__setattr__(sample, "_batch_index", idx) # Ensure batch_index is set correctly - - object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples]) - object.__setattr__(self, "_local_indexes", [sample.local_index for sample in self.samples]) - object.__setattr__(self, "_storage_ids", [sample.storage_id for sample in self.samples]) - - # assume all samples have the same fields. - object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names)) - - # Initialize storage groups for efficient client operations - storage_meta_groups = self._build_storage_meta_groups() - object.__setattr__(self, "_storage_meta_groups", storage_meta_groups) - else: - object.__setattr__(self, "_global_indexes", []) - object.__setattr__(self, "_local_indexes", []) - object.__setattr__(self, "_storage_ids", []) - object.__setattr__(self, "_field_names", []) - object.__setattr__(self, "_storage_meta_groups", {}) - - @property - def size(self) -> int: - """Return the number of samples in this batch""" - return getattr(self, "_size", 0) - - @property - def global_indexes(self) -> list[int]: - """Get all global indexes in this batch""" - return getattr(self, "_global_indexes", []) - - @property - def field_names(self) -> list[str]: - """Get all unique field names in this batch""" - return getattr(self, "_field_names", []) - - @property - def local_indexes(self) -> list[int]: - """Get all local indexes in this batch""" - return getattr(self, "_local_indexes", []) - - @property - def storage_ids(self) -> list[str]: - """Get all storage unit IDs in this batch""" - return getattr(self, "_storage_ids", []) - - @property - def is_ready(self) -> bool: - """Check if all samples in this batch are ready for consumption""" - # TODO: get ready status from controller realtime - return getattr(self, "_is_ready", False) - - def _build_storage_meta_groups(self) -> dict[str, StorageMetaGroup]: - """Build storage groups from samples during initialization""" - storage_meta_groups: dict[str, StorageMetaGroup] = {} - - for sample in self.samples: - storage_id = sample.storage_id - if storage_id not in storage_meta_groups: - storage_meta_groups[storage_id] = StorageMetaGroup(storage_id=storage_id) - - # Use add_sample_meta to store SampleMeta references directly - storage_meta_groups[storage_id].add_sample_meta(sample) - - return storage_meta_groups - - @property - def storage_meta_groups(self) -> dict[str, StorageMetaGroup]: - """Get storage groups organized by storage_id""" - return getattr(self, "_storage_meta_groups", {}) - - @property - def storage_unit_ids(self) -> list[str]: - """Get list of all storage unit IDs""" - return list(self.storage_meta_groups.keys()) - - def get_storage_meta_groups(self, storage_id: str) -> Optional[StorageMetaGroup]: - """Get storage group by storage ID""" - return self.storage_meta_groups.get(storage_id) - - # Extra info interface methods - def get_extra_info(self, key: str, default: Any = None) -> Any: - """Get extra info by key""" - return self.extra_info.get(key, default) - - def set_extra_info(self, key: str, value: Any) -> None: - """Set extra info by key""" - self.extra_info[key] = value - - def update_extra_info(self, info_dict: dict[str, Any]) -> None: - """Update extra info with multiple key-value pairs""" - self.extra_info.update(info_dict) - - def remove_extra_info(self, key: str) -> Any: - """Remove extra info by key and return its value""" - return self.extra_info.pop(key, None) - - def clear_extra_info(self) -> None: - """Clear all extra info""" - self.extra_info.clear() - - def has_extra_info(self, key: str) -> bool: - """Check if extra info contains a specific key""" - return key in self.extra_info - - def get_all_extra_info(self) -> dict[str, Any]: - """Get all extra info as a dictionary""" - return self.extra_info.copy() - - def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "BatchMeta": - """ - Add new fields from a TensorDict to all samples in this batch. - This modifies each sample in-place to include the new fields. - - Args: - tensor_dict (TensorDict): The input TensorDict containing new fields. - set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME. Default is True. - """ - fields = _extract_field_metas(tensor_dict, set_all_ready) - for idx, sample in enumerate(self.samples): - sample.add_fields(fields=fields[idx]) - - # Update batch-level fields cache - object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names)) - object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples)) - return self - - def __len__(self) -> int: - """Return the number of samples in this batch.""" - return len(self.samples) - - def __getitem__(self, item): - if isinstance(item, int | np.integer): - sample_meta = self.samples[item] if self.samples else [] - return BatchMeta(samples=[sample_meta], extra_info=self.extra_info) - else: - raise TypeError(f"Indexing with {type(item)} is not supported now!") - - def chunk(self, chunks: int) -> list["BatchMeta"]: - """ - Split this batch into smaller chunks. - - Args: - chunks: number of chunks - - Return: - List of smaller BatchMeta chunks - """ - chunk_list = [] - n = len(self.samples) - - # Calculate the base size and remainder of each chunk - base_size = n // chunks - remainder = n % chunks - - start = 0 - for i in range(chunks): - # Calculate the size of the current chunk(the first remainder chunk is 1 more than the base size) - current_chunk_size = base_size + 1 if i < remainder else base_size - end = start + current_chunk_size - chunk_samples = self.samples[start:end] - chunk = BatchMeta(samples=chunk_samples, extra_info=self.extra_info.copy()) - chunk_list.append(chunk) - start = end - return chunk_list - - @classmethod - def concat(cls, data: list["BatchMeta"], validate: bool = True) -> Optional["BatchMeta"]: - """ - Concatenate multiple BatchMeta chunks into one large batch. - - Args: - data: List of BatchMeta chunks to concatenate - validate: Whether to validate concatenation conditions - - Returns: - Concatenated BatchMeta - - Raises: - ValueError: If validation fails (e.g., field names do not match) - """ - if not data: - return None - - if validate: - base_fields = data[0].field_names - - for chunk in data: - if chunk.field_names != base_fields: - raise ValueError("Error: Field names do not match for concatenation.") - - # Combine all samples - all_samples = [] - for chunk in data: - all_samples.extend(chunk.samples) - # Merge all extra_info dictionaries from the chunks - merged_extra_info = {} - for chunk in data: - merged_extra_info.update(chunk.extra_info) - return BatchMeta(samples=all_samples, extra_info=merged_extra_info) - - def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMeta"]: - """ - Create a union of this batch's fields with another batch's fields. - Assume both batches have the same global indices. If fields overlap, the - fields in this batch will be replaced by the other batch's fields. - - Args: - other: Another BatchMeta to union with - validate: Whether to validate union conditions - - Returns: - New BatchMeta with unioned fields - - Raises: - ValueError: If validation fails (e.g., batch sizes or global indexes do not match) - """ - if validate: - if self.size != other.size: - raise ValueError("Error: Batch sizes do not match for union.") - - self_global_indexes = sorted(self.global_indexes) - other_global_indexes = sorted(other.global_indexes) - if self_global_indexes != other_global_indexes: - raise ValueError("Error: Global indexes do not match for union.") - - # Create a mapping from global_index to SampleMeta in the other batch - other_sample_map = {sample.global_index: sample for sample in other.samples} - - # Merge samples - merged_samples = [] - for sample in self.samples: - if sample.global_index in other_sample_map: - other_sample = other_sample_map[sample.global_index] - merged_sample = sample.union(other_sample, validate=validate) - merged_samples.append(merged_sample) - else: - merged_samples.append(sample) - - # Merge extra info dictionaries - merged_extra_info = {**self.extra_info, **other.extra_info} - - return BatchMeta(samples=merged_samples, extra_info=merged_extra_info) - - def reorder(self, indices: list[int]): - """ - Reorder the SampleMeta in the BatchMeta according to the given indices. - - The operation is performed in-place, modifying the current BatchMeta's SampleMeta order. - - Args: - indices : list[int] - A list of integers specifying the new order of SampleMeta. Each integer - represents the current index of the SampleMeta in the BatchMeta. - """ - # Reorder the samples - reordered_samples = [self.samples[i] for i in indices] - object.__setattr__(self, "samples", reordered_samples) - - # Update necessary attributes - self._update_after_reorder() - - def _update_after_reorder(self) -> None: - """Update related attributes specifically for the reorder operation""" - # Update batch_index for each sample - for idx, sample in enumerate(self.samples): - object.__setattr__(sample, "_batch_index", idx) - - # Update cached index lists - object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples]) - object.__setattr__(self, "_local_indexes", [sample.local_index for sample in self.samples]) - object.__setattr__(self, "_storage_ids", [sample.storage_id for sample in self.samples]) - - # Rebuild storage groups - storage_meta_groups = self._build_storage_meta_groups() - object.__setattr__(self, "_storage_meta_groups", storage_meta_groups) - - # Note: No need to update _size, _field_names, _is_ready, etc., as these remain unchanged after reorder - - @classmethod - def from_samples( - cls, samples: SampleMeta | list[SampleMeta], extra_info: Optional[dict[str, Any]] = None - ) -> "BatchMeta": - """ - Create a BatchMeta from a single SampleMeta or a list of SampleMeta objects. - - Args: - samples: A single SampleMeta or a list of SampleMeta objects - extra_info: Optional additional information to store with the batch - - Returns: - BatchMeta instance containing the provided sample(s) - - Example: - >>> sample_meta = SampleMeta(...) - >>> batch_meta = BatchMeta.from_samples(sample_meta) - - >>> sample_metas = [sample1, sample2, sample3] - >>> batch_meta = BatchMeta.from_samples(sample_metas, extra_info={"source": "training"}) - """ - if extra_info is None: - extra_info = {} - - if isinstance(samples, SampleMeta): - samples = [samples] - - return cls(samples=samples, extra_info=extra_info) - - @classmethod - def empty(cls, extra_info: Optional[dict[str, Any]] = None) -> "BatchMeta": - """ - Create an empty BatchMeta with no samples. - - Args: - extra_info: Optional additional information to store with the batch - - Returns: - Empty BatchMeta instance - - Example: - >>> empty_batch = BatchMeta.empty() - """ - if extra_info is None: - extra_info = {} - return cls(samples=[], extra_info=extra_info) - - -def _union_fields(fields1: dict[str, FieldMeta], fields2: dict[str, FieldMeta]) -> dict[str, FieldMeta]: - """Union two sample's fields. If fields overlap, the fields in fields1 will be replaced by fields2.""" - for name in fields2.keys(): - fields1[name] = fields2[name] - return fields1 - - -def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) -> list[dict[str, FieldMeta]]: - """ - Extract field metas from a TensorDict. If data in tensor_dict does not have dtype or shape attribute, - the corresponding dtype or shape will be set to None. - - Args: - tensor_dict (TensorDict): The input TensorDict. - set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME. - Otherwise, set to NOT_PRODUCED. Default is True. - - Returns: - all_fields (list[dict[FieldMeta]]): A list of dictionaries containing field metadata. - """ - all_fields = [] - batch_size = tensor_dict.batch_size[0] - for idx in range(batch_size): - fields = {} - sample = tensor_dict[idx] - for name, value in sample.items(): - fields[name] = FieldMeta( - name=name, - dtype=value.dtype if hasattr(value, "dtype") else None, - shape=value.shape if hasattr(value, "shape") else None, - production_status=ProductionStatus.READY_FOR_CONSUME - if set_all_ready - else ProductionStatus.NOT_PRODUCED, - ) - all_fields.append(fields) - - return all_fields diff --git a/verl/experimental/transfer_queue/storage.py b/verl/experimental/transfer_queue/storage.py deleted file mode 100644 index c8f908ee8d8..00000000000 --- a/verl/experimental/transfer_queue/storage.py +++ /dev/null @@ -1,516 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import time -from operator import itemgetter -from threading import Thread -from uuid import uuid4 - -import ray -import torch -import zmq -from ray.util import get_node_ip_address -from tensordict import NonTensorStack, TensorDict - -from verl.experimental.transfer_queue.utils.utils import TransferQueueRole -from verl.experimental.transfer_queue.utils.zmq_utils import ( - ZMQMessage, - ZMQRequestType, - ZMQServerInfo, - create_zmq_socket, - get_free_port, -) - -logger = logging.getLogger(__name__) -logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) - -TQ_STORAGE_POLLER_TIMEOUT = os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 1000) -TQ_STORAGE_HANDSHAKE_TIMEOUT = int(os.environ.get("TQ_STORAGE_HANDSHAKE_TIMEOUT", 30)) -TQ_DATA_UPDATE_RESPONSE_TIMEOUT = int(os.environ.get("TQ_DATA_UPDATE_RESPONSE_TIMEOUT", 600)) - - -class StorageUnitData: - """ - Class used for storing several elements, each element is composed of several fields and corresponding data, like: - ##################################################### - # local_index | field_name1 | field_name2 | ... # - # 0 | item1 | item2 | ... # - # 1 | item3 | item4 | ... # - # 2 | item5 | item6 | ... # - ##################################################### - """ - - def __init__(self, storage_size: int): - # Dict containing field names and corresponding data in the field, e.g. {"field_name1": [data1, data2, ...]} - self.field_data: dict[str, list] = {} - - # Maximum number of elements stored in storage unit - self.storage_size = storage_size - - def get_data(self, fields: list[str], local_indexes: list[int]) -> TensorDict[str, list]: - """ - Get data from storage unit according to given fields and local_indexes. - - param: - fields: Field names used for getting data. - local_indexes: Local indexes used for getting data. - return: - TensorDict with field names as keys, corresponding data list as values. - """ - result: dict[str, list] = {} - - for field in fields: - # Validate field name - if field not in self.field_data: - raise ValueError( - f"StorageUnitData get_data operation receive invalid field: {field} beyond {self.field_data.keys()}" - ) - - if len(local_indexes) == 1: - # The unsqueeze op make the shape from n to (1, n) - gathered_item = self.field_data[field][local_indexes[0]] - if not isinstance(gathered_item, torch.Tensor): - result[field] = NonTensorStack(gathered_item) - else: - result[field] = gathered_item.unsqueeze(0) - else: - gathered_items = list(itemgetter(*local_indexes)(self.field_data[field])) - - if gathered_items: - all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items) - if all_tensors: - result[field] = torch.nested.as_nested_tensor(gathered_items) - else: - result[field] = NonTensorStack(*gathered_items) - - return TensorDict(result) - - def put_data(self, field_data: TensorDict[str, list], local_indexes: list[int]) -> None: - """ - Put or update data into storage unit according to given field_data and local_indexes. - - param: - field_data: Dict with field names as keys, corresponding data in the field as values. - local_indexes: Local indexes used for putting data. - """ - for f in field_data.keys(): - for i, idx in enumerate(local_indexes): - # Validate local_indexes - if idx < 0 or idx >= self.storage_size: - raise ValueError( - f"StorageUnitData put_data operation receive invalid local_index: {idx} beyond " - f"storage_size: {self.storage_size}" - ) - - if f not in self.field_data: - # Initialize new field value list with None - self.field_data[f] = [None] * self.storage_size - - self.field_data[f][idx] = field_data[f][i] - - def clear(self, local_indexes: list[int]) -> None: - """ - Clear data at specified local_indexes by setting all related fields to None. - - param: - local_indexes: local_indexes to clear. - """ - # Validate local_indexes - for idx in local_indexes: - if idx < 0 or idx >= self.storage_size: - raise ValueError( - f"StorageUnitData clear operation receive invalid local_index: {idx} beyond " - f"storage_size: {self.storage_size}" - ) - - # Clear data at specified local_indexes - for f in self.field_data: - for idx in local_indexes: - self.field_data[f][idx] = None - - -@ray.remote(num_cpus=1) -class TransferQueueStorageSimpleUnit: - def __init__(self, storage_size: int): - super().__init__() - self.storage_unit_id = f"TQ_STORAGE_UNIT_{uuid4()}" - self.storage_size = storage_size - self.controller_infos: dict[str, ZMQServerInfo] = {} - - self.experience_data = StorageUnitData(self.storage_size) - - self.zmq_server_info = ZMQServerInfo.create( - role=TransferQueueRole.STORAGE, - id=str(self.storage_unit_id), - ip=get_node_ip_address(), - ports={"put_get_socket": get_free_port()}, - ) - self._init_zmq_socket() - - def _init_zmq_socket(self) -> None: - """ - Initialize ZMQ socket connections between storage unit and controllers/clients: - - controller_handshake_sockets: - Handshake between storage unit and controllers. - - data_status_update_sockets: - Broadcast data update status from storage unit to controllers when handling put operation. - - put_get_socket: - Handle put/get requests from clients. - """ - self.zmq_context = zmq.Context() - - self.controller_handshake_sockets: dict[str, zmq.Socket] = {} - self.data_status_update_sockets: dict[str, zmq.Socket] = {} - - self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER) - self.put_get_socket.bind(self.zmq_server_info.to_addr("put_get_socket")) - - def register_controller_info(self, controller_infos: dict[str, ZMQServerInfo]) -> None: - """ - Build connections between storage unit and controllers, start put/get process. - - param: - controller_infos: Dict with controller infos. - """ - self.controller_infos = controller_infos - - self._init_zmq_sockets_with_controller_infos() - self._connect_to_controller() - self._start_process_put_get() - - def _init_zmq_sockets_with_controller_infos(self) -> None: - """Initialize ZMQ sockets between storage unit and controllers for handshake.""" - for controller_id in self.controller_infos.keys(): - self.controller_handshake_sockets[controller_id] = create_zmq_socket( - self.zmq_context, - zmq.DEALER, - identity=f"{self.storage_unit_id}-controller_handshake_sockets-{uuid4()}".encode(), - ) - self.data_status_update_sockets[controller_id] = create_zmq_socket( - self.zmq_context, - zmq.DEALER, - identity=f"{self.storage_unit_id}-data_status_update_sockets-{uuid4()}".encode(), - ) - - def _connect_to_controller(self) -> None: - """Connect storage unit to all controllers.""" - connected_controllers: set[str] = set() - - # Create zmq poller for handshake confirmation between controller and storage unit - poller = zmq.Poller() - - for controller_id, controller_info in self.controller_infos.items(): - self.controller_handshake_sockets[controller_id].connect(controller_info.to_addr("handshake_socket")) - logger.debug( - f"[{self.zmq_server_info.id}]: Handshake connection from storage unit id #{self.zmq_server_info.id} " - f"to controller id #{controller_id} establish successfully." - ) - - # Send handshake request to controllers - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.HANDSHAKE, - sender_id=self.zmq_server_info.id, - body={ - "storage_unit_id": self.storage_unit_id, - "storage_size": self.storage_size, - }, - ).serialize() - - self.controller_handshake_sockets[controller_id].send(request_msg) - logger.debug( - f"[{self.zmq_server_info.id}]: Send handshake request from storage unit id #{self.zmq_server_info.id} " - f"to controller id #{controller_id} successfully." - ) - - poller.register(self.controller_handshake_sockets[controller_id], zmq.POLLIN) - - start_time = time.time() - while ( - len(connected_controllers) < len(self.controller_infos) - and time.time() - start_time < TQ_STORAGE_HANDSHAKE_TIMEOUT - ): - socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT)) - - for controller_handshake_socket in self.controller_handshake_sockets.values(): - if controller_handshake_socket in socks: - response_msg = ZMQMessage.deserialize(controller_handshake_socket.recv()) - - if response_msg.request_type == ZMQRequestType.HANDSHAKE_ACK: - connected_controllers.add(response_msg.sender_id) - logger.debug( - f"[{self.zmq_server_info.id}]: Get handshake ACK response from " - f"controller id #{str(response_msg.sender_id)} to storage unit id " - f"#{self.zmq_server_info.id} successfully." - ) - - if len(connected_controllers) < len(self.controller_infos): - logger.warning( - f"[{self.zmq_server_info.id}]: Only get {len(connected_controllers)} / {len(self.controller_infos)} " - f"successful handshake connections to controllers from storage unit id #{self.zmq_server_info.id}" - ) - - def _start_process_put_get(self) -> None: - """Create a daemon thread and start put/get process.""" - self.process_put_get_thread = Thread( - target=self._process_put_get, name=f"StorageUnitProcessPutGetThread-{self.zmq_server_info.id}", daemon=True - ) - self.process_put_get_thread.start() - - def _process_put_get(self) -> None: - """Process put_get_socket request.""" - poller = zmq.Poller() - poller.register(self.put_get_socket, zmq.POLLIN) - - while True: - socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT)) - - if self.put_get_socket in socks: - identity, serialized_msg = self.put_get_socket.recv_multipart() - - try: - request_msg = ZMQMessage.deserialize(serialized_msg) - operation = request_msg.request_type - logger.debug(f"[{self.zmq_server_info.id}]: receive operation: {operation}, message: {request_msg}") - - if operation == ZMQRequestType.PUT_DATA: - response_msg = self._handle_put(request_msg) - elif operation == ZMQRequestType.GET_DATA: - response_msg = self._handle_get(request_msg) - elif operation == ZMQRequestType.CLEAR_DATA: - response_msg = self._handle_clear(request_msg) - else: - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR, - sender_id=self.zmq_server_info.id, - body={ - "message": f"Storage unit id #{self.zmq_server_info.id} " - f"receive invalid operation: {operation}." - }, - ) - except Exception as e: - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_GET_ERROR, - sender_id=self.zmq_server_info.id, - body={ - "message": f"Storage unit id #{self.zmq_server_info.id} occur error in processing " - f"put/get/clear request, detail error message: {str(e)}." - }, - ) - - self.put_get_socket.send_multipart([identity, response_msg.serialize()]) - - def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: - """ - Handle put request, add or update data into storage unit. - - param: - data_parts: ZMQMessage from client. - return: - Put data success response ZMQMessage. - """ - try: - global_indexes = data_parts.body["global_indexes"] - local_indexes = data_parts.body["local_indexes"] - field_data = data_parts.body["field_data"] # field_data should be in {field_name: [real data]} format. - - self.experience_data.put_data(field_data, local_indexes) - - # After put operation finish, send a message to the client - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.PUT_DATA_RESPONSE, sender_id=self.zmq_server_info.id, body={} - ) - - # Gather per-field dtype and shape information for each field - # global_indexes, local_indexes, and field_data correspond one-to-one - per_field_dtypes = {} - per_field_shapes = {} - - # Initialize the data structure for each global index - for global_idx in global_indexes: - per_field_dtypes[global_idx] = {} - per_field_shapes[global_idx] = {} - - # For each field, extract dtype and shape for each sample - for field in field_data.keys(): - for i, data_item in enumerate(field_data[field]): - global_idx = global_indexes[i] - per_field_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None - per_field_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None - - # Broadcast data update message to all controllers with per-field dtype/shape information - self._notify_data_update(list(field_data.keys()), global_indexes, per_field_dtypes, per_field_shapes) - return response_msg - except Exception as e: - return ZMQMessage.create( - request_type=ZMQRequestType.PUT_ERROR, - sender_id=self.zmq_server_info.id, - body={ - "message": f"Failed to put data into storage unit id " - f"#{self.zmq_server_info.id}, detail error message: {str(e)}" - }, - ) - - def _notify_data_update(self, fields, global_indexes, dtypes, shapes) -> None: - """ - Broadcast data status update to all controllers. - - param: - fields: data update related fields. - global_indexes: data update related global_indexes. - dtypes: per-field dtypes for each field, in {global_index: {field: dtype}} format. - shapes: per-field shapes for each field, in {global_index: {field: shape}} format. - """ - # Create zmq poller for notifying data update information - poller = zmq.Poller() - - # Connect data status update socket to all controllers - for controller_id, controller_info in self.controller_infos.items(): - data_status_update_socket = self.data_status_update_sockets[controller_id] - data_status_update_socket.connect(controller_info.to_addr("data_status_update_socket")) - logger.debug( - f"[{self.zmq_server_info.id}]: Data status update connection from " - f"storage unit id #{self.zmq_server_info.id} to " - f"controller id #{controller_id} establish successfully." - ) - - try: - poller.register(data_status_update_socket, zmq.POLLIN) - - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, - sender_id=self.zmq_server_info.id, - body={ - "fields": fields, - "global_indexes": global_indexes, - "dtypes": dtypes, - "shapes": shapes, - }, - ).serialize() - - data_status_update_socket.send(request_msg) - logger.debug( - f"[{self.zmq_server_info.id}]: Send data status update request " - f"from storage unit id #{self.zmq_server_info.id} " - f"to controller id #{controller_id} successfully." - ) - except Exception as e: - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, - sender_id=self.zmq_server_info.id, - body={ - "message": f"Failed to notify data status update information from " - f"storage unit id #{self.zmq_server_info.id}, " - f"detail error message: {str(e)}" - }, - ).serialize() - - data_status_update_socket.send(request_msg) - - # Make sure all controllers successfully receive data status update information. - response_controllers: set[str] = set() - start_time = time.time() - - while ( - len(response_controllers) < len(self.controller_infos) - and time.time() - start_time < TQ_DATA_UPDATE_RESPONSE_TIMEOUT - ): - socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT)) - - for data_status_update_socket in self.data_status_update_sockets.values(): - if data_status_update_socket in socks: - response_msg = ZMQMessage.deserialize(data_status_update_socket.recv()) - - if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: - response_controllers.add(response_msg.sender_id) - logger.debug( - f"[{self.zmq_server_info.id}]: Get data status update ACK response " - f"from controller id #{response_msg.sender_id} " - f"to storage unit id #{self.zmq_server_info.id} successfully." - ) - - if len(response_controllers) < len(self.controller_infos): - logger.warning( - f"[{self.zmq_server_info.id}]: Storage unit id #{self.zmq_server_info.id} " - f"only get {len(response_controllers)} / {len(self.controller_infos)} " - f"data status update ACK responses from controllers." - ) - - def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage: - """ - Handle get request, return data from storage unit. - - param: - data_parts: ZMQMessage from client. - return: - Get data success response ZMQMessage, containing target data. - """ - try: - fields = data_parts.body["fields"] - local_indexes = data_parts.body["local_indexes"] - - result_data = self.experience_data.get_data(fields, local_indexes) - - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_DATA_RESPONSE, - sender_id=self.zmq_server_info.id, - body={ - "data": result_data, - }, - ) - except Exception as e: - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.GET_ERROR, - sender_id=self.zmq_server_info.id, - body={ - "message": f"Failed to get data from storage unit id #{self.zmq_server_info.id}, " - f"detail error message: {str(e)}" - }, - ) - return response_msg - - def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage: - """ - Handle clear request, clear data in storage unit according to given local_indexes. - - param: - data_parts: ZMQMessage from client, including target local_indexes. - return: - Clear data success response ZMQMessage. - """ - try: - local_indexes = data_parts.body["local_indexes"] - - self.experience_data.clear(local_indexes) - - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_DATA_RESPONSE, - sender_id=self.zmq_server_info.id, - body={"message": f"Clear data in storage unit id #{self.zmq_server_info.id} successfully."}, - ) - except Exception as e: - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CLEAR_DATA_ERROR, - sender_id=self.zmq_server_info.id, - body={ - "message": f"Failed to clear data in storage unit id #{self.zmq_server_info.id}, " - f"detail error message: {str(e)}" - }, - ) - return response_msg - - def get_zmq_server_info(self) -> ZMQServerInfo: - return self.zmq_server_info diff --git a/verl/experimental/transfer_queue/utils/__init__.py b/verl/experimental/transfer_queue/utils/__init__.py deleted file mode 100644 index 2df3b7f876f..00000000000 --- a/verl/experimental/transfer_queue/utils/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/experimental/transfer_queue/utils/utils.py b/verl/experimental/transfer_queue/utils/utils.py deleted file mode 100644 index 2fceb3f14ce..00000000000 --- a/verl/experimental/transfer_queue/utils/utils.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from enum import Enum - -import ray -import torch -from tensordict import TensorDict - - -class ExplicitEnum(str, Enum): - """ - Enum with more explicit error message for missing values. - """ - - @classmethod - def _missing_(cls, value): - raise ValueError( - f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" - ) - - -class TransferQueueRole(ExplicitEnum): - CONTROLLER = "TransferQueueController" - STORAGE = "TransferQueueStorage" - CLIENT = "TransferQueueClient" - - -# production_status enum: 0: not produced, 1: ready for consume, 2: consumed -class ProductionStatus(ExplicitEnum): - NOT_PRODUCED = 0 - READY_FOR_CONSUME = 1 - CONSUMED = 2 - - -def get_placement_group(num_ray_actors: int, num_cpus_per_actor: int = 1): - """ - Create a placement group with SPREAD strategy for Ray actors. - - Args: - num_ray_actors (int): Number of Ray actors to create. - num_cpus_per_actor (int): Number of CPUs to allocate per actor. - - Returns: - placement_group: The created placement group. - """ - bundle = {"CPU": num_cpus_per_actor} - placement_group = ray.util.placement_group([bundle for _ in range(num_ray_actors)], strategy="SPREAD") - ray.get(placement_group.ready()) - return placement_group - - -def random_sampler( - ready_for_consume_idx: list[int], - batch_size: int, - get_n_samples: bool, - n_samples_per_prompt: int, -) -> list[int]: - """ - random sampling batch_size samples from global indexes ready_for_consume_idx - input example: - if get_n_samples: (group_num=3, group_size=4) - ready_for_consume_idx could look like: [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19] - else: - ready_for_consume_idx could look like: [2, 5, 6] - """ - if get_n_samples: - assert len(ready_for_consume_idx) % n_samples_per_prompt == 0 - assert batch_size % n_samples_per_prompt == 0 - batch_size_n_samples = batch_size // n_samples_per_prompt - - group_ready_for_consume_idx = torch.tensor(ready_for_consume_idx, dtype=torch.int).view( - -1, n_samples_per_prompt - ) - - weights = torch.ones(group_ready_for_consume_idx.size(0)) - sampled_indexes_idx = torch.multinomial(weights, batch_size_n_samples, replacement=False).tolist() - sampled_indexes = group_ready_for_consume_idx[sampled_indexes_idx].flatten().tolist() - else: - weights = torch.ones(len(ready_for_consume_idx)) - sampled_indexes_idx = torch.multinomial(weights, batch_size, replacement=False).tolist() - sampled_indexes = [int(ready_for_consume_idx[i]) for i in sampled_indexes_idx] - return sampled_indexes - - -def extract_field_info(tensor_dict: TensorDict) -> dict: - """ - Extract field names, dtypes, and shapes from a TensorDict. - Assumes all tensors in the same field have the same dtype and shape (excluding batch dimension). - Returns a dictionary with keys: 'names', 'dtypes', 'shapes'. - """ - field_info: dict[str, list] = {"names": [], "dtypes": [], "shapes": []} - for key, value in tensor_dict.items(): - field_info["names"].append(key) - - # TODO: support nested tensors & non tensors - # field_info["dtypes"].append(value.dtype) - # field_info["shapes"].append(value.shape[1:]) # exclude batch dimension - return field_info diff --git a/verl/experimental/transfer_queue/utils/zmq_utils.py b/verl/experimental/transfer_queue/utils/zmq_utils.py deleted file mode 100644 index 947b48407ef..00000000000 --- a/verl/experimental/transfer_queue/utils/zmq_utils.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Huawei Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pickle -import socket -import time -import uuid -from dataclasses import dataclass -from typing import Any, Optional - -import psutil -import zmq -from typing_extensions import Self - -from verl.experimental.transfer_queue.utils.utils import ( - ExplicitEnum, - TransferQueueRole, -) - - -class ZMQRequestType(ExplicitEnum): - # HANDSHAKE - HANDSHAKE = "HANDSHAKE" # TransferQueueStorageUnit -> TransferQueueController - HANDSHAKE_ACK = "HANDSHAKE_ACK" # TransferQueueController -> TransferQueueStorageUnit - - # DATA_OPERATION - GET_DATA = "GET" - PUT_DATA = "PUT" - GET_DATA_RESPONSE = "GET_DATA_RESPONSE" - PUT_DATA_RESPONSE = "PUT_DATA_RESPONSE" - CLEAR_DATA = "CLEAR_DATA" - CLEAR_DATA_RESPONSE = "CLEAR_DATA_RESPONSE" - - PUT_GET_OPERATION_ERROR = "PUT_GET_OPERATION_ERROR" - PUT_GET_ERROR = "PUT_GET_ERROR" - PUT_ERROR = "PUT_ERROR" - GET_ERROR = "GET_ERROR" - CLEAR_DATA_ERROR = "CLEAR_DATA_ERROR" - - # META_OPERATION - GET_META = "GET_META" - GET_META_RESPONSE = "GET_META_RESPONSE" - GET_CLEAR_META = "GET_CLEAR_META" - GET_CLEAR_META_RESPONSE = "GET_CLEAR_META_RESPONSE" - CLEAR_META = "CLEAR_META" - CLEAR_META_RESPONSE = "CLEAR_META_RESPONSE" - - # CHECK_CONSUMPTION - CHECK_CONSUMPTION = "CHECK_CONSUMPTION" - CONSUMPTION_RESPONSE = "CONSUMPTION_RESPONSE" - - # NOTIFY_DATA_UPDATE - NOTIFY_DATA_UPDATE = "NOTIFY_DATA_UPDATE" - NOTIFY_DATA_UPDATE_ACK = "NOTIFY_DATA_UPDATE_ACK" - NOTIFY_DATA_UPDATE_ERROR = "NOTIFY_DATA_UPDATE_ERROR" - - -@dataclass -class ZMQServerInfo: - role: TransferQueueRole - id: str - ip: str - ports: dict[str, str] - - @classmethod - def create(cls, role: TransferQueueRole, id: str, ip: str, ports: dict[str, str]) -> Self: - return cls(role=role, id=id, ip=ip, ports=ports) - - def to_addr(self, port_name: str) -> str: - return f"tcp://{self.ip}:{self.ports[port_name]}" - - def to_dict(self): - return { - "role": self.role, - "id": self.id, - "ip": self.ip, - "ports": self.ports, - } - - def __str__(self) -> str: - return f"ZMQSocketInfo(role={self.role}, id={self.id}, ip={self.ip}, ports={self.ports})" - - -@dataclass -class ZMQMessage: - request_type: ZMQRequestType - sender_id: str - receiver_id: str | None - body: dict[str, Any] - request_id: str - timestamp: float - - @classmethod - def create( - cls, - request_type: ZMQRequestType, - sender_id: str, - body: dict[str, Any], - receiver_id: Optional[str] = None, - ) -> "ZMQMessage": - return cls( - request_type=request_type, - sender_id=sender_id, - receiver_id=receiver_id, - body=body, - request_id=str(uuid.uuid4()), - timestamp=time.time(), - ) - - def serialize(self) -> bytes: - """Using pickle to serialize ZMQMessage objects""" - return pickle.dumps(self) - - @classmethod - def deserialize(cls, data: bytes | list[bytes]): - """Using pickle to deserialize ZMQMessage objects""" - if isinstance(data, list): - # Process multiple byte streams by deserializing each in sequence - result = [] - for d in data: - result.append(pickle.loads(d)) - return result - else: - # Single byte stream case - return pickle.loads(data) - - -def get_free_port() -> str: - with socket.socket() as sock: - sock.bind(("", 0)) - return sock.getsockname()[1] - - -def create_zmq_socket( - ctx: zmq.Context, - socket_type: Any, - identity: Optional[bytes] = None, -) -> zmq.Socket: - mem = psutil.virtual_memory() - socket = ctx.socket(socket_type) - - # Calculate buffer size based on system memory - total_mem = mem.total / 1024**3 - available_mem = mem.available / 1024**3 - # For systems with substantial memory (>32GB total, >16GB available): - # - Set a large 0.5GB buffer to improve throughput - # For systems with less memory: - # - Use system default (-1) to avoid excessive memory consumption - if total_mem > 32 and available_mem > 16: - buf_size = int(0.5 * 1024**3) # 0.5GB in bytes - else: - buf_size = -1 # Use system default buffer size - - if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER): - socket.setsockopt(zmq.RCVHWM, 0) - socket.setsockopt(zmq.RCVBUF, buf_size) - - if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER): - socket.setsockopt(zmq.SNDHWM, 0) - socket.setsockopt(zmq.SNDBUF, buf_size) - - if identity is not None: - socket.setsockopt(zmq.IDENTITY, identity) - return socket