Skip to content

Commit 50cdb69

Browse files
committed
more
Signed-off-by: ji-huazhong <hzji210@gmail.com>
1 parent 5124c05 commit 50cdb69

8 files changed

Lines changed: 77 additions & 128 deletions

File tree

recipe/simple_use_case/single_controller_demo.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import argparse
1717
import asyncio
18-
import logging
1918
import os
2019
import random
2120
import time

transfer_queue/client.py

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,9 @@
2222
import zmq
2323
import zmq.asyncio
2424
from tensordict import TensorDict
25-
from torch import Tensor
2625

27-
from transfer_queue.metadata import (
28-
BatchMeta,
29-
)
30-
from transfer_queue.storage import (
31-
StorageManagerFactory,
32-
)
26+
from transfer_queue.metadata import BatchMeta
27+
from transfer_queue.storage import StorageManagerFactory
3328
from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads
3429
from transfer_queue.utils.logging_utils import get_logger
3530
from transfer_queue.utils.zmq_utils import (
@@ -576,7 +571,7 @@ async def async_get_consumption_status(
576571
task_name: str,
577572
partition_id: str,
578573
socket: zmq.asyncio.Socket | None = None,
579-
) -> tuple[Tensor | None, Tensor | None]:
574+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
580575
"""Get consumption status for current partition in a specific task.
581576
582577
Args:
@@ -639,7 +634,7 @@ async def async_get_production_status(
639634
data_fields: list[str],
640635
partition_id: str,
641636
socket: zmq.asyncio.Socket | None = None,
642-
) -> tuple[Tensor | None, Tensor | None]:
637+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
643638
"""Get production status for specific data fields and partition.
644639
645640
Args:
@@ -881,21 +876,20 @@ async def async_kv_retrieve_meta(
881876
create: bool = False,
882877
socket: zmq.asyncio.Socket | None = None,
883878
) -> BatchMeta:
884-
"""Asynchronously retrieve BatchMeta from the controller using user-specified keys.
879+
"""Asynchronously retrieve BatchMeta by user-defined keys.
880+
881+
Retrieves metadata for given keys from a specified partition.
882+
If keys do not exist and `create=True`, they will be automatically registered.
885883
886884
Args:
887-
keys: List of keys to retrieve from the controller
885+
keys: List of keys to retrieve.
888886
partition_id: The ID of the logical partition to search for keys.
889-
create: Whether to register new keys if not found.
890-
socket: ZMQ socket (injected by decorator)
887+
create: If True, automatically create entries for missing keys.
888+
socket: ZMQ socket injected by @with_controller_socket.
891889
892890
Returns:
893-
metadata: BatchMeta of the corresponding keys
894-
895-
Raises:
896-
TypeError: If `keys` is not a list of string or a string
891+
BatchMeta: Metadata for the requested keys.
897892
"""
898-
899893
if isinstance(keys, str):
900894
keys = [keys]
901895
elif isinstance(keys, list):
@@ -919,25 +913,23 @@ async def async_kv_retrieve_meta(
919913
)
920914

921915
try:
922-
assert socket is not None
916+
assert socket is not None, "Socket must be initialized before use"
923917
await socket.send_multipart(request_msg.serialize())
924918
response_serialized = await socket.recv_multipart(copy=False)
925919
response_msg = ZMQMessage.deserialize(response_serialized)
926920
logger.debug(
927-
f"[{self.client_id}]: Client get kv_retrieve_keys response: {response_msg} "
921+
f"[{self.client_id}] Received KV_RETRIEVE_META response: {response_msg} "
928922
f"from controller {self._controller.id}"
929923
)
930924

931925
if response_msg.request_type == ZMQRequestType.KV_RETRIEVE_META_RESPONSE:
932-
metadata = response_msg.body.get("metadata", BatchMeta.empty())
933-
return metadata
934-
else:
935-
raise RuntimeError(
936-
f"[{self.client_id}]: Failed to retrieve keys from controller {self._controller.id}: "
937-
f"{response_msg.body.get('message', 'Unknown error')}"
938-
)
926+
return response_msg.body.get("metadata", BatchMeta.empty())
927+
928+
raise RuntimeError(
929+
f"[{self.client_id}] Failed to retrieve metadata {response_msg.body.get('message', 'Unknown error')}"
930+
)
939931
except Exception as e:
940-
raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e
932+
raise RuntimeError(f"[{self.client_id}] Failed in async_kv_retrieve_meta: {e}") from e
941933

942934
@with_controller_socket
943935
async def async_kv_retrieve_keys(
@@ -1356,7 +1348,7 @@ def get_consumption_status(
13561348
self,
13571349
task_name: str,
13581350
partition_id: str,
1359-
) -> tuple[Tensor | None, Tensor | None]:
1351+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
13601352
"""Synchronously get consumption status for a specific task and partition.
13611353
13621354
Args:
@@ -1384,7 +1376,7 @@ def get_production_status(
13841376
self,
13851377
data_fields: list[str],
13861378
partition_id: str,
1387-
) -> tuple[Tensor | None, Tensor | None]:
1379+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
13881380
"""Synchronously get production status for specific data fields and partition.
13891381
13901382
Args:
@@ -1501,20 +1493,22 @@ def kv_retrieve_meta(
15011493
partition_id: str,
15021494
create: bool = False,
15031495
) -> BatchMeta:
1504-
"""Synchronously retrieve BatchMeta from the controller using user-specified keys.
1496+
"""Synchronously retrieve BatchMeta by user-defined keys.
1497+
1498+
Retrieves metadata for given keys from a specified partition.
1499+
If keys do not exist and `create=True`, they will be automatically registered.
15051500
15061501
Args:
1507-
keys: List of keys to retrieve from the controller
1508-
partition_id: The ID of the logical partition to search for keys.
1509-
create: Whether to register new keys if not found.
1502+
keys: List of keys to retrieve from the controller.
1503+
partition_id: Logical partition to query.
1504+
create: If True, automatically create entries for non-existent keys.
15101505
15111506
Returns:
1512-
metadata: BatchMeta of the corresponding keys
1507+
BatchMeta: Metadata for the requested keys.
15131508
15141509
Raises:
15151510
TypeError: If `keys` is not a list of string or a string
15161511
"""
1517-
15181512
return self._kv_retrieve_meta(keys=keys, partition_id=partition_id, create=create)
15191513

15201514
def kv_retrieve_keys(

transfer_queue/controller.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
create_zmq_socket,
4646
format_zmq_address,
4747
get_free_port,
48-
get_node_ip_address_raw,
48+
get_node_ip_address,
4949
)
5050

5151
logger = get_logger(__name__)
@@ -1577,17 +1577,17 @@ def kv_retrieve_meta(
15771577
metadata: BatchMeta of the requested keys
15781578
"""
15791579

1580-
logger.debug(f"[{self.controller_id}]: Retrieve keys {keys} in partition {partition_id}")
1580+
logger.debug(f"[{self.controller_id}] Retrieve keys {keys} in partition {partition_id}")
15811581

1582+
# Ensure partition exists
15821583
partition = self._get_partition(partition_id)
1583-
15841584
if partition is None:
15851585
if not create:
1586-
logger.warning(f"Partition {partition_id} were not found in controller!")
1586+
logger.warning(f"Partition {partition_id} not found!")
15871587
return BatchMeta.empty()
1588-
else:
1589-
self.create_partition(partition_id)
1590-
partition = self._get_partition(partition_id)
1588+
1589+
self.create_partition(partition_id)
1590+
partition = self._get_partition(partition_id)
15911591

15921592
assert partition is not None
15931593
global_indexes = partition.kv_retrieve_indexes(keys)
@@ -1631,9 +1631,7 @@ def kv_retrieve_meta(
16311631
if col_idx < len(col_mask) and col_mask[col_idx]:
16321632
data_fields.append(field_name)
16331633

1634-
metadata = self.generate_batch_meta(partition_id, verified_global_indexes, data_fields, mode="force_fetch")
1635-
1636-
return metadata
1634+
return self.generate_batch_meta(partition_id, verified_global_indexes, data_fields, mode="force_fetch")
16371635

16381636
def kv_retrieve_keys(
16391637
self,
@@ -1674,7 +1672,7 @@ def kv_retrieve_keys(
16741672
def _init_zmq_socket(self):
16751673
"""Initialize ZMQ sockets for communication."""
16761674
self.zmq_context = zmq.Context()
1677-
self._node_ip = get_node_ip_address_raw()
1675+
self._node_ip = get_node_ip_address()
16781676

16791677
while True:
16801678
try:

transfer_queue/interface.py

Lines changed: 29 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -474,37 +474,32 @@ def kv_batch_put(
474474
tags: list[dict[str, Any]] | None = None,
475475
data_parser: Callable[[Any], Any] | None = None,
476476
) -> KVBatchMeta:
477-
"""Put multiple key-value pairs to TransferQueue in batch.
477+
"""Batch put multiple key-value pairs into the TransferQueue.
478478
479-
This method stores multiple key-value pairs in a single operation, which is more
480-
efficient than calling kv_put multiple times.
479+
This method stores multiple key-value entries in a single operation,
480+
which is significantly more efficient than repeated calls to ``kv_put``.
481481
482482
Args:
483-
keys: List of user-specified keys for the data
484-
partition_id: Logical partition to store the data in
485-
fields: TensorDict containing data for all keys. Must have batch_size == len(keys).
486-
If not provided, will only update the newly given tags to the keys.
487-
tags: List of metadata tags, one for each key
488-
data_parser: Optional callable to parse reference data (e.g., URLs) into real
489-
content. The input is a slice of the `fields` parameter passed to
490-
kv_put / kv_batch_put, in plain dict format (not TensorDict),
491-
mapping field_name -> batched values. For a regular tensor column
492-
the value is a batched tensor; for nested tensors (jagged or
493-
strided) and NonTensorStack columns the values are extracted into
494-
a list. It must modify values in-place based on the original keys;
495-
do not add or remove keys. The number of elements per column must
496-
also remain unchanged. Do not change the inner order of values
497-
within each column. Only supported by SimpleStorage.
483+
keys: List of user-defined unique keys for the data entries.
484+
partition_id: Logical partition where the data will be stored.
485+
fields: TensorDict containing batched data for all keys. Must have ``batch_size == len(keys)``.
486+
If not provided, only the associated tags will be updated.
487+
tags: List of metadata dictionaries, one per key. Length must match the number of keys.
488+
data_parser: Optional callable to parse raw reference data (e.g., URLs) into real content
489+
before storage. The input is a plain dict (not TensorDict) mapping field names to
490+
batched values. The parser **must modify data in-place** without adding/removing
491+
keys or changing element counts/order. Only supported by ``SimpleStorage`` backend.
498492
499493
Returns:
500-
KVBatchMeta: Metadata containing the keys, tags, partition_id, and fields.
501-
The `fields` attribute includes all fields stored for these samples,
502-
including any new fields written by this put operation.
494+
KVBatchMeta: Metadata object containing stored keys, tags, partition ID,
495+
and field information. The ``fields`` attribute includes all
496+
persisted fields for the written samples.
503497
504498
Raises:
505-
ValueError: If neither `fields` nor `tags` is provided
506-
ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict
507-
RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
499+
ValueError: When both ``fields`` and ``tags`` are empty.
500+
ValueError: When ``fields`` batch size mismatches key count.
501+
ValueError: When ``tags`` length mismatches key count.
502+
RuntimeError: When retrieved metadata size mismatches input key count.
508503
509504
Example:
510505
>>> import transfer_queue as tq
@@ -517,49 +512,37 @@ def kv_batch_put(
517512
... }, batch_size=3)
518513
>>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}]
519514
>>> meta = tq.kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags)
520-
>>> print(meta.fields) # ['input_ids', 'attention_mask']
515+
>>> print(meta.fields)
521516
"""
517+
num_keys = len(keys)
522518

523519
if fields is None and tags is None:
524520
raise ValueError("Please provide at least one parameter of fields or tag.")
525521

526-
if fields is not None and fields.batch_size[0] != len(keys):
527-
raise ValueError(
528-
f"`keys` with length {len(keys)} does not match the `fields` TensorDict with "
529-
f"batch_size {fields.batch_size[0]}"
530-
)
522+
if fields is not None and fields.batch_size[0] != num_keys:
523+
raise ValueError(f"Length of `keys` ({num_keys}) does not match `fields` batch size ({fields.batch_size[0]}).")
531524

532525
tq_client = _maybe_create_tq_client()
533-
534-
# 1. translate user-specified key to BatchMeta
535526
batch_meta = tq_client.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=True)
536527

537-
if batch_meta.size != len(keys):
538-
raise RuntimeError(
539-
f"Retrieved BatchMeta size {batch_meta.size} does not match with input `keys` size {len(keys)}!"
540-
)
528+
if batch_meta.size != num_keys:
529+
raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match input `keys` size {num_keys}.")
541530

542-
# 2. register the user-specified tags to BatchMeta
543531
if tags is not None:
544-
if len(tags) != len(keys):
545-
raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}")
532+
if len(tags) != num_keys:
533+
raise ValueError(f"Length of `keys` ({num_keys}) does not match length of `tags` ({len(tags)}).")
546534
batch_meta.update_custom_meta(tags)
547535

548-
# 3. put data
549536
if fields is not None:
550-
# After put, batch_meta.field_names will include the new fields written by user
551537
batch_meta = tq_client.put(fields, batch_meta, data_parser=data_parser)
552-
else:
553-
# Directly update custom_meta (tags) to controller
538+
else: # tags is not None
554539
tq_client.set_custom_meta(batch_meta)
555540

556-
fields_to_return = batch_meta.field_names
557-
558541
return KVBatchMeta(
559542
keys=keys,
560543
tags=batch_meta.custom_meta,
561544
partition_id=partition_id,
562-
fields=fields_to_return,
545+
fields=batch_meta.field_names,
563546
extra_info=batch_meta.extra_info,
564547
)
565548

transfer_queue/storage/clients/mooncake_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ def __init__(self, config: dict[str, Any]):
6363
self.device_name = ""
6464

6565
if self.local_hostname is None or self.local_hostname == "":
66-
from transfer_queue.utils.zmq_utils import get_node_ip_address_raw
66+
from transfer_queue.utils.zmq_utils import get_node_ip_address
6767

68-
ip = get_node_ip_address_raw()
68+
ip = get_node_ip_address()
6969
logger.info(f"Try to use Ray IP ({ip}) as local hostname for MooncakeStore.")
7070
self.local_hostname = ip
7171

transfer_queue/storage/managers/base.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import itertools
1818
import os
1919
import time
20-
import warnings
2120
import weakref
2221
from abc import ABC, abstractmethod
2322
from concurrent.futures import ThreadPoolExecutor
@@ -388,32 +387,9 @@ def decorator(manager_cls: type[StorageManager]):
388387
@classmethod
389388
def create(cls, manager_type: str, controller_info: ZMQServerInfo, config: dict[str, Any]) -> StorageManager:
390389
"""Create and return a StorageManager instance."""
391-
if manager_type not in cls._registry:
392-
if manager_type == "AsyncSimpleStorageManager":
393-
warnings.warn(
394-
f"The manager_type {manager_type} will be deprecated in 0.1.7, please use SimpleStorage instead.",
395-
category=DeprecationWarning,
396-
stacklevel=2,
397-
)
398-
manager_type = "SimpleStorage"
399-
elif manager_type == "MooncakeStorageManager":
400-
warnings.warn(
401-
f"The manager_type {manager_type} will be deprecated in 0.1.7, please use MooncakeStore instead.",
402-
category=DeprecationWarning,
403-
stacklevel=2,
404-
)
405-
manager_type = "MooncakeStore"
406-
elif manager_type == "YuanrongStorageManager":
407-
warnings.warn(
408-
f"The manager_type {manager_type} will be deprecated in 0.1.7, please use Yuanrong instead.",
409-
category=DeprecationWarning,
410-
stacklevel=2,
411-
)
412-
manager_type = "Yuanrong"
413-
else:
414-
raise ValueError(
415-
f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}"
416-
)
390+
assert manager_type in cls._registry, (
391+
f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}"
392+
)
417393
return cls._registry[manager_type](controller_info, config)
418394

419395

transfer_queue/storage/simple_storage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
create_zmq_socket,
3535
format_zmq_address,
3636
get_free_port,
37-
get_node_ip_address_raw,
37+
get_node_ip_address,
3838
)
3939

4040
logger = get_logger(__name__)
@@ -186,7 +186,7 @@ def _init_zmq_socket(self) -> None:
186186
- worker_socket (DEALER): Backend socket for worker communication.
187187
"""
188188
self.zmq_context = zmq.Context()
189-
self._node_ip = get_node_ip_address_raw()
189+
self._node_ip = get_node_ip_address()
190190

191191
# Frontend: ROUTER for receiving client requests
192192
self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER, self._node_ip)

0 commit comments

Comments
 (0)